Commit 08ee8ea2 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

updated args for allreduce_embeddings

parent c88bc979
......@@ -334,7 +334,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# All-reduce embedding grads.
timers('backward-embedding-all-reduce').start()
self.allreduce_embedding_grads()
self.allreduce_embedding_grads(args)
timers('backward-embedding-all-reduce').stop()
# Reduce-scatter all grads.
......
......@@ -17,15 +17,17 @@
from abc import ABC
from abc import abstractmethod
import torch
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_timers
from megatron import mpu
from megatron import print_rank_0
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.utils import unwrap_model
from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
......@@ -190,7 +192,7 @@ class MegatronOptimizer(ABC):
do here.'''
pass
def allreduce_word_embedding_grads(self):
def allreduce_word_embedding_grads(self, args):
'''
All-reduce word embedding grads.
......@@ -202,11 +204,11 @@ class MegatronOptimizer(ABC):
if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
mpu.get_pipeline_model_parallel_world_size() > 1:
if mpu.is_pipeline_first_stage(ignore_virtual=True):
unwrapped_model = model[0]
unwrapped_model = self.models[0]
elif mpu.is_pipeline_last_stage(ignore_virtual=True):
unwrapped_model = model[-1]
unwrapped_model = self.models[-1]
else: # We do not support the interleaved schedule for T5 yet.
unwrapped_model = model[0]
unwrapped_model = self.models[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
......@@ -218,7 +220,7 @@ class MegatronOptimizer(ABC):
grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
def allreduce_position_embedding_grads(self):
def allreduce_position_embedding_grads(self, args):
'''
All-reduce position_embeddings grad across first (encoder) and
split (decoder) stages to ensure that position embeddings parameters
......@@ -228,7 +230,7 @@ class MegatronOptimizer(ABC):
if mpu.is_rank_in_position_embedding_group() and \
mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.pipeline_model_parallel_split_rank is not None:
unwrapped_model = model[0]
unwrapped_model = self.models[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
assert args.DDP_impl == 'local', \
......@@ -236,9 +238,9 @@ class MegatronOptimizer(ABC):
grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
def allreduce_embedding_grads(self):
self.allreduce_word_embedding_grads()
self.allreduce_position_embedding_grads()
def allreduce_embedding_grads(self, args):
self.allreduce_word_embedding_grads(args)
self.allreduce_position_embedding_grads(args)
def reduce_model_grads(self, args, timers):
......@@ -251,7 +253,7 @@ class MegatronOptimizer(ABC):
# All-reduce embedding grads.
timers('backward-embedding-all-reduce').start()
self.allreduce_embedding_grads()
self.allreduce_embedding_grads(args)
timers('backward-embedding-all-reduce').stop()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment