Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
d58d1762
Commit
d58d1762
authored
Mar 09, 2022
by
Lawrence McAfee
Browse files
passing args, timers from train_step to optimizer methods.
parent
45b364b1
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
12 additions
and
113 deletions
+12
-113
megatron/optimizer/distrib_optimizer.py
megatron/optimizer/distrib_optimizer.py
+3
-103
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+5
-5
megatron/training.py
megatron/training.py
+4
-5
No files found.
megatron/optimizer/distrib_optimizer.py
View file @
d58d1762
...
@@ -402,128 +402,28 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -402,128 +402,28 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
self
.
allreduce_embedding_grads
()
self
.
allreduce_embedding_grads
()
timers
(
'backward-embedding-all-reduce'
).
stop
()
timers
(
'backward-embedding-all-reduce'
).
stop
()
# # All-reduce word_embeddings' grad across first and last stages to ensure
# Reduce-scatter all grads.
# # that word_embeddings parameters stay in sync.
# # This should only run for models that support pipelined model parallelism
# # (BERT and GPT-2).
# timers('backward-embedding-all-reduce').start()
# if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
# mpu.get_pipeline_model_parallel_world_size() > 1:
# if mpu.is_pipeline_first_stage(ignore_virtual=True):
# unwrapped_model = model[0]
# elif mpu.is_pipeline_last_stage(ignore_virtual=True):
# unwrapped_model = model[-1]
# else: # We do not support the interleaved schedule for T5 yet.
# unwrapped_model = model[0]
# unwrapped_model = unwrap_model(
# unwrapped_model, (torchDDP, LocalDDP, Float16Module))
# if unwrapped_model.share_word_embeddings:
# word_embeddings_weight = unwrapped_model.word_embeddings_weight()
# if args.DDP_impl == 'local':
# grad = word_embeddings_weight.main_grad
# else:
# raise Exception("only 'main_grad' supported for distrib-opt.")
# grad = word_embeddings_weight.grad
# torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
# # All-reduce position_embeddings grad across first (encoder) and split (decoder)
# # stages to ensure that position embeddings parameters stay in sync.
# # This should only run for T5 models with pipeline parallelism
# if mpu.is_rank_in_position_embedding_group() and \
# mpu.get_pipeline_model_parallel_world_size() > 1 and \
# args.pipeline_model_parallel_split_rank is not None:
# # >>>
# raise Exception("[fix] ready for t5 sync?")
# # <<<
# unwrapped_model = model[0]
# unwrapped_model = unwrap_model(
# unwrapped_model, (torchDDP, LocalDDP, Float16Module))
# assert args.DDP_impl == 'local', \
# 'T5 model is only supported with local DDP mode'
# # >>>
# grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
# torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
# # +++
# # grad_shard = optimizer.get_grad_shard(
# # unwrapped_model.language_model.embedding.position_embeddings.weight)
# # torch.distributed.all_reduce(grad_shard,
# # group=mpu.get_position_embedding_group())
# # <<<
# timers('backward-embedding-all-reduce').stop()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Reduce-scatter.
# timers('backward-params-reduce-scatter').start()
timers
(
'backward-params-all-reduce'
).
start
()
timers
(
'backward-params-all-reduce'
).
start
()
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
data_parallel_group
=
mpu
.
get_data_parallel_group
()
data_parallel_group
=
mpu
.
get_data_parallel_group
()
gbuf_view_items
=
self
.
get_model_grad_buffer_dp_views
()
gbuf_view_items
=
self
.
get_model_grad_buffer_dp_views
()
# pax(0, {"gbuf_views": [g for item in gbuf_view_items for g in item[2]]})
# pax(0, {"gbufs": [
# g.data
# for m in self.models
# for g in m._grad_buffers.values()
# ]})
# >>>
# buffer_.data /= mpu.get_data_parallel_world_size()
# torch.distributed.all_reduce(
# buffer_.data, group=mpu.get_data_parallel_group())
# <<<
# >>>
# self.debug_main_param(0, "before reduce scatter")
# self.debug_main_grad(0, "before reduce scatter")
# <<<
for
model_index
,
dtype
,
gbuf_views
in
gbuf_view_items
:
for
model_index
,
dtype
,
gbuf_views
in
gbuf_view_items
:
# coalesced /= mpu.get_data_parallel_world_size()
gbuf
=
self
.
models
[
model_index
].
_grad_buffers
[
dtype
].
data
gbuf
=
self
.
models
[
model_index
].
_grad_buffers
[
dtype
].
data
# >>>
# ~~ distributed.py ~~
# gbuf /= data_parallel_world_size
# torch.distributed.all_reduce(gbuf, group=data_parallel_group)
# pax(0, {
# "gbuf" : tp(gbuf),
# })
# <<<
# torch.mul(gbuf.data, 1. / data_parallel_world_size, out = gbuf.data)
# gbuf_views = [ t / data_parallel_world_size for t in gbuf_views ]
gbuf
/=
data_parallel_world_size
gbuf
/=
data_parallel_world_size
# if 1:
# try:
# pax(0, {"gbuf_views": gbuf_views})
torch
.
distributed
.
reduce_scatter
(
torch
.
distributed
.
reduce_scatter
(
gbuf_views
[
data_parallel_rank
],
gbuf_views
[
data_parallel_rank
],
gbuf_views
,
gbuf_views
,
group
=
data_parallel_group
,
group
=
data_parallel_group
,
)
)
# except:
# pax(0, {
# "data_parallel_rank" : data_parallel_rank,
# "gbuf_views" : gbuf_views,
# })
# else:
# torch.distributed.all_reduce(
# gbuf,
# group = data_parallel_group,
# )
# timers('backward-params-reduce-scatter').stop()
timers
(
'backward-params-all-reduce'
).
stop
()
timers
(
'backward-params-all-reduce'
).
stop
()
# pax(0, {"gbuf_views": [g for item in gbuf_view_items for g in item[2]]})
def
gather_params
(
self
,
ITERATION
):
def
gather_params
(
self
,
args
,
timers
,
ITERATION
):
# >>>
# >>>
timers
=
get_timers
()
#
timers = get_timers()
# <<<
# <<<
timers
(
'backward-params-all-gather'
).
start
()
timers
(
'backward-params-all-gather'
).
start
()
...
...
megatron/optimizer/optimizer.py
View file @
d58d1762
...
@@ -177,10 +177,10 @@ class MegatronOptimizer(ABC):
...
@@ -177,10 +177,10 @@ class MegatronOptimizer(ABC):
@
abstractmethod
@
abstractmethod
def
step
(
self
):
def
step
(
self
,
args
,
timers
):
pass
pass
def
gather_params
(
self
,
ITERATION
):
def
gather_params
(
self
,
args
,
timers
,
ITERATION
):
pass
pass
def
allreduce_word_embedding_grads
(
self
):
def
allreduce_word_embedding_grads
(
self
):
...
@@ -400,9 +400,9 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
...
@@ -400,9 +400,9 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
step
(
self
,
ITERATION
):
def
step
(
self
,
args
,
timers
,
ITERATION
):
timers
=
get_timers
()
#
timers = get_timers()
# >>>
# >>>
# self.debug_model(ITERATION, "before copy grad.", 0)
# self.debug_model(ITERATION, "before copy grad.", 0)
...
@@ -720,7 +720,7 @@ class FP32Optimizer(MegatronOptimizer):
...
@@ -720,7 +720,7 @@ class FP32Optimizer(MegatronOptimizer):
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
step
(
self
):
def
step
(
self
,
args
,
timers
):
"""Clip gradients (if needed) and step the base optimizer.
"""Clip gradients (if needed) and step the base optimizer.
Always return successful since there is no overflow."""
Always return successful since there is no overflow."""
...
...
megatron/training.py
View file @
d58d1762
...
@@ -437,18 +437,17 @@ def train_step(forward_step_func, data_iterator,
...
@@ -437,18 +437,17 @@ def train_step(forward_step_func, data_iterator,
# >>>
# >>>
# Reduce gradients.
# Reduce gradients.
optimizer
.
reduce_grads
(
args
,
timers
)
# model)
optimizer
.
reduce_grads
(
args
,
timers
)
# <<<
# <<<
# Update parameters.
# Update parameters.
timers
(
'optimizer'
).
start
()
timers
(
'optimizer'
).
start
()
update_successful
,
grad_norm
,
num_zeros_in_grad
=
optimizer
.
step
(
ITERATION
)
update_successful
,
grad_norm
,
num_zeros_in_grad
=
optimizer
.
step
(
args
,
timers
,
ITERATION
)
timers
(
'optimizer'
).
stop
()
timers
(
'optimizer'
).
stop
()
# >>>
# >>>
# Gather params gradients. (with distributed optimizer option, optimizer
# Gather params.
# now responsible for gathering updated params)
optimizer
.
gather_params
(
args
,
timers
,
ITERATION
)
optimizer
.
gather_params
(
ITERATION
)
# <<<
# <<<
# >>>
# >>>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment