Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
fb3a1345
Commit
fb3a1345
authored
Mar 29, 2022
by
Lawrence McAfee
Browse files
partially commented distrib_optimizer.py.
parent
fe2d623e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
73 additions
and
8 deletions
+73
-8
megatron/optimizer/distrib_optimizer.py
megatron/optimizer/distrib_optimizer.py
+73
-8
No files found.
megatron/optimizer/distrib_optimizer.py
View file @
fb3a1345
...
...
@@ -296,14 +296,17 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
fp16
,
bf16
,
grad_scaler
,
models
):
"""
See top of class definition for argument descriptions.
"""
super
().
__init__
(
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
fp16
,
bf16
,
grad_scaler
,
models
)
# Verify that contiguous buffers are being used
# - Note: this should already be checked in arguments.py
# Verify that contiguous buffers are being used
.
# - Note: this should already be checked in arguments.py
.
assert
use_contiguous_buffers_in_local_ddp
# Model grad buffer ranges.
...
...
@@ -370,6 +373,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def
load_state_dict
(
self
,
state_dict
):
"""
Load the state dict.
"""
# Optimizer.
optimizer_key
=
'optimizer'
...
...
@@ -400,11 +406,15 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def
zero_grad
(
self
,
set_to_none
=
True
):
"""We only need to zero the model related parameters, i.e.,
float16_groups & fp32_groups. We additionally zero
fp32_from_float16_groups as a memory optimization to reduce
"""
Zero grads.
We only need to zero the model related parameters, i.e.,
model_float16_groups & model_fp32_groups. We additionally zero
the remaining groups as a memory optimization to reduce
fragmentation; in the case of set_to_none==True, the space
used by this field can be safely deallocated at this point."""
used by this field can be safely deallocated at this point.
"""
for
groups
in
(
self
.
model_float16_groups
,
self
.
model_fp32_groups
,
...
...
@@ -416,6 +426,20 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def
get_model_grad_buffer_dp_views
(
self
):
"""
Get shard views of each of the DDP's grad buffers.
In this nested list, the top level is grouped by the virtual model
index and the grad buffer's data type. The sub-level is a list of
shards of that grad buffer, where each shard in the list represents
a contiguous view of the grad buffer, that is owned by a data-parallel
rank. The shard boundary does not respect parameter boundaries, and
so the elements of some parameters are split across data parallel
ranks.
Additionally, return references to the entire grad buffers, for use
in _reduce_scatter_base and _all_gather_base.
"""
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
...
...
@@ -435,6 +459,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def
reduce_model_grads
(
self
,
args
,
timers
):
"""
Reduce-scatter model grads.
The DDP's grad buffer is used for the reduce-scatter, and thus no
tensors are dynamically allocated.
Note: this is a different order of reduction, versus the non-
distributed optimizer, which reduces: 1) all grads, 2) embedding
grads.
...
...
@@ -458,7 +487,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Reduce-scatter all grads.
gbuf_view_items
=
self
.
get_model_grad_buffer_dp_views
()
for
index
,
(
model_index
,
dtype
,
gbuf
,
gbuf_views
)
in
enumerate
(
gbuf_view_items
):
for
index
,
(
model_index
,
dtype
,
gbuf
,
gbuf_views
)
\
in
enumerate
(
gbuf_view_items
):
torch
.
distributed
.
_reduce_scatter_base
(
gbuf_views
[
data_parallel_rank
],
gbuf
,
...
...
@@ -469,6 +500,13 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def
gather_model_params
(
self
,
args
,
timers
):
"""
All-gather updated model params.
The DDP's grad buffer is used for the all-gather, and thus no
tensors are dynamically allocated. After the all-gather, the params
can be copied from param.main_grad to param.
"""
timers
(
'backward-params-all-gather'
).
start
()
...
...
@@ -481,7 +519,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# in distributed.py. Thus, all sub-views will have consistent start/end
# indexes across data parallel ranks.
gbuf_view_items
=
self
.
get_model_grad_buffer_dp_views
()
for
index
,
(
model_index
,
dtype
,
gbuf
,
gbuf_views
)
in
enumerate
(
gbuf_view_items
):
for
index
,
(
model_index
,
dtype
,
gbuf
,
gbuf_views
)
\
in
enumerate
(
gbuf_view_items
):
torch
.
distributed
.
_all_gather_base
(
gbuf
,
gbuf_views
[
data_parallel_rank
],
...
...
@@ -499,6 +539,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def
_collect_main_grad_data_for_unscaling
(
self
):
"""
Note: this should be equivalent to the float-16 optimizer's method,
but writtent differently, so the two should be combined.
"""
return
[
param
.
grad
.
data
for
group
in
self
.
optimizer
.
param_groups
...
...
@@ -507,6 +551,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def
_get_model_and_main_params_data_float16
(
self
):
"""
Get aligned list of model and main params.
"""
model_data
=
[]
main_data
=
[]
for
model_group
,
main_group
in
zip
(
self
.
shard_float16_groups
,
...
...
@@ -518,7 +565,15 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def
_copy_model_grads_to_main_grads
(
self
):
"""
Copy model grads to main grads.
Since this step follows a reduce-scatter through the DDP's grad
buffer, this method is responsible for copying the updated grads
from the grad buffer to the main shard's grad field.
"""
# Utility method for copying group grads.
def
copy_group_grads
(
model_groups
,
shard_main_groups
):
for
model_group
,
shard_main_group
in
zip
(
model_groups
,
shard_main_groups
):
...
...
@@ -534,6 +589,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
[
param_range
.
start
:
param_range
.
end
]
shard_main_param
.
grad
=
shard_model_grad
.
float
()
# Copy model groups to shard groups.
copy_group_grads
(
self
.
model_float16_groups
,
self
.
shard_fp32_from_float16_groups
)
copy_group_grads
(
self
.
model_fp32_groups
,
...
...
@@ -541,7 +597,15 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def
_copy_main_params_to_model_params
(
self
):
"""
Copy main params to model params.
Since this step is followed by an all-gather through the DDP's grad
buffer, this method is responsible for copying the updated params
from the main shards into the correct position in the grad buffer.
"""
# Utility method for copying group params.
def
copy_group_params
(
shard_main_groups
,
model_groups
):
for
shard_main_group
,
model_group
in
zip
(
shard_main_groups
,
model_groups
):
...
...
@@ -558,6 +622,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
shard_model_grad
.
data
.
copy_
(
shard_main_param
)
# Copy shard groups to model groups.
copy_group_params
(
self
.
shard_fp32_from_float16_groups
,
self
.
model_float16_groups
)
copy_group_params
(
self
.
shard_fp32_groups
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment