Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wuxk1
Megatron-LM
Commits
86e7d624
Commit
86e7d624
authored
Mar 28, 2020
by
Mohammad
Browse files
refactored checkpooints but not tested yet
parent
11220df8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
282 additions
and
44 deletions
+282
-44
megatron/checkpointing.py
megatron/checkpointing.py
+245
-0
megatron/utils.py
megatron/utils.py
+37
-44
No files found.
megatron/checkpointing.py
0 → 100644
View file @
86e7d624
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Input/output checkpointing."""
import
os
import
random
import
sys
import
numpy
as
np
import
torch
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron
import
mpu
from
.global_vars
import
get_args
from
.utils
import
print_rank_0
def
check_checkpoint_args
(
checkpoint_args
):
"""Ensure fixed arguments for a model are the same for the input
arguments and the one retreived frm checkpoint."""
args
=
get_args
()
def
_compare
(
arg_name
):
checkpoint_value
=
getattr
(
checkpoint_args
,
arg_name
)
args_value
=
getattr
(
args
,
arg_name
)
error_message
=
'{} value from checkpoint ({}) is not equal to the '
\
'input argument value ({}).'
.
format
(
arg_name
,
checkpoint_value
,
args_value
)
assert
checkpoint_value
==
args_value
,
error_message
_compare
(
'num_layers'
)
_compare
(
'hidden_size'
)
_compare
(
'num_attention_heads'
)
_compare
(
'max_position_embeddings'
)
_compare
(
'make_vocab_size_divisible_by'
)
_compare
(
'padded_vocab_size'
)
_compare
(
'tokenizer_type'
)
_compare
(
'model_parallel_size'
)
def
ensure_directory_exists
(
filename
):
"""Build filename's path if it does not already exists."""
dirname
=
os
.
path
.
dirname
(
filename
)
if
not
os
.
path
.
exists
(
dirname
):
os
.
makedirs
(
dirname
)
def
get_checkpoint_name
(
checkpoints_path
,
iteration
,
release
=
False
,
mp_rank
=
None
):
"""A unified checkpoint name."""
if
release
:
directory
=
'release'
else
:
directory
=
'iter_{:07d}'
.
format
(
iteration
)
return
os
.
path
.
join
(
checkpoints_path
,
directory
,
'mp_rank_{:02d}'
.
format
(
mpu
.
get_model_parallel_rank
()
if
mp_rank
is
None
\
else
mp_rank
),
'model_optim_rng.pt'
)
def
get_checkpoint_tracker_filename
(
checkpoints_path
):
"""Tracker file rescords the latest chckpoint during
training to restart from."""
return
os
.
path
.
join
(
checkpoints_path
,
'latest_checkpointed_iteration.txt'
)
def
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
):
"""Save a model checkpoint."""
args
=
get_args
()
# Only rank zero of the data parallel writes to the disk.
if
isinstance
(
model
,
torchDDP
):
model
=
model
.
module
if
mpu
.
get_data_parallel_rank
()
==
0
:
# Arguments, iteration, and model.
state_dict
=
{}
state_dict
[
'args'
]
=
args
state_dict
[
'iteration'
]
=
iteration
state_dict
[
'model'
]
=
model
.
state_dict_for_save_checkpoint
()
# Optimizer stuff.
if
not
args
.
no_save_optim
:
if
optimizer
is
not
None
:
state_dict
[
'optimizer'
]
=
optimizer
.
state_dict
()
if
lr_scheduler
is
not
None
:
state_dict
[
'lr_scheduler'
]
=
lr_scheduler
.
state_dict
()
# RNG states.
if
not
args
.
no_save_rng
:
state_dict
[
'random_rng_state'
]
=
random
.
getstate
()
state_dict
[
'np_rng_state'
]
=
np
.
random
.
get_state
()
state_dict
[
'torch_rng_state'
]
=
torch
.
get_rng_state
()
state_dict
[
'cuda_rng_state'
]
=
torch
.
cuda
.
get_rng_state
()
state_dict
[
'rng_tracker_states'
]
\
=
mpu
.
get_cuda_rng_tracker
().
get_states
()
# Save.
checkpoint_name
=
get_checkpoint_name
(
args
.
save
,
iteration
)
print
(
'global rank {} is saving checkpoint at iteration {:7d} to {}'
.
format
(
torch
.
distributed
.
get_rank
(),
iteration
,
checkpoint_name
))
ensure_directory_exists
(
checkpoint_name
)
torch
.
save
(
state_dict
,
checkpoint_name
)
print
(
' successfully saved {}'
.
format
(
checkpoint_name
))
# Wait so everyone is done (necessary)
torch
.
distributed
.
barrier
()
# And update the latest iteration
if
torch
.
distributed
.
get_rank
()
==
0
:
tracker_filename
=
get_checkpoint_tracker_filename
(
args
.
save
)
with
open
(
tracker_filename
,
'w'
)
as
f
:
f
.
write
(
str
(
iteration
))
# Wait so everyone is done (not necessary)
torch
.
distributed
.
barrier
()
def
load_checkpoint
(
model
,
optimizer
,
lr_scheduler
):
"""Load a model checkpoint and return the iteration."""
args
=
get_args
()
if
isinstance
(
model
,
torchDDP
):
model
=
model
.
module
# Read the tracker file and set the iteration.
tracker_filename
=
get_checkpoint_tracker_filename
(
args
.
load
)
# If no tracker file, return iretation zero.
if
not
os
.
path
.
isfile
(
tracker_filename
):
print_rank_0
(
'WARNING: could not find the metadata file {} '
.
format
(
tracker_filename
))
print_rank_0
(
' will not load any checkpoints and will start from '
'random'
)
return
0
# Otherwise, read the tracker file and either set the iteration or
# mark it as a release checkpoint.
iteration
=
0
release
=
False
with
open
(
tracker_filename
,
'r'
)
as
f
:
metastring
=
f
.
read
().
strip
()
try
:
iteration
=
int
(
metastring
)
except
ValueError
:
release
=
metastring
==
'release'
if
not
release
:
print_rank_0
(
'ERROR: Invalid metadata file {}. Exiting'
.
format
(
tracker_filename
))
sys
.
exit
()
assert
iteration
>
0
or
release
,
'error parsing metadata file {}'
.
format
(
tracker_filename
)
# Checkpoint.
checkpoint_name
=
get_checkpoint_name
(
args
.
load
,
iteration
,
release
)
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
'global rank {} is loading checkpoint {}'
.
format
(
torch
.
distributed
.
get_rank
(),
checkpoint_name
))
# Load the checkpoint.
try
:
state_dict
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
except
ModuleNotFoundError
:
# For backward compatibility.
print_rank_0
(
' > deserializing using the old code structure ...'
)
sys
.
modules
[
'fp16.loss_scaler'
]
=
sys
.
modules
[
'megatron.fp16.loss_scaler'
]
state_dict
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
sys
.
modules
.
pop
(
'fp16.loss_scaler'
,
None
)
except
:
print_rank_0
(
'could not load the checkpoint'
)
sys
.
exit
()
# Set iteration.
if
args
.
finetune
or
release
:
iteration
=
0
else
:
try
:
iteration
=
state_dict
[
'iteration'
]
except
KeyError
:
try
:
# Backward compatible with older checkpoints
iteration
=
state_dict
[
'total_iters'
]
except
KeyError
:
print_rank_0
(
'A metadata file exists but unable to load '
'iteration from checkpoint {}, exiting'
.
format
(
checkpoint_name
))
sys
.
exit
()
# Check arguments.
if
'args'
in
state_dict
:
checkpoint_args
=
state_dict
[
'args'
]
check_checkpoint_args
(
checkpoint_args
)
else
:
print_rank_0
(
'could not find arguments in the checkpoint ...'
)
# Model.
model
.
load_state_dict
(
state_dict
[
'model'
])
# Optimizer.
if
not
release
and
not
args
.
finetune
and
not
args
.
no_load_optim
:
try
:
if
optimizer
is
not
None
:
optimizer
.
load_state_dict
(
state_dict
[
'optimizer'
])
if
lr_scheduler
is
not
None
:
lr_scheduler
.
load_state_dict
(
state_dict
[
'lr_scheduler'
])
except
KeyError
:
print_rank_0
(
'Unable to load optimizer from checkpoint {}. '
'Specify --no-load-optim or --finetune to prevent '
'attempting to load the optimizer state, '
'exiting ...'
.
format
(
checkpoint_name
))
sys
.
exit
()
# rng states.
if
not
release
and
not
args
.
finetune
and
not
args
.
no_load_rng
:
try
:
random
.
setstate
(
state_dict
[
'random_rng_state'
])
np
.
random
.
set_state
(
state_dict
[
'np_rng_state'
])
torch
.
set_rng_state
(
state_dict
[
'torch_rng_state'
])
torch
.
cuda
.
set_rng_state
(
state_dict
[
'cuda_rng_state'
])
mpu
.
get_cuda_rng_tracker
().
set_states
(
state_dict
[
'rng_tracker_states'
])
except
KeyError
:
print_rank_0
(
'Unable to load optimizer from checkpoint {}. '
'Specify --no-load-rng or --finetune to prevent '
'attempting to load the optimizer state, '
'exiting ...'
.
format
(
checkpoint_name
))
sys
.
exit
()
torch
.
distributed
.
barrier
()
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
' successfully loaded {}'
.
format
(
checkpoint_name
))
return
iteration
megatron/utils.py
View file @
86e7d624
...
...
@@ -22,15 +22,11 @@ import numpy as np
import
torch
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
#
from megatron.global_vars import get_args
#
from megatron.global_vars import get_adlr_autoresume
from
megatron.global_vars
import
get_args
from
megatron.global_vars
import
get_adlr_autoresume
from
megatron
import
mpu
from
megatron.fp16
import
FP16_Module
from
megatron.fp16
import
FP16_Optimizer
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
get_params_for_weight_decay_optimization
def
print_rank_0
(
message
):
...
...
@@ -52,6 +48,41 @@ def reduce_losses(losses):
return
reduced_losses
def
report_memory
(
name
):
"""Simple GPU memory report."""
mega_bytes
=
1024.0
*
1024.0
string
=
name
+
' memory (MB)'
string
+=
' | allocated: {}'
.
format
(
torch
.
cuda
.
memory_allocated
()
/
mega_bytes
)
string
+=
' | max allocated: {}'
.
format
(
torch
.
cuda
.
max_memory_allocated
()
/
mega_bytes
)
string
+=
' | cached: {}'
.
format
(
torch
.
cuda
.
memory_cached
()
/
mega_bytes
)
string
+=
' | max cached: {}'
.
format
(
torch
.
cuda
.
max_memory_cached
()
/
mega_bytes
)
print_rank_0
(
string
)
def
print_params_min_max_norm
(
optimizer
,
iteration
):
"""Print min, max, and norm of all parameters."""
index
=
0
rank
=
torch
.
distributed
.
get_rank
()
string
=
'iteration, rank, index, model-parallel,min, max, norm
\n
'
optimizer_
=
optimizer
if
isinstance
(
optimizer
,
FP16_Optimizer
):
optimizer_
=
optimizer
.
optimizer
for
param_group
in
optimizer_
.
param_groups
:
for
param
in
param_group
[
'params'
]:
index
+=
1
min_
=
param
.
data
.
min
()
max_
=
param
.
data
.
max
()
norm
=
param
.
data
.
norm
()
string
+=
'{:7d}, {:4d}, {:4d}, {:2d}, '
.
format
(
iteration
,
rank
,
index
,
int
(
param
.
model_parallel
))
string
+=
'{:.6E}, {:.6E}, {:.6E}
\n
'
.
format
(
min_
,
max_
,
norm
)
print
(
string
,
flush
=
True
)
#######################################
def
check_adlr_autoresume_termination
(
iteration
,
model
,
optimizer
,
lr_scheduler
,
args
):
# Add barrier to ensure consistnecy.
...
...
@@ -66,8 +97,6 @@ def check_adlr_autoresume_termination(iteration, model, optimizer,
exit
(
0
)
def
get_ltor_masks_and_position_ids
(
data
,
eod_token
,
reset_position_ids
,
...
...
@@ -125,42 +154,6 @@ def get_ltor_masks_and_position_ids(data,
return
attention_mask
,
loss_mask
,
position_ids
def
print_params_min_max_norm
(
optimizer
,
iteration
):
"""Print min, max, and norm of all parameters."""
index
=
0
rank
=
torch
.
distributed
.
get_rank
()
string
=
'iteration, rank, index, model-parallel,min, max, norm
\n
'
optimizer_
=
optimizer
if
isinstance
(
optimizer
,
FP16_Optimizer
):
optimizer_
=
optimizer
.
optimizer
for
param_group
in
optimizer_
.
param_groups
:
for
param
in
param_group
[
'params'
]:
index
+=
1
min_
=
param
.
data
.
min
()
max_
=
param
.
data
.
max
()
norm
=
param
.
data
.
norm
()
string
+=
'{:7d}, {:4d}, {:4d}, {:2d}, '
.
format
(
iteration
,
rank
,
index
,
int
(
param
.
model_parallel
))
string
+=
'{:.6E}, {:.6E}, {:.6E}
\n
'
.
format
(
min_
,
max_
,
norm
)
print
(
string
,
flush
=
True
)
def
report_memory
(
name
):
"""Simple GPU memory report."""
mega_bytes
=
1024.0
*
1024.0
string
=
name
+
' memory (MB)'
string
+=
' | allocated: {}'
.
format
(
torch
.
cuda
.
memory_allocated
()
/
mega_bytes
)
string
+=
' | max allocated: {}'
.
format
(
torch
.
cuda
.
max_memory_allocated
()
/
mega_bytes
)
string
+=
' | cached: {}'
.
format
(
torch
.
cuda
.
memory_cached
()
/
mega_bytes
)
string
+=
' | max cached: {}'
.
format
(
torch
.
cuda
.
max_memory_cached
()
/
mega_bytes
)
print_rank_0
(
string
)
def
vocab_size_with_padding
(
num_tokens
,
args
):
after
=
num_tokens
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment