Commit fe3cfd86 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

working: copy grads.

parent 5e29905f
...@@ -277,13 +277,17 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -277,13 +277,17 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
gbuf_range = model_gbuf_ranges[model_index][dtype] gbuf_range = model_gbuf_ranges[model_index][dtype]
param_range = gbuf_range["param_map"][model_param]["param"] param_range = gbuf_range["param_map"][model_param]["param"]
# >>>
assert param_range.size > 0
# <<<
# fp16, bf16 params. # fp16, bf16 params.
if model_param.type() in ['torch.cuda.HalfTensor', if model_param.type() in ['torch.cuda.HalfTensor',
'torch.cuda.BFloat16Tensor']: 'torch.cuda.BFloat16Tensor']:
# Clone model -> main. # Clone model -> main.
shard_model_param = \ shard_model_param = model_param.detach().view(-1) \
model_param.detach()[param_range.start:param_range.end] [param_range.start:param_range.end]
shard_main_param = shard_model_param.clone().float() shard_main_param = shard_model_param.clone().float()
mpu.copy_tensor_model_parallel_attributes( mpu.copy_tensor_model_parallel_attributes(
shard_model_param, model_param) shard_model_param, model_param)
...@@ -293,6 +297,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -293,6 +297,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
shard_model_param.shared = model_param.shared shard_model_param.shared = model_param.shared
shard_main_param.shared = model_param.shared shard_main_param.shared = model_param.shared
# >>>
assert shard_main_param.nelement() > 0, \
"param_range = %s." % param_range
# <<<
# Add to group. # Add to group.
full_float16_params_this_group.append(model_param) full_float16_params_this_group.append(model_param)
shard_float16_params_this_group.append(shard_model_param) shard_float16_params_this_group.append(shard_model_param)
...@@ -300,8 +309,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -300,8 +309,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# fp32 params. # fp32 params.
elif param.type() == 'torch.cuda.FloatTensor': elif param.type() == 'torch.cuda.FloatTensor':
shard_model_param = \ shard_model_param = model_param.view(-1) \
model_param[param_range.start:param_range.end] [param_range.start:param_range.end]
full_fp32_params_this_group.append(model_param) full_fp32_params_this_group.append(model_param)
shard_fp32_params_this_group.append(shard_model_param) shard_fp32_params_this_group.append(shard_model_param)
...@@ -661,35 +670,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -661,35 +670,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# <<< # <<<
# >>> # >>>
# def _copy_model_grads_to_main_grads(self):
# for group_index, group_range in enumerate(self.opt_group_ranges):
# for model_param, main_range in group_range["param_map"].items():
# # Model range.
# # model_index, dtype = self.param_gbuf_map[model_param]
# # model_range = self.model_gbuf_ranges \
# # [model_index][dtype]["param_map"][model_param]["gbuf_world"]
# model_range = self.get_model_param_range_map(model_param)["gbuf_world"]
# assert main_range.size == model_range.size
# # Copy from DDP's contiguous buffer to main shard's grad.
# model_grad = self.models[model_index]._grad_buffers[dtype].data
# main_grad = self.get_main_grad(group_index)
# # Copy sub-range within tensor.
# model_view = model_grad[model_range.start:model_range.end]
# main_view = main_grad[main_range.start:main_range.end]
# main_view.detach().copy_(model_view)
# def _copy_model_grads_to_main_grads(self):
# super()._copy_model_grads_to_main_grads()
# raise Exception("check main param '.grad'.")
# for group in self.optimizer.param_groups:
# for param in group["params"]:
# param.grad =
def _copy_model_grads_to_main_grads(self): def _copy_model_grads_to_main_grads(self):
# >>> # >>>
...@@ -708,38 +688,22 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -708,38 +688,22 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
param_range_map = self.get_model_param_range_map(full_model_param) param_range_map = self.get_model_param_range_map(full_model_param)
param_range = param_range_map["param"] param_range = param_range_map["param"]
assert param_range.size == shard_main_param.nelement()
full_model_grad = full_model_param.main_grad full_model_grad = full_model_param.main_grad
shard_model_grad = \ shard_model_grad = full_model_grad.view(-1) \
full_model_grad[param_range.start:param_range.end] [param_range.start:param_range.end]
shard_main_param.grad = shard_model_grad.float() shard_main_param.grad = shard_model_grad.float()
# >>> # print_seq([ "%s / %d, [%d] %s" % (
if full_model_param.nelement() != shard_main_param.nelement(): # k, i, len(g), ", ".join(str(p.nelement()) for p in g),
pax(0, { # ) for k, gs in [
"param_range_map" : param_range_map, # ("model", self.full_float16_groups),
"param_range" : param_range, # ("main", self.shard_fp32_from_float16_groups),
"full_model_param" : tp(full_model_param), # ] for i, g in enumerate(gs)])
"full_model_grad" : tp(full_model_grad),
"shard_model_grad" : tp(shard_model_grad),
"shard_main_grad" : tp(shard_main_param.grad),
"shard_main_param" : tp(shard_main_param),
})
# <<<
# print_seq("float16 groups: %d [%s], %d [%s]." % (
# len(self.full_float16_groups),
# # ",".join(str(len(g)) for g in self.full_float16_groups),
# ",".join(str(tuple(p.shape)) for gs in self.full_float16_groups for g in gs for p in g),
# len(self.shard_fp32_from_float16_groups),
# ",".join(str(len(g)) for g in self.shard_fp32_from_float16_groups),
# ))
gs = self.full_float16_groups
pax(0, {
**{"gs / %d" % i : len(g) for i, g in enumerate(gs)},
})
copy_group_grads(self.full_float16_groups, copy_group_grads(self.full_float16_groups,
self.shard_fp32_from_float16_groups) self.shard_fp32_from_float16_groups)
print_seq("hi.")
copy_group_grads(self.full_fp32_groups, copy_group_grads(self.full_fp32_groups,
self.shard_fp32_groups) self.shard_fp32_groups)
...@@ -750,7 +714,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -750,7 +714,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# for p in g["params"] # for p in g["params"]
# ]) # ])
# <<< # <<<
# <<< # <<<
# >>> # >>>
...@@ -778,17 +741,61 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -778,17 +741,61 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# def _copy_main_params_to_model_params(self): # def _copy_main_params_to_model_params(self):
# super()._copy_main_params_to_model_params() # super()._copy_main_params_to_model_params()
# raise Exception("check main param '.grad'.") # raise Exception("check main param '.grad'.")
# def _copy_main_params_to_model_params(self):
# raise Exception("hi.")
# # This only needs to be done for the float16 group.
# for model_group, main_group in zip(self.float16_groups,
# self.fp32_from_float16_groups):
# for model_param, main_param in zip(model_group, main_group):
# model_param.main_grad.detach().copy_(main_param)
# # For fp32 grads, we need to reset the grads to main grad.
# for group in self.fp32_groups:
# for param in group:
# param.main_grad.detach().copy_(param)
def _copy_main_params_to_model_params(self): def _copy_main_params_to_model_params(self):
raise Exception("hi.")
# This only needs to be done for the float16 group. # >>>
for model_group, main_group in zip(self.float16_groups, # print_seq([
self.fp32_from_float16_groups): # "grad = %s." % tp(p.grad)
for model_param, main_param in zip(model_group, main_group): # for g in self.optimizer.param_groups
model_param.main_grad.detach().copy_(main_param) # for p in g["params"]
# ])
# <<<
def copy_group_params(shard_main_groups, full_model_groups):
for shard_main_group, full_model_group in zip(shard_main_groups,
full_model_groups):
for shard_main_param, full_model_param in zip(shard_main_group,
full_model_group):
param_range_map = self.get_model_param_range_map(full_model_param)
param_range = param_range_map["param"]
assert param_range.size == shard_main_param.nelement()
full_model_grad = full_model_param.main_grad
shard_model_grad = full_model_grad.view(-1) \
[param_range.start:param_range.end]
shard_main_param.grad = shard_model_grad.float()
# For fp32 grads, we need to reset the grads to main grad. # print_seq([ "%s / %d, [%d] %s" % (
for group in self.fp32_groups: # k, i, len(g), ", ".join(str(p.nelement()) for p in g),
for param in group: # ) for k, gs in [
param.main_grad.detach().copy_(param) # ("model", self.full_float16_groups),
# ("main", self.shard_fp32_from_float16_groups),
# ] for i, g in enumerate(gs)])
copy_group_params(self.shard_fp32_from_float16_groups,
self.full_float16_groups)
copy_group_params(self.shard_fp32_groups,
self.full_fp32_groups)
# >>>
# print_seq([
# "grad = %s." % tp(p.grad)
# for g in self.optimizer.param_groups
# for p in g["params"]
# ])
# <<<
# <<< # <<<
...@@ -327,22 +327,6 @@ class MixedPrecisionOptimizer(MegatronOptimizer): ...@@ -327,22 +327,6 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
self._copy_model_params_to_main_params() self._copy_model_params_to_main_params()
# >>>
# def zero_grad(self, set_to_none=True):
# """We only need to zero the model related parameters, i.e.,
# float16_groups & fp32_from_fp32_groups. We additionally zero
# fp32_from_float16_groups as a memory optimization to reduce
# fragmentation; in the case of set_to_none==True, the space
# used by this field can be safely deallocated at this point."""
# for group in self.float16_groups:
# _zero_grad_group_helper(group, set_to_none)
# for group in self.fp32_from_float16_groups:
# _zero_grad_group_helper(group, set_to_none)
# for group in self.fp32_from_fp32_groups:
# _zero_grad_group_helper(group, set_to_none)
# <<<
def _unscale_main_grads_and_check_for_nan(self): def _unscale_main_grads_and_check_for_nan(self):
# Collect main grads. # Collect main grads.
...@@ -431,7 +415,6 @@ class MixedPrecisionOptimizer(MegatronOptimizer): ...@@ -431,7 +415,6 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
self._copy_model_grads_to_main_grads() self._copy_model_grads_to_main_grads()
timers('optimizer-copy-to-main-grad').stop() timers('optimizer-copy-to-main-grad').stop()
print_seq("hi.")
# Do unscale, check for inf, and update grad scaler only for # Do unscale, check for inf, and update grad scaler only for
# the case that grad scaler is provided. # the case that grad scaler is provided.
if self.grad_scaler: if self.grad_scaler:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment