Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
98feae4e
Commit
98feae4e
authored
Jul 09, 2020
by
mohammad
Browse files
added allgather and allreduce
parent
de6640be
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
93 additions
and
12 deletions
+93
-12
pretrain_ict.py
pretrain_ict.py
+93
-12
No files found.
pretrain_ict.py
View file @
98feae4e
...
@@ -57,6 +57,73 @@ def model_provider():
...
@@ -57,6 +57,73 @@ def model_provider():
return
general_model_provider
(
False
,
False
)
return
general_model_provider
(
False
,
False
)
def
get_group_world_size_rank
():
group
=
mpu
.
get_data_parallel_group
()
rank
=
torch
.
distributed
.
get_rank
(
group
=
group
)
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
group
)
return
group
,
rank
,
world_size
def
get_rank_chunk_along_first_dim
(
tensor
):
group
,
rank
,
world_size
=
get_group_world_size_rank
()
assert
tensor
.
shape
[
0
]
%
world_size
==
0
dim_size
=
tensor
.
shape
[
0
]
//
world_size
output_list
=
torch
.
split
(
tensor
,
dim_size
,
dim
=
0
)
output
=
output_list
[
rank
].
contiguous
()
return
output
class
AllgatherFromDataParallelRegion
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
input_
):
assert
input_
.
dim
()
==
2
group
,
rank
,
world_size
=
get_group_world_size_rank
()
tensor_list
=
[
torch
.
empty_like
(
input_
)
for
_
in
range
(
world_size
)]
tensor_list
[
rank
]
=
input_
torch
.
distributed
.
all_gather
(
tensor_list
,
input_
,
group
=
group
)
output
=
torch
.
cat
(
tensor_list
,
dim
=
0
).
contiguous
()
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
get_rank_chunk_along_first_dim
(
grad_output
)
class
AllReduceFromDataParallelRegion
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
input_
):
assert
input_
.
dim
()
==
2
group
,
rank
,
world_size
=
get_group_world_size_rank
()
tensor_list
=
[
torch
.
zero_like
(
input_
)
for
_
in
range
(
world_size
)]
tensor_list
[
rank
]
=
input_
output
=
torch
.
cat
(
tensor_list
,
dim
=
0
).
contiguous
()
torch
.
distributed
.
all_reduce
(
output
,
group
=
group
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
get_rank_chunk_along_first_dim
(
grad_output
)
def
get_batch
(
data_iterator
):
def
get_batch
(
data_iterator
):
# Items and their type.
# Items and their type.
keys
=
[
'query_tokens'
,
'query_pad_mask'
,
keys
=
[
'query_tokens'
,
'query_pad_mask'
,
...
@@ -95,21 +162,35 @@ def forward_step(data_iterator, model):
...
@@ -95,21 +162,35 @@ def forward_step(data_iterator, model):
# Forward model.
# Forward model.
query_logits
,
block_logits
=
model
(
query_tokens
,
query_pad_mask
,
block_tokens
,
block_pad_mask
)
query_logits
,
block_logits
=
model
(
query_tokens
,
query_pad_mask
,
block_tokens
,
block_pad_mask
)
data_parallel_size
=
dist
.
get_world_size
()
/
args
.
model_parallel_size
IMPLEMENTATION
=
'original'
batch_size
=
query_logits
.
shape
[
0
]
global_batch_size
=
int
(
batch_size
*
data_parallel_size
)
all_logits_shape
=
(
int
(
global_batch_size
),
int
(
query_logits
.
shape
[
1
]))
if
IMPLEMENTATION
==
'original'
:
all_query_logits
=
torch
.
cuda
.
FloatTensor
(
*
all_logits_shape
).
type
(
query_logits
.
dtype
).
fill_
(
0.0
)
data_parallel_size
=
dist
.
get_world_size
()
/
args
.
model_parallel_size
all_block_logits
=
all_query_logits
.
clone
()
batch_size
=
query_logits
.
shape
[
0
]
global_batch_size
=
int
(
batch_size
*
data_parallel_size
)
all_logits_shape
=
(
int
(
global_batch_size
),
int
(
query_logits
.
shape
[
1
]))
all_query_logits
=
torch
.
cuda
.
FloatTensor
(
*
all_logits_shape
).
type
(
query_logits
.
dtype
).
fill_
(
0.0
)
all_block_logits
=
all_query_logits
.
clone
()
# record this processes' data
# record this processes' data
all_query_logits
[
args
.
rank
*
batch_size
:(
args
.
rank
+
1
)
*
batch_size
]
=
query_logits
all_query_logits
[
args
.
rank
*
batch_size
:(
args
.
rank
+
1
)
*
batch_size
]
=
query_logits
all_block_logits
[
args
.
rank
*
batch_size
:(
args
.
rank
+
1
)
*
batch_size
]
=
block_logits
all_block_logits
[
args
.
rank
*
batch_size
:(
args
.
rank
+
1
)
*
batch_size
]
=
block_logits
# merge data from all processes
# merge data from all processes
dist
.
all_reduce
(
all_query_logits
)
dist
.
all_reduce
(
all_query_logits
)
dist
.
all_reduce
(
all_block_logits
)
dist
.
all_reduce
(
all_block_logits
)
elif
IMPLEMENTATION
==
'allgather'
:
all_query_logits
=
AllgatherFromDataParallelRegion
.
apply
(
query_logits
)
all_block_logits
=
AllgatherFromDataParallelRegion
.
apply
(
block_logits
)
elif
IMPLEMENTATION
==
'allreduce'
:
all_query_logits
=
AllReduceFromDataParallelRegion
.
apply
(
query_logits
)
all_block_logits
=
AllReduceFromDataParallelRegion
.
apply
(
block_logits
)
else
:
raise
Exception
(
'should not be here.'
)
# scores are inner products between query and block embeddings
# scores are inner products between query and block embeddings
retrieval_scores
=
all_query_logits
.
float
().
matmul
(
torch
.
transpose
(
all_block_logits
,
0
,
1
).
float
())
retrieval_scores
=
all_query_logits
.
float
().
matmul
(
torch
.
transpose
(
all_block_logits
,
0
,
1
).
float
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment