Commit fe2d623e authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

replace triple single w/ triple double quote.

parent 7ed649ed
...@@ -29,9 +29,10 @@ from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper ...@@ -29,9 +29,10 @@ from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper
class Range: class Range:
'''A range represents a start and end points for indexing a shard """
A range represents a start and end points for indexing a shard
from a full tensor. from a full tensor.
''' """
def __init__(self, start, end): def __init__(self, start, end):
self.start = start self.start = start
self.end = end self.end = end
...@@ -43,7 +44,7 @@ class Range: ...@@ -43,7 +44,7 @@ class Range:
class DistributedOptimizer(MixedPrecisionOptimizer): class DistributedOptimizer(MixedPrecisionOptimizer):
'''Distributed optimizer, for all data types (fp16, bf16, and fp32). """Distributed optimizer, for all data types (fp16, bf16, and fp32).
Arguments: Arguments:
optimizer: base optimizer such as Adam or SGD optimizer: base optimizer such as Adam or SGD
...@@ -70,7 +71,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -70,7 +71,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
always require a grad scaler. always require a grad scaler.
models: list of models (i.e., the virtual pipelining models). This models: list of models (i.e., the virtual pipelining models). This
is used by the distributed optimizer for mapping parameters. is used by the distributed optimizer for mapping parameters.
''' """
@classmethod @classmethod
def build_model_gbuf_param_range_map(cls, model, dtype, gbuf_world_range): def build_model_gbuf_param_range_map(cls, model, dtype, gbuf_world_range):
...@@ -155,8 +156,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -155,8 +156,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
@classmethod @classmethod
def build_model_param_gbuf_map(cls, model_gbuf_ranges): def build_model_param_gbuf_map(cls, model_gbuf_ranges):
'''Create a reverse of the model_gbuf_ranges, for referencing in """
opposite direction.''' Create a reverse of the model_gbuf_ranges, for referencing in
opposite direction.
"""
param_gbuf_map = {} param_gbuf_map = {}
for model_index, model_gbuf_range_map in enumerate(model_gbuf_ranges): for model_index, model_gbuf_range_map in enumerate(model_gbuf_ranges):
for dtype, gbuf_range_map in model_gbuf_range_map.items(): for dtype, gbuf_range_map in model_gbuf_range_map.items():
...@@ -335,10 +338,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -335,10 +338,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def get_model_param_range_map(self, param): def get_model_param_range_map(self, param):
''' """
Given a model param, get the index sub-range of the param that this Given a model param, get the index sub-range of the param that this
data-parallel rank owns. data-parallel rank owns.
''' """
model_index, dtype = self.model_param_gbuf_map[param] model_index, dtype = self.model_param_gbuf_map[param]
gbuf_range_map = self.model_gbuf_ranges[model_index][dtype] gbuf_range_map = self.model_gbuf_ranges[model_index][dtype]
param_range_map = gbuf_range_map["param_map"][param] param_range_map = gbuf_range_map["param_map"][param]
...@@ -346,10 +349,17 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -346,10 +349,17 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def get_model_parallel_group(self): def get_model_parallel_group(self):
"""
With the distributed optimizer, the model parallel group is the
entire world.
"""
return None return None
def state_dict(self): def state_dict(self):
"""
The state dict must contain the fp32-from-float16 shards.
"""
state_dict = {} state_dict = {}
state_dict['optimizer'] = self.optimizer.state_dict() state_dict['optimizer'] = self.optimizer.state_dict()
if self.grad_scaler: if self.grad_scaler:
...@@ -424,10 +434,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -424,10 +434,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def reduce_model_grads(self, args, timers): def reduce_model_grads(self, args, timers):
'''Note: this is a different order of reduction, versus the non- """
distributed optimizer, which reduces: 1) all grads, 2) embedding Note: this is a different order of reduction, versus the non-
grads. distributed optimizer, which reduces: 1) all grads, 2) embedding
''' grads.
"""
# All-reduce embedding grads. # All-reduce embedding grads.
timers('backward-embedding-all-reduce').start() timers('backward-embedding-all-reduce').start()
......
...@@ -122,7 +122,7 @@ class MegatronOptimizer(ABC): ...@@ -122,7 +122,7 @@ class MegatronOptimizer(ABC):
def get_model_parallel_group(self): def get_model_parallel_group(self):
'''Default returned here, but the distributed optimizer overrides this.''' """Default returned here, but the distributed optimizer overrides this."""
return mpu.get_model_parallel_group() return mpu.get_model_parallel_group()
...@@ -205,19 +205,21 @@ class MegatronOptimizer(ABC): ...@@ -205,19 +205,21 @@ class MegatronOptimizer(ABC):
def gather_model_params(self, args, timers): def gather_model_params(self, args, timers):
'''For the case of a non-distributed-optimizer, there is nothing to """
do here.''' For the case of a non-distributed-optimizer, there is nothing to
do here.
"""
pass pass
def allreduce_word_embedding_grads(self, args): def allreduce_word_embedding_grads(self, args):
''' """
All-reduce word embedding grads. All-reduce word embedding grads.
Reduce grads across first and last stages to ensure that word_embeddings Reduce grads across first and last stages to ensure that word_embeddings
parameters stay in sync. This should only run for models that support parameters stay in sync. This should only run for models that support
pipelined model parallelism (BERT and GPT-2). pipelined model parallelism (BERT and GPT-2).
''' """
if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \ if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
mpu.get_pipeline_model_parallel_world_size() > 1: mpu.get_pipeline_model_parallel_world_size() > 1:
...@@ -240,12 +242,12 @@ class MegatronOptimizer(ABC): ...@@ -240,12 +242,12 @@ class MegatronOptimizer(ABC):
def allreduce_position_embedding_grads(self, args): def allreduce_position_embedding_grads(self, args):
''' """
All-reduce position_embeddings grad across first (encoder) and All-reduce position_embeddings grad across first (encoder) and
split (decoder) stages to ensure that position embeddings parameters split (decoder) stages to ensure that position embeddings parameters
stay in sync. This should only run for T5 models with pipeline stay in sync. This should only run for T5 models with pipeline
parallelism. parallelism.
''' """
if mpu.is_rank_in_position_embedding_group() and \ if mpu.is_rank_in_position_embedding_group() and \
mpu.get_pipeline_model_parallel_world_size() > 1 and \ mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.pipeline_model_parallel_split_rank is not None: args.pipeline_model_parallel_split_rank is not None:
...@@ -259,13 +261,13 @@ class MegatronOptimizer(ABC): ...@@ -259,13 +261,13 @@ class MegatronOptimizer(ABC):
def allreduce_embedding_grads(self, args): def allreduce_embedding_grads(self, args):
'''All-reduce both word and position embeddings.''' """All-reduce both word and position embeddings."""
self.allreduce_word_embedding_grads(args) self.allreduce_word_embedding_grads(args)
self.allreduce_position_embedding_grads(args) self.allreduce_position_embedding_grads(args)
def reduce_model_grads(self, args, timers): def reduce_model_grads(self, args, timers):
'''All-reduce all grads, and all-reduce embeddings.''' """All-reduce all grads, and all-reduce embeddings."""
# All-reduce if needed. # All-reduce if needed.
if args.DDP_impl == 'local': if args.DDP_impl == 'local':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment