Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
a372eb66
Commit
a372eb66
authored
Mar 22, 2021
by
Jiezhong Qiu
Browse files
fix pylint
parent
d8124b80
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
4 deletions
+10
-4
fmoe/megatron.py
fmoe/megatron.py
+10
-4
No files found.
fmoe/megatron.py
View file @
a372eb66
...
@@ -238,11 +238,12 @@ def get_fmoe_checkpoint_name(checkpoints_path, iteration,
...
@@ -238,11 +238,12 @@ def get_fmoe_checkpoint_name(checkpoints_path, iteration,
),
),
'model_optim_rng.pt'
)
'model_optim_rng.pt'
)
def
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
,
expert_dp_comm
=
'none'
):
def
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
):
"""Save a model checkpoint with expert parallel """
"""Save a model checkpoint with expert parallel """
# TODO: update patch
# TODO: update patch
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron
import
mpu
expert_dp_comm
=
'none'
if
mpu
.
get_data_parallel_rank
()
==
0
:
if
mpu
.
get_data_parallel_rank
()
==
0
:
# at dp rank 0, we still follows the native load_checkpoint by megatron
# at dp rank 0, we still follows the native load_checkpoint by megatron
...
@@ -362,8 +363,10 @@ def merge_state_dict(state_dict_rank0, state_dict_local, fp16):
...
@@ -362,8 +363,10 @@ def merge_state_dict(state_dict_rank0, state_dict_local, fp16):
before.sum={:7f}, after.sum={:7f}"
.
format
(
k
,
before
,
after
))
before.sum={:7f}, after.sum={:7f}"
.
format
(
k
,
before
,
after
))
merge_model
(
state_dict_rank0
[
'model'
],
state_dict_local
[
'model'
])
merge_model
(
state_dict_rank0
[
'model'
],
state_dict_local
[
'model'
])
optimizer_rank0
=
state_dict_rank0
[
'optimizer'
][
'optimizer'
]
if
fp16
else
state_dict_rank0
[
'optimizer'
]
optimizer_rank0
=
state_dict_rank0
[
'optimizer'
][
'optimizer'
]
\
optimizer_local
=
state_dict_local
[
'optimizer'
][
'optimizer'
]
if
fp16
else
state_dict_local
[
'optimizer'
]
if
fp16
else
state_dict_rank0
[
'optimizer'
]
optimizer_local
=
state_dict_local
[
'optimizer'
][
'optimizer'
]
\
if
fp16
else
state_dict_local
[
'optimizer'
]
for
k
,
v
in
optimizer_local
[
'state'
].
items
():
for
k
,
v
in
optimizer_local
[
'state'
].
items
():
before
=
{
kk
:
vv
.
sum
().
item
()
\
before
=
{
kk
:
vv
.
sum
().
item
()
\
...
@@ -389,7 +392,10 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
...
@@ -389,7 +392,10 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron
import
print_rank_last
from
megatron
import
print_rank_last
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
,
set_checkpoint_version
,
check_checkpoint_args
,
update_num_microbatches
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
from
megatron.checkpointing
import
set_checkpoint_version
from
megatron.checkpointing
import
check_checkpoint_args
from
megatron.checkpointing
import
update_num_microbatches
if
mpu
.
get_data_parallel_rank
()
==
0
:
if
mpu
.
get_data_parallel_rank
()
==
0
:
# at dp rank 0, we still follow the native load_checkpoint by megatron
# at dp rank 0, we still follow the native load_checkpoint by megatron
from
megatron.checkpointing
import
load_checkpoint
as
load_checkpoint_native
from
megatron.checkpointing
import
load_checkpoint
as
load_checkpoint_native
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment