Commit 651cc107 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

Merging megatron with ICT

parents e919dd8e c601d751
image: gitlab-master.nvidia.com/dl/dgx/pytorch:20.12-py3-devel
test:
script:
- pytest --junitxml=report.xml tests
artifacts:
when: always
reports:
junit: report.xml
\ No newline at end of file
...@@ -29,7 +29,7 @@ The following applies to all files unless otherwise noted: ...@@ -29,7 +29,7 @@ The following applies to all files unless otherwise noted:
-- --
This repository also contains code from Hugging Face Inc., Google Research, This repository also contains code from Hugging Face Inc., Google Research,
and Facebook (from their Fairseq project). Files from these Facebook (from their Fairseq project), and Philip Popien. Files from these
organizations have notices at the top of each file. Below are licenses organizations have notices at the top of each file. Below are licenses
used in those files, as indicated. used in those files, as indicated.
...@@ -262,3 +262,4 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER ...@@ -262,3 +262,4 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE. SOFTWARE.
This diff is collapsed.
...@@ -28,11 +28,11 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ ...@@ -28,11 +28,11 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--num-layers 24 \ --num-layers 24 \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
--batch-size 8 \ --micro-batch-size 8 \
--checkpoint-activations \ --checkpoint-activations \
--lr 5.0e-5 \ --lr 5.0e-5 \
--lr-decay-style linear \ --lr-decay-style linear \
--warmup 0.065 \ --lr-warmup-fraction 0.065 \
--seq-length 512 \ --seq-length 512 \
--max-position-embeddings 512 \ --max-position-embeddings 512 \
--save-interval 500000 \ --save-interval 500000 \
......
...@@ -28,11 +28,11 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ ...@@ -28,11 +28,11 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--num-layers 24 \ --num-layers 24 \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
--batch-size 4 \ --micro-batch-size 4 \
--checkpoint-activations \ --checkpoint-activations \
--lr 1.0e-5 \ --lr 1.0e-5 \
--lr-decay-style linear \ --lr-decay-style linear \
--warmup 0.06 \ --lr-warmup-fraction 0.06 \
--seq-length 512 \ --seq-length 512 \
--max-position-embeddings 512 \ --max-position-embeddings 512 \
--save-interval 100000 \ --save-interval 100000 \
......
...@@ -9,24 +9,24 @@ python pretrain_bert.py \ ...@@ -9,24 +9,24 @@ python pretrain_bert.py \
--num-layers 24 \ --num-layers 24 \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
--batch-size 4 \ --micro-batch-size 4 \
--global-batch-size 8 \
--seq-length 512 \ --seq-length 512 \
--max-position-embeddings 512 \ --max-position-embeddings 512 \
--train-iters 2000000 \ --train-iters 2000000 \
--lr-decay-iters 990000 \
--save $CHECKPOINT_PATH \ --save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH \ --load $CHECKPOINT_PATH \
--data-path $DATA_PATH \ --data-path $DATA_PATH \
--vocab-file bert-vocab.txt \ --vocab-file bert-vocab.txt \
--data-impl mmap \ --data-impl mmap \
--split 949,50,1 \ --split 949,50,1 \
--distributed-backend nccl \
--lr 0.0001 \ --lr 0.0001 \
--min-lr 0.00001 \ --min-lr 0.00001 \
--lr-decay-style linear \ --lr-decay-style linear \
--lr-decay-iters 990000 \ --lr-warmup-fraction .01 \
--weight-decay 1e-2 \ --weight-decay 1e-2 \
--clip-grad 1.0 \ --clip-grad 1.0 \
--warmup .01 \
--log-interval 100 \ --log-interval 100 \
--save-interval 10000 \ --save-interval 10000 \
--eval-interval 1000 \ --eval-interval 1000 \
......
...@@ -15,11 +15,11 @@ DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $ ...@@ -15,11 +15,11 @@ DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $
python -m torch.distributed.launch $DISTRIBUTED_ARGS \ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
pretrain_bert.py \ pretrain_bert.py \
--tensor-model-parallel-size 1 \
--num-layers 24 \ --num-layers 24 \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
--batch-size 4 \ --micro-batch-size 4 \
--global-batch-size 32 \
--seq-length 512 \ --seq-length 512 \
--max-position-embeddings 512 \ --max-position-embeddings 512 \
--train-iters 1000000 \ --train-iters 1000000 \
...@@ -36,7 +36,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \ ...@@ -36,7 +36,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
--lr-decay-iters 990000 \ --lr-decay-iters 990000 \
--weight-decay 1e-2 \ --weight-decay 1e-2 \
--clip-grad 1.0 \ --clip-grad 1.0 \
--warmup .01 \ --lr-warmup-fraction .01 \
--log-interval 100 \ --log-interval 100 \
--save-interval 10000 \ --save-interval 10000 \
--eval-interval 1000 \ --eval-interval 1000 \
......
...@@ -20,8 +20,8 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \ ...@@ -20,8 +20,8 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
--num-layers 24 \ --num-layers 24 \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
--batch-size 2 \ --micro-batch-size 2 \
--num-microbatches-in-minibatch 2 \ --global-batch-size 16 \
--seq-length 512 \ --seq-length 512 \
--max-position-embeddings 512 \ --max-position-embeddings 512 \
--train-iters 1000000 \ --train-iters 1000000 \
...@@ -38,7 +38,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \ ...@@ -38,7 +38,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
--lr-decay-iters 990000 \ --lr-decay-iters 990000 \
--weight-decay 1e-2 \ --weight-decay 1e-2 \
--clip-grad 1.0 \ --clip-grad 1.0 \
--warmup .01 \ --lr-warmup-fraction .01 \
--log-interval 100 \ --log-interval 100 \
--save-interval 10000 \ --save-interval 10000 \
--eval-interval 1000 \ --eval-interval 1000 \
......
...@@ -9,11 +9,12 @@ DATA_PATH=<Specify path and file prefix>_text_document ...@@ -9,11 +9,12 @@ DATA_PATH=<Specify path and file prefix>_text_document
CHECKPOINT_PATH=<Specify path> CHECKPOINT_PATH=<Specify path>
python pretrain_gpt2.py \ python pretrain_gpt.py \
--num-layers 24 \ --num-layers 24 \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
--batch-size 8 \ --micro-batch-size 4 \
--global-batch-size 8 \
--seq-length 1024 \ --seq-length 1024 \
--max-position-embeddings 1024 \ --max-position-embeddings 1024 \
--train-iters 500000 \ --train-iters 500000 \
...@@ -31,7 +32,7 @@ python pretrain_gpt2.py \ ...@@ -31,7 +32,7 @@ python pretrain_gpt2.py \
--lr-decay-style cosine \ --lr-decay-style cosine \
--weight-decay 1e-2 \ --weight-decay 1e-2 \
--clip-grad 1.0 \ --clip-grad 1.0 \
--warmup .01 \ --lr-warmup-fraction .01 \
--checkpoint-activations \ --checkpoint-activations \
--log-interval 100 \ --log-interval 100 \
--save-interval 10000 \ --save-interval 10000 \
......
#!/bin/bash
#SBATCH <SLURM OPTIONS> --nodes=128 --exclusive --ntasks-per-node=8 --job-name=megatron_gpt3_175b
DIR=`pwd`
DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'`
mkdir -p $DIR/logs
DATASET_1="<PATH TO THE FIRST DATASET>"
DATASET_2="<PATH TO THE SECOND DATASET>"
DATASET_3="<PATH TO THE THIRD DATASET>"
DATASET="0.2 ${DATASET_1} 0.3 ${DATASET_2} 0.5 ${DATASET_3}"
options=" \
--tensor-model-parallel-size 8 \
--pipeline-model-parallel-size 16 \
--num-layers 96 \
--hidden-size 12288 \
--num-attention-heads 96 \
--seq-length 2048 \
--max-position-embeddings 2048 \
--micro-batch-size 1 \
--global-batch-size 1536 \
--rampup-batch-size 16 16 5859375 \
--train-samples 146484375 \
--lr-decay-samples 126953125 \
--lr-warmup-samples 183105 \
--lr 6.0e-5 \
--min-lr 6.0e-6 \
--lr-decay-style cosine \
--log-interval 10 \
--eval-iters 40 \
--eval-interval 1000 \
--data-path ${DATASET} \
--vocab-file <PATH TO gpt-vocab.json> \
--merge-file <PATH TO gpt-merges.txt> \
--save-interval 1000 \
--save <PATH TO CHECKPOINTS DIRECTORY> \
--load <PATH TO CHECKPOINTS DIRECTORY> \
--split 98,2,0 \
--clip-grad 1.0 \
--weight-decay 0.1 \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--init-method-std 0.006 \
--tensorboard-dir <TENSORBOARD DIRECTORY> \
--fp16 \
--checkpoint-activations "
run_cmd="python -u ${DIR}/pretrain_gpt.py $@ ${options}"
srun -l \
--container-image "nvcr.io/nvidia/pytorch:20.12-py3" \
--container-mounts "<DIRECTORIES TO MOUNT>" \
--output=$DIR/logs/%x_%j_$DATETIME.log sh -c "${run_cmd}"
set +x
...@@ -16,12 +16,12 @@ CHECKPOINT_PATH=<Specify path> ...@@ -16,12 +16,12 @@ CHECKPOINT_PATH=<Specify path>
DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
python -m torch.distributed.launch $DISTRIBUTED_ARGS \ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
pretrain_gpt2.py \ pretrain_gpt.py \
--tensor-model-parallel-size 1 \
--num-layers 24 \ --num-layers 24 \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
--batch-size 8 \ --micro-batch-size 8 \
--global-batch-size 64 \
--seq-length 1024 \ --seq-length 1024 \
--max-position-embeddings 1024 \ --max-position-embeddings 1024 \
--train-iters 500000 \ --train-iters 500000 \
...@@ -39,7 +39,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \ ...@@ -39,7 +39,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
--min-lr 1.0e-5 \ --min-lr 1.0e-5 \
--weight-decay 1e-2 \ --weight-decay 1e-2 \
--clip-grad 1.0 \ --clip-grad 1.0 \
--warmup .01 \ --lr-warmup-fraction .01 \
--checkpoint-activations \ --checkpoint-activations \
--log-interval 100 \ --log-interval 100 \
--save-interval 10000 \ --save-interval 10000 \
......
...@@ -16,14 +16,14 @@ CHECKPOINT_PATH=<Specify path> ...@@ -16,14 +16,14 @@ CHECKPOINT_PATH=<Specify path>
DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
python -m torch.distributed.launch $DISTRIBUTED_ARGS \ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
pretrain_gpt2.py \ pretrain_gpt.py \
--tensor-model-parallel-size 2 \ --tensor-model-parallel-size 2 \
--pipeline-model-parallel-size 2 \ --pipeline-model-parallel-size 2 \
--num-layers 24 \ --num-layers 24 \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
--batch-size 4 \ --micro-batch-size 4 \
--num-microbatches-in-minibatch 2 \ --global-batch-size 16 \
--seq-length 1024 \ --seq-length 1024 \
--max-position-embeddings 1024 \ --max-position-embeddings 1024 \
--train-iters 500000 \ --train-iters 500000 \
...@@ -41,7 +41,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \ ...@@ -41,7 +41,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
--min-lr 1.0e-5 \ --min-lr 1.0e-5 \
--weight-decay 1e-2 \ --weight-decay 1e-2 \
--clip-grad 1.0 \ --clip-grad 1.0 \
--warmup .01 \ --lr-warmup-fraction .01 \
--checkpoint-activations \ --checkpoint-activations \
--log-interval 100 \ --log-interval 100 \
--save-interval 10000 \ --save-interval 10000 \
......
This diff is collapsed.
...@@ -23,16 +23,18 @@ import numpy as np ...@@ -23,16 +23,18 @@ import numpy as np
import torch import torch
from torch.nn.parallel import DistributedDataParallel as torchDDP from torch.nn.parallel import DistributedDataParallel as torchDDP
from megatron import mpu, get_args, update_num_microbatches from megatron import (get_args,
from megatron import get_args mpu,
from megatron import print_rank_0 print_rank_0,
update_num_microbatches)
_CHECKPOINT_VERSION = None _CHECKPOINT_VERSION = None
def set_checkpoint_version(value): def set_checkpoint_version(value):
global _CHECKPOINT_VERSION global _CHECKPOINT_VERSION
assert _CHECKPOINT_VERSION is None, \ if _CHECKPOINT_VERSION is not None:
"checkpoint version already set" assert _CHECKPOINT_VERSION == value, \
"checkpoint versions do not match"
_CHECKPOINT_VERSION = value _CHECKPOINT_VERSION = value
def get_checkpoint_version(): def get_checkpoint_version():
...@@ -59,9 +61,10 @@ def check_checkpoint_args(checkpoint_args): ...@@ -59,9 +61,10 @@ def check_checkpoint_args(checkpoint_args):
_compare('hidden_size') _compare('hidden_size')
_compare('num_attention_heads') _compare('num_attention_heads')
_compare('max_position_embeddings') _compare('max_position_embeddings')
_compare('make_vocab_size_divisible_by') if args.vocab_file:
_compare('padded_vocab_size') _compare('make_vocab_size_divisible_by')
_compare('tokenizer_type') _compare('padded_vocab_size')
_compare('tokenizer_type')
if get_checkpoint_version() < 3.0: if get_checkpoint_version() < 3.0:
_compare('tensor_model_parallel_size', _compare('tensor_model_parallel_size',
old_arg_name='model_parallel_size') old_arg_name='model_parallel_size')
...@@ -111,11 +114,10 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -111,11 +114,10 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
if isinstance(model, torchDDP): if isinstance(model, torchDDP):
model = model.module model = model.module
if torch.distributed.get_rank() == 0: print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
print('saving checkpoint at iteration {:7d} to {}'.format( iteration, args.save))
iteration, args.save), flush=True)
if mpu.get_data_parallel_rank() == 0: if not torch.distributed.is_initialized() or mpu.get_data_parallel_rank() == 0:
# Arguments, iteration, and model. # Arguments, iteration, and model.
state_dict = {} state_dict = {}
...@@ -146,21 +148,93 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -146,21 +148,93 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
torch.save(state_dict, checkpoint_name) torch.save(state_dict, checkpoint_name)
# Wait so everyone is done (necessary) # Wait so everyone is done (necessary)
torch.distributed.barrier() if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0: torch.distributed.barrier()
print(' successfully saved checkpoint at iteration {:7d} to {}'.format(
iteration, args.save), flush=True) print_rank_0(' successfully saved checkpoint at iteration {:7d} to {}'.format(
iteration, args.save))
# And update the latest iteration # And update the latest iteration
if torch.distributed.get_rank() == 0: if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
tracker_filename = get_checkpoint_tracker_filename(args.save) tracker_filename = get_checkpoint_tracker_filename(args.save)
with open(tracker_filename, 'w') as f: with open(tracker_filename, 'w') as f:
f.write(str(iteration)) f.write(str(iteration))
# Wait so everyone is done (not necessary)
torch.distributed.barrier()
def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): # Wait so everyone is done (not necessary)
"""Load a model checkpoint and return the iteration.""" if torch.distributed.is_initialized():
torch.distributed.barrier()
def _transpose_first_dim(t, num_splits, num_splits_first, model):
input_shape = t.size()
# We use a self_attention module but the values extracted aren't
# specific to self attention so should work for cross attention as well
while hasattr(model, 'module'):
model = model.module
attention_module = model.language_model.encoder.layers[0].self_attention
hidden_size_per_attention_head = attention_module.hidden_size_per_attention_head
num_attention_heads_per_partition = attention_module.num_attention_heads_per_partition
if num_splits_first:
"""[num_splits * np * hn, h]
-->(view) [num_splits, np, hn, h]
-->(tranpose) [np, num_splits, hn, h]
-->(view) [np * num_splits * hn, h] """
intermediate_shape = \
(num_splits, num_attention_heads_per_partition,
hidden_size_per_attention_head) + input_shape[1:]
t = t.view(*intermediate_shape)
t = t.transpose(0, 1).contiguous()
else:
"""[np * hn * num_splits, h]
-->(view) [np, hn, num_splits, h]
-->(tranpose) [np, num_splits, hn, h]
-->(view) [np * num_splits * hn, h] """
intermediate_shape = \
(num_attention_heads_per_partition,
hidden_size_per_attention_head, num_splits) +\
input_shape[1:]
t = t.view(*intermediate_shape)
t = t.transpose(1, 2).contiguous()
t = t.view(*input_shape)
return t
def fix_query_key_value_ordering(model, checkpoint_version):
"""Fix up query/key/value matrix ordering if checkpoint
version is smaller than 2.0
"""
if checkpoint_version < 2.0:
for name, param in model.named_parameters():
if name.endswith(('.query_key_value.weight', '.query_key_value.bias')):
if checkpoint_version == 0:
fixed_param = _transpose_first_dim(param.data, 3, True, model)
elif checkpoint_version == 1.0:
fixed_param = _transpose_first_dim(param.data, 3, False, model)
else:
print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
sys.exit()
param.data.copy_(fixed_param)
if name.endswith(('.key_value.weight', '.key_value.bias')):
if checkpoint_version == 0:
fixed_param = _transpose_first_dim(param.data, 2, True, model)
elif checkpoint_version == 1.0:
fixed_param = _transpose_first_dim(param.data, 2, False, model)
else:
print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
sys.exit()
param.data.copy_(fixed_param)
print_rank_0(" succesfully fixed query-key-values ordering for"
" checkpoint version {}".format(checkpoint_version))
def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True):
"""Load a model checkpoint and return the iteration.
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` of the checkpoint match the names of
parameters and buffers in model.
"""
args = get_args() args = get_args()
load_dir = getattr(args, load_arg) load_dir = getattr(args, load_arg)
...@@ -197,20 +271,22 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): ...@@ -197,20 +271,22 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
# Checkpoint. # Checkpoint.
checkpoint_name = get_checkpoint_name(load_dir, iteration, release) checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
if torch.distributed.get_rank() == 0: print_rank_0(f' loading checkpoint from {args.load} at iteration {iteration}')
print(' loading checkpoint from {} at iteration {}'.format(
args.load, iteration), flush=True)
# Load the checkpoint. # Load the checkpoint.
try: try:
state_dict = torch.load(checkpoint_name, map_location='cpu') state_dict = torch.load(checkpoint_name, map_location='cpu')
except ModuleNotFoundError: except ModuleNotFoundError:
from megatron.fp16_deprecated import loss_scaler
# For backward compatibility. # For backward compatibility.
print_rank_0(' > deserializing using the old code structure ...') print_rank_0(' > deserializing using the old code structure ...')
sys.modules['fp16.loss_scaler'] = sys.modules[ sys.modules['fp16.loss_scaler'] = sys.modules[
'megatron.fp16.loss_scaler'] 'megatron.fp16_deprecated.loss_scaler']
sys.modules['megatron.fp16.loss_scaler'] = sys.modules[
'megatron.fp16_deprecated.loss_scaler']
state_dict = torch.load(checkpoint_name, map_location='cpu') state_dict = torch.load(checkpoint_name, map_location='cpu')
sys.modules.pop('fp16.loss_scaler', None) sys.modules.pop('fp16.loss_scaler', None)
sys.modules.pop('megatron.fp16.loss_scaler', None)
except BaseException: except BaseException:
print_rank_0('could not load the checkpoint') print_rank_0('could not load the checkpoint')
sys.exit() sys.exit()
...@@ -248,7 +324,12 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): ...@@ -248,7 +324,12 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
print_rank_0('could not find arguments in the checkpoint ...') print_rank_0('could not find arguments in the checkpoint ...')
# Model. # Model.
model.load_state_dict(state_dict['model']) model.load_state_dict(state_dict['model'], strict=strict)
# Fix up query/key/value matrix ordering if needed
checkpoint_version = get_checkpoint_version()
print_rank_0(f' checkpoint version {checkpoint_version}')
fix_query_key_value_ordering(model, checkpoint_version)
# Optimizer. # Optimizer.
if not release and not args.finetune and not args.no_load_optim: if not release and not args.finetune and not args.no_load_optim:
...@@ -280,10 +361,12 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): ...@@ -280,10 +361,12 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
'exiting ...'.format(checkpoint_name)) 'exiting ...'.format(checkpoint_name))
sys.exit() sys.exit()
torch.distributed.barrier() # Some utilities want to load a checkpoint without distributed being initialized
if torch.distributed.get_rank() == 0: if torch.distributed.is_initialized():
print(' successfully loaded checkpoint from {} at iteration {}'.format( torch.distributed.barrier()
args.load, iteration), flush=True)
print_rank_0(f' successfully loaded checkpoint from {args.load} '
f'at iteration {iteration}')
return iteration return iteration
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment