Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
90e394bc
"vscode:/vscode.git/clone" did not exist on "0142f6f35ac14082b6d2416e8ee447f4db3f140b"
Commit
90e394bc
authored
Mar 22, 2021
by
Jiezhong Qiu
Browse files
load/save fp32 main weight when fp16 training
parent
bcfeaf3b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
31 additions
and
4 deletions
+31
-4
fmoe/megatron.py
fmoe/megatron.py
+31
-4
No files found.
fmoe/megatron.py
View file @
90e394bc
...
...
@@ -443,7 +443,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n
if
not
args
.
no_save_optim
:
if
optimizer
is
not
None
:
state_dict
[
'optimizer'
]
=
optimizer
.
state_dict
()
inde
x
=
0
param_global_id
x
=
0
for
param_group
in
optimizer
.
optimizer
.
param_groups
:
for
param
in
param_group
[
'params'
]:
if
not
(
hasattr
(
param
,
'dp_comm'
)
and
\
...
...
@@ -453,12 +453,31 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n
# since it has been saved by data parallel rank 0
if
args
.
fp16
:
# fp16 optimizer may have empty state due to overflow
state_dict
[
'optimizer'
][
'optimizer'
][
'state'
].
pop
(
index
,
None
)
state_dict
[
'optimizer'
][
'optimizer'
][
'state'
].
pop
(
param_global_idx
,
None
)
else
:
state_dict
[
'optimizer'
][
'state'
].
pop
(
index
)
index
+=
1
state_dict
[
'optimizer'
][
'state'
].
pop
(
param_global_idx
)
param_global_idx
+=
1
if
args
.
fp16
:
state_dict
[
'optimizer'
][
'optimizer'
].
pop
(
'param_groups'
)
# fp32_from_fp16_params in state_dict is not a copy
# but a reference to optimizer.fp32_from_fp16_params,
# changing it in state_dict will change
# optimizer.fp32_from_fp16_params as well
# thus we create an empty fp32_from_fp16_params in state_dict
# and only insert expert parameters.
fp32_from_fp16_params
=
\
state_dict
[
'optimizer'
][
'fp32_from_fp16_params'
]
state_dict
[
'optimizer'
][
'fp32_from_fp16_params'
]
=
[]
for
param_group
in
fp32_from_fp16_params
:
param_group_copy
=
[]
for
param
in
param_group
:
param_copy
=
param
if
hasattr
(
param
,
'dp_comm'
)
\
and
param
.
dp_comm
==
expert_dp_comm
else
None
param_group_copy
.
append
(
param_copy
)
state_dict
[
'optimizer'
][
'fp32_from_fp16_params'
].
append
(
param_group_copy
)
else
:
state_dict
[
'optimizer'
].
pop
(
'param_groups'
)
...
...
@@ -512,6 +531,14 @@ def merge_state_dict(state_dict_rank0, state_dict_local, fp16):
for
kk
,
vv
in
optimizer_rank0
[
'state'
][
k
].
items
()}
print_rank_last
(
"[merge optimizer] copy {},
\
before.sum={}, after.sum={}"
.
format
(
k
,
str
(
before
),
str
(
after
)))
if
fp16
:
for
group_idx
,
param_group
in
enumerate
(
state_dict_local
[
'optimizer'
][
'fp32_from_fp16_params'
]):
for
param_in_group_idx
,
param
in
enumerate
(
param_group
):
if
param
is
not
None
:
state_dict_rank0
[
'optimizer'
][
'fp32_from_fp16_params'
][
group_idx
][
param_in_group_idx
]
=
param
print_rank_last
(
"[merge fp32_from_fp16_params] copy parameter ({:d}, {:d})"
.
format
(
group_idx
,
param_in_group_idx
))
return
state_dict_rank0
def
load_checkpoint
(
model
,
optimizer
,
lr_scheduler
,
load_arg
=
'load'
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment