Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
7dc8c475
Commit
7dc8c475
authored
Feb 09, 2022
by
Lawrence McAfee
Browse files
feb 9 alpha
parent
e724785f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
69 additions
and
0 deletions
+69
-0
megatron/optimizer/__init__.py
megatron/optimizer/__init__.py
+12
-0
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+57
-0
No files found.
megatron/optimizer/__init__.py
View file @
7dc8c475
...
...
@@ -85,6 +85,18 @@ def get_megatron_optimizer(model,
scale_lr_cond
,
lr_mult
)
# >>>
# from lutil import pax
# pax(0, {
# "model" : model,
# "param_groups" : param_groups,
# "param_groups / 0" : param_groups[0],
# "param_groups / 0 / params" : param_groups[0]["params"],
# "param_groups / 1" : param_groups[1],
# "param_groups / 1 / params" : param_groups[1]["params"],
# })
# <<<
if
args
.
optimizer
==
'adam'
:
optimizer
=
Adam
(
param_groups
,
lr
=
args
.
lr
,
...
...
megatron/optimizer/optimizer.py
View file @
7dc8c475
...
...
@@ -259,14 +259,38 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
main_param
.
shared
=
param
.
shared
# Replace the optimizer params with the new fp32 copy.
param_group
[
'params'
][
i
]
=
main_param
# >>>
def
debug
():
from
lutil
import
pax
,
tp
pax
(
0
,
{
"optimizer"
:
optimizer
,
# "optimizer / state" : optimizer.state,
"optimizer / pg / 0"
:
optimizer
.
param_groups
[
0
][
"params"
],
"optimizer / pg / 1"
:
optimizer
.
param_groups
[
1
][
"params"
],
"param"
:
tp
(
param
),
"param / hash"
:
hash
(
param
),
"main_param"
:
tp
(
main_param
),
"main_param / hash"
:
hash
(
main_param
),
})
# <<<
# >>>
# debug()
# <<<
fp32_from_float16_params_this_group
.
append
(
main_param
)
# Reset existing state dict key to the new main param.
if
param
in
self
.
optimizer
.
state
:
self
.
optimizer
.
state
[
main_param
]
\
=
self
.
optimizer
.
state
.
pop
(
param
)
# >>>
# debug()
# <<<
# fp32 params.
elif
param
.
type
()
==
'torch.cuda.FloatTensor'
:
# >>>
from
lutil
import
pax
pax
(
0
,
{
"param"
:
param
})
# <<<
fp32_params_this_group
.
append
(
param
)
param_group
[
'params'
][
i
]
=
param
...
...
@@ -286,6 +310,29 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
# recast preexisting per-param state tensors
self
.
optimizer
.
load_state_dict
(
self
.
optimizer
.
state_dict
())
# >>>
# from lutil import pax
# pax(0, {
# # "float16_groups / len" : [ len(g) for g in self.float16_groups ],
# # "fp32_from_float16_groups / len" :
# # [ len(g) for g in self.fp32_from_float16_groups ],
# # "float16_groups / 0" : self.float16_groups[0],
# # "float16_groups / 1" : self.float16_groups[1],
# # "fp32_from_float16_groups / 0" : self.fp32_from_float16_groups[0],
# # "fp32_from_float16_groups / 1" : self.fp32_from_float16_groups[1],
# # "fp32_from_float32_groups" : self.fp32_from_fp32_groups,
# "optimizer" : self.optimizer,
# # "optimizer / sd" : self.optimizer.state_dict(),
# # "optimizer / state" : self.optimizer.state_dict()["state"],
# # "optimizer / pg" : self.optimizer.state_dict()["param_groups"],
# # "optimizer / pg / 0" : self.optimizer.state_dict()["param_groups"][0],
# # "optimizer / pg / 1" : self.optimizer.state_dict()["param_groups"][1],
# "optimizer -> pg" : optimizer.param_groups,
# "optimizer -> pg / 0" : optimizer.param_groups[0]["params"],
# "optimizer -> pg / 1" : optimizer.param_groups[1]["params"],
# })
# <<<
def
zero_grad
(
self
,
set_to_none
=
True
):
"""We only need to zero the model related parameters, i.e.,
...
...
@@ -435,6 +482,16 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
# Step the optimizer.
self
.
optimizer
.
step
()
# >>>
# from lutil import pax, tp
# pax(0, {
# "optimizer / state" :
# { hash(k):tp(v) for k,v in self.optimizer.state.items() },
# "optimizer / state / len" : len(self.optimizer.state),
# "optimizer / state / 0" : list(self.optimizer.state.values())[0],
# })
# <<<
# Update params from main params.
timers
(
'optimizer-copy-main-to-model-params'
).
start
()
self
.
_copy_main_params_to_model_params
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment