Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
e5bda3c9
Commit
e5bda3c9
authored
Mar 02, 2022
by
Lawrence McAfee
Browse files
working w/ mixed parallelism [ dp, tp, pp ].
parent
2f99b4f4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
13 deletions
+10
-13
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+10
-13
No files found.
megatron/optimizer/optimizer.py
View file @
e5bda3c9
...
...
@@ -598,12 +598,6 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
timers
=
get_timers
()
# <<<
# >>>
# pax(0, {
# "grads" : [ p.main_grad for m in model for p in m.parameters() ],
# })
# <<<
# All-reduce if needed.
if
args
.
DDP_impl
==
'local'
:
timers
(
'backward-params-all-reduce'
).
start
()
...
...
@@ -611,12 +605,6 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
model_module
.
allreduce_gradients
()
timers
(
'backward-params-all-reduce'
).
stop
()
# >>>
# pax(0, {
# "grads" : [ p.main_grad for m in model for p in m.parameters() ],
# })
# <<<
# All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism
...
...
@@ -1246,6 +1234,15 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
def
reduce_grads
(
self
,
model
):
# >>>
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron
import
get_args
from
megatron
import
get_timers
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
Float16Module
from
megatron.utils
import
unwrap_model
args
=
get_args
()
timers
=
get_timers
()
# <<<
...
...
@@ -1262,7 +1259,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
if
mpu
.
is_rank_in_embedding_group
(
ignore_virtual
=
True
)
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
# >>>
raise
Exception
(
"[fix] ready for weight sync?"
)
#
raise Exception("[fix] ready for weight sync?")
# <<<
if
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
unwrapped_model
=
model
[
0
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment