Commit e7f0cdee authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

renamed reduce_gradients -> reduce_grads [ matches gather_params ]

parent 4b843668
......@@ -31,6 +31,8 @@ from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
# >>>
from lutil import pax, tp
DEBUG_ITERATION = 10
# <<<
......@@ -130,7 +132,7 @@ class MegatronOptimizer(ABC):
@abstractmethod
def reduce_gradients(self):
def reduce_grads(self):
pass
......@@ -466,7 +468,7 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
# >>>
def reduce_gradients(self, model):
def reduce_grads(self, model):
# >>>
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
......@@ -481,26 +483,10 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
timers = get_timers()
# <<<
# >>>
# if not args.use_distributed_optimizer:
# All-reduce if needed.
# >>>
# if args.DDP_impl == 'local' and not args.use_distributed_optimizer:
if args.DDP_impl == 'local':
# <<<
timers('backward-params-all-reduce').start()
for model_module in model:
# >>>
# from lutil import pax, tp
# pax(0, {
# "model" : model,
# "model_module" : model_module,
# })
# <<<
# >>>
# e.g., grad_shard = optimizer.get_grad_shard()
# <<<
model_module.allreduce_gradients()
timers('backward-params-all-reduce').stop()
......@@ -559,7 +545,7 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
def gather_params(self):
pass
def _copy_model_grads_to_main_grads(self):
def _copy_model_grads_to_main_grads(self, ITERATION):
# This only needs to be done for the float16 group.
for model_group, main_group in zip(self.float16_groups,
self.fp32_from_float16_groups):
......@@ -627,11 +613,19 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
return model_data, main_data
def _copy_main_params_to_model_params(self):
def _copy_main_params_to_model_params(self, ITERATION):
# Only needed for the float16 params.
model_data, main_data = self._get_model_and_main_params_data_float16()
_multi_tensor_copy_this_to_that(this=main_data, that=model_data,
overflow_buf=self._dummy_overflow_buf)
# >>>
if ITERATION == DEBUG_ITERATION:
pax(0, {
"** branch **" : "** main. **",
"ITERATION" : ITERATION,
"model params" : [p for m in self.models for p in m.parameters() ],
})
# <<<
def _copy_model_params_to_main_params(self):
......@@ -766,14 +760,6 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
"gbuf_local" : param_local_shard,
"param" : sub_param_shard,
}
pax(1, {
"gbuf_world_shard" : gbuf_world_shard,
"param shards" : param_shard_map[param],
})
# >>>
# if param_world_start < gbuf_world_shard.start:
# pax({"param shards": param_shard_map[param]})
# <<<
# pax(0, {"param_shard_map": [ str((str(p.shape), s)) for p,s in param_shard_map.items() ]})
......@@ -1070,10 +1056,11 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# for main_group in self.optimizer.param_groups:
# main_params.extend(main_group["params"])
_zero_grad_group_helper(model_params, set_to_none)
# ** using contiguous buffer; don't set_to_none **
_zero_grad_group_helper(model_params, set_to_none = False) # set_to_none)
# _zero_grad_group_helper(params, set_to_none = False)
# pax(0, {"params": params})
# pax(0, {"model_params": model_params})
def get_model_grad_buffer_dp_views(self):
......@@ -1100,13 +1087,44 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
return gbuf_view_items
def reduce_gradients(self, model):
def reduce_grads(self, model):
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Sync word embedding params.
# ... todo ...
# All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism
# (BERT and GPT-2).
timers('backward-embedding-all-reduce').start()
if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
mpu.get_pipeline_model_parallel_world_size() > 1:
if mpu.is_pipeline_first_stage(ignore_virtual=True):
unwrapped_model = model[0]
elif mpu.is_pipeline_last_stage(ignore_virtual=True):
unwrapped_model = model[-1]
else: # We do not support the interleaved schedule for T5 yet.
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
# >>>
if args.DDP_impl == 'local':
grad = word_embeddings_weight.main_grad
else:
grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
# +++
# grad_shard = optimizer.get_grad_shard(word_embeddings)
# torch.distributed.all_reduce(grad_shard,
# group=mpu.get_embedding_group())
# <<<
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Sync T5 position embedding params.
......@@ -1153,27 +1171,16 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# # "grad" : tp(param.grad),
# })
# pax(0, {
# "gbuf_view_items" : gbuf_view_items,
# "param_gbuf_map" : [
# (str(tuple(p.shape)), d)
# for p, d in self.param_gbuf_map.items()
# ],
# pax(1, {
# "data_parallel_rank" : data_parallel_rank,
# "main params" : self.get_main_params(),
# "model params / world" : self.get_world_model_params(),
# **{"gbuf_view_items / %d"%i:v[2] for i,v in enumerate(gbuf_view_items)},
# # "gbuf_view_item" : tp(gbuf_view[data_parallel_rank]),
# # "model params / local" : self.get_local_model_param_views(),
# })
pax(1, {
"data_parallel_rank" : data_parallel_rank,
"main params" : self.get_main_params(),
# "model params / world" : self.get_world_model_params(),
**{"gbuf_view_items / %d"%i:v[2] for i,v in enumerate(gbuf_view_items)},
# "gbuf_view_item" : tp(gbuf_view[data_parallel_rank]),
# "model params / local" : self.get_local_model_param_views(),
})
def _collect_main_grad_data_for_unscaling(self):
# return [ p.grad.data for p in self.main_param_shards ]
# return [ p.grad.data for p in self.main_param_shards if p is not None ]
# return [ self.get_main_grad(gi).data
# for gi in range(len(self.opt_group_shards)) ]
return [ g.data for g in self.get_main_grads() ]
def _copy_model_params_to_main_params(self):
......@@ -1319,19 +1326,19 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
model_view.detach().copy_(main_view)
# Debug.
pax(1, {
"group_index" : group_index,
"group_shard" : group_shard,
"model_param" : tp(model_param),
"model_index" : model_index,
"dtype" : str(dtype),
"model_param" : tp(model_param),
"main_param" : tp(main_param),
"model_view" : tp(model_view),
"main_view" : tp(main_view),
"model_shard" : str(model_shard),
"main_shard" : str(main_shard),
})
# pax(1, {
# "group_index" : group_index,
# "group_shard" : group_shard,
# "model_param" : tp(model_param),
# "model_index" : model_index,
# "dtype" : str(dtype),
# "model_param" : tp(model_param),
# "main_param" : tp(main_param),
# "model_view" : tp(model_view),
# "main_view" : tp(main_view),
# "model_shard" : str(model_shard),
# "main_shard" : str(main_shard),
# })
# pax(0, {
# "model_gbuf_shards" : self.model_gbuf_shards,
......@@ -1347,12 +1354,14 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
"is_nan" : is_nan,
})
# model_param_views = self.get_local_model_param_views()
# pax(1, {
# "ITERATION" : ITERATION,
# "main params" : self.get_main_params(),
# "model params / local" : self.get_local_model_param_views(),
# })
if ITERATION == DEBUG_ITERATION:
pax(0, {
"** branch **" : "** fix. **",
"ITERATION" : ITERATION,
# "main params" : self.get_main_params(),
# "model params / local" : self.get_local_model_param_views(),
"model params" : [p for m in self.models for p in m.parameters()],
})
# <<<
# <<<
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment