Commit ee57e086 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

moved SP's layernorm all-reduce into optimizer.py.

parent c9a59554
...@@ -527,10 +527,15 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -527,10 +527,15 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
tensors are dynamically allocated. tensors are dynamically allocated.
Note: this is a different order of reduction, versus the non- Note: this is a different order of reduction, versus the non-
distributed optimizer, which reduces: 1) all grads, 2) embedding distributed optimizer, which reduces: 1) layernorm grads, 2) all
grads. grads, 3) embedding grads.
""" """
# All-reduce layer-norm grads (for sequence parallelism).
timers('backward-layernorm-all-reduce').start()
self.allreduce_layernorm_grads(args)
timers('backward-layernorm-all-reduce').stop()
# All-reduce embedding grads. # All-reduce embedding grads.
timers('backward-embedding-all-reduce').start() timers('backward-embedding-all-reduce').start()
self.allreduce_embedding_grads(args) self.allreduce_embedding_grads(args)
......
...@@ -266,9 +266,38 @@ class MegatronOptimizer(ABC): ...@@ -266,9 +266,38 @@ class MegatronOptimizer(ABC):
self.allreduce_position_embedding_grads(args) self.allreduce_position_embedding_grads(args)
def allreduce_layernorm_grads(self, args):
"""All-reduce layernorm grads (for sequence parallelism)."""
# All-reduce layernorm parameters across model parallel nodes
# when sequence parallelism is used
if mpu.get_tensor_model_parallel_world_size() > 1 and \
args.sequence_parallel:
raise Exception("hi.")
grads = []
for model_module in model:
unwrapped_model = unwrap_model(
model_module, (torchDDP, LocalDDP, Float16Module))
for param in unwrapped_model.parameters():
if getattr(param, 'sequence_parallel', False):
grad = param.main_grad if args.DDP_impl == 'local' else param.grad
grads.append(grad.data)
coalesced = _flatten_dense_tensors(grads)
torch.distributed.all_reduce(
coalesced, group=mpu.get_tensor_model_parallel_group())
for buf, synced in zip(grads, _unflatten_dense_tensors(
coalesced, grads)):
buf.copy_(synced)
def reduce_model_grads(self, args, timers): def reduce_model_grads(self, args, timers):
"""All-reduce all grads, and all-reduce embeddings.""" """All-reduce all grads, and all-reduce embeddings."""
# All-reduce layer-norm grads (for sequence parallelism).
timers('backward-layernorm-all-reduce').start()
self.allreduce_layernorm_grads(args)
timers('backward-layernorm-all-reduce').stop()
# All-reduce if needed. # All-reduce if needed.
if args.DDP_impl == 'local': if args.DDP_impl == 'local':
timers('backward-params-all-reduce').start() timers('backward-params-all-reduce').start()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment