Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
628bf0dd
Commit
628bf0dd
authored
Jul 13, 2020
by
Neel Kant
Browse files
Use the new allgather implementation
parent
98feae4e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
63 deletions
+10
-63
pretrain_ict.py
pretrain_ict.py
+10
-63
No files found.
pretrain_ict.py
View file @
628bf0dd
...
...
@@ -57,7 +57,6 @@ def model_provider():
return
general_model_provider
(
False
,
False
)
def
get_group_world_size_rank
():
group
=
mpu
.
get_data_parallel_group
()
...
...
@@ -67,23 +66,10 @@ def get_group_world_size_rank():
return
group
,
rank
,
world_size
def
get_rank_chunk_along_first_dim
(
tensor
):
group
,
rank
,
world_size
=
get_group_world_size_rank
()
assert
tensor
.
shape
[
0
]
%
world_size
==
0
dim_size
=
tensor
.
shape
[
0
]
//
world_size
output_list
=
torch
.
split
(
tensor
,
dim_size
,
dim
=
0
)
output
=
output_list
[
rank
].
contiguous
()
return
output
class
AllgatherFromDataParallelRegion
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
input_
):
assert
input_
.
dim
()
==
2
group
,
rank
,
world_size
=
get_group_world_size_rank
()
...
...
@@ -98,32 +84,17 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function):
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
get_rank_chunk_along_first_dim
(
grad_output
)
class
AllReduceFromDataParallelRegion
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
input_
):
assert
input_
.
dim
()
==
2
group
,
rank
,
world_size
=
get_group_world_size_rank
()
tensor_list
=
[
torch
.
zero_like
(
input_
)
for
_
in
range
(
world_size
)]
tensor_list
[
rank
]
=
input_
output
=
torch
.
cat
(
tensor_list
,
dim
=
0
).
contiguous
()
torch
.
distributed
.
all_reduce
(
output
,
group
=
group
)
assert
grad_output
.
shape
[
0
]
%
world_size
==
0
dim_size
=
grad_output
.
shape
[
0
]
//
world_size
output_list
=
torch
.
split
(
grad_output
,
dim_size
,
dim
=
0
)
# get chunk from this rank
output
=
output_list
[
rank
].
contiguous
()
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
get_rank_chunk_along_first_dim
(
grad_output
)
def
get_batch
(
data_iterator
):
# Items and their type.
keys
=
[
'query_tokens'
,
'query_pad_mask'
,
...
...
@@ -159,38 +130,14 @@ def forward_step(data_iterator, model):
block_tokens
,
block_pad_mask
,
block_indices
=
get_batch
(
data_iterator
)
timers
(
'batch generator'
).
stop
()
# Forward model.
query_logits
,
block_logits
=
model
(
query_tokens
,
query_pad_mask
,
block_tokens
,
block_pad_mask
)
local_batch_size
=
query_logits
.
shape
[
0
]
global_batch_size
=
dist
.
get_world_size
()
*
local_batch_size
# recall we assert that model_parallel_size == 1
IMPLEMENTATION
=
'original'
if
IMPLEMENTATION
==
'original'
:
data_parallel_size
=
dist
.
get_world_size
()
/
args
.
model_parallel_size
batch_size
=
query_logits
.
shape
[
0
]
global_batch_size
=
int
(
batch_size
*
data_parallel_size
)
all_logits_shape
=
(
int
(
global_batch_size
),
int
(
query_logits
.
shape
[
1
]))
all_query_logits
=
torch
.
cuda
.
FloatTensor
(
*
all_logits_shape
).
type
(
query_logits
.
dtype
).
fill_
(
0.0
)
all_block_logits
=
all_query_logits
.
clone
()
# record this processes' data
all_query_logits
[
args
.
rank
*
batch_size
:(
args
.
rank
+
1
)
*
batch_size
]
=
query_logits
all_block_logits
[
args
.
rank
*
batch_size
:(
args
.
rank
+
1
)
*
batch_size
]
=
block_logits
# merge data from all processes
dist
.
all_reduce
(
all_query_logits
)
dist
.
all_reduce
(
all_block_logits
)
elif
IMPLEMENTATION
==
'allgather'
:
all_query_logits
=
AllgatherFromDataParallelRegion
.
apply
(
query_logits
)
all_block_logits
=
AllgatherFromDataParallelRegion
.
apply
(
block_logits
)
elif
IMPLEMENTATION
==
'allreduce'
:
all_query_logits
=
AllReduceFromDataParallelRegion
.
apply
(
query_logits
)
all_block_logits
=
AllReduceFromDataParallelRegion
.
apply
(
block_logits
)
else
:
raise
Exception
(
'should not be here.'
)
all_query_logits
=
AllgatherFromDataParallelRegion
.
apply
(
query_logits
)
all_block_logits
=
AllgatherFromDataParallelRegion
.
apply
(
block_logits
)
# scores are inner products between query and block embeddings
retrieval_scores
=
all_query_logits
.
float
().
matmul
(
torch
.
transpose
(
all_block_logits
,
0
,
1
).
float
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment