Commit cbcd5579 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

fixed 'count zeros' for distrib opt

parent 67e23459
...@@ -222,9 +222,16 @@ def count_zeros_fp32(parameters): ...@@ -222,9 +222,16 @@ def count_zeros_fp32(parameters):
total_num_zeros = num_zeros + total_num_zeros total_num_zeros = num_zeros + total_num_zeros
# Sum across all model-parallel GPUs. # Sum across all model-parallel GPUs.
# >>>
if args.use_distributed_optimizer:
torch.distributed.all_reduce(total_num_zeros,
op=torch.distributed.ReduceOp.SUM)
else:
torch.distributed.all_reduce(total_num_zeros, torch.distributed.all_reduce(total_num_zeros,
op=torch.distributed.ReduceOp.SUM, op=torch.distributed.ReduceOp.SUM,
group=mpu.get_model_parallel_group()) group=mpu.get_model_parallel_group())
# <<<
total_num_zeros = total_num_zeros.item() total_num_zeros = total_num_zeros.item()
return total_num_zeros return total_num_zeros
...@@ -273,7 +273,7 @@ class BaseFloat16Optimizer(MegatronOptimizer): ...@@ -273,7 +273,7 @@ class BaseFloat16Optimizer(MegatronOptimizer):
return return
for r in range(torch.distributed.get_world_size()): for r in range(torch.distributed.get_world_size()):
if my_rank == r: if my_rank == r:
print(" + %4s; [r%d]; %s, %.12e." % ("fix" if args.use_distributed_optimizer else "main", my_rank, key, value)) print(" + %4s; [r%d]; %s, %.12e" % ("fix" if args.use_distributed_optimizer else "main", my_rank, key, value))
torch.distributed.barrier() torch.distributed.barrier()
torch.distributed.barrier() torch.distributed.barrier()
# if my_rank == 0: # if my_rank == 0:
...@@ -282,6 +282,26 @@ class BaseFloat16Optimizer(MegatronOptimizer): ...@@ -282,6 +282,26 @@ class BaseFloat16Optimizer(MegatronOptimizer):
# exit(0) # exit(0)
exit(0) exit(0)
def _debug_model(self, ITERATION, key, use_param):
tensors = [
(p.float() if use_param else p.main_grad.float())
for m in self.models for p in m.parameters()
]
# pax(0, {
# "params" : params,
# "params / abs" : [ torch.abs(p) for p in params ],
# "params / abs / sum" : [ torch.sum(torch.abs(p)) for p in params ],
# })
count = sum(t.nelement() for t in tensors)
return self.debug_general(
ITERATION,
"model/%s, %s [count %d]" % (
"param" if use_param else "grad",
key,
count,
),
sum(torch.sum(torch.abs(t)) for t in tensors).item() / count,
)
def _debug_main(self, ITERATION, key0, key1, f, ff): def _debug_main(self, ITERATION, key0, key1, f, ff):
count = sum( count = sum(
p.nelement() p.nelement()
...@@ -303,11 +323,16 @@ class BaseFloat16Optimizer(MegatronOptimizer): ...@@ -303,11 +323,16 @@ class BaseFloat16Optimizer(MegatronOptimizer):
# lambda p : p, # lambda p : p,
# torch.mean, # torch.mean,
# ) # )
def debug_main_param_sum(self, ITERATION, key): # def debug_main_param_sum(self, ITERATION, key):
def debug_model_param(self, ITERATION, key):
return self._debug_model(ITERATION, key, True)
def debug_model_grad(self, ITERATION, key):
return self._debug_model(ITERATION, key, False)
def debug_main_param(self, ITERATION, key):
return self._debug_main( return self._debug_main(
ITERATION, ITERATION,
key, key,
"param sum", "param", # sum",
# lambda p : p, # lambda p : p,
lambda p : torch.abs(p), lambda p : torch.abs(p),
torch.sum, torch.sum,
...@@ -320,11 +345,12 @@ class BaseFloat16Optimizer(MegatronOptimizer): ...@@ -320,11 +345,12 @@ class BaseFloat16Optimizer(MegatronOptimizer):
# lambda p : p.grad, # lambda p : p.grad,
# torch.mean, # torch.mean,
# ) # )
def debug_main_grad_sum(self, ITERATION, key): # def debug_main_grad_sum(self, ITERATION, key):
def debug_main_grad(self, ITERATION, key):
return self._debug_main( return self._debug_main(
ITERATION, ITERATION,
key, key,
"grad sum", "grad", # sum",
# lambda p : p.grad, # lambda p : p.grad,
lambda p : torch.abs(p.grad), lambda p : torch.abs(p.grad),
torch.sum, torch.sum,
...@@ -336,14 +362,21 @@ class BaseFloat16Optimizer(MegatronOptimizer): ...@@ -336,14 +362,21 @@ class BaseFloat16Optimizer(MegatronOptimizer):
timers = get_timers() timers = get_timers()
# >>>
# self.debug_model_param(ITERATION, "before copy grad.")
# self.debug_model_grad(ITERATION, "before copy grad.")
# <<<
# Copy gradients from model params to main params. # Copy gradients from model params to main params.
timers('optimizer-copy-to-main-grad').start() timers('optimizer-copy-to-main-grad').start()
self._copy_model_grads_to_main_grads(ITERATION) self._copy_model_grads_to_main_grads(ITERATION)
timers('optimizer-copy-to-main-grad').stop() timers('optimizer-copy-to-main-grad').stop()
# >>> # >>>
# self.debug_main_param_sum(ITERATION) # self.debug_model_param(ITERATION, "after copy grad.")
# self.debug_main_grad_sum(ITERATION) # self.debug_model_grad(ITERATION, "after copy grad.")
# self.debug_main_param(ITERATION, "after copy grad.")
# self.debug_main_grad(ITERATION, "after copy grad.")
# <<< # <<<
# Do unscale, check for inf, and update grad scaler only for # Do unscale, check for inf, and update grad scaler only for
...@@ -383,8 +416,8 @@ class BaseFloat16Optimizer(MegatronOptimizer): ...@@ -383,8 +416,8 @@ class BaseFloat16Optimizer(MegatronOptimizer):
self.optimizer.step() self.optimizer.step()
# >>> # >>>
# self.debug_main_param_sum(ITERATION, "after step.") # self.debug_main_param(ITERATION, "after step.")
self.debug_main_grad_sum(ITERATION, "after step.") # self.debug_main_grad(ITERATION, "after step.")
# <<< # <<<
# Update params from main params. # Update params from main params.
...@@ -393,8 +426,8 @@ class BaseFloat16Optimizer(MegatronOptimizer): ...@@ -393,8 +426,8 @@ class BaseFloat16Optimizer(MegatronOptimizer):
timers('optimizer-copy-main-to-model-params').stop() timers('optimizer-copy-main-to-model-params').stop()
# >>> # >>>
self.debug_main_param_sum(ITERATION, "after copy param.") # self.debug_main_param(ITERATION, "after copy param.")
self.debug_main_grad_sum(ITERATION, "after copy param.") # self.debug_main_grad(ITERATION, "after copy param.")
# <<< # <<<
# Successful update. # Successful update.
...@@ -1247,22 +1280,46 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1247,22 +1280,46 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
gbuf_view_items = self.get_model_grad_buffer_dp_views() gbuf_view_items = self.get_model_grad_buffer_dp_views()
# pax(0, {"gbuf_views": [g for item in gbuf_view_items for g in item[2]]}) # pax(0, {"gbuf_views": [g for item in gbuf_view_items for g in item[2]]})
# pax(0, {"gbufs": [
# g.data
# for m in self.models
# for g in m._grad_buffers.values()
# ]})
# >>>
# buffer_.data /= mpu.get_data_parallel_world_size()
# torch.distributed.all_reduce(
# buffer_.data, group=mpu.get_data_parallel_group())
# <<<
for model_index, dtype, gbuf_views in gbuf_view_items: for model_index, dtype, gbuf_views in gbuf_view_items:
# coalesced /= mpu.get_data_parallel_world_size() # coalesced /= mpu.get_data_parallel_world_size()
gbuf = self.models[model_index]._grad_buffers[dtype].data gbuf = self.models[model_index]._grad_buffers[dtype].data
torch.mul(gbuf.data, 1. / data_parallel_world_size, out = gbuf.data)
# gbuf_views = [ t / data_parallel_world_size for t in gbuf_views ] # >>>
# gbuf_d # ~~ distributed.py ~~
# gbuf /= data_parallel_world_size
# torch.distributed.all_reduce(gbuf, group=data_parallel_group)
# pax(0, { # pax(0, {
# "data_parallel_world_size" : data_parallel_world_size,
# "gbuf" : tp(gbuf), # "gbuf" : tp(gbuf),
# }) # })
# <<<
# torch.mul(gbuf.data, 1. / data_parallel_world_size, out = gbuf.data)
# gbuf_views = [ t / data_parallel_world_size for t in gbuf_views ]
gbuf /= data_parallel_world_size
# if 1:
torch.distributed.reduce_scatter( torch.distributed.reduce_scatter(
gbuf_views[data_parallel_rank], gbuf_views[data_parallel_rank],
gbuf_views, gbuf_views,
group = data_parallel_group, group = data_parallel_group,
) )
# else:
# torch.distributed.all_reduce(
# gbuf,
# group = data_parallel_group,
# )
# pax(0, {"gbuf_views": [g for item in gbuf_view_items for g in item[2]]}) # pax(0, {"gbuf_views": [g for item in gbuf_view_items for g in item[2]]})
......
...@@ -52,7 +52,9 @@ from megatron.utils import calc_params_l2_norm ...@@ -52,7 +52,9 @@ from megatron.utils import calc_params_l2_norm
from megatron.schedules import get_forward_backward_func from megatron.schedules import get_forward_backward_func
from megatron.utils import report_memory from megatron.utils import report_memory
# >>>
from lutil import pax
# <<<
def print_datetime(string): def print_datetime(string):
"""Note that this call will sync across all ranks.""" """Note that this call will sync across all ranks."""
...@@ -435,6 +437,21 @@ def train_step(forward_step_func, data_iterator, ...@@ -435,6 +437,21 @@ def train_step(forward_step_func, data_iterator,
optimizer.reduce_grads(model) optimizer.reduce_grads(model)
# <<< # <<<
# >>>
# r = mpu.get_data_parallel_rank()
# w = mpu.get_data_parallel_world_size()
# gbufs = []
# for m in model:
# for g in m._grad_buffers.values():
# t = g.data
# n = t.nelement()
# shard = int(n / w)
# start_index = r * shard
# end_index = min(n, start_index + shard)
# gbufs.append(t[start_index:end_index])
# pax(1, {"gbufs": gbufs})
# <<<
# >>> # >>>
# from lutil import pax # from lutil import pax
# pax(0, {"optimizer": optimizer}) # pax(0, {"optimizer": optimizer})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment