Commit 7dc8c475 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

feb 9 alpha

parent e724785f
......@@ -85,6 +85,18 @@ def get_megatron_optimizer(model,
scale_lr_cond,
lr_mult)
# >>>
# from lutil import pax
# pax(0, {
# "model" : model,
# "param_groups" : param_groups,
# "param_groups / 0" : param_groups[0],
# "param_groups / 0 / params" : param_groups[0]["params"],
# "param_groups / 1" : param_groups[1],
# "param_groups / 1 / params" : param_groups[1]["params"],
# })
# <<<
if args.optimizer == 'adam':
optimizer = Adam(param_groups,
lr=args.lr,
......
......@@ -259,14 +259,38 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
main_param.shared = param.shared
# Replace the optimizer params with the new fp32 copy.
param_group['params'][i] = main_param
# >>>
def debug():
from lutil import pax, tp
pax(0, {
"optimizer" : optimizer,
# "optimizer / state" : optimizer.state,
"optimizer / pg / 0" : optimizer.param_groups[0]["params"],
"optimizer / pg / 1" : optimizer.param_groups[1]["params"],
"param" : tp(param),
"param / hash" : hash(param),
"main_param" : tp(main_param),
"main_param / hash" : hash(main_param),
})
# <<<
# >>>
# debug()
# <<<
fp32_from_float16_params_this_group.append(main_param)
# Reset existing state dict key to the new main param.
if param in self.optimizer.state:
self.optimizer.state[main_param] \
= self.optimizer.state.pop(param)
# >>>
# debug()
# <<<
# fp32 params.
elif param.type() == 'torch.cuda.FloatTensor':
# >>>
from lutil import pax
pax(0, {"param": param})
# <<<
fp32_params_this_group.append(param)
param_group['params'][i] = param
......@@ -286,6 +310,29 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
# recast preexisting per-param state tensors
self.optimizer.load_state_dict(self.optimizer.state_dict())
# >>>
# from lutil import pax
# pax(0, {
# # "float16_groups / len" : [ len(g) for g in self.float16_groups ],
# # "fp32_from_float16_groups / len" :
# # [ len(g) for g in self.fp32_from_float16_groups ],
# # "float16_groups / 0" : self.float16_groups[0],
# # "float16_groups / 1" : self.float16_groups[1],
# # "fp32_from_float16_groups / 0" : self.fp32_from_float16_groups[0],
# # "fp32_from_float16_groups / 1" : self.fp32_from_float16_groups[1],
# # "fp32_from_float32_groups" : self.fp32_from_fp32_groups,
# "optimizer" : self.optimizer,
# # "optimizer / sd" : self.optimizer.state_dict(),
# # "optimizer / state" : self.optimizer.state_dict()["state"],
# # "optimizer / pg" : self.optimizer.state_dict()["param_groups"],
# # "optimizer / pg / 0" : self.optimizer.state_dict()["param_groups"][0],
# # "optimizer / pg / 1" : self.optimizer.state_dict()["param_groups"][1],
# "optimizer -> pg" : optimizer.param_groups,
# "optimizer -> pg / 0" : optimizer.param_groups[0]["params"],
# "optimizer -> pg / 1" : optimizer.param_groups[1]["params"],
# })
# <<<
def zero_grad(self, set_to_none=True):
"""We only need to zero the model related parameters, i.e.,
......@@ -435,6 +482,16 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
# Step the optimizer.
self.optimizer.step()
# >>>
# from lutil import pax, tp
# pax(0, {
# "optimizer / state" :
# { hash(k):tp(v) for k,v in self.optimizer.state.items() },
# "optimizer / state / len" : len(self.optimizer.state),
# "optimizer / state / 0" : list(self.optimizer.state.values())[0],
# })
# <<<
# Update params from main params.
timers('optimizer-copy-main-to-model-params').start()
self._copy_main_params_to_model_params()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment