Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
08ee8ea2
Commit
08ee8ea2
authored
Mar 10, 2022
by
Lawrence McAfee
Browse files
updated args for allreduce_embeddings
parent
c88bc979
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
14 deletions
+16
-14
megatron/optimizer/distrib_optimizer.py
megatron/optimizer/distrib_optimizer.py
+1
-1
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+15
-13
No files found.
megatron/optimizer/distrib_optimizer.py
View file @
08ee8ea2
...
...
@@ -334,7 +334,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# All-reduce embedding grads.
timers
(
'backward-embedding-all-reduce'
).
start
()
self
.
allreduce_embedding_grads
()
self
.
allreduce_embedding_grads
(
args
)
timers
(
'backward-embedding-all-reduce'
).
stop
()
# Reduce-scatter all grads.
...
...
megatron/optimizer/optimizer.py
View file @
08ee8ea2
...
...
@@ -17,15 +17,17 @@
from
abc
import
ABC
from
abc
import
abstractmethod
import
torch
from
apex.multi_tensor_apply
import
multi_tensor_applier
import
amp_C
import
torch
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
print_rank_0
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
Float16Module
from
megatron.utils
import
unwrap_model
from
.clip_grads
import
clip_grad_norm_fp32
,
count_zeros_fp32
...
...
@@ -190,7 +192,7 @@ class MegatronOptimizer(ABC):
do here.'''
pass
def
allreduce_word_embedding_grads
(
self
):
def
allreduce_word_embedding_grads
(
self
,
args
):
'''
All-reduce word embedding grads.
...
...
@@ -202,11 +204,11 @@ class MegatronOptimizer(ABC):
if
mpu
.
is_rank_in_embedding_group
(
ignore_virtual
=
True
)
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
unwrapped_model
=
model
[
0
]
unwrapped_model
=
self
.
model
s
[
0
]
elif
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
unwrapped_model
=
model
[
-
1
]
unwrapped_model
=
self
.
model
s
[
-
1
]
else
:
# We do not support the interleaved schedule for T5 yet.
unwrapped_model
=
model
[
0
]
unwrapped_model
=
self
.
model
s
[
0
]
unwrapped_model
=
unwrap_model
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
...
...
@@ -218,7 +220,7 @@ class MegatronOptimizer(ABC):
grad
=
word_embeddings_weight
.
grad
torch
.
distributed
.
all_reduce
(
grad
,
group
=
mpu
.
get_embedding_group
())
def
allreduce_position_embedding_grads
(
self
):
def
allreduce_position_embedding_grads
(
self
,
args
):
'''
All-reduce position_embeddings grad across first (encoder) and
split (decoder) stages to ensure that position embeddings parameters
...
...
@@ -228,7 +230,7 @@ class MegatronOptimizer(ABC):
if
mpu
.
is_rank_in_position_embedding_group
()
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
and
\
args
.
pipeline_model_parallel_split_rank
is
not
None
:
unwrapped_model
=
model
[
0
]
unwrapped_model
=
self
.
model
s
[
0
]
unwrapped_model
=
unwrap_model
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
assert
args
.
DDP_impl
==
'local'
,
\
...
...
@@ -236,9 +238,9 @@ class MegatronOptimizer(ABC):
grad
=
unwrapped_model
.
language_model
.
embedding
.
position_embeddings
.
weight
.
main_grad
torch
.
distributed
.
all_reduce
(
grad
,
group
=
mpu
.
get_position_embedding_group
())
def
allreduce_embedding_grads
(
self
):
self
.
allreduce_word_embedding_grads
()
self
.
allreduce_position_embedding_grads
()
def
allreduce_embedding_grads
(
self
,
args
):
self
.
allreduce_word_embedding_grads
(
args
)
self
.
allreduce_position_embedding_grads
(
args
)
def
reduce_model_grads
(
self
,
args
,
timers
):
...
...
@@ -251,7 +253,7 @@ class MegatronOptimizer(ABC):
# All-reduce embedding grads.
timers
(
'backward-embedding-all-reduce'
).
start
()
self
.
allreduce_embedding_grads
()
self
.
allreduce_embedding_grads
(
args
)
timers
(
'backward-embedding-all-reduce'
).
stop
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment