Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
49a4678c
"vscode:/vscode.git/clone" did not exist on "97c9e11768292c8f2732e2f4c9cde72a604c936b"
Commit
49a4678c
authored
Mar 16, 2021
by
Jiezhong Qiu
Browse files
save experts separately in each data parallel rank
parent
89d6c794
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
105 additions
and
0 deletions
+105
-0
fmoe/megatron.py
fmoe/megatron.py
+105
-0
No files found.
fmoe/megatron.py
View file @
49a4678c
...
...
@@ -3,8 +3,11 @@ The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two
lines of modification.
See `examples/megatron` for usage instructions.
"""
import
os
import
math
import
numpy
as
np
import
random
from
collections
import
OrderedDict
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
...
...
@@ -361,3 +364,105 @@ class DistributedDataParallel(DistributedGroupedDataParallel):
Keep consitency with Megatron
"""
return
self
.
module
.
load_state_dict
(
*
args
,
**
kwargs
)
def
get_checkpoint_name
(
checkpoints_path
,
iteration
,
release
=
False
):
"""A unified checkpoint name."""
from
megatron
import
mpu
if
release
:
directory
=
'release'
else
:
directory
=
'iter_{:07d}'
.
format
(
iteration
)
# Use both the tensor and pipeline MP rank.
if
mpu
.
get_pipeline_model_parallel_world_size
()
==
1
:
return
os
.
path
.
join
(
checkpoints_path
,
directory
,
'mp_rank_{:02d}_dp_rank_{:04d}'
.
format
(
mpu
.
get_tensor_model_parallel_rank
(),
mpu
.
get_data_parallel_rank
()
),
'model_optim_rng.pt'
)
return
os
.
path
.
join
(
checkpoints_path
,
directory
,
'mp_rank_{:02d}_{:03d}_dp_rank_{:04d}'
.
format
(
mpu
.
get_tensor_model_parallel_rank
(),
mpu
.
get_pipeline_model_parallel_rank
(),
mpu
.
get_data_parallel_rank
()
),
'model_optim_rng.pt'
)
def
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
):
"""Save a model checkpoint with expert parallel """
from
megatron
import
get_args
from
megatron
import
mpu
args
=
get_args
()
# Only rank zero of the data parallel writes to the disk.
if
isinstance
(
model
,
DistributedDataParallel
):
model
=
model
.
module
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'saving checkpoint at iteration {:7d} to {}'
.
format
(
iteration
,
args
.
save
),
flush
=
True
)
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
# Arguments, iteration, and model.
state_dict
=
{}
state_dict
[
'args'
]
=
args
state_dict
[
'checkpoint_version'
]
=
3.0
state_dict
[
'iteration'
]
=
iteration
keep_vars
=
False
if
mpu
.
get_data_parallel_rank
()
==
0
else
True
state_dict
[
'model'
]
=
model
.
state_dict_for_save_checkpoint
(
keep_vars
=
keep_vars
)
if
mpu
.
get_data_parallel_rank
()
!=
0
:
def
extract_expert_param
(
state_dict
,
expert_dp_comm
=
'none'
):
state_dict_new
=
state_dict
.
__class__
()
for
k
,
v
in
state_dict
.
items
():
# megatron uses both dict and OrderedDict in its state_dict
if
isinstance
(
v
,
OrderedDict
)
or
isinstance
(
v
,
dict
):
v_new
=
extract_expert_param
(
v
,
expert_dp_comm
)
if
len
(
v_new
):
state_dict_new
[
k
]
=
v_new
elif
hasattr
(
v
,
'dp_comm'
)
and
v
.
dp_comm
==
expert_dp_comm
:
state_dict_new
[
k
]
=
v
.
detach
()
return
state_dict_new
state_dict
[
'model'
]
=
extract_expert_param
(
state_dict
[
'model'
],
'none'
)
# Optimizer stuff.
if
not
args
.
no_save_optim
:
if
optimizer
is
not
None
:
state_dict
[
'optimizer'
]
=
optimizer
.
state_dict
()
if
lr_scheduler
is
not
None
:
state_dict
[
'lr_scheduler'
]
=
lr_scheduler
.
state_dict
()
# RNG states.
if
not
args
.
no_save_rng
:
state_dict
[
'random_rng_state'
]
=
random
.
getstate
()
state_dict
[
'np_rng_state'
]
=
np
.
random
.
get_state
()
state_dict
[
'torch_rng_state'
]
=
torch
.
get_rng_state
()
state_dict
[
'cuda_rng_state'
]
=
torch
.
cuda
.
get_rng_state
()
state_dict
[
'rng_tracker_states'
]
\
=
mpu
.
get_cuda_rng_tracker
().
get_states
()
# Save.
checkpoint_name
=
get_checkpoint_name
(
args
.
save
,
iteration
)
from
megatron.checkpointing
import
ensure_directory_exists
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
ensure_directory_exists
(
checkpoint_name
)
torch
.
save
(
state_dict
,
checkpoint_name
)
# Wait so everyone is done (necessary)
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
' successfully saved checkpoint at iteration {:7d} to {}'
.
format
(
iteration
,
args
.
save
),
flush
=
True
)
# And update the latest iteration
if
torch
.
distributed
.
get_rank
()
==
0
:
tracker_filename
=
get_checkpoint_tracker_filename
(
args
.
save
)
with
open
(
tracker_filename
,
'w'
)
as
f
:
f
.
write
(
str
(
iteration
))
# Wait so everyone is done (not necessary)
torch
.
distributed
.
barrier
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment