"docs/source/vscode:/vscode.git/clone" did not exist on "5c2a448a59b51095b0eb4c4fa75efbc8d933b9a2"
Commit b42c3052 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

commented distrib_optimizer.py.

parent fb3a1345
...@@ -75,6 +75,30 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -75,6 +75,30 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
@classmethod @classmethod
def build_model_gbuf_param_range_map(cls, model, dtype, gbuf_world_range): def build_model_gbuf_param_range_map(cls, model, dtype, gbuf_world_range):
"""
Build mapping from param reference to grad buffer shard ranges.
This method builds a mapping from parameter references to grad
buffer shard ranges, specific to each data-parallel (DP) rank's
set of 'owned' parameters. Each grad buffer (padded to be an even
multiple of DP-world-size) is conceptually divided into DP-world-size
contiguous regions, where each DP rank 'owns' a contiguous regions.
Ownership in this sense means DP rank is responsible for reducing
the relevant subset of grads, and updating the relevant subset of
params.
This conceptual partitioning of the grad buffer does NOT respect
parameter boundaries, and as such it is assumed that each created
range references a shard (or subset) of the full parameter. It is
easiest to think of each DP rank as operating (i.e., reducing,
gathering) purely on views into the grad buffer, for all model-to-
main & main-to-model operations.
This method creates three ranges:
- The param's range within the entire grad buffer (i.e., world index).
- The param's range within the DP rank's local view of the grad buffer.
- The param's range within itself (i.e., its shard).
"""
# Param range map. # Param range map.
param_world_index_map = model._grad_buffer_param_index_map[dtype] param_world_index_map = model._grad_buffer_param_index_map[dtype]
...@@ -108,6 +132,15 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -108,6 +132,15 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
@classmethod @classmethod
def build_model_gbuf_range(cls, model, dtype): def build_model_gbuf_range(cls, model, dtype):
"""
Build mapping between params and their grad buffers.
This method does the initial setup for the method above. This setup
includes determining the shard ranges into the DDP's grad buffer for
each data-parallel (DP) rank. Each DP rank keeps range info for
all other DP ranks, for the purpose of creating args for
reduce-scatter and all-gather.
"""
data_parallel_rank = mpu.get_data_parallel_rank() data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_world_size = mpu.get_data_parallel_world_size() data_parallel_world_size = mpu.get_data_parallel_world_size()
...@@ -134,7 +167,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -134,7 +167,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
dtype, dtype,
gbuf_world_range) gbuf_world_range)
# Altogether. # Group into dict.
data = { data = {
"local" : gbuf_local_range, "local" : gbuf_local_range,
"world" : gbuf_world_range, "world" : gbuf_world_range,
...@@ -148,6 +181,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -148,6 +181,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
@classmethod @classmethod
def build_model_gbuf_range_map(cls, model): def build_model_gbuf_range_map(cls, model):
"""
Create param-to-grad-buffer mappings, for grad buffer data types
within a specific virtual model.
"""
return { return {
dtype : cls.build_model_gbuf_range(model, dtype) dtype : cls.build_model_gbuf_range(model, dtype)
for dtype in model._grad_buffers for dtype in model._grad_buffers
...@@ -170,6 +207,14 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -170,6 +207,14 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
@classmethod @classmethod
def build_optimizer_group_ranges(cls, param_groups, model_gbuf_ranges): def build_optimizer_group_ranges(cls, param_groups, model_gbuf_ranges):
"""
Create optimizer groups.
Given the set of parameter shard ranges that are owned by the current
data-parallel (DP) rank, gather the set of parameters that will be
used (in the method below) to create the current DP's optimizer
groups.
"""
num_groups = len(param_groups) num_groups = len(param_groups)
...@@ -199,9 +244,19 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -199,9 +244,19 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
@classmethod @classmethod
def build_model_and_main_param_groups(cls, def build_model_and_main_param_groups(cls,
model_gbuf_ranges, model_gbuf_ranges,
param_gbuf_map, param_gbuf_map,
opt_group_ranges): opt_group_ranges):
"""
Create main parameter groups needed for the optimizer step.
These groups encompass both: 1) groups used by this class, for
reducing/gather, and 2) groups used by the inner optimizer for the
parameter update. Given that the conceptual grad buffer partitioning
(created in earlier method) doesn't respect parameter boundaries,
the optimizer operates on shards of the model parameters, rather than
the full parameters.
"""
# Parameter groups: # Parameter groups:
# model_float16_groups: original float16 parameters # model_float16_groups: original float16 parameters
...@@ -298,6 +353,12 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -298,6 +353,12 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
fp16, bf16, grad_scaler, models): fp16, bf16, grad_scaler, models):
""" """
See top of class definition for argument descriptions. See top of class definition for argument descriptions.
The steps in this method create the core mapping between DDP grad
buffers, parameters, and parameter shard ranges, that is needed for
converting between model param indexes and main parameter shard
indexes. This method also updates the optimizer parameter groups
with the newly created shards.
""" """
super().__init__( super().__init__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment