Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
6ca840b5
Commit
6ca840b5
authored
Mar 17, 2021
by
Jiezhong Qiu
Browse files
load checkpoint for expert parallel
parent
1c69da9c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
212 additions
and
25 deletions
+212
-25
fmoe/megatron.py
fmoe/megatron.py
+212
-25
No files found.
fmoe/megatron.py
View file @
6ca840b5
...
@@ -4,6 +4,7 @@ lines of modification.
...
@@ -4,6 +4,7 @@ lines of modification.
See `examples/megatron` for usage instructions.
See `examples/megatron` for usage instructions.
"""
"""
import
os
import
os
import
sys
import
math
import
math
import
random
import
random
from
collections
import
OrderedDict
from
collections
import
OrderedDict
...
@@ -207,10 +208,15 @@ class DistributedDataParallel(DistributedGroupedDataParallel):
...
@@ -207,10 +208,15 @@ class DistributedDataParallel(DistributedGroupedDataParallel):
"""
"""
return
self
.
module
.
load_state_dict
(
*
args
,
**
kwargs
)
return
self
.
module
.
load_state_dict
(
*
args
,
**
kwargs
)
def
get_checkpoint_name
(
checkpoints_path
,
iteration
,
def
get_
fmoe_
checkpoint_name
(
checkpoints_path
,
iteration
,
release
=
False
):
release
=
False
,
data_parallel_rank
=-
1
):
"""A unified checkpoint name
.
"""
"""A unified checkpoint name
, allowing specifying a data parallel rank
"""
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.checkpointing
import
get_checkpoint_name
if
data_parallel_rank
==
-
1
:
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
if
data_parallel_rank
==
0
:
return
get_checkpoint_name
(
checkpoints_path
,
iteration
,
release
)
if
release
:
if
release
:
directory
=
'release'
directory
=
'release'
...
@@ -221,14 +227,14 @@ def get_checkpoint_name(checkpoints_path, iteration,
...
@@ -221,14 +227,14 @@ def get_checkpoint_name(checkpoints_path, iteration,
return
os
.
path
.
join
(
checkpoints_path
,
directory
,
return
os
.
path
.
join
(
checkpoints_path
,
directory
,
'mp_rank_{:02d}_dp_rank_{:04d}'
.
format
(
'mp_rank_{:02d}_dp_rank_{:04d}'
.
format
(
mpu
.
get_tensor_model_parallel_rank
(),
mpu
.
get_tensor_model_parallel_rank
(),
mpu
.
get_
data_parallel_rank
()
data_parallel_rank
),
),
'model_optim_rng.pt'
)
'model_optim_rng.pt'
)
return
os
.
path
.
join
(
checkpoints_path
,
directory
,
return
os
.
path
.
join
(
checkpoints_path
,
directory
,
'mp_rank_{:02d}_{:03d}_dp_rank_{:04d}'
.
format
(
'mp_rank_{:02d}_{:03d}_dp_rank_{:04d}'
.
format
(
mpu
.
get_tensor_model_parallel_rank
(),
mpu
.
get_tensor_model_parallel_rank
(),
mpu
.
get_pipeline_model_parallel_rank
(),
mpu
.
get_pipeline_model_parallel_rank
(),
mpu
.
get_
data_parallel_rank
()
data_parallel_rank
),
),
'model_optim_rng.pt'
)
'model_optim_rng.pt'
)
...
@@ -238,8 +244,13 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n
...
@@ -238,8 +244,13 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron
import
mpu
args
=
get_args
()
if
mpu
.
get_data_parallel_rank
()
==
0
:
# at dp rank 0, we still follows the native load_checkpoint by megatron
from
megatron.checkpointing
import
save_checkpoint
as
save_checkpoint_native
save_checkpoint_native
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
return
args
=
get_args
()
# Only rank zero of the data parallel writes to the disk.
# Only rank zero of the data parallel writes to the disk.
if
isinstance
(
model
,
DistributedDataParallel
):
if
isinstance
(
model
,
DistributedDataParallel
):
...
@@ -257,29 +268,26 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n
...
@@ -257,29 +268,26 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n
state_dict
[
'model'
]
=
model
.
state_dict_for_save_checkpoint
(
state_dict
[
'model'
]
=
model
.
state_dict_for_save_checkpoint
(
keep_vars
=
(
mpu
.
get_data_parallel_rank
()
>
0
))
keep_vars
=
(
mpu
.
get_data_parallel_rank
()
>
0
))
if
mpu
.
get_data_parallel_rank
()
>
0
:
def
extract_expert_param
(
state_dict
,
expert_dp_comm
=
'none'
):
state_dict_new
=
state_dict
.
__class__
()
def
extract_expert_param
(
state_dict
,
expert_dp_comm
=
'none'
):
for
k
,
v
in
state_dict
.
items
():
state_dict_new
=
state_dict
.
__class__
()
# megatron uses both dict and OrderedDict in its state_dict
for
k
,
v
in
state_dict
.
items
():
if
isinstance
(
v
,
(
OrderedDict
,
dict
)):
# megatron uses both dict and OrderedDict in its state_dict
v_new
=
extract_expert_param
(
v
,
expert_dp_comm
)
if
isinstance
(
v
,
(
OrderedDict
,
dict
)):
if
len
(
v_new
)
>
0
:
v_new
=
extract_expert_param
(
v
,
expert_dp_comm
)
state_dict_new
[
k
]
=
v_new
if
len
(
v_new
)
>
0
:
elif
hasattr
(
v
,
'dp_comm'
)
and
v
.
dp_comm
==
expert_dp_comm
:
state_dict_new
[
k
]
=
v_new
state_dict_new
[
k
]
=
v
.
detach
()
elif
hasattr
(
v
,
'dp_comm'
)
and
v
.
dp_comm
==
expert_dp_comm
:
return
state_dict_new
state_dict_new
[
k
]
=
v
.
detach
()
return
state_dict_new
state_dict
[
'model'
]
=
extract_expert_param
(
state_dict
[
'model'
],
state_dict
[
'model'
]
=
extract_expert_param
(
expert_dp_comm
)
state_dict
[
'model'
],
expert_dp_comm
)
# Optimizer stuff.
# Optimizer stuff.
if
not
args
.
no_save_optim
:
if
not
args
.
no_save_optim
:
if
optimizer
is
not
None
:
if
optimizer
is
not
None
:
state_dict
[
'optimizer'
]
=
optimizer
.
state_dict
()
state_dict
[
'optimizer'
]
=
optimizer
.
state_dict
()
if
mpu
.
get_data_parallel_rank
()
>
0
:
index
=
0
index
=
0
for
param_group
in
optimizer
.
optimizer
.
param_groups
:
for
param_group
in
optimizer
.
optimizer
.
param_groups
:
for
param
in
param_group
[
'params'
]:
for
param
in
param_group
[
'params'
]:
...
@@ -304,7 +312,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n
...
@@ -304,7 +312,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n
=
mpu
.
get_cuda_rng_tracker
().
get_states
()
=
mpu
.
get_cuda_rng_tracker
().
get_states
()
# Save.
# Save.
checkpoint_name
=
get_checkpoint_name
(
args
.
save
,
iteration
)
checkpoint_name
=
get_
fmoe_
checkpoint_name
(
args
.
save
,
iteration
)
from
megatron.checkpointing
import
ensure_directory_exists
from
megatron.checkpointing
import
ensure_directory_exists
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
ensure_directory_exists
(
checkpoint_name
)
ensure_directory_exists
(
checkpoint_name
)
...
@@ -322,3 +330,182 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n
...
@@ -322,3 +330,182 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n
f
.
write
(
str
(
iteration
))
f
.
write
(
str
(
iteration
))
# Wait so everyone is done (not necessary)
# Wait so everyone is done (not necessary)
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
def
merge_state_dict
(
state_dict_rank0
,
state_dict_local
):
"""merge two state dicts, one from data parallel rank 0,
another only contains expert states"""
from
megatron
import
print_rank_last
def
merge_model
(
state_dict_rank0
,
state_dict_local
):
for
k
,
v
in
state_dict_local
.
items
():
# megatron uses both dict and OrderedDict in its state_dict
if
isinstance
(
v
,
(
OrderedDict
,
dict
)):
print_rank_last
(
"[merge model] go recursively to {}"
.
format
(
k
))
merge_model
(
state_dict_rank0
[
k
],
v
)
else
:
before
=
state_dict_rank0
[
k
].
sum
().
item
()
state_dict_rank0
[
k
]
=
v
after
=
state_dict_rank0
[
k
].
sum
().
item
()
print_rank_last
(
"[merge model] copy parameter {},
\
before.sum={:7f}, after.sum={:7f}"
.
format
(
k
,
before
,
after
))
merge_model
(
state_dict_rank0
[
'model'
],
state_dict_local
[
'model'
])
for
k
,
v
in
state_dict_local
[
'optimizer'
][
'state'
].
items
():
before
=
{
kk
:
vv
.
sum
().
item
()
\
for
kk
,
vv
in
state_dict_rank0
[
'optimizer'
][
'state'
][
k
].
items
()}
state_dict_rank0
[
'optimizer'
][
'state'
][
k
]
=
v
after
=
{
kk
:
vv
.
sum
().
item
()
\
for
kk
,
vv
in
state_dict_rank0
[
'optimizer'
][
'state'
][
k
].
items
()}
print_rank_last
(
"[merge optimizer] copy {},
\
before.sum={}, after.sum={}"
.
format
(
k
,
str
(
before
),
str
(
after
)))
return
state_dict_rank0
def
load_checkpoint
(
model
,
optimizer
,
lr_scheduler
,
load_arg
=
'load'
):
"""Load a model checkpoint and return the iteration."""
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron
import
print_rank_last
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
,
set_checkpoint_version
,
check_checkpoint_args
,
update_num_microbatches
if
mpu
.
get_data_parallel_rank
()
==
0
:
# at dp rank 0, we still follow the native load_checkpoint by megatron
from
megatron.checkpointing
import
load_checkpoint
as
load_checkpoint_native
return
load_checkpoint_native
(
model
,
optimizer
,
lr_scheduler
,
load_arg
)
args
=
get_args
()
load_dir
=
getattr
(
args
,
load_arg
)
if
isinstance
(
model
,
DistributedDataParallel
):
model
=
model
.
module
# Read the tracker file and set the iteration.
tracker_filename
=
get_checkpoint_tracker_filename
(
load_dir
)
# If no tracker file, return iretation zero.
if
not
os
.
path
.
isfile
(
tracker_filename
):
print_rank_last
(
'WARNING: could not find the metadata file {} '
.
format
(
tracker_filename
))
print_rank_last
(
' will not load any checkpoints and will start from '
'random'
)
return
0
# Otherwise, read the tracker file and either set the iteration or
# mark it as a release checkpoint.
iteration
=
0
release
=
False
with
open
(
tracker_filename
,
'r'
)
as
f
:
metastring
=
f
.
read
().
strip
()
try
:
iteration
=
int
(
metastring
)
except
ValueError
:
release
=
metastring
==
'release'
if
not
release
:
print_rank_last
(
'ERROR: Invalid metadata file {}. Exiting'
.
format
(
tracker_filename
))
sys
.
exit
()
assert
iteration
>
0
or
release
,
'error parsing metadata file {}'
.
format
(
tracker_filename
)
# Checkpoint.
checkpoint_name_rank0
=
get_fmoe_checkpoint_name
(
load_dir
,
iteration
,
release
,
0
)
checkpoint_name_local
=
get_fmoe_checkpoint_name
(
load_dir
,
iteration
,
release
,
mpu
.
get_data_parallel_rank
())
print_rank_last
(
' loading checkpoint at rank 0 from {} and rank {} from {} at iteration {}, will merge them later'
.
format
(
checkpoint_name_rank0
,
mpu
.
get_data_parallel_rank
(),
checkpoint_name_local
,
iteration
))
# Load the checkpoint.
def
load_state_dict
(
checkpoint_name
):
try
:
state_dict
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
except
ModuleNotFoundError
:
from
megatron.fp16_deprecated
import
loss_scaler
# For backward compatibility.
print_rank_last
(
' > deserializing using the old code structure ...'
)
sys
.
modules
[
'fp16.loss_scaler'
]
=
sys
.
modules
[
'megatron.fp16_deprecated.loss_scaler'
]
sys
.
modules
[
'megatron.fp16.loss_scaler'
]
=
sys
.
modules
[
'megatron.fp16_deprecated.loss_scaler'
]
state_dict
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
sys
.
modules
.
pop
(
'fp16.loss_scaler'
,
None
)
sys
.
modules
.
pop
(
'megatron.fp16.loss_scaler'
,
None
)
except
BaseException
:
print_rank_last
(
'could not load the checkpoint'
)
sys
.
exit
()
return
state_dict
state_dict_rank0
=
load_state_dict
(
checkpoint_name_rank0
)
state_dict_local
=
load_state_dict
(
checkpoint_name_local
)
state_dict
=
merge_state_dict
(
state_dict_rank0
,
state_dict_local
)
# set checkpoint version
set_checkpoint_version
(
state_dict
.
get
(
'checkpoint_version'
,
0
))
# Set iteration.
if
args
.
finetune
or
release
:
iteration
=
0
else
:
try
:
iteration
=
state_dict
[
'iteration'
]
except
KeyError
:
try
:
# Backward compatible with older checkpoints
iteration
=
state_dict
[
'total_iters'
]
except
KeyError
:
print_rank_last
(
'A metadata file exists but unable to load '
'iteration from checkpoint {}, exiting'
.
format
(
checkpoint_name_local
))
sys
.
exit
()
# Check arguments.
assert
args
.
consumed_train_samples
==
0
assert
args
.
consumed_valid_samples
==
0
if
'args'
in
state_dict
:
checkpoint_args
=
state_dict
[
'args'
]
check_checkpoint_args
(
checkpoint_args
)
args
.
consumed_train_samples
=
getattr
(
checkpoint_args
,
'consumed_train_samples'
,
0
)
update_num_microbatches
(
consumed_samples
=
args
.
consumed_train_samples
)
args
.
consumed_valid_samples
=
getattr
(
checkpoint_args
,
'consumed_valid_samples'
,
0
)
else
:
print_rank_last
(
'could not find arguments in the checkpoint ...'
)
# Model.
model
.
load_state_dict
(
state_dict
[
'model'
])
# Optimizer.
if
not
release
and
not
args
.
finetune
and
not
args
.
no_load_optim
:
try
:
if
optimizer
is
not
None
:
optimizer
.
load_state_dict
(
state_dict
[
'optimizer'
])
if
lr_scheduler
is
not
None
:
lr_scheduler
.
load_state_dict
(
state_dict
[
'lr_scheduler'
])
except
KeyError
:
print_rank_last
(
'Unable to load optimizer from checkpoint {}. '
'Specify --no-load-optim or --finetune to prevent '
'attempting to load the optimizer state, '
'exiting ...'
.
format
(
checkpoint_name_local
))
sys
.
exit
()
# rng states.
if
not
release
and
not
args
.
finetune
and
not
args
.
no_load_rng
:
try
:
random
.
setstate
(
state_dict
[
'random_rng_state'
])
np
.
random
.
set_state
(
state_dict
[
'np_rng_state'
])
torch
.
set_rng_state
(
state_dict
[
'torch_rng_state'
])
torch
.
cuda
.
set_rng_state
(
state_dict
[
'cuda_rng_state'
])
mpu
.
get_cuda_rng_tracker
().
set_states
(
state_dict
[
'rng_tracker_states'
])
except
KeyError
:
print_rank_last
(
'Unable to load optimizer from checkpoint {}. '
'Specify --no-load-rng or --finetune to prevent '
'attempting to load the optimizer state, '
'exiting ...'
.
format
(
checkpoint_name_local
))
sys
.
exit
()
torch
.
distributed
.
barrier
()
print_rank_last
(
' successfully loaded checkpoint (with expert parametes updated) from {} at iteration {}'
.
format
(
args
.
load
,
iteration
))
return
iteration
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment