Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
bcfeaf3b
Commit
bcfeaf3b
authored
Mar 17, 2021
by
Jiezhong Qiu
Browse files
fp16/fp32 optimizer has different data structure
parent
12b23ae3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
8 deletions
+18
-8
fmoe/megatron.py
fmoe/megatron.py
+18
-8
No files found.
fmoe/megatron.py
View file @
bcfeaf3b
...
@@ -451,9 +451,16 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n
...
@@ -451,9 +451,16 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n
# this parameter is not an expert parameter
# this parameter is not an expert parameter
# thus there is no need to save its state in current rank
# thus there is no need to save its state in current rank
# since it has been saved by data parallel rank 0
# since it has been saved by data parallel rank 0
state_dict
[
'optimizer'
][
'state'
].
pop
(
index
)
if
args
.
fp16
:
# fp16 optimizer may have empty state due to overflow
state_dict
[
'optimizer'
][
'optimizer'
][
'state'
].
pop
(
index
,
None
)
else
:
state_dict
[
'optimizer'
][
'state'
].
pop
(
index
)
index
+=
1
index
+=
1
state_dict
[
'optimizer'
].
pop
(
'param_groups'
)
if
args
.
fp16
:
state_dict
[
'optimizer'
][
'optimizer'
].
pop
(
'param_groups'
)
else
:
state_dict
[
'optimizer'
].
pop
(
'param_groups'
)
# Save.
# Save.
checkpoint_name
=
get_fmoe_checkpoint_name
(
args
.
save
,
iteration
)
checkpoint_name
=
get_fmoe_checkpoint_name
(
args
.
save
,
iteration
)
...
@@ -476,7 +483,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n
...
@@ -476,7 +483,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
def
merge_state_dict
(
state_dict_rank0
,
state_dict_local
):
def
merge_state_dict
(
state_dict_rank0
,
state_dict_local
,
fp16
):
"""merge two state dicts, one from data parallel rank 0,
"""merge two state dicts, one from data parallel rank 0,
another only contains expert states"""
another only contains expert states"""
from
megatron
import
print_rank_last
from
megatron
import
print_rank_last
...
@@ -494,12 +501,15 @@ def merge_state_dict(state_dict_rank0, state_dict_local):
...
@@ -494,12 +501,15 @@ def merge_state_dict(state_dict_rank0, state_dict_local):
before.sum={:7f}, after.sum={:7f}"
.
format
(
k
,
before
,
after
))
before.sum={:7f}, after.sum={:7f}"
.
format
(
k
,
before
,
after
))
merge_model
(
state_dict_rank0
[
'model'
],
state_dict_local
[
'model'
])
merge_model
(
state_dict_rank0
[
'model'
],
state_dict_local
[
'model'
])
for
k
,
v
in
state_dict_local
[
'optimizer'
][
'state'
].
items
():
optimizer_rank0
=
state_dict_rank0
[
'optimizer'
][
'optimizer'
]
if
fp16
else
state_dict_rank0
[
'optimizer'
]
optimizer_local
=
state_dict_local
[
'optimizer'
][
'optimizer'
]
if
fp16
else
state_dict_local
[
'optimizer'
]
for
k
,
v
in
optimizer_local
[
'state'
].
items
():
before
=
{
kk
:
vv
.
sum
().
item
()
\
before
=
{
kk
:
vv
.
sum
().
item
()
\
for
kk
,
vv
in
state_dict_rank0
[
'
optimizer
'
]
[
'state'
][
k
].
items
()}
for
kk
,
vv
in
optimizer
_rank0
[
'state'
][
k
].
items
()}
state_dict_rank0
[
'
optimizer
'
]
[
'state'
][
k
]
=
v
optimizer
_rank0
[
'state'
][
k
]
=
v
after
=
{
kk
:
vv
.
sum
().
item
()
\
after
=
{
kk
:
vv
.
sum
().
item
()
\
for
kk
,
vv
in
state_dict_rank0
[
'
optimizer
'
]
[
'state'
][
k
].
items
()}
for
kk
,
vv
in
optimizer
_rank0
[
'state'
][
k
].
items
()}
print_rank_last
(
"[merge optimizer] copy {},
\
print_rank_last
(
"[merge optimizer] copy {},
\
before.sum={}, after.sum={}"
.
format
(
k
,
str
(
before
),
str
(
after
)))
before.sum={}, after.sum={}"
.
format
(
k
,
str
(
before
),
str
(
after
)))
return
state_dict_rank0
return
state_dict_rank0
...
@@ -581,7 +591,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
...
@@ -581,7 +591,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
state_dict_rank0
=
load_state_dict
(
checkpoint_name_rank0
)
state_dict_rank0
=
load_state_dict
(
checkpoint_name_rank0
)
state_dict_local
=
load_state_dict
(
checkpoint_name_local
)
state_dict_local
=
load_state_dict
(
checkpoint_name_local
)
state_dict
=
merge_state_dict
(
state_dict_rank0
,
state_dict_local
)
state_dict
=
merge_state_dict
(
state_dict_rank0
,
state_dict_local
,
args
.
fp16
)
# set checkpoint version
# set checkpoint version
set_checkpoint_version
(
state_dict
.
get
(
'checkpoint_version'
,
0
))
set_checkpoint_version
(
state_dict
.
get
(
'checkpoint_version'
,
0
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment