Commit f8fd34e3 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

fixed gather params; now copying all params, not just local dp subset.

parent cbcd5579
......@@ -155,14 +155,13 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, ITERATION=None):
# >>>
from megatron import get_args
args = get_args()
if not args.use_distributed_optimizer:
if args.use_distributed_optimizer:
torch.distributed.all_reduce(total_norm,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_model_parallel_group())
# +++
op=torch.distributed.ReduceOp.SUM)
else:
torch.distributed.all_reduce(total_norm,
op=torch.distributed.ReduceOp.SUM)
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_model_parallel_group())
# <<<
total_norm = total_norm.item() ** (1.0 / norm_type)
......@@ -223,9 +222,12 @@ def count_zeros_fp32(parameters):
# Sum across all model-parallel GPUs.
# >>>
from megatron import get_args
args = get_args()
if args.use_distributed_optimizer:
torch.distributed.all_reduce(total_num_zeros,
op=torch.distributed.ReduceOp.SUM)
# pax({"total_num_zeros": total_num_zeros.item()})
else:
torch.distributed.all_reduce(total_num_zeros,
op=torch.distributed.ReduceOp.SUM,
......
......@@ -32,7 +32,7 @@ from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
# >>>
from lutil import pax, tp
DEBUG_ITERATION = 0 # 10
DEBUG_ITERATION = 2 # 10
# <<<
......@@ -273,7 +273,7 @@ class BaseFloat16Optimizer(MegatronOptimizer):
return
for r in range(torch.distributed.get_world_size()):
if my_rank == r:
print(" + %4s; [r%d]; %s, %.12e" % ("fix" if args.use_distributed_optimizer else "main", my_rank, key, value))
print(" + br/%s; [r%d, i%d]; %s, %.12e" % ("fix " if args.use_distributed_optimizer else "main", my_rank, ITERATION, key, value))
torch.distributed.barrier()
torch.distributed.barrier()
# if my_rank == 0:
......@@ -282,9 +282,11 @@ class BaseFloat16Optimizer(MegatronOptimizer):
# exit(0)
exit(0)
def _debug_model(self, ITERATION, key, use_param):
# def _debug_model(self, ITERATION, key, use_param):
def debug_model(self, ITERATION, key, use_grad):
use_grad = bool(use_grad)
tensors = [
(p.float() if use_param else p.main_grad.float())
(p.main_grad.float() if use_grad else p.float())
for m in self.models for p in m.parameters()
]
# pax(0, {
......@@ -296,65 +298,72 @@ class BaseFloat16Optimizer(MegatronOptimizer):
return self.debug_general(
ITERATION,
"model/%s, %s [count %d]" % (
"param" if use_param else "grad",
"grad" if use_grad else "param",
key,
count,
),
sum(torch.sum(torch.abs(t)) for t in tensors).item() / count,
# sum(torch.sum(torch.abs(t)) for t in tensors).item() / count,
sum(torch.sum(torch.abs(t)) for t in tensors),
)
def _debug_main(self, ITERATION, key0, key1, f, ff):
count = sum(
p.nelement()
for g in self.optimizer.param_groups
for p in g["params"]
)
return self.debug_general(
ITERATION,
"main/%s, %s [count %d]" % (key1, key0, count),
sum(ff(f(p))
for g in self.optimizer.param_groups
for p in g["params"]).item() / count,
)
# def debug_main_param_mean(self, ITERATION, key):
# def debug_model_param(self, ITERATION, key):
# return self._debug_model(ITERATION, key, True)
# def debug_model_grad(self, ITERATION, key):
# return self._debug_model(ITERATION, key, False)
# def _debug_main(self, ITERATION, key0, key1, f, ff):
# count = sum(
# p.nelement()
# for g in self.optimizer.param_groups
# for p in g["params"]
# )
# return self.debug_general(
# ITERATION,
# "main/%s, %s [count %d]" % (key1, key0, count),
# sum(ff(f(p))
# for g in self.optimizer.param_groups
# for p in g["params"]).item() / count,
# )
# def debug_main_param(self, ITERATION, key):
# return self._debug_main(
# ITERATION,
# key,
# "param mean",
# lambda p : p,
# torch.mean,
# "param", # sum",
# # lambda p : p,
# lambda p : torch.abs(p),
# torch.sum,
# )
# def debug_main_param_sum(self, ITERATION, key):
def debug_model_param(self, ITERATION, key):
return self._debug_model(ITERATION, key, True)
def debug_model_grad(self, ITERATION, key):
return self._debug_model(ITERATION, key, False)
def debug_main_param(self, ITERATION, key):
return self._debug_main(
ITERATION,
key,
"param", # sum",
# lambda p : p,
lambda p : torch.abs(p),
torch.sum,
)
# def debug_main_grad_mean(self, ITERATION, key):
# def debug_main_grad(self, ITERATION, key):
# return self._debug_main(
# ITERATION,
# key,
# "grad mean",
# lambda p : p.grad,
# torch.mean,
# "grad", # sum",
# # lambda p : p.grad,
# lambda p : torch.abs(p.grad),
# torch.sum,
# )
# def debug_main_grad_sum(self, ITERATION, key):
def debug_main_grad(self, ITERATION, key):
return self._debug_main(
# def _debug_main(self, ITERATION, key, use_param):
def debug_main(self, ITERATION, key, use_grad):
use_grad = bool(use_grad)
tensors = [
p.grad if use_grad else p
for g in self.optimizer.param_groups
for p in g["params"]
]
tensors = [ t.float() for t in tensors ]
count = sum(t.nelement() for t in tensors)
return self.debug_general(
ITERATION,
key,
"grad", # sum",
# lambda p : p.grad,
lambda p : torch.abs(p.grad),
torch.sum,
"main/%s, %s [count %d]" % (
"grad" if use_grad else "param",
key,
count,
),
sum(torch.sum(torch.abs(t)) for t in tensors),
)
# def debug_main_param(self, ITERATION, key):
# return self._debug_main(ITERATION, key, True)
# def debug_main_grad(self, ITERATION, key):
# return self._debug_main(ITERATION, key, False)
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
@torch.no_grad()
......@@ -365,6 +374,8 @@ class BaseFloat16Optimizer(MegatronOptimizer):
# >>>
# self.debug_model_param(ITERATION, "before copy grad.")
# self.debug_model_grad(ITERATION, "before copy grad.")
# self.debug_main_param(ITERATION, "before copy grad.")
# self.debug_main_grad(ITERATION, "before copy grad.")
# <<<
# Copy gradients from model params to main params.
......@@ -373,10 +384,8 @@ class BaseFloat16Optimizer(MegatronOptimizer):
timers('optimizer-copy-to-main-grad').stop()
# >>>
# self.debug_model_param(ITERATION, "after copy grad.")
# self.debug_model_grad(ITERATION, "after copy grad.")
# self.debug_main_param(ITERATION, "after copy grad.")
# self.debug_main_grad(ITERATION, "after copy grad.")
# self.debug_model(ITERATION, "after copy grad.", 0)
# self.debug_main(ITERATION, "after copy grad.", 1)
# <<<
# Do unscale, check for inf, and update grad scaler only for
......@@ -412,12 +421,23 @@ class BaseFloat16Optimizer(MegatronOptimizer):
num_zeros_in_grad = self.count_zeros() if \
self.log_num_zeros_in_grad else None
# >>>
# param = self.optimizer.param_groups[0]["params"][0]
# pax(0, {
# "param" : tp(param),
# "grad" : tp(param.grad),
# })
# <<<
# >>>
# self.debug_main(ITERATION, "before step.", 0)
# <<<
# Step the optimizer.
self.optimizer.step()
# >>>
# self.debug_main_param(ITERATION, "after step.")
# self.debug_main_grad(ITERATION, "after step.")
# self.debug_main(ITERATION, "after step.", 0)
# <<<
# Update params from main params.
......@@ -652,7 +672,7 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
# <<<
timers('backward-embedding-all-reduce').stop()
def gather_params(self):
def gather_params(self, ITERATION):
pass
def _copy_model_grads_to_main_grads(self, ITERATION):
......@@ -1273,6 +1293,8 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Reduce-scatter.
# timers('backward-params-reduce-scatter').start()
timers('backward-params-all-reduce').start()
data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_world_size = mpu.get_data_parallel_world_size()
data_parallel_group = mpu.get_data_parallel_group()
......@@ -1292,6 +1314,11 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# buffer_.data, group=mpu.get_data_parallel_group())
# <<<
# >>>
# self.debug_main_param(0, "before reduce scatter")
# self.debug_main_grad(0, "before reduce scatter")
# <<<
for model_index, dtype, gbuf_views in gbuf_view_items:
# coalesced /= mpu.get_data_parallel_world_size()
gbuf = self.models[model_index]._grad_buffers[dtype].data
......@@ -1320,10 +1347,18 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# gbuf,
# group = data_parallel_group,
# )
# timers('backward-params-reduce-scatter').stop()
timers('backward-params-all-reduce').stop()
# pax(0, {"gbuf_views": [g for item in gbuf_view_items for g in item[2]]})
def gather_params(self):
def gather_params(self, ITERATION):
# >>>
timers = get_timers()
# <<<
timers('backward-params-all-gather').start()
data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_group = mpu.get_data_parallel_group()
......@@ -1340,11 +1375,32 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# Each model param now contains its updated values in its
# '.main_grad' field.
for param in self.param_gbuf_map:
param.detach().copy_(param.main_grad)
# for param in self.param_gbuf_map: # ... incomplete param list.
for model in self.models:
for dtype, param_map in model._grad_buffer_param_index_map.items():
for param in param_map:
param.detach().copy_(param.main_grad)
timers('backward-params-all-gather').stop()
# pax(0, {"gbuf_view_items": gbuf_view_items})
# >>>
# self.debug_main(ITERATION, "after/inside gather_params.", 0)
# self.debug_model(ITERATION, "after/inside gather_params.", 0)
# if ITERATION == 2:
# pax(1, {
# "ITERATION" : ITERATION,
# # "gbufs" : [
# # tp(b.data)
# # for m in self.models
# # for b in m._grad_buffers.values()
# # ],
# "param_gbuf_map" : [ str(tuple(p.shape)) for p in self.param_gbuf_map ],
# })
# <<<
def _collect_main_grad_data_for_unscaling(self):
return [ g.data for g in self.get_main_grads() ]
......@@ -1400,24 +1456,29 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# pax(0, {
# "group_index" : group_index,
# "group_shard" : group_shard,
# "param" : tp(param),
# # "param" : tp(param),
# "model_index" : model_index,
# "gbuf_dtype" : str(gbuf_dtype),
# "model_grad_tensor" : tp(model_grad_tensor),
# "main_grad_tensor" : tp(main_grad_tensor),
# "model_grad_view" : tp(model_grad_view),
# "main_grad_view" : tp(main_grad_view),
# "dtype" : str(dtype),
# "model_grad" : tp(model_grad),
# "main_grad" : tp(main_grad),
# "model_view" : tp(model_view),
# "main_view" : tp(main_view),
# "model_shard" : str(model_shard),
# "main_shard" : str(main_shard),
# })
# >>>
# if ITERATION == DEBUG_ITERATION:
# if 1 or ITERATION == DEBUG_ITERATION:
# pax(0, {
# "** branch **" : "** fix. **",
# "ITERATION" : ITERATION,
# # "model grads" : self.get_world_model_grads(),
# "main_grads" : self.get_main_grads(),
# "group shards" : [
# "group %d; %s" % (grp_idx, main_shard)
# for grp_idx, grp_shard in enumerate(self.opt_group_shards)
# for model_param, main_shard in grp_shard["param_map"].items()
# ],
# })
# <<<
......
......@@ -431,6 +431,10 @@ def train_step(forward_step_func, data_iterator,
if args.empty_unused_memory_level >= 1:
torch.cuda.empty_cache()
# >>>
# optimizer.debug_model(ITERATION, "before reduce grads.", 0)
# <<<
# >>>
# Reduce gradients. (with distributed optimizer option, optimizer
# now responsible for reducing gradients)
......@@ -465,7 +469,11 @@ def train_step(forward_step_func, data_iterator,
# >>>
# Gather params gradients. (with distributed optimizer option, optimizer
# now responsible for gathering updated params)
optimizer.gather_params()
optimizer.gather_params(ITERATION)
# <<<
# >>>
# optimizer.debug_model(ITERATION, "after gather params.", 0)
# <<<
# Update learning rate.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment