Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
12b23ae3
Commit
12b23ae3
authored
Mar 17, 2021
by
Jiezhong Qiu
Browse files
remove unnecessary states when dp rank>0
parent
bc655118
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
1 addition
and
15 deletions
+1
-15
fmoe/megatron.py
fmoe/megatron.py
+1
-15
No files found.
fmoe/megatron.py
View file @
12b23ae3
...
...
@@ -420,9 +420,6 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n
# Arguments, iteration, and model.
state_dict
=
{}
state_dict
[
'args'
]
=
args
state_dict
[
'checkpoint_version'
]
=
3.0
state_dict
[
'iteration'
]
=
iteration
state_dict
[
'model'
]
=
model
.
state_dict_for_save_checkpoint
(
keep_vars
=
(
mpu
.
get_data_parallel_rank
()
>
0
))
...
...
@@ -456,18 +453,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n
# since it has been saved by data parallel rank 0
state_dict
[
'optimizer'
][
'state'
].
pop
(
index
)
index
+=
1
if
lr_scheduler
is
not
None
:
state_dict
[
'lr_scheduler'
]
=
lr_scheduler
.
state_dict
()
# RNG states.
if
not
args
.
no_save_rng
:
state_dict
[
'random_rng_state'
]
=
random
.
getstate
()
state_dict
[
'np_rng_state'
]
=
np
.
random
.
get_state
()
state_dict
[
'torch_rng_state'
]
=
torch
.
get_rng_state
()
state_dict
[
'cuda_rng_state'
]
=
torch
.
cuda
.
get_rng_state
()
state_dict
[
'rng_tracker_states'
]
\
=
mpu
.
get_cuda_rng_tracker
().
get_states
()
state_dict
[
'optimizer'
].
pop
(
'param_groups'
)
# Save.
checkpoint_name
=
get_fmoe_checkpoint_name
(
args
.
save
,
iteration
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment