Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
49a4678c
Commit
49a4678c
authored
Mar 16, 2021
by
Jiezhong Qiu
Browse files
save experts separately in each data parallel rank
parent
89d6c794
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
105 additions
and
0 deletions
+105
-0
fmoe/megatron.py
fmoe/megatron.py
+105
-0
No files found.
fmoe/megatron.py
View file @
49a4678c
...
@@ -3,8 +3,11 @@ The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two
...
@@ -3,8 +3,11 @@ The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two
lines of modification.
lines of modification.
See `examples/megatron` for usage instructions.
See `examples/megatron` for usage instructions.
"""
"""
import
os
import
math
import
math
import
numpy
as
np
import
numpy
as
np
import
random
from
collections
import
OrderedDict
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -361,3 +364,105 @@ class DistributedDataParallel(DistributedGroupedDataParallel):
...
@@ -361,3 +364,105 @@ class DistributedDataParallel(DistributedGroupedDataParallel):
Keep consitency with Megatron
Keep consitency with Megatron
"""
"""
return
self
.
module
.
load_state_dict
(
*
args
,
**
kwargs
)
return
self
.
module
.
load_state_dict
(
*
args
,
**
kwargs
)
def
get_checkpoint_name
(
checkpoints_path
,
iteration
,
release
=
False
):
"""A unified checkpoint name."""
from
megatron
import
mpu
if
release
:
directory
=
'release'
else
:
directory
=
'iter_{:07d}'
.
format
(
iteration
)
# Use both the tensor and pipeline MP rank.
if
mpu
.
get_pipeline_model_parallel_world_size
()
==
1
:
return
os
.
path
.
join
(
checkpoints_path
,
directory
,
'mp_rank_{:02d}_dp_rank_{:04d}'
.
format
(
mpu
.
get_tensor_model_parallel_rank
(),
mpu
.
get_data_parallel_rank
()
),
'model_optim_rng.pt'
)
return
os
.
path
.
join
(
checkpoints_path
,
directory
,
'mp_rank_{:02d}_{:03d}_dp_rank_{:04d}'
.
format
(
mpu
.
get_tensor_model_parallel_rank
(),
mpu
.
get_pipeline_model_parallel_rank
(),
mpu
.
get_data_parallel_rank
()
),
'model_optim_rng.pt'
)
def
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
):
"""Save a model checkpoint with expert parallel """
from
megatron
import
get_args
from
megatron
import
mpu
args
=
get_args
()
# Only rank zero of the data parallel writes to the disk.
if
isinstance
(
model
,
DistributedDataParallel
):
model
=
model
.
module
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'saving checkpoint at iteration {:7d} to {}'
.
format
(
iteration
,
args
.
save
),
flush
=
True
)
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
# Arguments, iteration, and model.
state_dict
=
{}
state_dict
[
'args'
]
=
args
state_dict
[
'checkpoint_version'
]
=
3.0
state_dict
[
'iteration'
]
=
iteration
keep_vars
=
False
if
mpu
.
get_data_parallel_rank
()
==
0
else
True
state_dict
[
'model'
]
=
model
.
state_dict_for_save_checkpoint
(
keep_vars
=
keep_vars
)
if
mpu
.
get_data_parallel_rank
()
!=
0
:
def
extract_expert_param
(
state_dict
,
expert_dp_comm
=
'none'
):
state_dict_new
=
state_dict
.
__class__
()
for
k
,
v
in
state_dict
.
items
():
# megatron uses both dict and OrderedDict in its state_dict
if
isinstance
(
v
,
OrderedDict
)
or
isinstance
(
v
,
dict
):
v_new
=
extract_expert_param
(
v
,
expert_dp_comm
)
if
len
(
v_new
):
state_dict_new
[
k
]
=
v_new
elif
hasattr
(
v
,
'dp_comm'
)
and
v
.
dp_comm
==
expert_dp_comm
:
state_dict_new
[
k
]
=
v
.
detach
()
return
state_dict_new
state_dict
[
'model'
]
=
extract_expert_param
(
state_dict
[
'model'
],
'none'
)
# Optimizer stuff.
if
not
args
.
no_save_optim
:
if
optimizer
is
not
None
:
state_dict
[
'optimizer'
]
=
optimizer
.
state_dict
()
if
lr_scheduler
is
not
None
:
state_dict
[
'lr_scheduler'
]
=
lr_scheduler
.
state_dict
()
# RNG states.
if
not
args
.
no_save_rng
:
state_dict
[
'random_rng_state'
]
=
random
.
getstate
()
state_dict
[
'np_rng_state'
]
=
np
.
random
.
get_state
()
state_dict
[
'torch_rng_state'
]
=
torch
.
get_rng_state
()
state_dict
[
'cuda_rng_state'
]
=
torch
.
cuda
.
get_rng_state
()
state_dict
[
'rng_tracker_states'
]
\
=
mpu
.
get_cuda_rng_tracker
().
get_states
()
# Save.
checkpoint_name
=
get_checkpoint_name
(
args
.
save
,
iteration
)
from
megatron.checkpointing
import
ensure_directory_exists
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
ensure_directory_exists
(
checkpoint_name
)
torch
.
save
(
state_dict
,
checkpoint_name
)
# Wait so everyone is done (necessary)
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
' successfully saved checkpoint at iteration {:7d} to {}'
.
format
(
iteration
,
args
.
save
),
flush
=
True
)
# And update the latest iteration
if
torch
.
distributed
.
get_rank
()
==
0
:
tracker_filename
=
get_checkpoint_tracker_filename
(
args
.
save
)
with
open
(
tracker_filename
,
'w'
)
as
f
:
f
.
write
(
str
(
iteration
))
# Wait so everyone is done (not necessary)
torch
.
distributed
.
barrier
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment