Commit 03d09af0 authored by Jared Casper's avatar Jared Casper
Browse files

Adding checkpoint_util and associted loader and saver.

parent 89e8d27e
......@@ -404,29 +404,15 @@ python tools/create_doc_index.py \
We provide several command line arguments, detailed in the scripts listed below, to handle various zero-shot and fine-tuned downstream tasks. However, you can also finetune your model from a pretrained checkpoint on other corpora as desired. To do so, simply add the `--finetune` flag and adjust the input files and training parameters within the original training script. The iteration count will be reset to zero, and the optimizer and internal state will be reinitialized. If the fine-tuning is interrupted for any reason, be sure to remove the `--finetune` flag before continuing, otherwise the training will start again from the beginning.
Because evaluation requires substantially less memory than training, it may be advantageous to merge a model trained in parallel for use on a single GPU in downstream tasks. The following script accomplishes this. Currently only tensor model parallelism is supported on input and pipeline model parallelsim on the output. This example reads in a model with 2-way tensor model parallelism and writes out a model with 2-way pipeline model parallelism.
Because evaluation requires substantially less memory than training, it may be advantageous to merge a model trained in parallel for use on fewer GPUs in downstream tasks. The following script accomplishes this. This example reads in a GPT model with 4-way tensor and 4-way pipeline model parallelism and writes out a model with 2-way tensor and 2-way pipeline model parallelism.
<pre>
TENSOR_MODEL_PARALLEL_SIZE=2
TARGET_PIPELINE_MODEL_PARALLEL_SIZE=2
VOCAB_FILE=bert-vocab.txt
CHECKPOINT_PATH=checkpoints/bert_345m
WORLD_SIZE=$TENSOR_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \
--model-type BERT \
--tensor-model-parallel-size $TENSOR_MODEL_PARALLEL_SIZE \
--pipeline-model-parallel-size 1 \
--target-pipeline-model-parallel-size $TARGET_PIPELINE_MODEL_PARALLEL_SIZE \
--tokenizer-type BertWordPieceLowerCase \
--vocab-file $VOCAB_FILE \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--seq-length 512 \
--max-position-embeddings 512 \
--load $CHECKPOINT_PATH
--save $CHECKPOINT_PATH/merged
python tools/checkpoint_util.py \
--model-type GPT \
--load-dir checkpoints/gpt3_tp4_pp4 \
--save-dir checkpoints/gpt3_tp2_pp2 \
--target-tensor-parallel-size 2 \
--target-pipeline-paralle-size 2
</pre>
......
......@@ -21,7 +21,7 @@ import os
import torch
def parse_args(extra_args_provider=None, defaults={},
ignore_unknown_args=False):
ignore_unknown_args=False, validate=True):
"""Parse all arguments."""
parser = argparse.ArgumentParser(description='Megatron-LM Arguments',
allow_abbrev=False)
......@@ -52,6 +52,11 @@ def parse_args(extra_args_provider=None, defaults={},
else:
args = parser.parse_args()
if validate:
return validate_args(args, defaults)
return args
def validate_args(args, defaults={}):
# Distributed args.
args.rank = int(os.getenv('RANK', '0'))
args.world_size = int(os.getenv("WORLD_SIZE", '1'))
......@@ -547,6 +552,11 @@ def _add_checkpointing_args(parser):
help='Load model for finetuning. Do not load optimizer '
'or rng state from checkpoint and set iteration to 0. '
'Assumed when loading a release checkpoint.')
group.add_argument('--no-initialization', action='store_false',
help='Do not perform initialization when building model, '
'can reduce startup time when definitely loading from a '
'checkpoint',
dest='perform_initialization')
return parser
......
......@@ -80,27 +80,56 @@ def ensure_directory_exists(filename):
os.makedirs(dirname)
def get_checkpoint_name(checkpoints_path, iteration,
release=False):
def get_checkpoint_name(checkpoints_path, iteration, release=False,
pipeline_parallel_size=None, tensor_rank=None, pipeline_rank=None):
"""A unified checkpoint name."""
if release:
directory = 'release'
else:
directory = 'iter_{:07d}'.format(iteration)
# Use both the tensor and pipeline MP rank.
if mpu.get_pipeline_model_parallel_world_size() == 1:
if pipeline_parallel_size is None:
parallel_size = mpu.get_pipeline_model_parallel_world_size()
if tensor_rank is None:
tensor_rank = mpu.get_tensor_model_parallel_rank()
if pipeline_rank is None:
pipeline_rank = mpu.get_pipeline_model_parallel_rank()
if pipeline_parallel_size == 1:
return os.path.join(checkpoints_path, directory,
'mp_rank_{:02d}'.format(
mpu.get_tensor_model_parallel_rank()),
f'mp_rank_{tensor_rank:02d}',
'model_optim_rng.pt')
return os.path.join(checkpoints_path, directory,
'mp_rank_{:02d}_{:03d}'.format(
mpu.get_tensor_model_parallel_rank(),
mpu.get_pipeline_model_parallel_rank()),
f'mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}',
'model_optim_rng.pt')
def find_checkpoint_rank_0(checkpoints_path, iteration, release=False):
"""Finds the checkpoint for rank 0 without knowing if we are using
pipeline parallelism or not.
Since the checkpoint naming scheme changes if pipeline parallelism
is present, we need to look for both naming schemes if we don't
know if the checkpoint has pipeline parallelism.
"""
# Look for checkpoint with no pipelining
filename = get_checkpoint_name(checkpoints_path, iteration, release,
pipeline_parallel_size=1,
tensor_rank=0, pipeline_rank=0)
if os.path.isfile(filename):
return filename
# Look for checkpoint with pipelining
filename = get_checkpoint_name(checkpoints_path, iteration, release,
pipeline_parallel_size=2,
tensor_rank=0, pipeline_rank=0)
if os.path.isfile(filename):
return filename
return None
def get_checkpoint_tracker_filename(checkpoints_path):
"""Tracker file rescords the latest chckpoint during
training to restart from."""
return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')
......@@ -125,18 +154,24 @@ def read_metadata(tracker_filename):
tracker_filename)
# Get the max iteration retrieved across the ranks.
iters_cuda = torch.cuda.LongTensor([iteration])
torch.distributed.all_reduce(iters_cuda, op=torch.distributed.ReduceOp.MAX)
max_iter = iters_cuda[0].item()
# We should now have all the same iteration.
# If not, print a warning and chose the maximum
# iteration across all ranks.
if iteration != max_iter:
print('WARNING: on rank {} found iteration {} in the '
'metadata while max iteration across the ranks '
'is {}, replacing it with max iteration.'.format(
rank, iteration, max_iter), flush=True)
if torch.distributed.is_initialized():
iters_cuda = torch.cuda.LongTensor([iteration])
torch.distributed.all_reduce(iters_cuda, op=torch.distributed.ReduceOp.MAX)
max_iter = iters_cuda[0].item()
# We should now have all the same iteration.
# If not, print a warning and chose the maximum
# iteration across all ranks.
if iteration != max_iter:
print('WARNING: on rank {} found iteration {} in the '
'metadata while max iteration across the ranks '
'is {}, replacing it with max iteration.'.format(
rank, iteration, max_iter), flush=True)
else:
# When loading a checkpoint outside of training (for example,
# when editing it), we might not have torch distributed
# initialized, in this case, just assume we have the latest
max_iter = iteration
return max_iter, release
......@@ -270,35 +305,38 @@ def fix_query_key_value_ordering(model, checkpoint_version):
print_rank_0(" succesfully fixed query-key-values ordering for"
" checkpoint version {}".format(checkpoint_version))
def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True):
"""Load a model checkpoint and return the iteration.
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` of the checkpoint match the names of
parameters and buffers in model.
def _load_base_checkpoint(load_dir, rank0=False):
""" Load the base state_dict from the given directory
If rank0 is true, just loads rank 0 checkpoint, ignoring arguments.
"""
args = get_args()
load_dir = getattr(args, load_arg)
model = utils.unwrap_model(model)
# Read the tracker file and set the iteration.
tracker_filename = get_checkpoint_tracker_filename(load_dir)
# If no tracker file, return iretation zero.
# If no tracker file, return nothing
if not os.path.isfile(tracker_filename):
print_rank_0('WARNING: could not find the metadata file {} '.format(
tracker_filename))
print_rank_0(' will not load any checkpoints and will start from '
'random')
return 0
if not rank0:
print_rank_0('WARNING: could not find the metadata file {} '.format(
tracker_filename))
print_rank_0(' will not load any checkpoints and will start from '
'random')
return None, False
# Otherwise, read the tracker file and either set the iteration or
# mark it as a release checkpoint.
iteration, release = read_metadata(tracker_filename)
# Checkpoint.
checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
print_rank_0(f' loading checkpoint from {args.load} at iteration {iteration}')
if rank0:
checkpoint_name = find_checkpoint_rank_0(load_dir, iteration, release)
else:
checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
if release:
print_rank_0(f' loading release checkpoint from {load_dir}')
else:
print_rank_0(f' loading checkpoint from {load_dir} at iteration {iteration}')
# Load the checkpoint.
try:
......@@ -306,7 +344,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
except ModuleNotFoundError:
from megatron.fp16_deprecated import loss_scaler
# For backward compatibility.
print_rank_0(' > deserializing using the old code structure ...')
if not rank0:
print_rank_0(' > deserializing using the old code structure ...')
sys.modules['fp16.loss_scaler'] = sys.modules[
'megatron.fp16_deprecated.loss_scaler']
sys.modules['megatron.fp16.loss_scaler'] = sys.modules[
......@@ -319,6 +358,79 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
print_rank_0(e)
sys.exit()
return state_dict, release
def load_args_from_checkpoint(args, load_arg='load'):
"""Set any arguments that are not currently set from the checkpoint
specified in the arguments.
Returns the same args NameSpace with the new values added/updated.
If no checkpoint is specified in args, or if the checkpoint is
there but invalid, the arguments will not be modified
"""
load_dir = getattr(args, load_arg)
if load_dir is None:
return args
state_dict, release = _load_base_checkpoint(load_dir, True)
if not state_dict:
return args
if 'args' not in state_dict:
return args
checkpoint_args = state_dict['args']
checkpoint_version = state_dict.get('checkpoint_version', 0)
args.iteration = state_dict['iteration']
def _set_arg(arg_name, old_arg_name=None, force=False):
if not force and getattr(args, arg_name, None) is not None:
return
if old_arg_name is not None:
checkpoint_value = getattr(checkpoint_args, old_arg_name, None)
else:
checkpoint_value = getattr(checkpoint_args, arg_name, None)
if checkpoint_value is not None:
print(f"Setting {arg_name} to {checkpoint_value}")
setattr(args, arg_name, checkpoint_value)
_set_arg('num_layers')
_set_arg('hidden_size')
_set_arg('ffn_hidden_size')
_set_arg('seq_length')
_set_arg('num_attention_heads')
_set_arg('kv_channels')
_set_arg('max_position_embeddings')
_set_arg('tokenizer_type')
_set_arg('padded_vocab_size')
if checkpoint_version < 3.0:
_set_arg('tensor_model_parallel_size',
'model_parallel_size')
else:
_set_arg('tensor_model_parallel_size', force=True)
_set_arg('pipeline_model_parallel_size', force=True)
_set_arg('num_layers_per_virtual_pipeline_stage')
return args
def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True):
"""Load a model checkpoint and return the iteration.
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` of the checkpoint match the names of
parameters and buffers in model.
"""
args = get_args()
load_dir = getattr(args, load_arg)
model = utils.unwrap_model(model)
state_dict, release = _load_base_checkpoint(load_dir, False)
# set checkpoint version
set_checkpoint_version(state_dict.get('checkpoint_version', 0))
......@@ -410,7 +522,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
def load_biencoder_checkpoint(model, only_query_model=False,
only_context_model=False, custom_load_path=None):
"""
selectively load retrieval models for indexing/retrieving
selectively load retrieval models for indexing/retrieving
from saved checkpoints
"""
......@@ -445,4 +557,3 @@ def load_biencoder_checkpoint(model, only_query_model=False,
print(' successfully loaded {}'.format(checkpoint_name))
return model
......@@ -77,11 +77,15 @@ def get_timers():
def set_global_variables(extra_args_provider=None, args_defaults={},
ignore_unknown_args=False):
ignore_unknown_args=False, parse_args=True):
"""Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
args = _parse_args(extra_args_provider=extra_args_provider,
defaults=args_defaults,
ignore_unknown_args=ignore_unknown_args)
if parse_args:
args = _parse_args(extra_args_provider=extra_args_provider,
defaults=args_defaults,
ignore_unknown_args=ignore_unknown_args)
else:
_ensure_var_is_initialized(_GLOBAL_ARGS, 'args')
args = get_args()
_build_num_microbatches_calculator(args)
if args.vocab_file:
_ = _build_tokenizer(args)
......@@ -89,6 +93,9 @@ def set_global_variables(extra_args_provider=None, args_defaults={},
_set_adlr_autoresume(args)
_set_timers()
def set_args(args):
global _GLOBAL_ARGS
_GLOBAL_ARGS = args
def _parse_args(extra_args_provider=None, defaults={},
ignore_unknown_args=False):
......@@ -97,7 +104,8 @@ def _parse_args(extra_args_provider=None, defaults={},
_ensure_var_is_not_initialized(_GLOBAL_ARGS, 'args')
_GLOBAL_ARGS = parse_args(extra_args_provider=extra_args_provider,
defaults=defaults,
ignore_unknown_args=ignore_unknown_args)
ignore_unknown_args=ignore_unknown_args,
validate=True)
return _GLOBAL_ARGS
......
......@@ -141,7 +141,8 @@ class Embedding(MegatronModule):
max_sequence_length, self.hidden_size)
self._position_embeddings_key = 'position_embeddings'
# Initialize the position embeddings.
self.init_method(self.position_embeddings.weight)
if args.perform_initialization:
self.init_method(self.position_embeddings.weight)
# Token type embedding.
# Add this as an optional field that can be added through
......@@ -152,7 +153,8 @@ class Embedding(MegatronModule):
self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes,
self.hidden_size)
# Initialize the token-type embeddings.
self.init_method(self.tokentype_embeddings.weight)
if args.perform_initialization:
self.init_method(self.tokentype_embeddings.weight)
else:
self.tokentype_embeddings = None
......
......@@ -47,7 +47,8 @@ def attention_mask_func(attention_scores, attention_mask):
def get_linear_layer(rows, columns, init_method):
"""Simple linear layer with weight initialization."""
layer = torch.nn.Linear(rows, columns)
init_method(layer.weight)
if get_args().perform_initialization:
init_method(layer.weight)
with torch.no_grad():
layer.bias.zero_()
return layer
......
......@@ -165,15 +165,17 @@ class VocabParallelEmbedding(torch.nn.Module):
self.weight = Parameter(torch.empty(
self.num_embeddings_per_partition, self.embedding_dim,
dtype=args.params_dtype))
_initialize_affine_weight_cpu(
self.weight, self.num_embeddings, self.embedding_dim,
self.num_embeddings_per_partition, 0, init_method)
if args.perform_initialization:
_initialize_affine_weight_cpu(
self.weight, self.num_embeddings, self.embedding_dim,
self.num_embeddings_per_partition, 0, init_method)
else:
self.weight = Parameter(torch.empty(
self.num_embeddings_per_partition, self.embedding_dim,
device=torch.cuda.current_device(), dtype=args.params_dtype))
_initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=0, stride=1)
if args.perform_initialization:
_initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=0, stride=1)
def forward(self, input_):
if self.tensor_model_parallel_size > 1:
......@@ -218,7 +220,7 @@ class ColumnParallelLinear(torch.nn.Module):
set to False. It returns the master weights
used for initialization.
skip_bias_add: This was added to enable performance optimations where bias
can be fused with other elementwise operations. we skip
can be fused with other elementwise operations. we skip
adding bias but instead return it.
"""
......@@ -246,16 +248,18 @@ class ColumnParallelLinear(torch.nn.Module):
self.weight = Parameter(torch.empty(self.output_size_per_partition,
self.input_size,
dtype=args.params_dtype))
self.master_weight = _initialize_affine_weight_cpu(
self.weight, self.output_size, self.input_size,
self.output_size_per_partition, 0, init_method,
stride=stride, return_master_weight=keep_master_weight_for_test)
if args.perform_initialization:
self.master_weight = _initialize_affine_weight_cpu(
self.weight, self.output_size, self.input_size,
self.output_size_per_partition, 0, init_method,
stride=stride, return_master_weight=keep_master_weight_for_test)
else:
self.weight = Parameter(torch.empty(
self.output_size_per_partition, self.input_size,
device=torch.cuda.current_device(), dtype=args.params_dtype))
_initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=0, stride=stride)
if args.perform_initialization:
_initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=0, stride=stride)
if bias:
if args.use_cpu_initialization:
......@@ -346,16 +350,18 @@ class RowParallelLinear(torch.nn.Module):
self.weight = Parameter(torch.empty(self.output_size,
self.input_size_per_partition,
dtype=args.params_dtype))
self.master_weight = _initialize_affine_weight_cpu(
self.weight, self.output_size, self.input_size,
self.input_size_per_partition, 1, init_method,
stride=stride, return_master_weight=keep_master_weight_for_test)
if args.perform_initialization:
self.master_weight = _initialize_affine_weight_cpu(
self.weight, self.output_size, self.input_size,
self.input_size_per_partition, 1, init_method,
stride=stride, return_master_weight=keep_master_weight_for_test)
else:
self.weight = Parameter(torch.empty(
self.output_size, self.input_size_per_partition,
device=torch.cuda.current_device(), dtype=args.params_dtype))
_initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=1, stride=stride)
if args.perform_initialization:
_initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=1, stride=stride)
if bias:
if args.use_cpu_initialization:
self.bias = Parameter(torch.empty(self.output_size,
......@@ -389,4 +395,3 @@ class RowParallelLinear(torch.nn.Module):
output = output_
output_bias = self.bias
return output, output_bias
import os
import sys
import types
import torch
def add_arguments(parser):
group = parser.add_argument_group(title='Megatron loader')
group.add_argument('--megatron-path', type=str, default=None,
help='Base directory of deepspeed repository')
def _load_checkpoint(queue, args):
# Search in directory above this
sys.path.append(os.path.abspath(
os.path.join(os.path.dirname(__file__),
os.path.pardir)))
if args.megatron_path is not None:
sys.path.insert(0, args.megatron_path)
try:
from megatron.arguments import parse_args, validate_args
from megatron.global_vars import set_args, set_global_variables, rebuild_tokenizer
from megatron.checkpointing import load_args_from_checkpoint, load_checkpoint
from megatron import mpu, fused_kernels
except ModuleNotFoundError:
print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.")
queue.put("exit")
exit(1)
def get_models(count, dtype, pre_process, post_process):
if args.model_type == 'GPT':
from pretrain_gpt import model_provider
elif args.model_type == 'BERT':
from pretrain_bert import model_provider
else:
raise Exception(f'unrecognized model type: {args.model_type}')
# with concurrent.futures.ThreadPoolExecutor(max_workers=count) as executor:
# futures = [executor.submit(model_provider, pre_process, post_process) for _ in range(count)]
# models = [f.result().bfloat16() for f in futures]
models = []
for rank in range(count):
mpu.initialize.set_tensor_model_parallel_rank(rank)
model_ = [model_provider(pre_process, post_process).to(dtype)]
margs.consumed_train_samples = 0
margs.consumed_valid_samples = 0
load_checkpoint(model_, None, None)
assert(len(model_) == 1)
models.append(model_[0])
return models
# We want all arguments to come from us
sys.argv = ['script.py',
'--no-masked-softmax-fusion',
'--no-bias-gelu-fusion',
'--no-bias-dropout-fusion',
'--use-cpu-initialization',
'--micro-batch-size', '1',
'--no-load-optim',
'--no-load-rng',
'--no-save-optim',
'--no-save-rng',
'--no-initialization',
'--load', args.load_dir
]
margs = parse_args(validate=False)
margs = load_args_from_checkpoint(margs)
def check_for_arg(arg_name):
if getattr(margs, arg_name, None) is None:
print(f"Checkpoint does not specify the argument {arg_name}. Exiting.")
print(f"Arguments: {margs}")
queue.put("exit")
exit(1)
check_for_arg('tensor_model_parallel_size')
check_for_arg('pipeline_model_parallel_size')
check_for_arg('num_layers')
check_for_arg('hidden_size')
check_for_arg('seq_length')
check_for_arg('num_attention_heads')
check_for_arg('max_position_embeddings')
check_for_arg('tokenizer_type')
check_for_arg('iteration')
check_for_arg('bert_binary_head')
# Arguments do sanity checks on the world size, but we don't care,
# so trick it into thinking we are plenty of processes
os.environ["WORLD_SIZE"] = f'{margs.tensor_model_parallel_size * margs.pipeline_model_parallel_size}'
margs = validate_args(margs)
check_for_arg('params_dtype')
set_args(margs)
if margs.num_layers_per_virtual_pipeline_stage is not None:
print("Model with an interleaved pipeline schedule are not yet supported.")
queue.put("exit")
exit(1)
set_global_variables(parse_args=False)
mpu.initialize.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size)
mpu.initialize.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size)
fused_kernels.load(margs)
# short aliases
tp_size = margs.tensor_model_parallel_size
pp_size = margs.pipeline_model_parallel_size
# metadata
md = types.SimpleNamespace()
md.model_type = args.model_type
md.num_layers = margs.num_layers
md.hidden_size = margs.hidden_size
md.seq_length = margs.seq_length
md.num_attention_heads = margs.num_attention_heads
md.max_position_embeddings = margs.max_position_embeddings
md.tokenizer_type = margs.tokenizer_type
md.iteration = margs.iteration
md.params_dtype = margs.params_dtype
md.bert_binary_head = margs.bert_binary_head
md.previous_tensor_parallel_size = margs.tensor_model_parallel_size
md.previous_pipeline_parallel_size = margs.pipeline_model_parallel_size
queue.put(md)
# Get first pipe stage
mpu.initialize.set_pipeline_model_parallel_rank(0)
post_process = pp_size == 1
models = get_models(tp_size, md.params_dtype, True, post_process)
# Send embeddings
word_embed = []
for tp_rank in range(tp_size):
if tp_rank == 0:
print("Sending position embeddings")
queue.put(models[tp_rank].language_model.embedding.position_embeddings.weight.data)
word_embed.append(models[tp_rank].language_model.embedding.word_embeddings.weight.data)
full_word_embed = torch.cat(word_embed, dim=0)
print("Sending word embeddings")
queue.put(full_word_embed)
total_layer_num = 0
for pp_rank in range(pp_size):
if pp_rank > 0:
mpu.initialize.set_pipeline_model_parallel_rank(pp_rank)
post_process = pp_rank == pp_size - 1
models = get_models(tp_size, md.params_dtype, False, post_process)
for layer_num in range(len(models[0].language_model.encoder.layers)):
qkv_weight = []
qkv_bias = []
dense_weight = []
mlp_l0_weight = []
mlp_l0_bias = []
mlp_l1_weight = []
# Get non-parallel tensors from tp_rank 0
layer = models[0].language_model.encoder.layers[layer_num]
input_layernorm_weight = layer.input_layernorm.weight.data
input_layernorm_bias = layer.input_layernorm.bias.data
dense_bias = layer.self_attention.dense.bias.data
post_layernorm_weight = layer.post_attention_layernorm.weight.data
post_layernorm_bias = layer.post_attention_layernorm.bias.data
mlp_l1_bias = layer.mlp.dense_4h_to_h.bias.data
# Grab all parallel tensors for this layer
for tp_rank, model in enumerate(models):
layer = model.language_model.encoder.layers[layer_num]
qkv_weight.append(layer.self_attention.query_key_value.weight.data)
qkv_bias.append(layer.self_attention.query_key_value.bias.data)
dense_weight.append(layer.self_attention.dense.weight.data)
mlp_l0_weight.append(layer.mlp.dense_h_to_4h.weight.data)
mlp_l0_bias.append(layer.mlp.dense_h_to_4h.bias.data)
mlp_l1_weight.append(layer.mlp.dense_4h_to_h.weight.data)
# send everything in order while concatenating them
print(f"Sending layer {layer_num} of pipeline rank {pp_rank} (total layer {total_layer_num})")
queue.put(input_layernorm_weight)
queue.put(input_layernorm_bias)
queue.put(torch.cat(qkv_weight, dim=0))
queue.put(torch.cat(qkv_bias, dim=0))
queue.put(torch.cat(dense_weight, dim=1))
queue.put(dense_bias)
queue.put(post_layernorm_weight)
queue.put(post_layernorm_bias)
queue.put(torch.cat(mlp_l0_weight, dim=0))
queue.put(torch.cat(mlp_l0_bias, dim=0))
queue.put(torch.cat(mlp_l1_weight, dim=1))
queue.put(mlp_l1_bias)
total_layer_num = total_layer_num + 1
# Send final layernorm from tp_rank 0
print("Sending final layernorm")
queue.put(models[0].language_model.encoder.final_layernorm.weight.data)
queue.put(models[0].language_model.encoder.final_layernorm.bias.data)
# Send BERT lm head and binary head if it exists
if md.model_type == 'BERT':
print("Sending LM Pooler")
queue.put("pooler")
queue.put(models[0].language_model.pooler.dense.weight.data)
queue.put(models[0].language_model.pooler.dense.bias.data)
print("Sending BERT LM head")
queue.put("lm head")
queue.put(models[0].lm_head.dense.weight.data)
queue.put(models[0].lm_head.dense.bias.data)
queue.put(models[0].lm_head.layernorm.weight.data)
queue.put(models[0].lm_head.layernorm.bias.data)
if md.bert_binary_head:
print("Sending BERT Binary head")
queue.put("binary head")
queue.put(models[0].binary_head.weight.data)
queue.put(models[0].binary_head.bias.data)
queue.put("done")
def load_checkpoint(queue, args):
try:
_load_checkpoint(queue, args)
except:
queue.put("exit")
raise
import argparse
import concurrent.futures
import os
import sys
import torch
def add_arguments(parser):
group = parser.add_argument_group(title='Megatron saver')
group.add_argument('--megatron-path', type=str, default=None,
help='Base directory of Megatron repository')
group.add_argument('--target-tensor-parallel-size', type=int,
help='Target tensor model parallel size, defaults to the tensor parallel size '
'in the input checkpoint if provided by the loader, otherwise to 1')
group.add_argument('--target-pipeline-parallel-size', type=int,
help='Target tensor model parallel size, default to the pipeline parall size '
'in the input checkpoint if provided by the loader, otherwise to 1')
def save_checkpoint(queue, args):
# Search in directory above this
sys.path.append(os.path.abspath(
os.path.join(os.path.dirname(__file__),
os.path.pardir)))
if args.megatron_path is not None:
sys.path.insert(0, args.megatron_path)
try:
from megatron.checkpointing import save_checkpoint
from megatron.global_vars import set_global_variables, get_args
from megatron import mpu, fused_kernels
except ModuleNotFoundError:
print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.")
exit(1)
def queue_get():
val = queue.get()
if val == "exit":
print("Loader exited, exiting saver")
exit(1)
return val
md = queue_get()
def get_models(count, dtype, pre_process, post_process):
if md.model_type == 'GPT':
from pretrain_gpt import model_provider
elif md.model_type == 'BERT':
from pretrain_bert import model_provider
else:
raise Exception(f'unrecognized model type: {md.model_type}')
# with concurrent.futures.ThreadPoolExecutor(max_workers=count) as executor:
# futures = [executor.submit(model_provider, pre_process, post_process) for _ in range(count)]
# models = [f.result().bfloat16() for f in futures]
models = [model_provider(pre_process, post_process).to(dtype) for _ in range(count)]
return models
if args.target_tensor_parallel_size is None:
if hasattr(md, 'previous_tensor_parallel_size'):
args.target_tensor_parallel_size = md.previous_tensor_parallel_size
else:
print("loader did not provide a tensor parallel size and --target-tensor-parallel-size not provided on command line. "
"Default to 1.")
args.target_tensor_parallel_size = 1
if args.target_pipeline_parallel_size is None:
if hasattr(md, 'previous_pipeline_parallel_size'):
args.target_pipeline_parallel_size = md.previous_pipeline_parallel_size
else:
print("loader did not provide a pipeline parallel size and --target-pipeline-parallel-size not provided on command line. "
"Default to 1.")
args.target_pipeline_parallel_size = 1
# Arguments do sanity checks on the world size, but we don't care,
# so trick it into thinking we are plenty of processes
if args.target_tensor_parallel_size is not None and args.target_pipeline_parallel_size is not None:
os.environ["WORLD_SIZE"] = f'{args.target_tensor_parallel_size * args.target_pipeline_parallel_size}'
# We want all arguments to come from us
sys.argv = ['script.py',
'--num-layers', str(md.num_layers),
'--hidden-size', str(md.hidden_size),
'--seq-length', str(md.seq_length),
'--num-attention-heads', str(md.num_attention_heads),
'--max-position-embeddings', str(md.max_position_embeddings),
'--tokenizer-type', str(md.tokenizer_type),
'--tensor-model-parallel-size', str(args.target_tensor_parallel_size),
'--pipeline-model-parallel-size', str(args.target_pipeline_parallel_size),
'--no-masked-softmax-fusion',
'--no-bias-gelu-fusion',
'--no-bias-dropout-fusion',
'--use-cpu-initialization',
'--micro-batch-size', '1',
'--no-load-optim',
'--no-load-rng',
'--no-save-optim',
'--no-save-rng',
'--no-initialization',
'--save-interval', '1',
'--save', args.save_dir
]
if md.params_dtype == torch.float16:
sys.argv.append('--fp16')
elif md.params_dtype == torch.bfloat16:
sys.argv.append('--bf16')
if md.model_type == 'BERT' and not md.bert_binary_head:
sys.argv.append('--bert-no-binary-head')
set_global_variables()
# margs = megatron args
margs = get_args()
# fake initializing distributed
mpu.initialize.set_tensor_model_parallel_world_size(args.target_tensor_parallel_size)
mpu.initialize.set_pipeline_model_parallel_world_size(args.target_pipeline_parallel_size)
mpu.initialize.set_tensor_model_parallel_rank(0)
mpu.initialize.set_pipeline_model_parallel_rank(0)
fused_kernels.load(margs)
# Embeddings
#-----------
pos_embed = queue_get()
full_word_embed = queue_get()
# Tell Megatron what our full size is
margs.padded_vocab_size = full_word_embed.shape[0]
if margs.padded_vocab_size % args.target_tensor_parallel_size != 0:
print("source vocab size is not evenly divisble by target tensor parallel size")
exit(1)
# Split into new tensor model parallel sizes
out_word_embed = torch.chunk(full_word_embed, args.target_tensor_parallel_size, dim=0)
# Make models for first pipeline stage and fill in embeddings
mpu.initialize.set_pipeline_model_parallel_rank(0)
post_process = args.target_pipeline_parallel_size == 1
models = get_models(args.target_tensor_parallel_size, md.params_dtype, True, post_process)
for tp_rank, model in enumerate(models):
model.language_model.embedding.word_embeddings.weight.data.copy_(out_word_embed[tp_rank])
model.language_model.embedding.position_embeddings.weight.data.copy_(pos_embed)
# Transformer layers
#-------------------
if md.num_layers % args.target_pipeline_parallel_size != 0:
print("Source number of layers is not divisible by target pipeline parallel size")
exit(1)
layers_per_rank = md.num_layers // args.target_pipeline_parallel_size
assert layers_per_rank == len(models[0].language_model.encoder.layers)
for pp_rank in range(args.target_pipeline_parallel_size):
# For later pipeline parallel ranks, make the new models
if pp_rank > 0:
mpu.initialize.set_pipeline_model_parallel_rank(pp_rank)
post_process = pp_rank == args.target_pipeline_parallel_size - 1
models = get_models(args.target_tensor_parallel_size, md.params_dtype, False, post_process)
for layer in range(layers_per_rank):
# get full tensors
input_layernorm_weight = queue_get()
input_layernorm_bias = queue_get()
full_qkv_weight = queue_get()
full_qkv_bias = queue_get()
full_dense_weight = queue_get()
dense_bias = queue_get()
post_layernorm_weight = queue_get()
post_layernorm_bias = queue_get()
full_mlp_l0_weight = queue_get()
full_mlp_l0_bias = queue_get()
full_mlp_l1_weight = queue_get()
mlp_l1_bias = queue_get()
# Split up the parallel tensors
out_qkv_weight = torch.chunk(full_qkv_weight, args.target_tensor_parallel_size, dim=0)
out_qkv_bias = torch.chunk(full_qkv_bias, args.target_tensor_parallel_size, dim=0)
out_dense_weight = torch.chunk(full_dense_weight, args.target_tensor_parallel_size, dim=1)
out_mlp_l0_weight = torch.chunk(full_mlp_l0_weight, args.target_tensor_parallel_size, dim=0)
out_mlp_l0_bias = torch.chunk(full_mlp_l0_bias, args.target_tensor_parallel_size, dim=0)
out_mlp_l1_weight = torch.chunk(full_mlp_l1_weight, args.target_tensor_parallel_size, dim=1)
# Save them to the model
for tp_rank in range(args.target_tensor_parallel_size):
l = models[tp_rank].language_model.encoder.layers[layer]
l.input_layernorm.weight.data.copy_(input_layernorm_weight)
l.input_layernorm.bias.data.copy_(input_layernorm_bias)
l.self_attention.query_key_value.weight.data.copy_(out_qkv_weight[tp_rank])
l.self_attention.query_key_value.bias.data.copy_(out_qkv_bias[tp_rank])
l.self_attention.dense.weight.data.copy_(out_dense_weight[tp_rank])
l.self_attention.dense.bias.data.copy_(dense_bias)
l.post_attention_layernorm.weight.data.copy_(post_layernorm_weight)
l.post_attention_layernorm.bias.data.copy_(post_layernorm_bias)
l.mlp.dense_h_to_4h.weight.data.copy_(out_mlp_l0_weight[tp_rank])
l.mlp.dense_h_to_4h.bias.data.copy_(out_mlp_l0_bias[tp_rank])
l.mlp.dense_4h_to_h.weight.data.copy_(out_mlp_l1_weight[tp_rank])
l.mlp.dense_4h_to_h.bias.data.copy_(mlp_l1_bias)
if post_process:
final_layernorm_weight = queue_get()
final_layernorm_bias = queue_get()
for tp_rank in range(args.target_tensor_parallel_size):
models[tp_rank].language_model.encoder.final_layernorm.weight.data.copy_(final_layernorm_weight)
models[tp_rank].language_model.encoder.final_layernorm.bias.data.copy_(final_layernorm_bias)
if pp_rank != 0:
# Copy word embeddings to final pipeline rank
models[tp_rank].word_embeddings.weight.data.copy_(out_word_embed[tp_rank])
del final_layernorm_weight
del final_layernorm_bias
name = queue_get()
if name == "pooler":
if not hasattr(models[0].language_model, 'pooler'):
print("ERROR: got a pooler, but model does not have one")
exit(1)
pooler_weight = queue_get()
pooler_bias = queue_get()
for tp_rank in range(args.target_tensor_parallel_size):
models[tp_rank].language_model.pooler.dense.weight.data.copy_(pooler_weight)
models[tp_rank].language_model.pooler.dense.bias.data.copy_(pooler_bias)
name = queue_get()
del pooler_weight
del pooler_bias
if name == "lm head":
if not hasattr(models[0], 'lm_head'):
print("ERROR: got an lm head, but model does not have one")
exit(1)
lm_head_dense_weight = queue_get()
lm_head_dense_bias = queue_get()
lm_head_layernorm_weight = queue_get()
lm_head_layernorm_bias = queue_get()
for tp_rank in range(args.target_tensor_parallel_size):
models[tp_rank].lm_head.dense.weight.data.copy_(lm_head_dense_weight)
models[tp_rank].lm_head.dense.bias.data.copy_(lm_head_dense_bias)
models[tp_rank].lm_head.layernorm.weight.data.copy_(lm_head_layernorm_weight)
models[tp_rank].lm_head.layernorm.bias.data.copy_(lm_head_layernorm_bias)
name = queue_get()
if name == "binary head":
if not hasattr(models[0], 'binary_head'):
print("ERROR: got a binary head, but model does not have one")
exit(1)
binary_head_weight = queue_get()
binary_head_bias = queue_get()
for tp_rank in range(args.target_tensor_parallel_size):
models[tp_rank].binary_head.weight.data.copy_(binary_head_weight)
models[tp_rank].binary_head.bias.data.copy_(binary_head_bias)
name = queue_get()
if name != "done":
print("ERROR: got some more data but were expecting to be done")
for tp_rank in range(args.target_tensor_parallel_size):
mpu.initialize.set_tensor_model_parallel_rank(tp_rank)
save_checkpoint(md.iteration, [models[tp_rank]], None, None)
print("Done!")
import argparse
import importlib
import torch.multiprocessing as mp
import os
# A loader is a python file with at least two functions
# - add_arguments - takes in a parser and adds any arguments needed
# - load_checkpoint - takes in the queue and parsed arguments
# A saver is similar but has save_checkpoint instead of
# load_checkpoint
# The loader and saver process are each given a queue, the loader
# should load the checkpoint and send the weights in the following
# order, the saver should receive them in this order and save the
# checkpoints. Note that the weight sent over the queue are the full
# model weights, nothing split.
# If the loader ever sends "exit" to the queue, that means something
# went wrong and it is exiting.
# - Metadata Namespace with the following attributes:
# model_type - GPT, BERT, T5, etc. (Part of protocol to allow this to be deduced later instead of given on command line)
# num_layers - Number of transformer layers
# hidden_size
# seq_length
# num_attention_heads
# max_position_embeddings
# tokenizer_type
# iteration
# params_dtype
# bert_binary_head - Used only if model_type is BERT
# previous_tensor_parallel_size - Optional
# previous_pipeline_parallel_size - Optional
# - Position embeddings
# - Word embeddings
# - For each transformer layer:
# - input layernorm weights
# - input layernorm bias
# - qkv weight
# - qkv bias
# - dense weight
# - dense bias
# - post attention layernorm weight
# - post attention layernorm bias
# - mlp layer 0 (h to 4h) weight
# - mlp layer 0 (h to 4h) bias
# - mlp layer 1 (4h to h) weight
# - mlp layer 1 (4h to h) bias
# - final layer norm weight
# - final layer norm bias
# - if present (i.e. for BERT):
# - "pooler"
# - LM Pooler weight
# - LM Pooler bias
# - "lm head"
# - LM head dense weight
# - LM head dense bias
# - LM head layernorm weight
# - LM head layernorm bias
# - "binary head"
# - BERT Binary head weight
# - BERT Binary head bias
# - "done"
def load_plugin(plugin_type, name):
module_name = f"checkpoint_{plugin_type}_{name}"
try:
plugin = importlib.import_module(module_name)
except ModuleNotFoundError:
module_name = name
try:
plugin = importlib.import_module(module_name)
except ModuleNotFoundError:
print(f"Unable to load {plugin_type} plugin {name}. Exiting.")
exit
if not hasattr(plugin, 'add_arguments'):
print(f"{module_name} module is not a plugin. Exiting.")
exit
print(f"Loaded {module_name} as the {plugin_type}.")
return plugin
def main():
import argparse
parser = argparse.ArgumentParser(description="Megatron Checkpoint Utility Arguments",
allow_abbrev=False, conflict_handler='resolve')
parser.add_argument('--model-type', type=str, required=True,
choices=['GPT', 'BERT'],
help='Type of the model')
parser.add_argument('--loader', type=str, default='megatron',
help='Module name to load checkpoint, should be on python path')
parser.add_argument('--saver', type=str, default='megatron',
help='Module name to save checkpoint, shdoul be on python path')
parser.add_argument('--load-dir', type=str, required=True,
help='Directory to load model checkpoint from')
parser.add_argument('--save-dir', type=str, required=True,
help='Directory to save model checkpoint to')
parser.add_argument('--max-queue-size', type=int, default=50,
help='Maximum number of tensors in the queue')
known_args, _ = parser.parse_known_args()
loader = load_plugin('loader', known_args.loader)
saver = load_plugin('saver', known_args.saver)
loader.add_arguments(parser)
saver.add_arguments(parser)
args = parser.parse_args()
queue = mp.Queue(maxsize=args.max_queue_size)
print("Starting saver...")
saver_proc = mp.Process(target=saver.save_checkpoint, args=(queue, args))
saver_proc.start()
print("Starting loader...")
loader.load_checkpoint(queue, args)
print("Waiting for saver to complete...")
saver_proc.join()
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment