Commit 94a90215 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

removed faulty (already commented out) gather_model_params param iterator.

parent 64b9d28c
......@@ -412,10 +412,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def gather_model_params(self, args, timers, ITERATION):
# >>>
# timers = get_timers()
# <<<
timers('backward-params-all-gather').start()
data_parallel_rank = mpu.get_data_parallel_rank()
......@@ -433,7 +429,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Each model param now contains its updated values in its
# '.main_grad' field.
# for param in self.param_gbuf_map: # ... incomplete param list.
for model in self.models:
for dtype, param_map in model._grad_buffer_param_index_map.items():
for param in param_map:
......@@ -441,23 +436,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
timers('backward-params-all-gather').stop()
# pax(0, {"gbuf_view_items": gbuf_view_items})
# >>>
# self.debug_main(ITERATION, "after/inside gather_model_params.", 0)
# self.debug_model(ITERATION, "after/inside gather_model_params.", 0)
# if ITERATION == 2:
# pax(1, {
# "ITERATION" : ITERATION,
# # "gbufs" : [
# # tp(b.data)
# # for m in self.models
# # for b in m._grad_buffers.values()
# # ],
# "param_gbuf_map" : [ str(tuple(p.shape)) for p in self.param_gbuf_map ],
# })
# <<<
def _collect_main_grad_data_for_unscaling(self):
return [ g.data for g in self.get_main_grads() ]
......@@ -494,13 +472,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
assert main_shard.size == model_shard.size
# pax(0, {
# "model_param" : tp(model_param),
# "main_shard" : str(main_shard),
# "param shard" : self.model_gbuf_shards \
# [model_index][dtype]["param_map"][model_param],
# })
# Copy from DDP's contiguous buffer to main shard's grad.
model_grad = self.models[model_index]._grad_buffers[dtype].data
main_grad = self.get_main_grad(group_index)
......@@ -511,35 +482,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
main_view.detach().copy_(model_view)
# pax(0, {
# "group_index" : group_index,
# "group_shard" : group_shard,
# # "param" : tp(param),
# "model_index" : model_index,
# "dtype" : str(dtype),
# "model_grad" : tp(model_grad),
# "main_grad" : tp(main_grad),
# "model_view" : tp(model_view),
# "main_view" : tp(main_view),
# "model_shard" : str(model_shard),
# "main_shard" : str(main_shard),
# })
# >>>
# if 1 or ITERATION == DEBUG_ITERATION:
# pax(0, {
# "** branch **" : "** fix. **",
# "ITERATION" : ITERATION,
# # "model grads" : self.get_world_model_grads(),
# "main_grads" : self.get_main_grads(),
# "group shards" : [
# "group %d; %s" % (grp_idx, main_shard)
# for grp_idx, grp_shard in enumerate(self.opt_group_shards)
# for model_param, main_shard in grp_shard["param_map"].items()
# ],
# })
# <<<
def _copy_main_params_to_model_params(self, ITERATION):
......@@ -562,30 +504,3 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
model_view.detach().copy_(main_view)
# Debug.
# pax(1, {
# "group_index" : group_index,
# "group_shard" : group_shard,
# "model_param" : tp(model_param),
# "model_index" : model_index,
# "dtype" : str(dtype),
# "model_param" : tp(model_param),
# "main_param" : tp(main_param),
# "model_view" : tp(model_view),
# "main_view" : tp(main_view),
# "model_shard" : str(model_shard),
# "main_shard" : str(main_shard),
# })
# >>>
# if ITERATION == DEBUG_ITERATION:
# pax(0, {
# "** branch **" : "** fix. **",
# "ITERATION" : ITERATION,
# "model params" : self.get_world_model_params(),
# })
# <<<
# <<<
......@@ -196,9 +196,6 @@ class MegatronOptimizer(ABC):
if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
mpu.get_pipeline_model_parallel_world_size() > 1:
# >>>
# raise Exception("[main] ready for weight sync?")
# <<<
if mpu.is_pipeline_first_stage(ignore_virtual=True):
unwrapped_model = model[0]
elif mpu.is_pipeline_last_stage(ignore_virtual=True):
......@@ -240,29 +237,11 @@ class MegatronOptimizer(ABC):
def reduce_model_grads(self, args, timers):
# pax(0, {
# "*models" : self.models,
# "model" : model,
# })
# >>>
# from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
# from megatron import get_args
# from megatron import get_timers
# from megatron.model import DistributedDataParallel as LocalDDP
# from megatron.model import Float16Module
# from megatron.utils import unwrap_model
# args = get_args()
# timers = get_timers()
# <<<
# All-reduce if needed.
if args.DDP_impl == 'local':
timers('backward-params-all-reduce').start()
for model_module in self.models:
model_module.allreduce_gradients()
for model in self.models:
model.allreduce_gradients()
timers('backward-params-all-reduce').stop()
# All-reduce embedding grads.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment