Commit 371a8828 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

isolated grad discrepency to BaseFloat16Optimizer.step()

parent e7f0cdee
......@@ -32,7 +32,7 @@ from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
# >>>
from lutil import pax, tp
DEBUG_ITERATION = 10
DEBUG_ITERATION = 1 # 10
# <<<
......@@ -278,6 +278,14 @@ class BaseFloat16Optimizer(MegatronOptimizer):
self._copy_model_grads_to_main_grads(ITERATION)
timers('optimizer-copy-to-main-grad').stop()
# >>>
# pax(0, {
# "[LOC]" : "[** BEFORE UNSCALE **]",
# "param_group / params" : [ p for g in self.optimizer.param_groups for p in g["params"] ],
# "param_group / grads" : [ p.grad for g in self.optimizer.param_groups for p in g["params"] ],
# })
# <<<
# pax(0, {
# "params" : self.get_parameters(), # self.main_param_shards,
# "grads" : [ p.grad for p in self.get_parameters() ], # self.main_param_shards ],
......@@ -305,6 +313,14 @@ class BaseFloat16Optimizer(MegatronOptimizer):
})
return False, None, None
# >>>
pax(0, {
"[LOC]" : "[** BEFORE CLIP **]",
"param_group / params" : [ p for g in self.optimizer.param_groups for p in g["params"] ],
"param_group / grads" : [ p.grad for g in self.optimizer.param_groups for p in g["params"] ],
})
# <<<
# Clip the main gradients.
timers('optimizer-clip-main-grad').start()
grad_norm = None
......@@ -316,16 +332,18 @@ class BaseFloat16Optimizer(MegatronOptimizer):
num_zeros_in_grad = self.count_zeros() if \
self.log_num_zeros_in_grad else None
# Step the optimizer.
self.optimizer.step()
# >>>
# pax(0, {
# "main params" : self.get_main_params(),
# "main grads" : self.get_main_grads(),
# })
pax(0, {
# "main params" : self.get_main_params(),
# "main grads" : self.get_main_grads(),
**{"param_groups / %d" % i : g for i, g in enumerate(self.optimizer.param_groups)},
"param_group / grads" : [ p.grad for g in self.optimizer.param_groups for p in g["params"] ],
})
# <<<
# Step the optimizer.
self.optimizer.step()
# Update params from main params.
timers('optimizer-copy-main-to-model-params').start()
self._copy_main_params_to_model_params(ITERATION)
......@@ -415,6 +433,9 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
fp32_from_float16_params_this_group.append(main_param)
# Reset existing state dict key to the new main param.
if param in self.optimizer.state:
# >>>
raise Exception("hi.")
# <<<
self.optimizer.state[main_param] \
= self.optimizer.state.pop(param)
......@@ -483,6 +504,12 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
timers = get_timers()
# <<<
# >>>
# pax(0, {
# "grads" : [ p.main_grad for m in model for p in m.parameters() ],
# })
# <<<
# All-reduce if needed.
if args.DDP_impl == 'local':
timers('backward-params-all-reduce').start()
......@@ -490,6 +517,12 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
model_module.allreduce_gradients()
timers('backward-params-all-reduce').stop()
# >>>
# pax(0, {
# "grads" : [ p.main_grad for m in model for p in m.parameters() ],
# })
# <<<
# All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism
......@@ -497,6 +530,9 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
timers('backward-embedding-all-reduce').start()
if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
mpu.get_pipeline_model_parallel_world_size() > 1:
# >>>
raise Exception("hi.")
# <<<
if mpu.is_pipeline_first_stage(ignore_virtual=True):
unwrapped_model = model[0]
elif mpu.is_pipeline_last_stage(ignore_virtual=True):
......@@ -576,6 +612,16 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
if not self.use_contiguous_buffers_in_local_ddp:
model_param.main_grad = None
# >>>
# if ITERATION == DEBUG_ITERATION:
# pax(0, {
# "** branch **" : "** main. **",
# "ITERATION" : ITERATION,
# "model grads" :
# [ p.main_grad for m in self.models for p in m.parameters() ],
# })
# <<<
def _collect_main_grad_data_for_unscaling(self):
main_grads = []
......@@ -623,7 +669,7 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
pax(0, {
"** branch **" : "** main. **",
"ITERATION" : ITERATION,
"model params" : [p for m in self.models for p in m.parameters() ],
"model params" : [p for m in self.models for p in m.parameters()],
})
# <<<
......@@ -984,9 +1030,18 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
[ g["orig_group"] for g in self.opt_group_shards ]
self.optimizer.load_state_dict(self.optimizer.state_dict())
# pax(0, {
# # "opt_group_shards" : self.opt_group_shards,
# # "param_groups" : self.optimizer.param_groups,
# "optimizer" : self.optimizer,
# "optimizer / state" : self.optimizer.state,
# })
# pax(1, {
# "opt_group_shards" : self.opt_group_shards,
# "param_groups" : self.optimizer.param_groups,
# "optimizer" : self.optimizer,
# **{"optimizer / param_groups / %d" % i : g
# for i, g in enumerate(self.optimizer.param_groups)},
# "optimizer / state" : self.optimizer.state,
# "optimizer / state_dict" : self.optimizer.state_dict(),
# })
# Initialize main params.
......@@ -1028,6 +1083,9 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
def get_world_model_params(self):
'''** FOR DEBUGGING. **'''
return [ p for m in self.models for p in m.parameters() ]
def get_world_model_grads(self):
'''** FOR DEBUGGING. **'''
return [ p.main_grad for p in self.get_world_model_params() ]
def get_main_params(self):
return [ g["params"][0] for g in self.optimizer.param_groups ]
......@@ -1075,20 +1133,25 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
for model_index, model in enumerate(self.models):
for dtype, gbuf_shard in self.model_gbuf_shards[model_index].items():
world_shards = gbuf_shard["world_all"]
gbuf = model._grad_buffers[dtype]
gbuf_views = []
for shard in world_shards:
gbuf_views.append(gbuf.data[shard.start:shard.end])
gbuf = model._grad_buffers[dtype].data
gbuf_views = [ gbuf[s.start:s.end] for s in world_shards ]
gbuf_view_items.append((model_index, dtype, gbuf_views))
# pax(0, {
# "world_shards" : world_shards,
# "gbuf_views" : gbuf_views,
# })
# pax(0, {"gbuf_view_items": gbuf_view_items})
return gbuf_view_items
def reduce_grads(self, model):
# >>>
timers = get_timers()
# <<<
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Sync word embedding params.
......@@ -1101,6 +1164,9 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
timers('backward-embedding-all-reduce').start()
if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
mpu.get_pipeline_model_parallel_world_size() > 1:
# >>>
raise Exception("hi.")
# <<<
if mpu.is_pipeline_first_stage(ignore_virtual=True):
unwrapped_model = model[0]
elif mpu.is_pipeline_last_stage(ignore_virtual=True):
......@@ -1116,6 +1182,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
if args.DDP_impl == 'local':
grad = word_embeddings_weight.main_grad
else:
raise Exception("only 'main_grad' supported for distrib-opt.")
grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
# +++
......@@ -1123,7 +1190,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# torch.distributed.all_reduce(grad_shard,
# group=mpu.get_embedding_group())
# <<<
timers('backward-embedding-all-reduce').stop()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Sync T5 position embedding params.
......@@ -1133,18 +1200,30 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Reduce-scatter.
data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_world_size = mpu.get_data_parallel_world_size()
data_parallel_group = mpu.get_data_parallel_group()
gbuf_view_items = self.get_model_grad_buffer_dp_views()
# pax(0, {"gbuf_views": [g for item in gbuf_view_items for g in item[2]]})
for model_index, dtype, gbuf_views in gbuf_view_items:
# coalesced /= mpu.get_data_parallel_world_size()
gbuf = self.models[model_index]._grad_buffers[dtype].data
torch.mul(gbuf.data, 1. / data_parallel_world_size, out = gbuf.data)
# gbuf_views = [ t / data_parallel_world_size for t in gbuf_views ]
# gbuf_d
# pax(0, {
# "data_parallel_world_size" : data_parallel_world_size,
# "gbuf" : tp(gbuf),
# })
torch.distributed.reduce_scatter(
gbuf_views[data_parallel_rank],
gbuf_views,
group = data_parallel_group,
)
# pax(0, {"gbuf_view_items": gbuf_view_items})
# pax(0, {"gbuf_views": [g for item in gbuf_view_items for g in item[2]]})
def gather_params(self):
......@@ -1161,24 +1240,12 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
group = data_parallel_group,
)
# Each model param now contains its updated values in it's
# Each model param now contains its updated values in its
# '.main_grad' field.
for param in self.param_gbuf_map:
param.detach().copy_(param.main_grad)
# pax(0, {
# "param" : tp(param),
# "main_grad" : tp(param.main_grad),
# # "grad" : tp(param.grad),
# })
# pax(1, {
# "data_parallel_rank" : data_parallel_rank,
# "main params" : self.get_main_params(),
# "model params / world" : self.get_world_model_params(),
# **{"gbuf_view_items / %d"%i:v[2] for i,v in enumerate(gbuf_view_items)},
# # "gbuf_view_item" : tp(gbuf_view[data_parallel_rank]),
# # "model params / local" : self.get_local_model_param_views(),
# })
# pax(0, {"gbuf_view_items": gbuf_view_items})
def _collect_main_grad_data_for_unscaling(self):
return [ g.data for g in self.get_main_grads() ]
......@@ -1199,51 +1266,29 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# Copy shard data.
main_view = main_param[main_shard.start:main_shard.end]
model_view = model_param.view(-1)[model_shard.start:model_shard.end]
# try:
main_view.detach().copy_(model_view)
# except:
# pax({
# "main_param" : tp(main_param),
# "model_param" : tp(model_param),
# "main_view" : tp(main_view),
# "model_view" : tp(model_view),
# "main_shard" : str(main_shard),
# "model_shard" : str(model_shard),
# })
# pax(0, {
# **{
# "opt_group_shards / %d" % i : s
# for i, s in enumerate(self.opt_group_shards)
# },
# "main_params" : self.get_main_params(),
# })
def _copy_model_grads_to_main_grads(self, ITERATION):
# >>>
model_grads = self.get_local_model_grad_views()
model_has_nan = self.has_nan_debug(model_grads)
if model_has_nan:
pax(1, {
"ITERATION" : ITERATION,
"model grads" : model_grads,
"model_has_nan" : model_has_nan,
"model params / local" : self.get_local_model_param_views(),
# "model params / world" : [ list(self.param_gbuf_map),
# "main grads" : self.get_main_grads(),
})
# <<<
for group_index, group_shard in enumerate(self.opt_group_shards):
for model_param, main_shard in group_shard["param_map"].items():
# Model shard.
model_index, dtype = self.param_gbuf_map[model_param]
model_shard = self.model_gbuf_shards \
[model_index][dtype]["param_map"][model_param]["gbuf_world"]
assert main_shard.size == model_shard.size
# pax(0, {
# "model_param" : tp(model_param),
# "main_shard" : str(main_shard),
# "param shard" : self.model_gbuf_shards \
# [model_index][dtype]["param_map"][model_param],
# })
# Copy from DDP's contiguous buffer to main shard's grad.
model_grad = self.models[model_index]._grad_buffers[dtype].data
main_grad = self.get_main_grad(group_index)
......@@ -1269,38 +1314,13 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# })
# >>>
# pax(1, {
# # "model_gbuf_shards" : self.model_gbuf_shards,
# **{
# "opt_group_shards / %d" % i : s
# for i, s in enumerate(self.opt_group_shards)
# },
# "main_grads" : self.get_main_grads(),
# })
# for group_index, main_grad in enumerate(self.get_main_grads()):
# # is_nan = torch.any(torch.isnan(main_grad)).item()
# if is_nan:
# # opt_group_shard = self.opt_group_shards[group_index]
# # param_views = []
# # for param, shard in opt_group_shard["param_map"].items():
# # ddd
# pax(0, {
# "opt_group_shard" : self.opt_group_shards[group_index],
# "param_map" : [ (str(p.shape), str(d)) for p, d in self.opt_group_shards[group_index]["param_map"].items() ],
# "gbufs" : [ b.data for m in self.models for d, b in m._grad_buffers.items() ],
# "group_index" : group_index,
# "main_param" : tp(self.get_main_param(group_index)),
# "main_grad" : tp(main_grad),
# "is_nan" : is_nan,
# })
main_grads = self.get_main_grads()
main_has_nan = self.has_nan_debug(main_grads)
if main_has_nan:
raise Exception("hi.")
# pax(1, {
# "model grads" : self.get_local_model_grad_views(),
# })
# if ITERATION == DEBUG_ITERATION:
# pax(0, {
# "** branch **" : "** fix. **",
# "ITERATION" : ITERATION,
# # "model grads" : self.get_world_model_grads(),
# "main_grads" : self.get_main_grads(),
# })
# <<<
......@@ -1340,27 +1360,12 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# "main_shard" : str(main_shard),
# })
# pax(0, {
# "model_gbuf_shards" : self.model_gbuf_shards,
# "opt_group_shards" : self.opt_group_shards,
# })
# >>>
for param in self.param_gbuf_map:
# is_nan = torch.any(torch.isnan(param)).item()
is_nan = not torch.all(torch.isfinite(param)).item()
if is_nan:
pax({
"param" : tp(param),
"is_nan" : is_nan,
})
if ITERATION == DEBUG_ITERATION:
pax(0, {
"** branch **" : "** fix. **",
"ITERATION" : ITERATION,
# "main params" : self.get_main_params(),
# "model params / local" : self.get_local_model_param_views(),
"model params" : [p for m in self.models for p in m.parameters()],
"model params" : self.get_world_model_params(),
})
# <<<
......
......@@ -432,7 +432,7 @@ def train_step(forward_step_func, data_iterator,
# >>>
# Reduce gradients. (with distributed optimizer option, optimizer
# now responsible for reducing gradients)
optimizer.reduce_gradients(model)
optimizer.reduce_grads(model)
# <<<
# >>>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment