Commit e7f0cdee authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

renamed reduce_gradients -> reduce_grads [ matches gather_params ]

parent 4b843668
...@@ -31,6 +31,8 @@ from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32 ...@@ -31,6 +31,8 @@ from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
# >>> # >>>
from lutil import pax, tp from lutil import pax, tp
DEBUG_ITERATION = 10
# <<< # <<<
...@@ -130,7 +132,7 @@ class MegatronOptimizer(ABC): ...@@ -130,7 +132,7 @@ class MegatronOptimizer(ABC):
@abstractmethod @abstractmethod
def reduce_gradients(self): def reduce_grads(self):
pass pass
...@@ -466,7 +468,7 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer): ...@@ -466,7 +468,7 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
# >>> # >>>
def reduce_gradients(self, model): def reduce_grads(self, model):
# >>> # >>>
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
...@@ -481,26 +483,10 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer): ...@@ -481,26 +483,10 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
timers = get_timers() timers = get_timers()
# <<< # <<<
# >>>
# if not args.use_distributed_optimizer:
# All-reduce if needed. # All-reduce if needed.
# >>>
# if args.DDP_impl == 'local' and not args.use_distributed_optimizer:
if args.DDP_impl == 'local': if args.DDP_impl == 'local':
# <<<
timers('backward-params-all-reduce').start() timers('backward-params-all-reduce').start()
for model_module in model: for model_module in model:
# >>>
# from lutil import pax, tp
# pax(0, {
# "model" : model,
# "model_module" : model_module,
# })
# <<<
# >>>
# e.g., grad_shard = optimizer.get_grad_shard()
# <<<
model_module.allreduce_gradients() model_module.allreduce_gradients()
timers('backward-params-all-reduce').stop() timers('backward-params-all-reduce').stop()
...@@ -559,7 +545,7 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer): ...@@ -559,7 +545,7 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
def gather_params(self): def gather_params(self):
pass pass
def _copy_model_grads_to_main_grads(self): def _copy_model_grads_to_main_grads(self, ITERATION):
# This only needs to be done for the float16 group. # This only needs to be done for the float16 group.
for model_group, main_group in zip(self.float16_groups, for model_group, main_group in zip(self.float16_groups,
self.fp32_from_float16_groups): self.fp32_from_float16_groups):
...@@ -627,11 +613,19 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer): ...@@ -627,11 +613,19 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
return model_data, main_data return model_data, main_data
def _copy_main_params_to_model_params(self): def _copy_main_params_to_model_params(self, ITERATION):
# Only needed for the float16 params. # Only needed for the float16 params.
model_data, main_data = self._get_model_and_main_params_data_float16() model_data, main_data = self._get_model_and_main_params_data_float16()
_multi_tensor_copy_this_to_that(this=main_data, that=model_data, _multi_tensor_copy_this_to_that(this=main_data, that=model_data,
overflow_buf=self._dummy_overflow_buf) overflow_buf=self._dummy_overflow_buf)
# >>>
if ITERATION == DEBUG_ITERATION:
pax(0, {
"** branch **" : "** main. **",
"ITERATION" : ITERATION,
"model params" : [p for m in self.models for p in m.parameters() ],
})
# <<<
def _copy_model_params_to_main_params(self): def _copy_model_params_to_main_params(self):
...@@ -766,14 +760,6 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -766,14 +760,6 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
"gbuf_local" : param_local_shard, "gbuf_local" : param_local_shard,
"param" : sub_param_shard, "param" : sub_param_shard,
} }
pax(1, {
"gbuf_world_shard" : gbuf_world_shard,
"param shards" : param_shard_map[param],
})
# >>>
# if param_world_start < gbuf_world_shard.start:
# pax({"param shards": param_shard_map[param]})
# <<<
# pax(0, {"param_shard_map": [ str((str(p.shape), s)) for p,s in param_shard_map.items() ]}) # pax(0, {"param_shard_map": [ str((str(p.shape), s)) for p,s in param_shard_map.items() ]})
...@@ -1070,10 +1056,11 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1070,10 +1056,11 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# for main_group in self.optimizer.param_groups: # for main_group in self.optimizer.param_groups:
# main_params.extend(main_group["params"]) # main_params.extend(main_group["params"])
_zero_grad_group_helper(model_params, set_to_none) # ** using contiguous buffer; don't set_to_none **
_zero_grad_group_helper(model_params, set_to_none = False) # set_to_none)
# _zero_grad_group_helper(params, set_to_none = False) # _zero_grad_group_helper(params, set_to_none = False)
# pax(0, {"params": params}) # pax(0, {"model_params": model_params})
def get_model_grad_buffer_dp_views(self): def get_model_grad_buffer_dp_views(self):
...@@ -1100,13 +1087,44 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1100,13 +1087,44 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
return gbuf_view_items return gbuf_view_items
def reduce_gradients(self, model): def reduce_grads(self, model):
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Sync word embedding params. # Sync word embedding params.
# ... todo ... # ... todo ...
# All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism
# (BERT and GPT-2).
timers('backward-embedding-all-reduce').start()
if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
mpu.get_pipeline_model_parallel_world_size() > 1:
if mpu.is_pipeline_first_stage(ignore_virtual=True):
unwrapped_model = model[0]
elif mpu.is_pipeline_last_stage(ignore_virtual=True):
unwrapped_model = model[-1]
else: # We do not support the interleaved schedule for T5 yet.
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
# >>>
if args.DDP_impl == 'local':
grad = word_embeddings_weight.main_grad
else:
grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
# +++
# grad_shard = optimizer.get_grad_shard(word_embeddings)
# torch.distributed.all_reduce(grad_shard,
# group=mpu.get_embedding_group())
# <<<
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Sync T5 position embedding params. # Sync T5 position embedding params.
...@@ -1153,27 +1171,16 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1153,27 +1171,16 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# # "grad" : tp(param.grad), # # "grad" : tp(param.grad),
# }) # })
# pax(0, { # pax(1, {
# "gbuf_view_items" : gbuf_view_items, # "data_parallel_rank" : data_parallel_rank,
# "param_gbuf_map" : [ # "main params" : self.get_main_params(),
# (str(tuple(p.shape)), d) # "model params / world" : self.get_world_model_params(),
# for p, d in self.param_gbuf_map.items() # **{"gbuf_view_items / %d"%i:v[2] for i,v in enumerate(gbuf_view_items)},
# ], # # "gbuf_view_item" : tp(gbuf_view[data_parallel_rank]),
# # "model params / local" : self.get_local_model_param_views(),
# }) # })
pax(1, {
"data_parallel_rank" : data_parallel_rank,
"main params" : self.get_main_params(),
# "model params / world" : self.get_world_model_params(),
**{"gbuf_view_items / %d"%i:v[2] for i,v in enumerate(gbuf_view_items)},
# "gbuf_view_item" : tp(gbuf_view[data_parallel_rank]),
# "model params / local" : self.get_local_model_param_views(),
})
def _collect_main_grad_data_for_unscaling(self): def _collect_main_grad_data_for_unscaling(self):
# return [ p.grad.data for p in self.main_param_shards ]
# return [ p.grad.data for p in self.main_param_shards if p is not None ]
# return [ self.get_main_grad(gi).data
# for gi in range(len(self.opt_group_shards)) ]
return [ g.data for g in self.get_main_grads() ] return [ g.data for g in self.get_main_grads() ]
def _copy_model_params_to_main_params(self): def _copy_model_params_to_main_params(self):
...@@ -1319,19 +1326,19 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1319,19 +1326,19 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
model_view.detach().copy_(main_view) model_view.detach().copy_(main_view)
# Debug. # Debug.
pax(1, { # pax(1, {
"group_index" : group_index, # "group_index" : group_index,
"group_shard" : group_shard, # "group_shard" : group_shard,
"model_param" : tp(model_param), # "model_param" : tp(model_param),
"model_index" : model_index, # "model_index" : model_index,
"dtype" : str(dtype), # "dtype" : str(dtype),
"model_param" : tp(model_param), # "model_param" : tp(model_param),
"main_param" : tp(main_param), # "main_param" : tp(main_param),
"model_view" : tp(model_view), # "model_view" : tp(model_view),
"main_view" : tp(main_view), # "main_view" : tp(main_view),
"model_shard" : str(model_shard), # "model_shard" : str(model_shard),
"main_shard" : str(main_shard), # "main_shard" : str(main_shard),
}) # })
# pax(0, { # pax(0, {
# "model_gbuf_shards" : self.model_gbuf_shards, # "model_gbuf_shards" : self.model_gbuf_shards,
...@@ -1347,12 +1354,14 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1347,12 +1354,14 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
"is_nan" : is_nan, "is_nan" : is_nan,
}) })
# model_param_views = self.get_local_model_param_views() if ITERATION == DEBUG_ITERATION:
# pax(1, { pax(0, {
# "ITERATION" : ITERATION, "** branch **" : "** fix. **",
# "main params" : self.get_main_params(), "ITERATION" : ITERATION,
# "model params / local" : self.get_local_model_param_views(), # "main params" : self.get_main_params(),
# }) # "model params / local" : self.get_local_model_param_views(),
"model params" : [p for m in self.models for p in m.parameters()],
})
# <<< # <<<
# <<< # <<<
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment