Commit abe36e2e authored by Raul Puri's avatar Raul Puri
Browse files

large update including model parallelism and gpt2


Co-authored-by: default avatarshoeybi <shoeybim@gmail.com>
Co-authored-by: default avatarraulpuric <raulpuric@berkeley.edu>
Co-authored-by: default avatarjaredcasper <jaredcasper@gmail.com>
Co-authored-by: default avatarmpatwary <mostofa.patwary@gmail.com>
Co-authored-by: default avatarplegresl <plegresl@gmail.com>
parent 0399d32c
#!/bin/bash
WORLD_SIZE=8
GPUS_PER_NODE=8
# Change for multinode config
MASTER_ADDR=localhost
MASTER_PORT=6000
NNODES=1
NODE_RANK=0
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
python -m torch.distributed.launch $DISTRIBUTED_ARGS \
pretrain_bert.py \
--batch-size 4 \
--tokenizer-type BertWordPieceTokenizer \
--cache-dir cache_dir \
--tokenizer-model-type bert-large-uncased \
--vocab-size 30522 \
--train-data wikipedia \
--presplit-sentences \
--loose-json \
--text-key text \
--split 1000,1,1 \
--lazy-loader \
--max-preds-per-seq 80 \
--seq-length 512 \
--max-position-embeddings 512 \
--num-layers 24 \
--hidden-size 1024 \
--intermediate-size 4096 \
--num-attention-heads 16 \
--hidden-dropout 0.1 \
--attention-dropout 0.1 \
--train-iters 1000000 \
--lr 0.0001 \
--lr-decay-style linear \
--lr-decay-iters 990000 \
--warmup .01 \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--fp16 \
--fp32-layernorm \
--fp32-embedding \
--hysteresis 2 \
--num-workers 2
pretrain_bert.py \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--batch-size 4 \
--seq-length 512 \
--max-preds-per-seq 80 \
--max-position-embeddings 512 \
--train-iters 1000000 \
--save checkpoints/bert_345m \
--load checkpoints/bert_345m \
--resume-dataloader \
--train-data wikipedia \
--lazy-loader \
--tokenizer-type BertWordPieceTokenizer \
--tokenizer-model-type bert-large-uncased \
--presplit-sentences \
--cache-dir cache \
--split 949,50,1 \
--distributed-backend nccl \
--lr 0.0001 \
--lr-decay-style linear \
--lr-decay-iters 990000 \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--warmup .01 \
--fp16 \
--fp32-layernorm \
--fp32-embedding
#!/bin/bash
GPUS_PER_NODE=8
# Change for multinode config
MASTER_ADDR=localhost
MASTER_PORT=6000
NNODES=1
NODE_RANK=0
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
python -m torch.distributed.launch $DISTRIBUTED_ARGS \
pretrain_bert.py \
--model-parallel-size 2 \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--batch-size 4 \
--seq-length 512 \
--max-preds-per-seq 80 \
--max-position-embeddings 512 \
--train-iters 1000000 \
--save checkpoints/bert_345m_mp2 \
--load checkpoints/bert_345m_mp2 \
--resume-dataloader \
--train-data wikipedia \
--lazy-loader \
--tokenizer-type BertWordPieceTokenizer \
--tokenizer-model-type bert-large-uncased \
--presplit-sentences \
--cache-dir cache \
--split 949,50,1 \
--distributed-backend nccl \
--lr 0.0001 \
--lr-decay-style linear \
--lr-decay-iters 990000 \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--warmup .01 \
--fp16 \
--fp32-layernorm \
--fp32-embedding
......@@ -4,35 +4,32 @@ RANK=0
WORLD_SIZE=1
python pretrain_bert.py \
--batch-size 4 \
--tokenizer-type SentencePieceTokenizer \
--tokenizer-model-type bpe \
--tokenizer-path tokenizer.model \
--vocab-size 30522 \
--train-data wikipedia \
--presplit-sentences \
--loose-json \
--text-key text \
--split 1000,1,1 \
--lazy-loader \
--max-preds-per-seq 80 \
--seq-length 512 \
--max-position-embeddings 512 \
--num-layers 24 \
--hidden-size 1024 \
--intermediate-size 4096 \
--num-attention-heads 16 \
--hidden-dropout 0.1 \
--attention-dropout 0.1 \
--train-iters 1000000 \
--lr 0.0001 \
--lr-decay-style linear \
--lr-decay-iters 990000 \
--warmup .01 \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--fp16 \
--fp32-layernorm \
--fp32-embedding \
--hysteresis 2 \
--num-workers 2
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--batch-size 4 \
--seq-length 512 \
--max-preds-per-seq 80 \
--max-position-embeddings 512 \
--train-iters 1000000 \
--save checkpoints/bert_345m \
--load checkpoints/bert_345m \
--resume-dataloader \
--train-data wikipedia \
--lazy-loader \
--tokenizer-type SentencePieceTokenizer \
--tokenizer-model-type bpe \
--tokenizer-path tokenizer.model \
--presplit-sentences \
--cache-dir cache \
--split 949,50,1 \
--distributed-backend nccl \
--lr 0.0001 \
--lr-decay-style linear \
--lr-decay-iters 990000 \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--warmup .01 \
--fp16 \
--fp32-layernorm \
--fp32-embedding
#!/bin/bash
WORLD_SIZE=8
GPUS_PER_NODE=8
# Change for multinode config
MASTER_ADDR=localhost
MASTER_PORT=6000
NNODES=1
NODE_RANK=0
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
python -m torch.distributed.launch $DISTRIBUTED_ARGS \
pretrain_bert.py \
--batch-size 4 \
--tokenizer-type BertWordPieceTokenizer \
--cache-dir cache_dir \
--tokenizer-model-type bert-large-uncased \
--vocab-size 30522 \
--use-tfrecords \
--train-data <TFRecord 1> <TFRecord 2> \
--valid-data <TF Record 3> \
--test-data <TF Record 4> \
--max-preds-per-seq 80 \
--seq-length 512 \
--max-position-embeddings 512 \
--num-layers 24 \
--hidden-size 1024 \
--intermediate-size 4096 \
--num-attention-heads 16 \
--hidden-dropout 0.1 \
--attention-dropout 0.1 \
--train-iters 1000000 \
--lr 0.0001 \
--lr-decay-style linear \
--lr-decay-iters 990000 \
--warmup .01 \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--fp16 \
--fp32-layernorm \
--fp32-embedding \
--hysteresis 2 \
--num-workers 2
pretrain_bert.py \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--batch-size 4 \
--seq-length 512 \
--max-preds-per-seq 80 \
--max-position-embeddings 512 \
--train-iters 1000000 \
--save checkpoints/bert_345m \
--load checkpoints/bert_345m \
--resume-dataloader \
--use-tfrecords \
--train-data <TF Record 1> <TFRecord 2> \
--valid-data <TF Record 3> \
--test-data <TF Record 4> \
--tokenizer-type BertWordPieceTokenizer \
--tokenizer-model-type bert-large-uncased \
--presplit-sentences \
--cache-dir cache \
--split 949,50,1 \
--distributed-backend nccl \
--lr 0.0001 \
--lr-decay-style linear \
--lr-decay-iters 990000 \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--warmup .01 \
--fp16 \
--fp32-layernorm \
--fp32-embedding
#! /bin/bash
# Runs the "345M" parameter model
RANK=0
WORLD_SIZE=1
python pretrain_gpt2.py \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--batch-size 8 \
--seq-length 1024 \
--max-position-embeddings 1024 \
--train-iters 320000 \
--save checkpoints/gpt2_345m \
--load checkpoints/gpt2_345m \
--resume-dataloader \
--train-data wikipedia \
--lazy-loader \
--tokenizer-type GPT2BPETokenizer \
--cache-dir cache \
--split 949,50,1 \
--distributed-backend nccl \
--lr 0.00015 \
--lr-decay-style cosine \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--warmup .01 \
--checkpoint-activations \
--fp16
set +x
#! /bin/bash
# Runs the "345M" parameter model
GPUS_PER_NODE=8
# Change for multinode config
MASTER_ADDR=localhost
MASTER_PORT=6000
NNODES=1
NODE_RANK=0
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
python -m torch.distributed.launch $DISTRIBUTED_ARGS \
pretrain_gpt2.py \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--batch-size 8 \
--seq-length 1024 \
--max-position-embeddings 1024 \
--train-iters 320000 \
--save checkpoints/gpt2_345m \
--load checkpoints/gpt2_345m \
--resume-dataloader \
--train-data wikipedia \
--lazy-loader \
--tokenizer-type GPT2BPETokenizer \
--cache-dir cache \
--split 949,50,1 \
--distributed-backend nccl \
--lr 0.00015 \
--lr-decay-style cosine \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--warmup .01 \
--checkpoint-activations \
--fp16
set +x
#! /bin/bash
# Runs the "345M" parameter model
GPUS_PER_NODE=8
# Change for multinode config
MASTER_ADDR=localhost
MASTER_PORT=6000
NNODES=1
NODE_RANK=0
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
python -m torch.distributed.launch $DISTRIBUTED_ARGS \
pretrain_gpt2.py \
--model-parallel-size 2 \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--batch-size 8 \
--seq-length 1024 \
--max-position-embeddings 1024 \
--train-iters 320000 \
--save checkpoints/gpt2_345m_mp2 \
--load checkpoints/gpt2_345m_mp2 \
--resume-dataloader \
--train-data wikipedia \
--lazy-loader \
--tokenizer-type GPT2BPETokenizer \
--cache-dir cache \
--split 949,50,1 \
--distributed-backend nccl \
--lr 0.00015 \
--lr-decay-style cosine \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--warmup .01 \
--checkpoint-activations \
--fp16
set +x
"""
example usage:
python scripts/run_gpt2_eval.py \
--model-parallel-size 1 \
--num-layers 12 \
--hidden-size 768 \
--num-attention-heads 12 \
--model-path <gpt2_117_path> \
--data-path <wikitext_tokens_test_path> \
--batch-size 16 \
--cache-dir <cache dir path>
"""
import argparse
import subprocess
parser = argparse.ArgumentParser('run zero shot GPT2 eval')
parser.add_argument('--model-path', type=str, required=True,
help='Saved model path for evaluation')
parser.add_argument('--batch-size', type=int, default=4,
help='batch size to use for evaluation')
parser.add_argument('--num-attention-heads', type=int, default=12,
help='num of transformer attention heads')
parser.add_argument('--hidden-size', type=int, default=768,
help='tansformer hidden size')
parser.add_argument('--num-layers', type=int, default=12,
help='num decoder layers')
parser.add_argument('--data-path', type=str, required=True,
help='Data path for evaluation data')
parser.add_argument('--cloze-eval', action='store_true',
help='Run lambada cloze eval instead of perplexity eval.')
parser.add_argument('--webtext-eval', action='store_true',
help='Run webtext PPL eval instead of wikitext PPL eval.')
parser.add_argument('--eval-iters', default=5000, type=int,
help='number of iterations to run webtext evaluation')
parser.add_argument('--model-parallel-size', type=int, default=1,
help='model parallel size to use')
parser.add_argument('--load-openai', action='store_true',
help='Load weights from saved openai/hf checkpoints')
parser.add_argument('--cache-dir', type=str, default='cache',
help='directory to cache gpt2 tokenizers')
args = parser.parse_args()
multinode_args = ''
if args.model_parallel_size > 1:
multinode_args += ' -m torch.distributed.launch --nproc_per_node {} '.format(args.model_parallel_size)
CMD = ' --model-parallel-size {model_par} \
--num-layers {nlayers} \
--hidden-size {hidden} \
--log-interval 100 \
--load {model} \
--eval-batch-size {batch} \
--num-attention-heads {natt} \
--seq-length 1024 \
--max-position-embeddings 1024 \
--tokenizer-type GPT2BPETokenizer \
--text-key text \
--distributed-backend nccl \
--hidden-dropout 0.1 \
--attention-dropout 0.1 \
--fp16 \
--overlapping-eval 32 \
--cache-dir {cache} '.format(model_par=args.model_parallel_size,
nlayers=args.num_layers,
hidden=args.hidden_size,
model=args.model_path,
batch=args.batch_size,
natt=args.num_attention_heads,
cache=args.cache_dir)
if args.load_openai:
CMD += ' --load-openai '
if args.cloze_eval:
CMD += ' --cloze-eval '
CMD = 'evaluate_gpt2.py' + CMD
print('Running Lambada Eval Command:', flush=True)
elif args.webtext_eval:
CMD += '--train-iters 0 --eval-iters {} --test-data {} --loose-json '.format(args.eval_iters, args.data_path)
CMD = 'pretrain_gpt2.py' + CMD
print('Running Webtext Eval Command:', flush=True)
else:
CMD = 'evaluate_gpt2.py' + CMD
print('Running PPL Eval Command:', flush=True)
CMD = 'python3 '+multinode_args+CMD
print(CMD, flush=True)
subprocess.call(CMD.split())
"""
Takes a corpora of files (specified by `--input_files`) with json data separated
by newlines (loose json). Splits data into train.json, val.json, test.json files
under `output_dir`.
Note: This code has the potential to override files with the names
train.json, val.json, test.json in `--output_dir`.
"""
import os
import argparse
import math
import random
parser = argparse.ArgumentParser('resplit loose json data into train/val/test')
parser.add_argument('--input_files', nargs='+', required=True,
help='whitespace separated list of input data files')
parser.add_argument('--output_dir', required=True,
help='output directory where to put files')
parser.add_argument('--test_percent', type=float, nargs='+', default=[0.05, 0],
help='percentage of available data to use for val/test dataset')
args = parser.parse_args()
def get_lines(filepath):
lines = []
with open(filepath, 'r') as f:
for i, l in enumerate(f.readlines()):
l = l.strip()
lines.append(l)
return lines
def get_splits(lines, line_counts):
all_lines = []
line_idx = []
file_mappings = []
for i, l in enumerate(lines):
all_lines.extend(l)
line_idx.extend(list(range(len(l))))
file_mappings.extend([i]*len(l))
indices = list(range(len(all_lines)))
random.shuffle(indices)
all_lines = [all_lines[idx] for idx in indices]
line_idx = [line_idx[idx] for idx in indices]
file_mappings = [file_mappings[idx] for idx in indices]
splits = []
mappings = []
start = 0
for end in line_counts:
end += start
splits.append(all_lines[start:end])
mappings.append(format_mappings(line_idx[start:end], file_mappings[start:end]))
start = end
return splits, mappings
def format_mappings(line_idx, file_mappings):
lines = []
for m, l in zip(file_mappings, line_idx):
lines.append(str(m).strip()+'\t'+str(l).strip())
return lines
def get_filepaths(filepaths, output_dir):
paths = []
train_path = 'train.json'
dev_path = 'dev.json'
test_path = 'test.json'
paths.append(os.path.join(output_dir, train_path))
paths.append(os.path.join(output_dir, dev_path))
paths.append(os.path.join(output_dir, test_path))
return paths
def write_files(lines, mappings, filepaths):
for l, m, path in zip(lines, mappings, filepaths):
write_file(l, path)
write_mapping_file(m, path)
def write_file(lines, path):
print('Writing:', path)
with open(path, 'w') as f:
for l in lines:
f.write(l+'\n')
def write_mapping_file(m, path):
path = path+'.map'
m = [get_mapping_header()]+m
write_file(m, path)
def get_mapping_header():
return 'file\tline #'
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
lines = []
for filepath in args.input_files:
_lines = get_lines(filepath)
lines.append(_lines)
#calculate number of lines to use for each
line_counts = [len(l) for l in lines]
total_lines = sum(line_counts)
dev_percent = args.test_percent[0]
dev_lines = math.ceil(dev_percent*total_lines)
test_percent = 0
if len(args.test_percent)==2:
test_percent=args.test_percent[1]
test_lines = math.ceil(test_percent*total_lines)
train_lines = total_lines-(test_lines+dev_lines)
normed_lines = [train_lines, dev_lines, test_lines]
normed_lines = [int(l) for l in normed_lines]
splits, mappings = get_splits(lines, normed_lines)
filepaths = get_filepaths(args.input_files, args.output_dir)
print('Writing output to:', filepaths)
write_files(splits, mappings, filepaths)
......@@ -21,6 +21,48 @@ import time
import numpy as np
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from fp16 import FP16_Optimizer
import mpu
import model
def print_rank_0(message):
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(message, flush=True)
else:
print(message, flush=True)
def print_args(args):
"""Print arguments."""
print('arguments:', flush=True)
for arg in vars(args):
dots = '.' * (29 - len(arg))
print(' {} {} {}'.format(arg, dots, getattr(args, arg)), flush=True)
def print_params_min_max_norm(optimizer, iteration):
"""Print min, max, and norm of all parameters."""
index = 0
rank = torch.distributed.get_rank()
string = 'iteration, rank, index, model-parallel,min, max, norm\n'
optimizer_ = optimizer
if isinstance(optimizer, FP16_Optimizer):
optimizer_ = optimizer.optimizer
for param_group in optimizer_.param_groups:
for param in param_group['params']:
index += 1
min_ = param.data.min()
max_ = param.data.max()
norm = param.data.norm()
string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format(
iteration, rank, index, int(param.model_parallel))
string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm)
print(string, flush=True)
class Timers:
"""Group of timers."""
......@@ -85,7 +127,7 @@ class Timers:
elapsed_time = self.timers[name].elapsed(
reset=reset) * 1000.0/ normalizer
string += ' | {}: {:.2f}'.format(name, elapsed_time)
print(string, flush=True)
print_rank_0(string)
def report_memory(name):
......@@ -100,81 +142,214 @@ def report_memory(name):
string += ' | cached: {}'.format(torch.cuda.memory_cached() / mega_bytes)
string += ' | max cached: {}'.format(
torch.cuda.max_memory_cached()/ mega_bytes)
print(string, flush=True)
print_rank_0(string)
def get_checkpoint_name(checkpoints_path, iteration, release=False):
if release:
d = 'release'
else:
d = 'iter_{:07d}'.format(iteration)
return os.path.join(checkpoints_path, d,
'mp_rank_{:02d}'.format(mpu.get_model_parallel_rank()),
'model_optim_rng.pt')
def ensure_directory_exists(filename):
dirname = os.path.dirname(filename)
if not os.path.exists(dirname):
os.makedirs(dirname)
def get_checkpoint_tracker_filename(checkpoints_path):
return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')
def save_checkpoint(iteration, model, optimizer,
lr_scheduler, args):
"""Save a model checkpoint."""
# Only rank zer0 of the data parallel writes to the disk.
if isinstance(model, torchDDP):
model = model.module
if mpu.get_data_parallel_rank() == 0:
checkpoint_name = get_checkpoint_name(args.save, iteration)
print('global rank {} is saving checkpoint at iteration {:7d} to {}'.
format(torch.distributed.get_rank(), iteration, checkpoint_name))
sd = {}
sd['iteration'] = iteration
sd['model'] = model.state_dict()
# Optimizer stuff.
if not args.no_save_optim:
if optimizer is not None:
sd['optimizer'] = optimizer.state_dict()
if lr_scheduler is not None:
sd['lr_scheduler'] = lr_scheduler.state_dict()
# rng states.
if not args.no_save_rng:
sd['random_rng_state'] = random.getstate()
sd['np_rng_state'] = np.random.get_state()
sd['torch_rng_state'] = torch.get_rng_state()
sd['cuda_rng_state'] = torch.cuda.get_rng_state()
sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states()
ensure_directory_exists(checkpoint_name)
torch.save(sd, checkpoint_name)
print(' successfully saved {}'.format(checkpoint_name))
# Wait so everyone is done (necessary)
torch.distributed.barrier()
# And update the latest iteration
if torch.distributed.get_rank() == 0:
tracker_filename = get_checkpoint_tracker_filename(args.save)
with open(tracker_filename, 'w') as f:
f.write(str(iteration))
# Wait so everyone is done (not necessary)
torch.distributed.barrier()
def load_checkpoint(model, optimizer, lr_scheduler, args):
"""Load a model checkpoint."""
if isinstance(model, torchDDP):
model = model.module
# Read the tracker file and set the iteration.
tracker_filename = get_checkpoint_tracker_filename(args.load)
if not os.path.isfile(tracker_filename):
print_rank_0('WARNING: could not find the metadata file {} '.format(
tracker_filename))
print_rank_0(' will not load any checkpoints and will start from '
'random')
return 0
iteration = 0
release = False
with open(tracker_filename, 'r') as f:
metastring = f.read().strip()
try:
iteration = int(metastring)
except ValueError:
release = metastring == 'release'
if not release:
print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format(
tracker_filename))
exit()
checkpoint_path = args.load
model_path = checkpoint_path
model_sd = torch.load(model_path, map_location='cpu')
total_iters = model_sd['total_iters']
epoch = model_sd['epoch']
i = model_sd['mid_epoch_iters']
model.load_state_dict(model_sd['sd'])
checkpoint_path = os.path.dirname(checkpoint_path)
if args.load_optim:
optim_path = os.path.join(checkpoint_path, 'optim.pt')
optim_sd, lr_sd = torch.load(optim_path, map_location='cpu')
optimizer.load_state_dict(optim_sd)
lr_scheduler.load_state_dict(lr_sd)
elif args.fp16:
optimizer._model_params_to_master_params()
rng_path = None
if args.load_rng:
rng_path = os.path.join(checkpoint_path, 'rng.pt')
if args.load_all_rng:
rng_path = os.path.join(checkpoint_path,
'rng.%d.pt'%(torch.distributed.get_rank()))
if rng_path is not None:
rng_state = torch.load(rng_path)
torch.cuda.set_rng_state(rng_state[0])
torch.set_rng_state(rng_state[1])
np.random.set_state(rng_state[2])
random.setstate(rng_state[3])
return epoch, i, total_iters
def save_checkpoint(model_suffix, epoch, i, model, optimizer, lr_scheduler, args):
"""Save a model checkpoint."""
assert iteration > 0 or release, 'error parsing metadata file {}'.format(
tracker_filename)
# Checkpoint.
checkpoint_name = get_checkpoint_name(args.load, iteration, release)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
# Load the checkpoint.
sd = torch.load(checkpoint_name, map_location='cpu')
model_path = os.path.join(args.save, model_suffix)
checkpoint_dir = os.path.dirname(model_path)
rng_state = (torch.cuda.get_rng_state(),
torch.get_rng_state(),
np.random.get_state(),
random.getstate())
if not (torch.distributed.is_initialized() and \
torch.distributed.get_rank() > 0):
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
total_iters = args.train_iters * (epoch-1) + i
sd = {'sd': model.state_dict()}
sd['total_iters'] = total_iters
sd['epoch'] = epoch
sd['mid_epoch_iters'] = i
torch.save(sd, model_path)
print('saved', model_path)
if args.save_optim:
optim_path = os.path.join(checkpoint_dir, 'optim.pt')
torch.save((optimizer.state_dict(),
lr_scheduler.state_dict()), optim_path)
print('saved', optim_path)
if args.save_rng:
rng_path = os.path.join(checkpoint_dir, 'rng.pt')
torch.save(rng_state, rng_path)
print('saved', rng_path)
# Iterations.
if args.finetune or release:
iteration = 0
else:
while not os.path.exists(checkpoint_dir):
time.sleep(1)
if args.save_all_rng:
rng_path = os.path.join(checkpoint_dir,
'rng.%d.pt'%(torch.distributed.get_rank()))
torch.save(rng_state, rng_path)
print('saved', rng_path)
try:
iteration = sd['iteration']
except KeyError:
try: # Backward compatible with older checkpoints
iteration = sd['total_iters']
except KeyError:
print_rank_0('A metadata file exists but Unable to load iteration '
' from checkpoint {}, exiting'.format(checkpoint_name))
exit()
# Model.
try:
model.load_state_dict(sd['model'])
except KeyError:
print_rank_0('A metadata file exists but unable to load model '
'from checkpoint {}, exiting'.format(checkpoint_name))
exit()
# Optimizer.
if not release and not args.finetune and not args.no_load_optim:
try:
if optimizer is not None:
optimizer.load_state_dict(sd['optimizer'])
if lr_scheduler is not None:
lr_scheduler.load_state_dict(sd['lr_scheduler'])
except KeyError:
print_rank_0('Unable to load optimizer from checkpoint {}, exiting. '
'Specify --no-load-optim or --finetune to prevent '
'attempting to load the optimizer '
'state.'.format(checkpoint_name))
exit()
# rng states.
if not release and not args.finetune and not args.no_load_rng:
try:
random.setstate(sd['random_rng_state'])
np.random.set_state(sd['np_rng_state'])
torch.set_rng_state(sd['torch_rng_state'])
torch.cuda.set_rng_state(sd['cuda_rng_state'])
mpu.get_cuda_rng_tracker().set_states(sd['rng_tracker_states'])
except KeyError:
print_rank_0('Unable to load optimizer from checkpoint {}, exiting. '
'Specify --no-load-optim or --finetune to prevent '
'attempting to load the optimizer '
'state.'.format(checkpoint_name))
exit()
torch.distributed.barrier()
if mpu.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name))
return iteration
def load_weights(src, dst, dst2src=False):
"""
Loads weights from src to dst via in place copy.
src is a huggingface gpt2model, while dst is one of our models.
dst2src=True loads parameters from our models into huggingface's.
^dst2src is still untested
"""
conv_layer = 'Conv1D' in str(type(src))
for n, p in src.named_parameters():
if dst2src:
data = dst._parameters[n].data
load = p.data
else:
data = p.data
load = dst._parameters[n].data
if conv_layer and 'weight' in n:
data = data.t().contiguous()
load.copy_(data)
# dst._parameters[n].data.copy_(data)
def load_mlp(our, oai, dst2src=False):
load_weights(oai.c_fc, our.dense_h_to_4h, dst2src)
load_weights(oai.c_proj, our.dense_4h_to_h, dst2src)
def load_attention(our, oai, dst2src=False):
load_weights(oai.c_attn, our.query_key_value, dst2src)
load_weights(oai.c_proj, our.dense, dst2src)
def load_transformer_layer(our, oai, dst2src=False):
load_weights(oai.ln_1, our.input_layernorm, dst2src)
load_weights(oai.ln_2, our.post_attention_layernorm, dst2src)
load_mlp(our.mlp, oai.mlp, dst2src)
load_attention(our.attention, oai.attn, dst2src)
def move_weights(our, oai, dst2src=False):
"""
Loads weights from `oai` to `our` via in place copy.
`oai` is a huggingface gpt2model, while `our` is one of our models.
dst2src=True loads parameters from our models into huggingface's.
^dst2src=True is still untested
"""
# while isinstance(our, (torchDDP, model.distributed.DistributedDataParallel, FP16_Module)):
# our=our.module
transformer_model = oai.transformer
load_weights(transformer_model.ln_f, our.transformer.final_layernorm, dst2src)
load_weights(transformer_model.wte, our.word_embeddings, dst2src)
load_weights(transformer_model.wpe, our.position_embeddings, dst2src)
for our_layer, oai_layer in zip(our.transformer.layers, oai.transformer.h):
load_transformer_layer(our_layer, oai_layer, dst2src)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment