Commit d58d1762 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

passing args, timers from train_step to optimizer methods.

parent 45b364b1
......@@ -402,128 +402,28 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
self.allreduce_embedding_grads()
timers('backward-embedding-all-reduce').stop()
# # All-reduce word_embeddings' grad across first and last stages to ensure
# # that word_embeddings parameters stay in sync.
# # This should only run for models that support pipelined model parallelism
# # (BERT and GPT-2).
# timers('backward-embedding-all-reduce').start()
# if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
# mpu.get_pipeline_model_parallel_world_size() > 1:
# if mpu.is_pipeline_first_stage(ignore_virtual=True):
# unwrapped_model = model[0]
# elif mpu.is_pipeline_last_stage(ignore_virtual=True):
# unwrapped_model = model[-1]
# else: # We do not support the interleaved schedule for T5 yet.
# unwrapped_model = model[0]
# unwrapped_model = unwrap_model(
# unwrapped_model, (torchDDP, LocalDDP, Float16Module))
# if unwrapped_model.share_word_embeddings:
# word_embeddings_weight = unwrapped_model.word_embeddings_weight()
# if args.DDP_impl == 'local':
# grad = word_embeddings_weight.main_grad
# else:
# raise Exception("only 'main_grad' supported for distrib-opt.")
# grad = word_embeddings_weight.grad
# torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
# # All-reduce position_embeddings grad across first (encoder) and split (decoder)
# # stages to ensure that position embeddings parameters stay in sync.
# # This should only run for T5 models with pipeline parallelism
# if mpu.is_rank_in_position_embedding_group() and \
# mpu.get_pipeline_model_parallel_world_size() > 1 and \
# args.pipeline_model_parallel_split_rank is not None:
# # >>>
# raise Exception("[fix] ready for t5 sync?")
# # <<<
# unwrapped_model = model[0]
# unwrapped_model = unwrap_model(
# unwrapped_model, (torchDDP, LocalDDP, Float16Module))
# assert args.DDP_impl == 'local', \
# 'T5 model is only supported with local DDP mode'
# # >>>
# grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
# torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
# # +++
# # grad_shard = optimizer.get_grad_shard(
# # unwrapped_model.language_model.embedding.position_embeddings.weight)
# # torch.distributed.all_reduce(grad_shard,
# # group=mpu.get_position_embedding_group())
# # <<<
# timers('backward-embedding-all-reduce').stop()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Reduce-scatter.
# timers('backward-params-reduce-scatter').start()
# Reduce-scatter all grads.
timers('backward-params-all-reduce').start()
data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_world_size = mpu.get_data_parallel_world_size()
data_parallel_group = mpu.get_data_parallel_group()
gbuf_view_items = self.get_model_grad_buffer_dp_views()
# pax(0, {"gbuf_views": [g for item in gbuf_view_items for g in item[2]]})
# pax(0, {"gbufs": [
# g.data
# for m in self.models
# for g in m._grad_buffers.values()
# ]})
# >>>
# buffer_.data /= mpu.get_data_parallel_world_size()
# torch.distributed.all_reduce(
# buffer_.data, group=mpu.get_data_parallel_group())
# <<<
# >>>
# self.debug_main_param(0, "before reduce scatter")
# self.debug_main_grad(0, "before reduce scatter")
# <<<
for model_index, dtype, gbuf_views in gbuf_view_items:
# coalesced /= mpu.get_data_parallel_world_size()
gbuf = self.models[model_index]._grad_buffers[dtype].data
# >>>
# ~~ distributed.py ~~
# gbuf /= data_parallel_world_size
# torch.distributed.all_reduce(gbuf, group=data_parallel_group)
# pax(0, {
# "gbuf" : tp(gbuf),
# })
# <<<
# torch.mul(gbuf.data, 1. / data_parallel_world_size, out = gbuf.data)
# gbuf_views = [ t / data_parallel_world_size for t in gbuf_views ]
gbuf /= data_parallel_world_size
# if 1:
# try:
# pax(0, {"gbuf_views": gbuf_views})
torch.distributed.reduce_scatter(
gbuf_views[data_parallel_rank],
gbuf_views,
group = data_parallel_group,
)
# except:
# pax(0, {
# "data_parallel_rank" : data_parallel_rank,
# "gbuf_views" : gbuf_views,
# })
# else:
# torch.distributed.all_reduce(
# gbuf,
# group = data_parallel_group,
# )
# timers('backward-params-reduce-scatter').stop()
timers('backward-params-all-reduce').stop()
# pax(0, {"gbuf_views": [g for item in gbuf_view_items for g in item[2]]})
def gather_params(self, ITERATION):
def gather_params(self, args, timers, ITERATION):
# >>>
timers = get_timers()
# timers = get_timers()
# <<<
timers('backward-params-all-gather').start()
......
......@@ -177,10 +177,10 @@ class MegatronOptimizer(ABC):
@abstractmethod
def step(self):
def step(self, args, timers):
pass
def gather_params(self, ITERATION):
def gather_params(self, args, timers, ITERATION):
pass
def allreduce_word_embedding_grads(self):
......@@ -400,9 +400,9 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
@torch.no_grad()
def step(self, ITERATION):
def step(self, args, timers, ITERATION):
timers = get_timers()
# timers = get_timers()
# >>>
# self.debug_model(ITERATION, "before copy grad.", 0)
......@@ -720,7 +720,7 @@ class FP32Optimizer(MegatronOptimizer):
@torch.no_grad()
def step(self):
def step(self, args, timers):
"""Clip gradients (if needed) and step the base optimizer.
Always return successful since there is no overflow."""
......
......@@ -437,18 +437,17 @@ def train_step(forward_step_func, data_iterator,
# >>>
# Reduce gradients.
optimizer.reduce_grads(args, timers) # model)
optimizer.reduce_grads(args, timers)
# <<<
# Update parameters.
timers('optimizer').start()
update_successful, grad_norm, num_zeros_in_grad = optimizer.step(ITERATION)
update_successful, grad_norm, num_zeros_in_grad = optimizer.step(args, timers, ITERATION)
timers('optimizer').stop()
# >>>
# Gather params gradients. (with distributed optimizer option, optimizer
# now responsible for gathering updated params)
optimizer.gather_params(ITERATION)
# Gather params.
optimizer.gather_params(args, timers, ITERATION)
# <<<
# >>>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment