Commit 08ee8ea2 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

updated args for allreduce_embeddings

parent c88bc979
...@@ -334,7 +334,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -334,7 +334,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# All-reduce embedding grads. # All-reduce embedding grads.
timers('backward-embedding-all-reduce').start() timers('backward-embedding-all-reduce').start()
self.allreduce_embedding_grads() self.allreduce_embedding_grads(args)
timers('backward-embedding-all-reduce').stop() timers('backward-embedding-all-reduce').stop()
# Reduce-scatter all grads. # Reduce-scatter all grads.
......
...@@ -17,15 +17,17 @@ ...@@ -17,15 +17,17 @@
from abc import ABC from abc import ABC
from abc import abstractmethod from abc import abstractmethod
import torch
from apex.multi_tensor_apply import multi_tensor_applier from apex.multi_tensor_apply import multi_tensor_applier
import amp_C import amp_C
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron import print_rank_0 from megatron import print_rank_0
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.utils import unwrap_model
from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32 from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
...@@ -190,7 +192,7 @@ class MegatronOptimizer(ABC): ...@@ -190,7 +192,7 @@ class MegatronOptimizer(ABC):
do here.''' do here.'''
pass pass
def allreduce_word_embedding_grads(self): def allreduce_word_embedding_grads(self, args):
''' '''
All-reduce word embedding grads. All-reduce word embedding grads.
...@@ -202,11 +204,11 @@ class MegatronOptimizer(ABC): ...@@ -202,11 +204,11 @@ class MegatronOptimizer(ABC):
if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \ if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
mpu.get_pipeline_model_parallel_world_size() > 1: mpu.get_pipeline_model_parallel_world_size() > 1:
if mpu.is_pipeline_first_stage(ignore_virtual=True): if mpu.is_pipeline_first_stage(ignore_virtual=True):
unwrapped_model = model[0] unwrapped_model = self.models[0]
elif mpu.is_pipeline_last_stage(ignore_virtual=True): elif mpu.is_pipeline_last_stage(ignore_virtual=True):
unwrapped_model = model[-1] unwrapped_model = self.models[-1]
else: # We do not support the interleaved schedule for T5 yet. else: # We do not support the interleaved schedule for T5 yet.
unwrapped_model = model[0] unwrapped_model = self.models[0]
unwrapped_model = unwrap_model( unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module)) unwrapped_model, (torchDDP, LocalDDP, Float16Module))
...@@ -218,7 +220,7 @@ class MegatronOptimizer(ABC): ...@@ -218,7 +220,7 @@ class MegatronOptimizer(ABC):
grad = word_embeddings_weight.grad grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group()) torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
def allreduce_position_embedding_grads(self): def allreduce_position_embedding_grads(self, args):
''' '''
All-reduce position_embeddings grad across first (encoder) and All-reduce position_embeddings grad across first (encoder) and
split (decoder) stages to ensure that position embeddings parameters split (decoder) stages to ensure that position embeddings parameters
...@@ -228,7 +230,7 @@ class MegatronOptimizer(ABC): ...@@ -228,7 +230,7 @@ class MegatronOptimizer(ABC):
if mpu.is_rank_in_position_embedding_group() and \ if mpu.is_rank_in_position_embedding_group() and \
mpu.get_pipeline_model_parallel_world_size() > 1 and \ mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.pipeline_model_parallel_split_rank is not None: args.pipeline_model_parallel_split_rank is not None:
unwrapped_model = model[0] unwrapped_model = self.models[0]
unwrapped_model = unwrap_model( unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module)) unwrapped_model, (torchDDP, LocalDDP, Float16Module))
assert args.DDP_impl == 'local', \ assert args.DDP_impl == 'local', \
...@@ -236,9 +238,9 @@ class MegatronOptimizer(ABC): ...@@ -236,9 +238,9 @@ class MegatronOptimizer(ABC):
grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group()) torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
def allreduce_embedding_grads(self): def allreduce_embedding_grads(self, args):
self.allreduce_word_embedding_grads() self.allreduce_word_embedding_grads(args)
self.allreduce_position_embedding_grads() self.allreduce_position_embedding_grads(args)
def reduce_model_grads(self, args, timers): def reduce_model_grads(self, args, timers):
...@@ -251,7 +253,7 @@ class MegatronOptimizer(ABC): ...@@ -251,7 +253,7 @@ class MegatronOptimizer(ABC):
# All-reduce embedding grads. # All-reduce embedding grads.
timers('backward-embedding-all-reduce').start() timers('backward-embedding-all-reduce').start()
self.allreduce_embedding_grads() self.allreduce_embedding_grads(args)
timers('backward-embedding-all-reduce').stop() timers('backward-embedding-all-reduce').stop()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment