Commit fb3a1345 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

partially commented distrib_optimizer.py.

parent fe2d623e
......@@ -296,14 +296,17 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp,
fp16, bf16, grad_scaler, models):
"""
See top of class definition for argument descriptions.
"""
super().__init__(
optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp,
fp16, bf16, grad_scaler, models)
# Verify that contiguous buffers are being used
# - Note: this should already be checked in arguments.py
# Verify that contiguous buffers are being used.
# - Note: this should already be checked in arguments.py.
assert use_contiguous_buffers_in_local_ddp
# Model grad buffer ranges.
......@@ -370,6 +373,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def load_state_dict(self, state_dict):
"""
Load the state dict.
"""
# Optimizer.
optimizer_key = 'optimizer'
......@@ -400,11 +406,15 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def zero_grad(self, set_to_none=True):
"""We only need to zero the model related parameters, i.e.,
float16_groups & fp32_groups. We additionally zero
fp32_from_float16_groups as a memory optimization to reduce
"""
Zero grads.
We only need to zero the model related parameters, i.e.,
model_float16_groups & model_fp32_groups. We additionally zero
the remaining groups as a memory optimization to reduce
fragmentation; in the case of set_to_none==True, the space
used by this field can be safely deallocated at this point."""
used by this field can be safely deallocated at this point.
"""
for groups in (
self.model_float16_groups,
self.model_fp32_groups,
......@@ -416,6 +426,20 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def get_model_grad_buffer_dp_views(self):
"""
Get shard views of each of the DDP's grad buffers.
In this nested list, the top level is grouped by the virtual model
index and the grad buffer's data type. The sub-level is a list of
shards of that grad buffer, where each shard in the list represents
a contiguous view of the grad buffer, that is owned by a data-parallel
rank. The shard boundary does not respect parameter boundaries, and
so the elements of some parameters are split across data parallel
ranks.
Additionally, return references to the entire grad buffers, for use
in _reduce_scatter_base and _all_gather_base.
"""
data_parallel_world_size = mpu.get_data_parallel_world_size()
......@@ -435,6 +459,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def reduce_model_grads(self, args, timers):
"""
Reduce-scatter model grads.
The DDP's grad buffer is used for the reduce-scatter, and thus no
tensors are dynamically allocated.
Note: this is a different order of reduction, versus the non-
distributed optimizer, which reduces: 1) all grads, 2) embedding
grads.
......@@ -458,7 +487,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Reduce-scatter all grads.
gbuf_view_items = self.get_model_grad_buffer_dp_views()
for index, (model_index, dtype, gbuf, gbuf_views) in enumerate(gbuf_view_items):
for index, (model_index, dtype, gbuf, gbuf_views) \
in enumerate(gbuf_view_items):
torch.distributed._reduce_scatter_base(
gbuf_views[data_parallel_rank],
gbuf,
......@@ -469,6 +500,13 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def gather_model_params(self, args, timers):
"""
All-gather updated model params.
The DDP's grad buffer is used for the all-gather, and thus no
tensors are dynamically allocated. After the all-gather, the params
can be copied from param.main_grad to param.
"""
timers('backward-params-all-gather').start()
......@@ -481,7 +519,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# in distributed.py. Thus, all sub-views will have consistent start/end
# indexes across data parallel ranks.
gbuf_view_items = self.get_model_grad_buffer_dp_views()
for index, (model_index, dtype, gbuf, gbuf_views) in enumerate(gbuf_view_items):
for index, (model_index, dtype, gbuf, gbuf_views) \
in enumerate(gbuf_view_items):
torch.distributed._all_gather_base(
gbuf,
gbuf_views[data_parallel_rank],
......@@ -499,6 +539,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def _collect_main_grad_data_for_unscaling(self):
"""
Note: this should be equivalent to the float-16 optimizer's method,
but writtent differently, so the two should be combined.
"""
return [
param.grad.data
for group in self.optimizer.param_groups
......@@ -507,6 +551,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def _get_model_and_main_params_data_float16(self):
"""
Get aligned list of model and main params.
"""
model_data = []
main_data = []
for model_group, main_group in zip(self.shard_float16_groups,
......@@ -518,7 +565,15 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def _copy_model_grads_to_main_grads(self):
"""
Copy model grads to main grads.
Since this step follows a reduce-scatter through the DDP's grad
buffer, this method is responsible for copying the updated grads
from the grad buffer to the main shard's grad field.
"""
# Utility method for copying group grads.
def copy_group_grads(model_groups, shard_main_groups):
for model_group, shard_main_group in zip(model_groups,
shard_main_groups):
......@@ -534,6 +589,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
[param_range.start:param_range.end]
shard_main_param.grad = shard_model_grad.float()
# Copy model groups to shard groups.
copy_group_grads(self.model_float16_groups,
self.shard_fp32_from_float16_groups)
copy_group_grads(self.model_fp32_groups,
......@@ -541,7 +597,15 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def _copy_main_params_to_model_params(self):
"""
Copy main params to model params.
Since this step is followed by an all-gather through the DDP's grad
buffer, this method is responsible for copying the updated params
from the main shards into the correct position in the grad buffer.
"""
# Utility method for copying group params.
def copy_group_params(shard_main_groups, model_groups):
for shard_main_group, model_group in zip(shard_main_groups,
model_groups):
......@@ -558,6 +622,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
shard_model_grad.data.copy_(shard_main_param)
# Copy shard groups to model groups.
copy_group_params(self.shard_fp32_from_float16_groups,
self.model_float16_groups)
copy_group_params(self.shard_fp32_groups,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment