Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
bcfeaf3b
Commit
bcfeaf3b
authored
Mar 17, 2021
by
Jiezhong Qiu
Browse files
fp16/fp32 optimizer has different data structure
parent
12b23ae3
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
8 deletions
+18
-8
fmoe/megatron.py
fmoe/megatron.py
+18
-8
No files found.
fmoe/megatron.py
View file @
bcfeaf3b
...
...
@@ -451,8 +451,15 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n
# this parameter is not an expert parameter
# thus there is no need to save its state in current rank
# since it has been saved by data parallel rank 0
if
args
.
fp16
:
# fp16 optimizer may have empty state due to overflow
state_dict
[
'optimizer'
][
'optimizer'
][
'state'
].
pop
(
index
,
None
)
else
:
state_dict
[
'optimizer'
][
'state'
].
pop
(
index
)
index
+=
1
if
args
.
fp16
:
state_dict
[
'optimizer'
][
'optimizer'
].
pop
(
'param_groups'
)
else
:
state_dict
[
'optimizer'
].
pop
(
'param_groups'
)
# Save.
...
...
@@ -476,7 +483,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n
torch
.
distributed
.
barrier
()
def
merge_state_dict
(
state_dict_rank0
,
state_dict_local
):
def
merge_state_dict
(
state_dict_rank0
,
state_dict_local
,
fp16
):
"""merge two state dicts, one from data parallel rank 0,
another only contains expert states"""
from
megatron
import
print_rank_last
...
...
@@ -494,12 +501,15 @@ def merge_state_dict(state_dict_rank0, state_dict_local):
before.sum={:7f}, after.sum={:7f}"
.
format
(
k
,
before
,
after
))
merge_model
(
state_dict_rank0
[
'model'
],
state_dict_local
[
'model'
])
for
k
,
v
in
state_dict_local
[
'optimizer'
][
'state'
].
items
():
optimizer_rank0
=
state_dict_rank0
[
'optimizer'
][
'optimizer'
]
if
fp16
else
state_dict_rank0
[
'optimizer'
]
optimizer_local
=
state_dict_local
[
'optimizer'
][
'optimizer'
]
if
fp16
else
state_dict_local
[
'optimizer'
]
for
k
,
v
in
optimizer_local
[
'state'
].
items
():
before
=
{
kk
:
vv
.
sum
().
item
()
\
for
kk
,
vv
in
state_dict_rank0
[
'
optimizer
'
]
[
'state'
][
k
].
items
()}
state_dict_rank0
[
'
optimizer
'
]
[
'state'
][
k
]
=
v
for
kk
,
vv
in
optimizer
_rank0
[
'state'
][
k
].
items
()}
optimizer
_rank0
[
'state'
][
k
]
=
v
after
=
{
kk
:
vv
.
sum
().
item
()
\
for
kk
,
vv
in
state_dict_rank0
[
'
optimizer
'
]
[
'state'
][
k
].
items
()}
for
kk
,
vv
in
optimizer
_rank0
[
'state'
][
k
].
items
()}
print_rank_last
(
"[merge optimizer] copy {},
\
before.sum={}, after.sum={}"
.
format
(
k
,
str
(
before
),
str
(
after
)))
return
state_dict_rank0
...
...
@@ -581,7 +591,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
state_dict_rank0
=
load_state_dict
(
checkpoint_name_rank0
)
state_dict_local
=
load_state_dict
(
checkpoint_name_local
)
state_dict
=
merge_state_dict
(
state_dict_rank0
,
state_dict_local
)
state_dict
=
merge_state_dict
(
state_dict_rank0
,
state_dict_local
,
args
.
fp16
)
# set checkpoint version
set_checkpoint_version
(
state_dict
.
get
(
'checkpoint_version'
,
0
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment