Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
f5a5d31a
Commit
f5a5d31a
authored
Mar 17, 2021
by
Jiezhong Qiu
Browse files
avoid saveing duplicated optim state in dp rank>0
parent
c844413b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
3 deletions
+16
-3
fmoe/megatron.py
fmoe/megatron.py
+16
-3
No files found.
fmoe/megatron.py
View file @
f5a5d31a
...
@@ -390,7 +390,7 @@ def get_checkpoint_name(checkpoints_path, iteration,
...
@@ -390,7 +390,7 @@ def get_checkpoint_name(checkpoints_path, iteration,
),
),
'model_optim_rng.pt'
)
'model_optim_rng.pt'
)
def
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
):
def
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
,
expert_dp_comm
=
'none'
):
"""Save a model checkpoint with expert parallel """
"""Save a model checkpoint with expert parallel """
# TODO: update patch
# TODO: update patch
from
megatron
import
get_args
from
megatron
import
get_args
...
@@ -398,6 +398,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
...
@@ -398,6 +398,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
args
=
get_args
()
args
=
get_args
()
# Only rank zero of the data parallel writes to the disk.
# Only rank zero of the data parallel writes to the disk.
if
isinstance
(
model
,
DistributedDataParallel
):
if
isinstance
(
model
,
DistributedDataParallel
):
model
=
model
.
module
model
=
model
.
module
...
@@ -414,7 +415,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
...
@@ -414,7 +415,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
state_dict
[
'model'
]
=
model
.
state_dict_for_save_checkpoint
(
state_dict
[
'model'
]
=
model
.
state_dict_for_save_checkpoint
(
keep_vars
=
(
mpu
.
get_data_parallel_rank
()
>
0
))
keep_vars
=
(
mpu
.
get_data_parallel_rank
()
>
0
))
if
mpu
.
get_data_parallel_rank
()
!=
0
:
if
mpu
.
get_data_parallel_rank
()
>
0
:
def
extract_expert_param
(
state_dict
,
expert_dp_comm
=
'none'
):
def
extract_expert_param
(
state_dict
,
expert_dp_comm
=
'none'
):
state_dict_new
=
state_dict
.
__class__
()
state_dict_new
=
state_dict
.
__class__
()
...
@@ -430,12 +431,24 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
...
@@ -430,12 +431,24 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
state_dict
[
'model'
]
=
extract_expert_param
(
state_dict
[
'model'
]
=
extract_expert_param
(
state_dict
[
'model'
],
state_dict
[
'model'
],
expert_dp_comm
=
'none'
)
expert_dp_comm
)
# Optimizer stuff.
# Optimizer stuff.
if
not
args
.
no_save_optim
:
if
not
args
.
no_save_optim
:
if
optimizer
is
not
None
:
if
optimizer
is
not
None
:
state_dict
[
'optimizer'
]
=
optimizer
.
state_dict
()
state_dict
[
'optimizer'
]
=
optimizer
.
state_dict
()
if
mpu
.
get_data_parallel_rank
()
>
0
:
index
=
0
for
param_group
in
optimizer
.
optimizer
.
param_groups
:
for
param
in
param_group
[
'params'
]:
if
not
(
hasattr
(
param
,
'dp_comm'
)
and
\
param
.
dp_comm
==
expert_dp_comm
):
# this parameter is not an expert parameter
# thus there is no need to save its state in current rank
# since it has been saved by data parallel rank 0
state_dict
[
'optimizer'
][
'state'
].
pop
(
index
)
index
+=
1
if
lr_scheduler
is
not
None
:
if
lr_scheduler
is
not
None
:
state_dict
[
'lr_scheduler'
]
=
lr_scheduler
.
state_dict
()
state_dict
[
'lr_scheduler'
]
=
lr_scheduler
.
state_dict
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment