Commit f8fd34e3 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

fixed gather params; now copying all params, not just local dp subset.

parent cbcd5579
...@@ -155,14 +155,13 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, ITERATION=None): ...@@ -155,14 +155,13 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, ITERATION=None):
# >>> # >>>
from megatron import get_args from megatron import get_args
args = get_args() args = get_args()
if not args.use_distributed_optimizer: if args.use_distributed_optimizer:
torch.distributed.all_reduce(total_norm, torch.distributed.all_reduce(total_norm,
op=torch.distributed.ReduceOp.SUM, op=torch.distributed.ReduceOp.SUM)
group=mpu.get_model_parallel_group())
# +++
else: else:
torch.distributed.all_reduce(total_norm, torch.distributed.all_reduce(total_norm,
op=torch.distributed.ReduceOp.SUM) op=torch.distributed.ReduceOp.SUM,
group=mpu.get_model_parallel_group())
# <<< # <<<
total_norm = total_norm.item() ** (1.0 / norm_type) total_norm = total_norm.item() ** (1.0 / norm_type)
...@@ -223,9 +222,12 @@ def count_zeros_fp32(parameters): ...@@ -223,9 +222,12 @@ def count_zeros_fp32(parameters):
# Sum across all model-parallel GPUs. # Sum across all model-parallel GPUs.
# >>> # >>>
from megatron import get_args
args = get_args()
if args.use_distributed_optimizer: if args.use_distributed_optimizer:
torch.distributed.all_reduce(total_num_zeros, torch.distributed.all_reduce(total_num_zeros,
op=torch.distributed.ReduceOp.SUM) op=torch.distributed.ReduceOp.SUM)
# pax({"total_num_zeros": total_num_zeros.item()})
else: else:
torch.distributed.all_reduce(total_num_zeros, torch.distributed.all_reduce(total_num_zeros,
op=torch.distributed.ReduceOp.SUM, op=torch.distributed.ReduceOp.SUM,
......
...@@ -32,7 +32,7 @@ from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32 ...@@ -32,7 +32,7 @@ from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
# >>> # >>>
from lutil import pax, tp from lutil import pax, tp
DEBUG_ITERATION = 0 # 10 DEBUG_ITERATION = 2 # 10
# <<< # <<<
...@@ -273,7 +273,7 @@ class BaseFloat16Optimizer(MegatronOptimizer): ...@@ -273,7 +273,7 @@ class BaseFloat16Optimizer(MegatronOptimizer):
return return
for r in range(torch.distributed.get_world_size()): for r in range(torch.distributed.get_world_size()):
if my_rank == r: if my_rank == r:
print(" + %4s; [r%d]; %s, %.12e" % ("fix" if args.use_distributed_optimizer else "main", my_rank, key, value)) print(" + br/%s; [r%d, i%d]; %s, %.12e" % ("fix " if args.use_distributed_optimizer else "main", my_rank, ITERATION, key, value))
torch.distributed.barrier() torch.distributed.barrier()
torch.distributed.barrier() torch.distributed.barrier()
# if my_rank == 0: # if my_rank == 0:
...@@ -282,9 +282,11 @@ class BaseFloat16Optimizer(MegatronOptimizer): ...@@ -282,9 +282,11 @@ class BaseFloat16Optimizer(MegatronOptimizer):
# exit(0) # exit(0)
exit(0) exit(0)
def _debug_model(self, ITERATION, key, use_param): # def _debug_model(self, ITERATION, key, use_param):
def debug_model(self, ITERATION, key, use_grad):
use_grad = bool(use_grad)
tensors = [ tensors = [
(p.float() if use_param else p.main_grad.float()) (p.main_grad.float() if use_grad else p.float())
for m in self.models for p in m.parameters() for m in self.models for p in m.parameters()
] ]
# pax(0, { # pax(0, {
...@@ -296,65 +298,72 @@ class BaseFloat16Optimizer(MegatronOptimizer): ...@@ -296,65 +298,72 @@ class BaseFloat16Optimizer(MegatronOptimizer):
return self.debug_general( return self.debug_general(
ITERATION, ITERATION,
"model/%s, %s [count %d]" % ( "model/%s, %s [count %d]" % (
"param" if use_param else "grad", "grad" if use_grad else "param",
key, key,
count, count,
), ),
sum(torch.sum(torch.abs(t)) for t in tensors).item() / count, # sum(torch.sum(torch.abs(t)) for t in tensors).item() / count,
sum(torch.sum(torch.abs(t)) for t in tensors),
) )
def _debug_main(self, ITERATION, key0, key1, f, ff): # def debug_model_param(self, ITERATION, key):
count = sum( # return self._debug_model(ITERATION, key, True)
p.nelement() # def debug_model_grad(self, ITERATION, key):
for g in self.optimizer.param_groups # return self._debug_model(ITERATION, key, False)
for p in g["params"]
) # def _debug_main(self, ITERATION, key0, key1, f, ff):
return self.debug_general( # count = sum(
ITERATION, # p.nelement()
"main/%s, %s [count %d]" % (key1, key0, count), # for g in self.optimizer.param_groups
sum(ff(f(p)) # for p in g["params"]
for g in self.optimizer.param_groups # )
for p in g["params"]).item() / count, # return self.debug_general(
) # ITERATION,
# def debug_main_param_mean(self, ITERATION, key): # "main/%s, %s [count %d]" % (key1, key0, count),
# sum(ff(f(p))
# for g in self.optimizer.param_groups
# for p in g["params"]).item() / count,
# )
# def debug_main_param(self, ITERATION, key):
# return self._debug_main( # return self._debug_main(
# ITERATION, # ITERATION,
# key, # key,
# "param mean", # "param", # sum",
# lambda p : p, # # lambda p : p,
# torch.mean, # lambda p : torch.abs(p),
# torch.sum,
# ) # )
# def debug_main_param_sum(self, ITERATION, key): # def debug_main_grad(self, ITERATION, key):
def debug_model_param(self, ITERATION, key):
return self._debug_model(ITERATION, key, True)
def debug_model_grad(self, ITERATION, key):
return self._debug_model(ITERATION, key, False)
def debug_main_param(self, ITERATION, key):
return self._debug_main(
ITERATION,
key,
"param", # sum",
# lambda p : p,
lambda p : torch.abs(p),
torch.sum,
)
# def debug_main_grad_mean(self, ITERATION, key):
# return self._debug_main( # return self._debug_main(
# ITERATION, # ITERATION,
# key, # key,
# "grad mean", # "grad", # sum",
# lambda p : p.grad, # # lambda p : p.grad,
# torch.mean, # lambda p : torch.abs(p.grad),
# torch.sum,
# ) # )
# def debug_main_grad_sum(self, ITERATION, key): # def _debug_main(self, ITERATION, key, use_param):
def debug_main_grad(self, ITERATION, key): def debug_main(self, ITERATION, key, use_grad):
return self._debug_main( use_grad = bool(use_grad)
tensors = [
p.grad if use_grad else p
for g in self.optimizer.param_groups
for p in g["params"]
]
tensors = [ t.float() for t in tensors ]
count = sum(t.nelement() for t in tensors)
return self.debug_general(
ITERATION, ITERATION,
key, "main/%s, %s [count %d]" % (
"grad", # sum", "grad" if use_grad else "param",
# lambda p : p.grad, key,
lambda p : torch.abs(p.grad), count,
torch.sum, ),
sum(torch.sum(torch.abs(t)) for t in tensors),
) )
# def debug_main_param(self, ITERATION, key):
# return self._debug_main(ITERATION, key, True)
# def debug_main_grad(self, ITERATION, key):
# return self._debug_main(ITERATION, key, False)
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
@torch.no_grad() @torch.no_grad()
...@@ -365,6 +374,8 @@ class BaseFloat16Optimizer(MegatronOptimizer): ...@@ -365,6 +374,8 @@ class BaseFloat16Optimizer(MegatronOptimizer):
# >>> # >>>
# self.debug_model_param(ITERATION, "before copy grad.") # self.debug_model_param(ITERATION, "before copy grad.")
# self.debug_model_grad(ITERATION, "before copy grad.") # self.debug_model_grad(ITERATION, "before copy grad.")
# self.debug_main_param(ITERATION, "before copy grad.")
# self.debug_main_grad(ITERATION, "before copy grad.")
# <<< # <<<
# Copy gradients from model params to main params. # Copy gradients from model params to main params.
...@@ -373,10 +384,8 @@ class BaseFloat16Optimizer(MegatronOptimizer): ...@@ -373,10 +384,8 @@ class BaseFloat16Optimizer(MegatronOptimizer):
timers('optimizer-copy-to-main-grad').stop() timers('optimizer-copy-to-main-grad').stop()
# >>> # >>>
# self.debug_model_param(ITERATION, "after copy grad.") # self.debug_model(ITERATION, "after copy grad.", 0)
# self.debug_model_grad(ITERATION, "after copy grad.") # self.debug_main(ITERATION, "after copy grad.", 1)
# self.debug_main_param(ITERATION, "after copy grad.")
# self.debug_main_grad(ITERATION, "after copy grad.")
# <<< # <<<
# Do unscale, check for inf, and update grad scaler only for # Do unscale, check for inf, and update grad scaler only for
...@@ -412,12 +421,23 @@ class BaseFloat16Optimizer(MegatronOptimizer): ...@@ -412,12 +421,23 @@ class BaseFloat16Optimizer(MegatronOptimizer):
num_zeros_in_grad = self.count_zeros() if \ num_zeros_in_grad = self.count_zeros() if \
self.log_num_zeros_in_grad else None self.log_num_zeros_in_grad else None
# >>>
# param = self.optimizer.param_groups[0]["params"][0]
# pax(0, {
# "param" : tp(param),
# "grad" : tp(param.grad),
# })
# <<<
# >>>
# self.debug_main(ITERATION, "before step.", 0)
# <<<
# Step the optimizer. # Step the optimizer.
self.optimizer.step() self.optimizer.step()
# >>> # >>>
# self.debug_main_param(ITERATION, "after step.") # self.debug_main(ITERATION, "after step.", 0)
# self.debug_main_grad(ITERATION, "after step.")
# <<< # <<<
# Update params from main params. # Update params from main params.
...@@ -652,7 +672,7 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer): ...@@ -652,7 +672,7 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
# <<< # <<<
timers('backward-embedding-all-reduce').stop() timers('backward-embedding-all-reduce').stop()
def gather_params(self): def gather_params(self, ITERATION):
pass pass
def _copy_model_grads_to_main_grads(self, ITERATION): def _copy_model_grads_to_main_grads(self, ITERATION):
...@@ -1273,6 +1293,8 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1273,6 +1293,8 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Reduce-scatter. # Reduce-scatter.
# timers('backward-params-reduce-scatter').start()
timers('backward-params-all-reduce').start()
data_parallel_rank = mpu.get_data_parallel_rank() data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_world_size = mpu.get_data_parallel_world_size() data_parallel_world_size = mpu.get_data_parallel_world_size()
data_parallel_group = mpu.get_data_parallel_group() data_parallel_group = mpu.get_data_parallel_group()
...@@ -1292,6 +1314,11 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1292,6 +1314,11 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# buffer_.data, group=mpu.get_data_parallel_group()) # buffer_.data, group=mpu.get_data_parallel_group())
# <<< # <<<
# >>>
# self.debug_main_param(0, "before reduce scatter")
# self.debug_main_grad(0, "before reduce scatter")
# <<<
for model_index, dtype, gbuf_views in gbuf_view_items: for model_index, dtype, gbuf_views in gbuf_view_items:
# coalesced /= mpu.get_data_parallel_world_size() # coalesced /= mpu.get_data_parallel_world_size()
gbuf = self.models[model_index]._grad_buffers[dtype].data gbuf = self.models[model_index]._grad_buffers[dtype].data
...@@ -1320,10 +1347,18 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1320,10 +1347,18 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# gbuf, # gbuf,
# group = data_parallel_group, # group = data_parallel_group,
# ) # )
# timers('backward-params-reduce-scatter').stop()
timers('backward-params-all-reduce').stop()
# pax(0, {"gbuf_views": [g for item in gbuf_view_items for g in item[2]]}) # pax(0, {"gbuf_views": [g for item in gbuf_view_items for g in item[2]]})
def gather_params(self): def gather_params(self, ITERATION):
# >>>
timers = get_timers()
# <<<
timers('backward-params-all-gather').start()
data_parallel_rank = mpu.get_data_parallel_rank() data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_group = mpu.get_data_parallel_group() data_parallel_group = mpu.get_data_parallel_group()
...@@ -1340,11 +1375,32 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1340,11 +1375,32 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# Each model param now contains its updated values in its # Each model param now contains its updated values in its
# '.main_grad' field. # '.main_grad' field.
for param in self.param_gbuf_map: # for param in self.param_gbuf_map: # ... incomplete param list.
param.detach().copy_(param.main_grad) for model in self.models:
for dtype, param_map in model._grad_buffer_param_index_map.items():
for param in param_map:
param.detach().copy_(param.main_grad)
timers('backward-params-all-gather').stop()
# pax(0, {"gbuf_view_items": gbuf_view_items}) # pax(0, {"gbuf_view_items": gbuf_view_items})
# >>>
# self.debug_main(ITERATION, "after/inside gather_params.", 0)
# self.debug_model(ITERATION, "after/inside gather_params.", 0)
# if ITERATION == 2:
# pax(1, {
# "ITERATION" : ITERATION,
# # "gbufs" : [
# # tp(b.data)
# # for m in self.models
# # for b in m._grad_buffers.values()
# # ],
# "param_gbuf_map" : [ str(tuple(p.shape)) for p in self.param_gbuf_map ],
# })
# <<<
def _collect_main_grad_data_for_unscaling(self): def _collect_main_grad_data_for_unscaling(self):
return [ g.data for g in self.get_main_grads() ] return [ g.data for g in self.get_main_grads() ]
...@@ -1400,24 +1456,29 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1400,24 +1456,29 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# pax(0, { # pax(0, {
# "group_index" : group_index, # "group_index" : group_index,
# "group_shard" : group_shard, # "group_shard" : group_shard,
# "param" : tp(param), # # "param" : tp(param),
# "model_index" : model_index, # "model_index" : model_index,
# "gbuf_dtype" : str(gbuf_dtype), # "dtype" : str(dtype),
# "model_grad_tensor" : tp(model_grad_tensor), # "model_grad" : tp(model_grad),
# "main_grad_tensor" : tp(main_grad_tensor), # "main_grad" : tp(main_grad),
# "model_grad_view" : tp(model_grad_view), # "model_view" : tp(model_view),
# "main_grad_view" : tp(main_grad_view), # "main_view" : tp(main_view),
# "model_shard" : str(model_shard), # "model_shard" : str(model_shard),
# "main_shard" : str(main_shard), # "main_shard" : str(main_shard),
# }) # })
# >>> # >>>
# if ITERATION == DEBUG_ITERATION: # if 1 or ITERATION == DEBUG_ITERATION:
# pax(0, { # pax(0, {
# "** branch **" : "** fix. **", # "** branch **" : "** fix. **",
# "ITERATION" : ITERATION, # "ITERATION" : ITERATION,
# # "model grads" : self.get_world_model_grads(), # # "model grads" : self.get_world_model_grads(),
# "main_grads" : self.get_main_grads(), # "main_grads" : self.get_main_grads(),
# "group shards" : [
# "group %d; %s" % (grp_idx, main_shard)
# for grp_idx, grp_shard in enumerate(self.opt_group_shards)
# for model_param, main_shard in grp_shard["param_map"].items()
# ],
# }) # })
# <<< # <<<
......
...@@ -431,6 +431,10 @@ def train_step(forward_step_func, data_iterator, ...@@ -431,6 +431,10 @@ def train_step(forward_step_func, data_iterator,
if args.empty_unused_memory_level >= 1: if args.empty_unused_memory_level >= 1:
torch.cuda.empty_cache() torch.cuda.empty_cache()
# >>>
# optimizer.debug_model(ITERATION, "before reduce grads.", 0)
# <<<
# >>> # >>>
# Reduce gradients. (with distributed optimizer option, optimizer # Reduce gradients. (with distributed optimizer option, optimizer
# now responsible for reducing gradients) # now responsible for reducing gradients)
...@@ -465,7 +469,11 @@ def train_step(forward_step_func, data_iterator, ...@@ -465,7 +469,11 @@ def train_step(forward_step_func, data_iterator,
# >>> # >>>
# Gather params gradients. (with distributed optimizer option, optimizer # Gather params gradients. (with distributed optimizer option, optimizer
# now responsible for gathering updated params) # now responsible for gathering updated params)
optimizer.gather_params() optimizer.gather_params(ITERATION)
# <<<
# >>>
# optimizer.debug_model(ITERATION, "after gather params.", 0)
# <<< # <<<
# Update learning rate. # Update learning rate.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment