Commit 371a8828 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

isolated grad discrepency to BaseFloat16Optimizer.step()

parent e7f0cdee
...@@ -32,7 +32,7 @@ from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32 ...@@ -32,7 +32,7 @@ from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
# >>> # >>>
from lutil import pax, tp from lutil import pax, tp
DEBUG_ITERATION = 10 DEBUG_ITERATION = 1 # 10
# <<< # <<<
...@@ -278,6 +278,14 @@ class BaseFloat16Optimizer(MegatronOptimizer): ...@@ -278,6 +278,14 @@ class BaseFloat16Optimizer(MegatronOptimizer):
self._copy_model_grads_to_main_grads(ITERATION) self._copy_model_grads_to_main_grads(ITERATION)
timers('optimizer-copy-to-main-grad').stop() timers('optimizer-copy-to-main-grad').stop()
# >>>
# pax(0, {
# "[LOC]" : "[** BEFORE UNSCALE **]",
# "param_group / params" : [ p for g in self.optimizer.param_groups for p in g["params"] ],
# "param_group / grads" : [ p.grad for g in self.optimizer.param_groups for p in g["params"] ],
# })
# <<<
# pax(0, { # pax(0, {
# "params" : self.get_parameters(), # self.main_param_shards, # "params" : self.get_parameters(), # self.main_param_shards,
# "grads" : [ p.grad for p in self.get_parameters() ], # self.main_param_shards ], # "grads" : [ p.grad for p in self.get_parameters() ], # self.main_param_shards ],
...@@ -305,6 +313,14 @@ class BaseFloat16Optimizer(MegatronOptimizer): ...@@ -305,6 +313,14 @@ class BaseFloat16Optimizer(MegatronOptimizer):
}) })
return False, None, None return False, None, None
# >>>
pax(0, {
"[LOC]" : "[** BEFORE CLIP **]",
"param_group / params" : [ p for g in self.optimizer.param_groups for p in g["params"] ],
"param_group / grads" : [ p.grad for g in self.optimizer.param_groups for p in g["params"] ],
})
# <<<
# Clip the main gradients. # Clip the main gradients.
timers('optimizer-clip-main-grad').start() timers('optimizer-clip-main-grad').start()
grad_norm = None grad_norm = None
...@@ -316,16 +332,18 @@ class BaseFloat16Optimizer(MegatronOptimizer): ...@@ -316,16 +332,18 @@ class BaseFloat16Optimizer(MegatronOptimizer):
num_zeros_in_grad = self.count_zeros() if \ num_zeros_in_grad = self.count_zeros() if \
self.log_num_zeros_in_grad else None self.log_num_zeros_in_grad else None
# Step the optimizer.
self.optimizer.step()
# >>> # >>>
# pax(0, { pax(0, {
# "main params" : self.get_main_params(), # "main params" : self.get_main_params(),
# "main grads" : self.get_main_grads(), # "main grads" : self.get_main_grads(),
# }) **{"param_groups / %d" % i : g for i, g in enumerate(self.optimizer.param_groups)},
"param_group / grads" : [ p.grad for g in self.optimizer.param_groups for p in g["params"] ],
})
# <<< # <<<
# Step the optimizer.
self.optimizer.step()
# Update params from main params. # Update params from main params.
timers('optimizer-copy-main-to-model-params').start() timers('optimizer-copy-main-to-model-params').start()
self._copy_main_params_to_model_params(ITERATION) self._copy_main_params_to_model_params(ITERATION)
...@@ -415,6 +433,9 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer): ...@@ -415,6 +433,9 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
fp32_from_float16_params_this_group.append(main_param) fp32_from_float16_params_this_group.append(main_param)
# Reset existing state dict key to the new main param. # Reset existing state dict key to the new main param.
if param in self.optimizer.state: if param in self.optimizer.state:
# >>>
raise Exception("hi.")
# <<<
self.optimizer.state[main_param] \ self.optimizer.state[main_param] \
= self.optimizer.state.pop(param) = self.optimizer.state.pop(param)
...@@ -483,6 +504,12 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer): ...@@ -483,6 +504,12 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
timers = get_timers() timers = get_timers()
# <<< # <<<
# >>>
# pax(0, {
# "grads" : [ p.main_grad for m in model for p in m.parameters() ],
# })
# <<<
# All-reduce if needed. # All-reduce if needed.
if args.DDP_impl == 'local': if args.DDP_impl == 'local':
timers('backward-params-all-reduce').start() timers('backward-params-all-reduce').start()
...@@ -490,6 +517,12 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer): ...@@ -490,6 +517,12 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
model_module.allreduce_gradients() model_module.allreduce_gradients()
timers('backward-params-all-reduce').stop() timers('backward-params-all-reduce').stop()
# >>>
# pax(0, {
# "grads" : [ p.main_grad for m in model for p in m.parameters() ],
# })
# <<<
# All-reduce word_embeddings' grad across first and last stages to ensure # All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync. # that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism # This should only run for models that support pipelined model parallelism
...@@ -497,6 +530,9 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer): ...@@ -497,6 +530,9 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
timers('backward-embedding-all-reduce').start() timers('backward-embedding-all-reduce').start()
if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \ if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
mpu.get_pipeline_model_parallel_world_size() > 1: mpu.get_pipeline_model_parallel_world_size() > 1:
# >>>
raise Exception("hi.")
# <<<
if mpu.is_pipeline_first_stage(ignore_virtual=True): if mpu.is_pipeline_first_stage(ignore_virtual=True):
unwrapped_model = model[0] unwrapped_model = model[0]
elif mpu.is_pipeline_last_stage(ignore_virtual=True): elif mpu.is_pipeline_last_stage(ignore_virtual=True):
...@@ -576,6 +612,16 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer): ...@@ -576,6 +612,16 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
if not self.use_contiguous_buffers_in_local_ddp: if not self.use_contiguous_buffers_in_local_ddp:
model_param.main_grad = None model_param.main_grad = None
# >>>
# if ITERATION == DEBUG_ITERATION:
# pax(0, {
# "** branch **" : "** main. **",
# "ITERATION" : ITERATION,
# "model grads" :
# [ p.main_grad for m in self.models for p in m.parameters() ],
# })
# <<<
def _collect_main_grad_data_for_unscaling(self): def _collect_main_grad_data_for_unscaling(self):
main_grads = [] main_grads = []
...@@ -623,7 +669,7 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer): ...@@ -623,7 +669,7 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
pax(0, { pax(0, {
"** branch **" : "** main. **", "** branch **" : "** main. **",
"ITERATION" : ITERATION, "ITERATION" : ITERATION,
"model params" : [p for m in self.models for p in m.parameters() ], "model params" : [p for m in self.models for p in m.parameters()],
}) })
# <<< # <<<
...@@ -984,9 +1030,18 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -984,9 +1030,18 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
[ g["orig_group"] for g in self.opt_group_shards ] [ g["orig_group"] for g in self.opt_group_shards ]
self.optimizer.load_state_dict(self.optimizer.state_dict()) self.optimizer.load_state_dict(self.optimizer.state_dict())
# pax(0, {
# # "opt_group_shards" : self.opt_group_shards,
# # "param_groups" : self.optimizer.param_groups,
# "optimizer" : self.optimizer,
# "optimizer / state" : self.optimizer.state,
# })
# pax(1, { # pax(1, {
# "opt_group_shards" : self.opt_group_shards, # "optimizer" : self.optimizer,
# "param_groups" : self.optimizer.param_groups, # **{"optimizer / param_groups / %d" % i : g
# for i, g in enumerate(self.optimizer.param_groups)},
# "optimizer / state" : self.optimizer.state,
# "optimizer / state_dict" : self.optimizer.state_dict(),
# }) # })
# Initialize main params. # Initialize main params.
...@@ -1028,6 +1083,9 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1028,6 +1083,9 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
def get_world_model_params(self): def get_world_model_params(self):
'''** FOR DEBUGGING. **''' '''** FOR DEBUGGING. **'''
return [ p for m in self.models for p in m.parameters() ] return [ p for m in self.models for p in m.parameters() ]
def get_world_model_grads(self):
'''** FOR DEBUGGING. **'''
return [ p.main_grad for p in self.get_world_model_params() ]
def get_main_params(self): def get_main_params(self):
return [ g["params"][0] for g in self.optimizer.param_groups ] return [ g["params"][0] for g in self.optimizer.param_groups ]
...@@ -1075,20 +1133,25 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1075,20 +1133,25 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
for model_index, model in enumerate(self.models): for model_index, model in enumerate(self.models):
for dtype, gbuf_shard in self.model_gbuf_shards[model_index].items(): for dtype, gbuf_shard in self.model_gbuf_shards[model_index].items():
world_shards = gbuf_shard["world_all"] world_shards = gbuf_shard["world_all"]
gbuf = model._grad_buffers[dtype].data
gbuf = model._grad_buffers[dtype] gbuf_views = [ gbuf[s.start:s.end] for s in world_shards ]
gbuf_views = []
for shard in world_shards:
gbuf_views.append(gbuf.data[shard.start:shard.end])
gbuf_view_items.append((model_index, dtype, gbuf_views)) gbuf_view_items.append((model_index, dtype, gbuf_views))
# pax(0, {
# "world_shards" : world_shards,
# "gbuf_views" : gbuf_views,
# })
# pax(0, {"gbuf_view_items": gbuf_view_items}) # pax(0, {"gbuf_view_items": gbuf_view_items})
return gbuf_view_items return gbuf_view_items
def reduce_grads(self, model): def reduce_grads(self, model):
# >>>
timers = get_timers()
# <<<
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Sync word embedding params. # Sync word embedding params.
...@@ -1101,6 +1164,9 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1101,6 +1164,9 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
timers('backward-embedding-all-reduce').start() timers('backward-embedding-all-reduce').start()
if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \ if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
mpu.get_pipeline_model_parallel_world_size() > 1: mpu.get_pipeline_model_parallel_world_size() > 1:
# >>>
raise Exception("hi.")
# <<<
if mpu.is_pipeline_first_stage(ignore_virtual=True): if mpu.is_pipeline_first_stage(ignore_virtual=True):
unwrapped_model = model[0] unwrapped_model = model[0]
elif mpu.is_pipeline_last_stage(ignore_virtual=True): elif mpu.is_pipeline_last_stage(ignore_virtual=True):
...@@ -1116,6 +1182,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1116,6 +1182,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
if args.DDP_impl == 'local': if args.DDP_impl == 'local':
grad = word_embeddings_weight.main_grad grad = word_embeddings_weight.main_grad
else: else:
raise Exception("only 'main_grad' supported for distrib-opt.")
grad = word_embeddings_weight.grad grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group()) torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
# +++ # +++
...@@ -1123,7 +1190,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1123,7 +1190,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# torch.distributed.all_reduce(grad_shard, # torch.distributed.all_reduce(grad_shard,
# group=mpu.get_embedding_group()) # group=mpu.get_embedding_group())
# <<< # <<<
timers('backward-embedding-all-reduce').stop()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Sync T5 position embedding params. # Sync T5 position embedding params.
...@@ -1133,18 +1200,30 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1133,18 +1200,30 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Reduce-scatter. # Reduce-scatter.
data_parallel_rank = mpu.get_data_parallel_rank() data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_world_size = mpu.get_data_parallel_world_size()
data_parallel_group = mpu.get_data_parallel_group() data_parallel_group = mpu.get_data_parallel_group()
gbuf_view_items = self.get_model_grad_buffer_dp_views() gbuf_view_items = self.get_model_grad_buffer_dp_views()
# pax(0, {"gbuf_views": [g for item in gbuf_view_items for g in item[2]]})
for model_index, dtype, gbuf_views in gbuf_view_items: for model_index, dtype, gbuf_views in gbuf_view_items:
# coalesced /= mpu.get_data_parallel_world_size()
gbuf = self.models[model_index]._grad_buffers[dtype].data
torch.mul(gbuf.data, 1. / data_parallel_world_size, out = gbuf.data)
# gbuf_views = [ t / data_parallel_world_size for t in gbuf_views ]
# gbuf_d
# pax(0, {
# "data_parallel_world_size" : data_parallel_world_size,
# "gbuf" : tp(gbuf),
# })
torch.distributed.reduce_scatter( torch.distributed.reduce_scatter(
gbuf_views[data_parallel_rank], gbuf_views[data_parallel_rank],
gbuf_views, gbuf_views,
group = data_parallel_group, group = data_parallel_group,
) )
# pax(0, {"gbuf_view_items": gbuf_view_items}) # pax(0, {"gbuf_views": [g for item in gbuf_view_items for g in item[2]]})
def gather_params(self): def gather_params(self):
...@@ -1161,24 +1240,12 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1161,24 +1240,12 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
group = data_parallel_group, group = data_parallel_group,
) )
# Each model param now contains its updated values in it's # Each model param now contains its updated values in its
# '.main_grad' field. # '.main_grad' field.
for param in self.param_gbuf_map: for param in self.param_gbuf_map:
param.detach().copy_(param.main_grad) param.detach().copy_(param.main_grad)
# pax(0, {
# "param" : tp(param),
# "main_grad" : tp(param.main_grad),
# # "grad" : tp(param.grad),
# })
# pax(1, { # pax(0, {"gbuf_view_items": gbuf_view_items})
# "data_parallel_rank" : data_parallel_rank,
# "main params" : self.get_main_params(),
# "model params / world" : self.get_world_model_params(),
# **{"gbuf_view_items / %d"%i:v[2] for i,v in enumerate(gbuf_view_items)},
# # "gbuf_view_item" : tp(gbuf_view[data_parallel_rank]),
# # "model params / local" : self.get_local_model_param_views(),
# })
def _collect_main_grad_data_for_unscaling(self): def _collect_main_grad_data_for_unscaling(self):
return [ g.data for g in self.get_main_grads() ] return [ g.data for g in self.get_main_grads() ]
...@@ -1199,51 +1266,29 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1199,51 +1266,29 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# Copy shard data. # Copy shard data.
main_view = main_param[main_shard.start:main_shard.end] main_view = main_param[main_shard.start:main_shard.end]
model_view = model_param.view(-1)[model_shard.start:model_shard.end] model_view = model_param.view(-1)[model_shard.start:model_shard.end]
# try:
main_view.detach().copy_(model_view) main_view.detach().copy_(model_view)
# except:
# pax({
# "main_param" : tp(main_param),
# "model_param" : tp(model_param),
# "main_view" : tp(main_view),
# "model_view" : tp(model_view),
# "main_shard" : str(main_shard),
# "model_shard" : str(model_shard),
# })
# pax(0, {
# **{
# "opt_group_shards / %d" % i : s
# for i, s in enumerate(self.opt_group_shards)
# },
# "main_params" : self.get_main_params(),
# })
def _copy_model_grads_to_main_grads(self, ITERATION): def _copy_model_grads_to_main_grads(self, ITERATION):
# >>>
model_grads = self.get_local_model_grad_views()
model_has_nan = self.has_nan_debug(model_grads)
if model_has_nan:
pax(1, {
"ITERATION" : ITERATION,
"model grads" : model_grads,
"model_has_nan" : model_has_nan,
"model params / local" : self.get_local_model_param_views(),
# "model params / world" : [ list(self.param_gbuf_map),
# "main grads" : self.get_main_grads(),
})
# <<<
for group_index, group_shard in enumerate(self.opt_group_shards): for group_index, group_shard in enumerate(self.opt_group_shards):
for model_param, main_shard in group_shard["param_map"].items(): for model_param, main_shard in group_shard["param_map"].items():
# Model shard.
model_index, dtype = self.param_gbuf_map[model_param] model_index, dtype = self.param_gbuf_map[model_param]
model_shard = self.model_gbuf_shards \ model_shard = self.model_gbuf_shards \
[model_index][dtype]["param_map"][model_param]["gbuf_world"] [model_index][dtype]["param_map"][model_param]["gbuf_world"]
assert main_shard.size == model_shard.size assert main_shard.size == model_shard.size
# pax(0, {
# "model_param" : tp(model_param),
# "main_shard" : str(main_shard),
# "param shard" : self.model_gbuf_shards \
# [model_index][dtype]["param_map"][model_param],
# })
# Copy from DDP's contiguous buffer to main shard's grad. # Copy from DDP's contiguous buffer to main shard's grad.
model_grad = self.models[model_index]._grad_buffers[dtype].data model_grad = self.models[model_index]._grad_buffers[dtype].data
main_grad = self.get_main_grad(group_index) main_grad = self.get_main_grad(group_index)
...@@ -1269,38 +1314,13 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1269,38 +1314,13 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# }) # })
# >>> # >>>
# pax(1, { # if ITERATION == DEBUG_ITERATION:
# # "model_gbuf_shards" : self.model_gbuf_shards, # pax(0, {
# **{ # "** branch **" : "** fix. **",
# "opt_group_shards / %d" % i : s # "ITERATION" : ITERATION,
# for i, s in enumerate(self.opt_group_shards) # # "model grads" : self.get_world_model_grads(),
# }, # "main_grads" : self.get_main_grads(),
# "main_grads" : self.get_main_grads(), # })
# })
# for group_index, main_grad in enumerate(self.get_main_grads()):
# # is_nan = torch.any(torch.isnan(main_grad)).item()
# if is_nan:
# # opt_group_shard = self.opt_group_shards[group_index]
# # param_views = []
# # for param, shard in opt_group_shard["param_map"].items():
# # ddd
# pax(0, {
# "opt_group_shard" : self.opt_group_shards[group_index],
# "param_map" : [ (str(p.shape), str(d)) for p, d in self.opt_group_shards[group_index]["param_map"].items() ],
# "gbufs" : [ b.data for m in self.models for d, b in m._grad_buffers.items() ],
# "group_index" : group_index,
# "main_param" : tp(self.get_main_param(group_index)),
# "main_grad" : tp(main_grad),
# "is_nan" : is_nan,
# })
main_grads = self.get_main_grads()
main_has_nan = self.has_nan_debug(main_grads)
if main_has_nan:
raise Exception("hi.")
# pax(1, {
# "model grads" : self.get_local_model_grad_views(),
# })
# <<< # <<<
...@@ -1340,27 +1360,12 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1340,27 +1360,12 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# "main_shard" : str(main_shard), # "main_shard" : str(main_shard),
# }) # })
# pax(0, {
# "model_gbuf_shards" : self.model_gbuf_shards,
# "opt_group_shards" : self.opt_group_shards,
# })
# >>> # >>>
for param in self.param_gbuf_map:
# is_nan = torch.any(torch.isnan(param)).item()
is_nan = not torch.all(torch.isfinite(param)).item()
if is_nan:
pax({
"param" : tp(param),
"is_nan" : is_nan,
})
if ITERATION == DEBUG_ITERATION: if ITERATION == DEBUG_ITERATION:
pax(0, { pax(0, {
"** branch **" : "** fix. **", "** branch **" : "** fix. **",
"ITERATION" : ITERATION, "ITERATION" : ITERATION,
# "main params" : self.get_main_params(), "model params" : self.get_world_model_params(),
# "model params / local" : self.get_local_model_param_views(),
"model params" : [p for m in self.models for p in m.parameters()],
}) })
# <<< # <<<
......
...@@ -432,7 +432,7 @@ def train_step(forward_step_func, data_iterator, ...@@ -432,7 +432,7 @@ def train_step(forward_step_func, data_iterator,
# >>> # >>>
# Reduce gradients. (with distributed optimizer option, optimizer # Reduce gradients. (with distributed optimizer option, optimizer
# now responsible for reducing gradients) # now responsible for reducing gradients)
optimizer.reduce_gradients(model) optimizer.reduce_grads(model)
# <<< # <<<
# >>> # >>>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment