"torchvision/vscode:/vscode.git/clone" did not exist on "af225a8ad529a2a337bca2faef6752199f637d38"
Commit cbcd5579 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

fixed 'count zeros' for distrib opt

parent 67e23459
......@@ -222,9 +222,16 @@ def count_zeros_fp32(parameters):
total_num_zeros = num_zeros + total_num_zeros
# Sum across all model-parallel GPUs.
torch.distributed.all_reduce(total_num_zeros,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_model_parallel_group())
# >>>
if args.use_distributed_optimizer:
torch.distributed.all_reduce(total_num_zeros,
op=torch.distributed.ReduceOp.SUM)
else:
torch.distributed.all_reduce(total_num_zeros,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_model_parallel_group())
# <<<
total_num_zeros = total_num_zeros.item()
return total_num_zeros
......@@ -273,7 +273,7 @@ class BaseFloat16Optimizer(MegatronOptimizer):
return
for r in range(torch.distributed.get_world_size()):
if my_rank == r:
print(" + %4s; [r%d]; %s, %.12e." % ("fix" if args.use_distributed_optimizer else "main", my_rank, key, value))
print(" + %4s; [r%d]; %s, %.12e" % ("fix" if args.use_distributed_optimizer else "main", my_rank, key, value))
torch.distributed.barrier()
torch.distributed.barrier()
# if my_rank == 0:
......@@ -282,6 +282,26 @@ class BaseFloat16Optimizer(MegatronOptimizer):
# exit(0)
exit(0)
def _debug_model(self, ITERATION, key, use_param):
tensors = [
(p.float() if use_param else p.main_grad.float())
for m in self.models for p in m.parameters()
]
# pax(0, {
# "params" : params,
# "params / abs" : [ torch.abs(p) for p in params ],
# "params / abs / sum" : [ torch.sum(torch.abs(p)) for p in params ],
# })
count = sum(t.nelement() for t in tensors)
return self.debug_general(
ITERATION,
"model/%s, %s [count %d]" % (
"param" if use_param else "grad",
key,
count,
),
sum(torch.sum(torch.abs(t)) for t in tensors).item() / count,
)
def _debug_main(self, ITERATION, key0, key1, f, ff):
count = sum(
p.nelement()
......@@ -303,11 +323,16 @@ class BaseFloat16Optimizer(MegatronOptimizer):
# lambda p : p,
# torch.mean,
# )
def debug_main_param_sum(self, ITERATION, key):
# def debug_main_param_sum(self, ITERATION, key):
def debug_model_param(self, ITERATION, key):
return self._debug_model(ITERATION, key, True)
def debug_model_grad(self, ITERATION, key):
return self._debug_model(ITERATION, key, False)
def debug_main_param(self, ITERATION, key):
return self._debug_main(
ITERATION,
key,
"param sum",
"param", # sum",
# lambda p : p,
lambda p : torch.abs(p),
torch.sum,
......@@ -320,11 +345,12 @@ class BaseFloat16Optimizer(MegatronOptimizer):
# lambda p : p.grad,
# torch.mean,
# )
def debug_main_grad_sum(self, ITERATION, key):
# def debug_main_grad_sum(self, ITERATION, key):
def debug_main_grad(self, ITERATION, key):
return self._debug_main(
ITERATION,
key,
"grad sum",
"grad", # sum",
# lambda p : p.grad,
lambda p : torch.abs(p.grad),
torch.sum,
......@@ -336,14 +362,21 @@ class BaseFloat16Optimizer(MegatronOptimizer):
timers = get_timers()
# >>>
# self.debug_model_param(ITERATION, "before copy grad.")
# self.debug_model_grad(ITERATION, "before copy grad.")
# <<<
# Copy gradients from model params to main params.
timers('optimizer-copy-to-main-grad').start()
self._copy_model_grads_to_main_grads(ITERATION)
timers('optimizer-copy-to-main-grad').stop()
# >>>
# self.debug_main_param_sum(ITERATION)
# self.debug_main_grad_sum(ITERATION)
# self.debug_model_param(ITERATION, "after copy grad.")
# self.debug_model_grad(ITERATION, "after copy grad.")
# self.debug_main_param(ITERATION, "after copy grad.")
# self.debug_main_grad(ITERATION, "after copy grad.")
# <<<
# Do unscale, check for inf, and update grad scaler only for
......@@ -383,8 +416,8 @@ class BaseFloat16Optimizer(MegatronOptimizer):
self.optimizer.step()
# >>>
# self.debug_main_param_sum(ITERATION, "after step.")
self.debug_main_grad_sum(ITERATION, "after step.")
# self.debug_main_param(ITERATION, "after step.")
# self.debug_main_grad(ITERATION, "after step.")
# <<<
# Update params from main params.
......@@ -393,8 +426,8 @@ class BaseFloat16Optimizer(MegatronOptimizer):
timers('optimizer-copy-main-to-model-params').stop()
# >>>
self.debug_main_param_sum(ITERATION, "after copy param.")
self.debug_main_grad_sum(ITERATION, "after copy param.")
# self.debug_main_param(ITERATION, "after copy param.")
# self.debug_main_grad(ITERATION, "after copy param.")
# <<<
# Successful update.
......@@ -1247,22 +1280,46 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
gbuf_view_items = self.get_model_grad_buffer_dp_views()
# pax(0, {"gbuf_views": [g for item in gbuf_view_items for g in item[2]]})
# pax(0, {"gbufs": [
# g.data
# for m in self.models
# for g in m._grad_buffers.values()
# ]})
# >>>
# buffer_.data /= mpu.get_data_parallel_world_size()
# torch.distributed.all_reduce(
# buffer_.data, group=mpu.get_data_parallel_group())
# <<<
for model_index, dtype, gbuf_views in gbuf_view_items:
# coalesced /= mpu.get_data_parallel_world_size()
gbuf = self.models[model_index]._grad_buffers[dtype].data
torch.mul(gbuf.data, 1. / data_parallel_world_size, out = gbuf.data)
# gbuf_views = [ t / data_parallel_world_size for t in gbuf_views ]
# gbuf_d
# >>>
# ~~ distributed.py ~~
# gbuf /= data_parallel_world_size
# torch.distributed.all_reduce(gbuf, group=data_parallel_group)
# pax(0, {
# "data_parallel_world_size" : data_parallel_world_size,
# "gbuf" : tp(gbuf),
# })
# <<<
# torch.mul(gbuf.data, 1. / data_parallel_world_size, out = gbuf.data)
# gbuf_views = [ t / data_parallel_world_size for t in gbuf_views ]
gbuf /= data_parallel_world_size
# if 1:
torch.distributed.reduce_scatter(
gbuf_views[data_parallel_rank],
gbuf_views,
group = data_parallel_group,
)
# else:
# torch.distributed.all_reduce(
# gbuf,
# group = data_parallel_group,
# )
# pax(0, {"gbuf_views": [g for item in gbuf_view_items for g in item[2]]})
......
......@@ -52,7 +52,9 @@ from megatron.utils import calc_params_l2_norm
from megatron.schedules import get_forward_backward_func
from megatron.utils import report_memory
# >>>
from lutil import pax
# <<<
def print_datetime(string):
"""Note that this call will sync across all ranks."""
......@@ -435,6 +437,21 @@ def train_step(forward_step_func, data_iterator,
optimizer.reduce_grads(model)
# <<<
# >>>
# r = mpu.get_data_parallel_rank()
# w = mpu.get_data_parallel_world_size()
# gbufs = []
# for m in model:
# for g in m._grad_buffers.values():
# t = g.data
# n = t.nelement()
# shard = int(n / w)
# start_index = r * shard
# end_index = min(n, start_index + shard)
# gbufs.append(t[start_index:end_index])
# pax(1, {"gbufs": gbufs})
# <<<
# >>>
# from lutil import pax
# pax(0, {"optimizer": optimizer})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment