Commit 5f694372 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'checkpoint_util' into 'main'

Checkpoint conversion utility

See merge request ADLR/megatron-lm!289
parents bdbb30fb 748a9a7d
......@@ -410,29 +410,15 @@ python tools/create_doc_index.py \
We provide several command line arguments, detailed in the scripts listed below, to handle various zero-shot and fine-tuned downstream tasks. However, you can also finetune your model from a pretrained checkpoint on other corpora as desired. To do so, simply add the `--finetune` flag and adjust the input files and training parameters within the original training script. The iteration count will be reset to zero, and the optimizer and internal state will be reinitialized. If the fine-tuning is interrupted for any reason, be sure to remove the `--finetune` flag before continuing, otherwise the training will start again from the beginning.
Because evaluation requires substantially less memory than training, it may be advantageous to merge a model trained in parallel for use on a single GPU in downstream tasks. The following script accomplishes this. Currently only tensor model parallelism is supported on input and pipeline model parallelism on the output. This example reads in a model with 2-way tensor model parallelism and writes out a model with 2-way pipeline model parallelism.
Because evaluation requires substantially less memory than training, it may be advantageous to merge a model trained in parallel for use on fewer GPUs in downstream tasks. The following script accomplishes this. This example reads in a GPT model with 4-way tensor and 4-way pipeline model parallelism and writes out a model with 2-way tensor and 2-way pipeline model parallelism.
<pre>
TENSOR_MODEL_PARALLEL_SIZE=2
TARGET_PIPELINE_MODEL_PARALLEL_SIZE=2
VOCAB_FILE=bert-vocab.txt
CHECKPOINT_PATH=checkpoints/bert_345m
WORLD_SIZE=$TENSOR_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \
--model-type BERT \
--tensor-model-parallel-size $TENSOR_MODEL_PARALLEL_SIZE \
--pipeline-model-parallel-size 1 \
--target-pipeline-model-parallel-size $TARGET_PIPELINE_MODEL_PARALLEL_SIZE \
--tokenizer-type BertWordPieceLowerCase \
--vocab-file $VOCAB_FILE \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--seq-length 512 \
--max-position-embeddings 512 \
--load $CHECKPOINT_PATH
--save $CHECKPOINT_PATH/merged
python tools/checkpoint_util.py \
--model-type GPT \
--load-dir checkpoints/gpt3_tp4_pp4 \
--save-dir checkpoints/gpt3_tp2_pp2 \
--target-tensor-parallel-size 2 \
--target-pipeline-paralle-size 2
</pre>
......
......@@ -26,22 +26,6 @@ from .global_vars import get_timers
from .global_vars import get_global_memory_buffer
from .initialize import initialize_megatron
def print_rank_0(message):
"""If distributed is initialized, print only on rank 0."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(message, flush=True)
else:
print(message, flush=True)
def is_last_rank():
return torch.distributed.get_rank() == (
torch.distributed.get_world_size() - 1)
def print_rank_last(message):
"""If distributed is initialized, print only on last rank."""
if torch.distributed.is_initialized():
if is_last_rank():
print(message, flush=True)
else:
print(message, flush=True)
from .utils import (print_rank_0,
is_last_rank,
print_rank_last)
......@@ -20,8 +20,7 @@ import os
import torch
def parse_args(extra_args_provider=None, defaults={},
ignore_unknown_args=False):
def parse_args(extra_args_provider=None, ignore_unknown_args=False):
"""Parse all arguments."""
parser = argparse.ArgumentParser(description='Megatron-LM Arguments',
allow_abbrev=False)
......@@ -53,9 +52,13 @@ def parse_args(extra_args_provider=None, defaults={},
else:
args = parser.parse_args()
# Distributed args.
# Args from environment
args.rank = int(os.getenv('RANK', '0'))
args.world_size = int(os.getenv("WORLD_SIZE", '1'))
return args
def validate_args(args, defaults={}):
# Tensor model parallel size.
args.tensor_model_parallel_size = min(
args.tensor_model_parallel_size, args.world_size)
......@@ -679,6 +682,14 @@ def _add_checkpointing_args(parser):
help='Load model for finetuning. Do not load optimizer '
'or rng state from checkpoint and set iteration to 0. '
'Assumed when loading a release checkpoint.')
group.add_argument('--no-initialization', action='store_false',
help='Do not perform initialization when building model, '
'can reduce startup time when definitely loading from a '
'checkpoint',
dest='perform_initialization')
group.add_argument('--use-checkpoint-args', action='store_true',
help='Override any command line arguments with arguments '
'from the checkpoint')
return parser
......
......@@ -22,11 +22,12 @@ import numpy as np
import torch
from megatron import (get_args,
mpu,
print_rank_0,
update_num_microbatches,
utils)
from megatron import (mpu,
update_num_microbatches)
from .global_vars import get_args
from .utils import (unwrap_model,
print_rank_0)
_CHECKPOINT_VERSION = None
......@@ -81,25 +82,31 @@ def ensure_directory_exists(filename):
os.makedirs(dirname)
def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
release=False):
"""A unified checkpoint name."""
def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, release=False,
pipeline_parallel=None, tensor_rank=None, pipeline_rank=None):
"""Determine the directory name for this rank's checkpoint."""
if release:
directory = 'release'
else:
directory = 'iter_{:07d}'.format(iteration)
# Use both the tensor and pipeline MP rank.
if pipeline_parallel is None:
pipeline_parallel = (mpu.get_pipeline_model_parallel_world_size() > 1)
if tensor_rank is None:
tensor_rank = mpu.get_tensor_model_parallel_rank()
if pipeline_rank is None:
pipeline_rank = mpu.get_pipeline_model_parallel_rank()
# Use both the tensor and pipeline MP rank. If using the distributed
# optimizer, then the optimizer's path must additionally include the
# data parallel rank.
if mpu.get_pipeline_model_parallel_world_size() == 1:
if not pipeline_parallel:
common_path = os.path.join(checkpoints_path, directory,
'mp_rank_{:02d}'.format(
mpu.get_tensor_model_parallel_rank()))
f'mp_rank_{tensor_rank:02d}')
else:
common_path = os.path.join(checkpoints_path, directory,
'mp_rank_{:02d}_{:03d}'.format(
mpu.get_tensor_model_parallel_rank(),
mpu.get_pipeline_model_parallel_rank()))
f'mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}')
if use_distributed_optimizer:
model_name = os.path.join(common_path, "model_rng.pt")
......@@ -110,8 +117,34 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
model_name = optim_name = os.path.join(common_path, "model_optim_rng.pt")
return model_name, optim_name
def find_checkpoint_rank_0(checkpoints_path, iteration, use_distributed_optimizer, release=False):
"""Finds the checkpoint for rank 0 without knowing if we are using
pipeline parallelism or not.
Since the checkpoint naming scheme changes if pipeline parallelism
is present, we need to look for both naming schemes if we don't
know if the checkpoint has pipeline parallelism.
"""
# Look for checkpoint with no pipelining
filenames = get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, release,
pipeline_parallel=False,
tensor_rank=0, pipeline_rank=0)
if os.path.isfile(filenames[0]):
return filenames
# Look for checkpoint with pipelining
filenames = get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer, release,
pipeline_parallel=True,
tensor_rank=0, pipeline_rank=0)
if os.path.isfile(filenames[0]):
return filenames
return None, None
def get_checkpoint_tracker_filename(checkpoints_path):
"""Tracker file rescords the latest chckpoint during
training to restart from."""
return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')
......@@ -136,18 +169,24 @@ def read_metadata(tracker_filename):
tracker_filename)
# Get the max iteration retrieved across the ranks.
iters_cuda = torch.cuda.LongTensor([iteration])
torch.distributed.all_reduce(iters_cuda, op=torch.distributed.ReduceOp.MAX)
max_iter = iters_cuda[0].item()
# We should now have all the same iteration.
# If not, print a warning and chose the maximum
# iteration across all ranks.
if iteration != max_iter:
print('WARNING: on rank {} found iteration {} in the '
'metadata while max iteration across the ranks '
'is {}, replacing it with max iteration.'.format(
rank, iteration, max_iter), flush=True)
if torch.distributed.is_initialized():
iters_cuda = torch.cuda.LongTensor([iteration])
torch.distributed.all_reduce(iters_cuda, op=torch.distributed.ReduceOp.MAX)
max_iter = iters_cuda[0].item()
# We should now have all the same iteration.
# If not, print a warning and chose the maximum
# iteration across all ranks.
if iteration != max_iter:
print('WARNING: on rank {} found iteration {} in the '
'metadata while max iteration across the ranks '
'is {}, replacing it with max iteration.'.format(
rank, iteration, max_iter), flush=True)
else:
# When loading a checkpoint outside of training (for example,
# when editing it), we might not have torch distributed
# initialized, in this case, just assume we have the latest
max_iter = iteration
return max_iter, release
......@@ -182,7 +221,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
args = get_args()
# Only rank zero of the data parallel writes to the disk.
model = utils.unwrap_model(model)
model = unwrap_model(model)
print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
iteration, args.save))
......@@ -331,50 +370,55 @@ def fix_query_key_value_ordering(model, checkpoint_version):
print_rank_0(" succesfully fixed query-key-values ordering for"
" checkpoint version {}".format(checkpoint_version))
def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', strict=True):
"""Load a model checkpoint and return the iteration.
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` of the checkpoint match the names of
parameters and buffers in model.
def _load_base_checkpoint(load_dir, use_distributed_optimizer, rank0=False):
""" Load the base state_dict from the given directory
If rank0 is true, just loads rank 0 checkpoint, ignoring arguments.
"""
args = get_args()
load_dir = getattr(args, load_arg)
model = utils.unwrap_model(model)
# Read the tracker file and set the iteration.
tracker_filename = get_checkpoint_tracker_filename(load_dir)
# If no tracker file, return iretation zero.
# If no tracker file, return nothing
if not os.path.isfile(tracker_filename):
print_rank_0('WARNING: could not find the metadata file {} '.format(
tracker_filename))
print_rank_0(' will not load any checkpoints and will start from '
'random')
return 0
if not rank0:
print_rank_0('WARNING: could not find the metadata file {} '.format(
tracker_filename))
print_rank_0(' will not load any checkpoints and will start from '
'random')
return None, None, False
# Otherwise, read the tracker file and either set the iteration or
# mark it as a release checkpoint.
iteration, release = read_metadata(tracker_filename)
# Checkpoint.
model_checkpoint_name, optim_checkpoint_name = \
get_checkpoint_names(load_dir, iteration,
args.use_distributed_optimizer,
release)
print_rank_0(f' loading checkpoint from {args.load} at iteration {iteration}')
if rank0:
checkpoint_names = find_checkpoint_rank_0(load_dir, iteration, use_distributed_optimizer,
release)
else:
checkpoint_names = get_checkpoint_names(load_dir, iteration, use_distributed_optimizer,
release)
if release:
print_rank_0(f' loading release checkpoint from {load_dir}')
else:
print_rank_0(f' loading checkpoint from {load_dir} at iteration {iteration}')
model_checkpoint_name, optim_checkpoint_name = checkpoint_names
# Load the checkpoint.
try:
model_state_dict = torch.load(model_checkpoint_name, map_location='cpu')
if args.use_distributed_optimizer:
if use_distributed_optimizer:
optim_state_dict = torch.load(optim_checkpoint_name, map_location='cpu')
else:
optim_state_dict = model_state_dict
except ModuleNotFoundError:
from megatron.fp16_deprecated import loss_scaler
# For backward compatibility.
print_rank_0(' > deserializing using the old code structure ...')
if not rank0:
print_rank_0(' > deserializing using the old code structure ...')
sys.modules['fp16.loss_scaler'] = sys.modules[
'megatron.fp16_deprecated.loss_scaler']
sys.modules['megatron.fp16.loss_scaler'] = sys.modules[
......@@ -388,7 +432,99 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
print_rank_0(e)
sys.exit()
# Set checkpoint version.
return model_state_dict, optim_state_dict, release
def load_args_from_checkpoint(args, load_arg='load'):
"""Set required arguments from the checkpoint specified in the
arguments.
Will overwrite arguments that have a non-None default value, but
will leave any arguments that default to None as set.
Returns the same args NameSpace with the new values added/updated.
If no checkpoint is specified in args, or if the checkpoint is
there but invalid, the arguments will not be modified
"""
load_dir = getattr(args, load_arg)
if load_dir is None:
print_rank_0('No load directory specified, using provided arguments.')
return args
model_state_dict, optim_state_dict, release = \
_load_base_checkpoint(load_dir,
use_distributed_optimizer=args.use_distributed_optimizer,
rank0=True)
# For args we only care about model state dict
state_dict = model_state_dict
if not state_dict:
print_rank_0('Checkpoint not found to provide arguments, using provided arguments.')
return args
if 'args' not in state_dict:
print_rank_0('Checkpoint provided does not have arguments saved, using provided arguments.')
return args
checkpoint_args = state_dict['args']
checkpoint_version = state_dict.get('checkpoint_version', 0)
args.iteration = state_dict['iteration']
def _set_arg(arg_name, old_arg_name=None, force=False):
if not force and getattr(args, arg_name, None) is not None:
return
if old_arg_name is not None:
checkpoint_value = getattr(checkpoint_args, old_arg_name, None)
else:
checkpoint_value = getattr(checkpoint_args, arg_name, None)
if checkpoint_value is not None:
print_rank_0(f"Setting {arg_name} to {checkpoint_value} from checkpoint")
setattr(args, arg_name, checkpoint_value)
_set_arg('num_layers')
_set_arg('hidden_size')
_set_arg('ffn_hidden_size')
_set_arg('seq_length')
_set_arg('num_attention_heads')
_set_arg('kv_channels')
_set_arg('max_position_embeddings')
_set_arg('tokenizer_type')
_set_arg('padded_vocab_size')
if checkpoint_version < 3.0:
_set_arg('tensor_model_parallel_size',
'model_parallel_size')
else:
_set_arg('tensor_model_parallel_size', force=True)
_set_arg('pipeline_model_parallel_size', force=True)
_set_arg('num_layers_per_virtual_pipeline_stage')
return args
def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', strict=True):
"""Load a model checkpoint and return the iteration.
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` of the checkpoint match the names of
parameters and buffers in model.
"""
args = get_args()
load_dir = getattr(args, load_arg)
model = unwrap_model(model)
model_state_dict, optim_state_dict, release = \
_load_base_checkpoint(load_dir,
use_distributed_optimizer=args.use_distributed_optimizer,
rank0=False)
if model_state_dict is None:
return 0
# set checkpoint version
set_checkpoint_version(model_state_dict.get('checkpoint_version', 0))
# Set iteration.
......@@ -499,13 +635,13 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
def load_biencoder_checkpoint(model, only_query_model=False,
only_context_model=False, custom_load_path=None):
"""
selectively load retrieval models for indexing/retrieving
selectively load retrieval models for indexing/retrieving
from saved checkpoints
"""
args = get_args()
model = utils.unwrap_model(model)
model = unwrap_model(model)
load_path = custom_load_path if custom_load_path is not None else args.load
......@@ -515,7 +651,8 @@ def load_biencoder_checkpoint(model, only_query_model=False,
checkpoint_name, _ = get_checkpoint_names(load_path, iteration,
args.use_distributed_optimizer,
False)
release=False)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
......@@ -536,4 +673,3 @@ def load_biencoder_checkpoint(model, only_query_model=False,
print(' successfully loaded {}'.format(checkpoint_name))
return model
......@@ -24,7 +24,6 @@ import torch
from megatron import dist_signal_handler
from megatron.tokenizer import build_tokenizer
from .arguments import parse_args
from .microbatches import build_num_microbatches_calculator
_GLOBAL_ARGS = None
......@@ -95,12 +94,15 @@ def _set_signal_handler():
_GLOBAL_SIGNAL_HANDLER = dist_signal_handler.DistributedSignalHandler().__enter__()
def set_global_variables(extra_args_provider=None, args_defaults={},
ignore_unknown_args=False):
def set_global_variables(args):
"""Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
args = _parse_args(extra_args_provider=extra_args_provider,
defaults=args_defaults,
ignore_unknown_args=ignore_unknown_args)
assert args is not None
_ensure_var_is_not_initialized(_GLOBAL_ARGS, 'args')
set_args(args)
_build_num_microbatches_calculator(args)
if args.vocab_file:
_ = _build_tokenizer(args)
......@@ -111,17 +113,11 @@ def set_global_variables(extra_args_provider=None, args_defaults={},
if args.exit_signal_handler:
_set_signal_handler()
def _parse_args(extra_args_provider=None, defaults={},
ignore_unknown_args=False):
"""Parse entire arguments."""
def set_args(args):
global _GLOBAL_ARGS
_ensure_var_is_not_initialized(_GLOBAL_ARGS, 'args')
_GLOBAL_ARGS = parse_args(extra_args_provider=extra_args_provider,
defaults=defaults,
ignore_unknown_args=ignore_unknown_args)
return _GLOBAL_ARGS
_GLOBAL_ARGS = args
def _build_num_microbatches_calculator(args):
......
......@@ -28,6 +28,8 @@ from megatron import get_adlr_autoresume
from megatron import get_args
from megatron import get_tensorboard_writer
from megatron import mpu
from megatron.arguments import (parse_args, validate_args)
from megatron.checkpointing import load_args_from_checkpoint
from megatron.global_vars import set_global_variables
from megatron.mpu import (set_tensor_model_parallel_rank,
set_tensor_model_parallel_world_size)
......@@ -49,11 +51,18 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# Make sure cuda is available.
assert torch.cuda.is_available(), 'Megatron requires CUDA.'
# Parse args, build tokenizer, and set adlr-autoresume,
# Parse arguments
args = parse_args(extra_args_provider, ignore_unknown_args)
if args.use_checkpoint_args or args_defaults.get('use_checkpoint_args', False):
assert args.load is not None, '--use-checkpoints-args requires --load argument'
load_args_from_checkpoint(args)
validate_args(args, args_defaults)
# set global args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers.
set_global_variables(extra_args_provider=extra_args_provider,
args_defaults=args_defaults,
ignore_unknown_args=ignore_unknown_args)
set_global_variables(args)
# torch.distributed initialization
def finish_mpu_init():
......
......@@ -166,7 +166,8 @@ class Embedding(MegatronModule):
max_sequence_length, self.hidden_size)
self._position_embeddings_key = 'position_embeddings'
# Initialize the position embeddings.
self.init_method(self.position_embeddings.weight)
if args.perform_initialization:
self.init_method(self.position_embeddings.weight)
# Token type embedding.
# Add this as an optional field that can be added through
......@@ -177,7 +178,8 @@ class Embedding(MegatronModule):
self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes,
self.hidden_size)
# Initialize the token-type embeddings.
self.init_method(self.tokentype_embeddings.weight)
if args.perform_initialization:
self.init_method(self.tokentype_embeddings.weight)
else:
self.tokentype_embeddings = None
......
......@@ -102,29 +102,32 @@ class MegatronModule(torch.nn.Module):
self.pre_process:
self.language_model.embedding.zero_parameters()
if not torch.distributed.is_initialized():
if not getattr(MegatronModule, "embedding_warning_printed", False):
print("WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong.")
MegatronModule.embedding_warning_printed = True
return
# Ensure that first and last stages have the same initial parameter
# values.
if torch.distributed.is_initialized():
if mpu.is_rank_in_embedding_group():
torch.distributed.all_reduce(self.word_embeddings_weight().data,
group=mpu.get_embedding_group())
# Ensure that encoder(first stage) and decoder(split stage) position
# embeddings have the same initial parameter values
# NOTE: We don't currently support T5 with the interleaved schedule.
if mpu.is_rank_in_position_embedding_group() and \
args.pipeline_model_parallel_split_rank is not None:
# TODO: Support tokentype embedding.
self.language_model.embedding.cuda()
position_embeddings = self.language_model.embedding.position_embeddings
torch.distributed.all_reduce(position_embeddings.weight.data,
group=mpu.get_position_embedding_group())
else:
print("WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong.")
if mpu.is_rank_in_embedding_group():
torch.distributed.all_reduce(self.word_embeddings_weight().data,
group=mpu.get_embedding_group())
# Ensure that encoder(first stage) and decoder(split stage) position
# embeddings have the same initial parameter values
# NOTE: We don't currently support T5 with the interleaved schedule.
if mpu.is_rank_in_position_embedding_group() and \
args.pipeline_model_parallel_split_rank is not None:
# TODO: Support tokentype embedding.
self.language_model.embedding.cuda()
position_embeddings = self.language_model.embedding.position_embeddings
torch.distributed.all_reduce(position_embeddings.weight.data,
group=mpu.get_position_embedding_group())
def conversion_helper(val, conversion):
......
......@@ -47,7 +47,8 @@ def attention_mask_func(attention_scores, attention_mask):
def get_linear_layer(rows, columns, init_method):
"""Simple linear layer with weight initialization."""
layer = torch.nn.Linear(rows, columns)
init_method(layer.weight)
if get_args().perform_initialization:
init_method(layer.weight)
with torch.no_grad():
layer.bias.zero_()
return layer
......
......@@ -167,15 +167,17 @@ class VocabParallelEmbedding(torch.nn.Module):
self.weight = Parameter(torch.empty(
self.num_embeddings_per_partition, self.embedding_dim,
dtype=args.params_dtype))
_initialize_affine_weight_cpu(
self.weight, self.num_embeddings, self.embedding_dim,
self.num_embeddings_per_partition, 0, init_method)
if args.perform_initialization:
_initialize_affine_weight_cpu(
self.weight, self.num_embeddings, self.embedding_dim,
self.num_embeddings_per_partition, 0, init_method)
else:
self.weight = Parameter(torch.empty(
self.num_embeddings_per_partition, self.embedding_dim,
device=torch.cuda.current_device(), dtype=args.params_dtype))
_initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=0, stride=1)
if args.perform_initialization:
_initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=0, stride=1)
def forward(self, input_):
if self.tensor_model_parallel_size > 1:
......@@ -330,7 +332,7 @@ class ColumnParallelLinear(torch.nn.Module):
set to False. It returns the master weights
used for initialization.
skip_bias_add: This was added to enable performance optimations where bias
can be fused with other elementwise operations. we skip
can be fused with other elementwise operations. we skip
adding bias but instead return it.
"""
......@@ -358,16 +360,18 @@ class ColumnParallelLinear(torch.nn.Module):
self.weight = Parameter(torch.empty(self.output_size_per_partition,
self.input_size,
dtype=args.params_dtype))
self.master_weight = _initialize_affine_weight_cpu(
self.weight, self.output_size, self.input_size,
self.output_size_per_partition, 0, init_method,
stride=stride, return_master_weight=keep_master_weight_for_test)
if args.perform_initialization:
self.master_weight = _initialize_affine_weight_cpu(
self.weight, self.output_size, self.input_size,
self.output_size_per_partition, 0, init_method,
stride=stride, return_master_weight=keep_master_weight_for_test)
else:
self.weight = Parameter(torch.empty(
self.output_size_per_partition, self.input_size,
device=torch.cuda.current_device(), dtype=args.params_dtype))
_initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=0, stride=stride)
if args.perform_initialization:
_initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=0, stride=stride)
if bias:
if args.use_cpu_initialization:
......@@ -471,16 +475,18 @@ class RowParallelLinear(torch.nn.Module):
self.weight = Parameter(torch.empty(self.output_size,
self.input_size_per_partition,
dtype=args.params_dtype))
self.master_weight = _initialize_affine_weight_cpu(
self.weight, self.output_size, self.input_size,
self.input_size_per_partition, 1, init_method,
stride=stride, return_master_weight=keep_master_weight_for_test)
if args.perform_initialization:
self.master_weight = _initialize_affine_weight_cpu(
self.weight, self.output_size, self.input_size,
self.input_size_per_partition, 1, init_method,
stride=stride, return_master_weight=keep_master_weight_for_test)
else:
self.weight = Parameter(torch.empty(
self.output_size, self.input_size_per_partition,
device=torch.cuda.current_device(), dtype=args.params_dtype))
_initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=1, stride=stride)
if args.perform_initialization:
_initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=1, stride=stride)
if bias:
if args.use_cpu_initialization:
self.bias = Parameter(torch.empty(self.output_size,
......@@ -524,4 +530,3 @@ class RowParallelLinear(torch.nn.Module):
output = output_
output_bias = self.bias
return output, output_bias
......@@ -24,7 +24,6 @@ from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from megatron import get_args
from megatron import print_rank_0
from megatron import get_adlr_autoresume
from megatron import mpu
from megatron.model.module import param_is_not_shared
......@@ -204,3 +203,22 @@ def get_ltor_masks_and_position_ids(data,
return attention_mask, loss_mask, position_ids
def print_rank_0(message):
"""If distributed is initialized, print only on rank 0."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(message, flush=True)
else:
print(message, flush=True)
def is_last_rank():
return torch.distributed.get_rank() == (
torch.distributed.get_world_size() - 1)
def print_rank_last(message):
"""If distributed is initialized, print only on last rank."""
if torch.distributed.is_initialized():
if is_last_rank():
print(message, flush=True)
else:
print(message, flush=True)
import json
import os
import sys
import types
import torch
def add_arguments(parser):
group = parser.add_argument_group(title='Megatron loader')
group.add_argument('--true-vocab-size', type=int, default=None,
help='original size of vocab, if specified will trim padding from embedding table.')
group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file. If specified will use this to get vocab size and '
'trim padding from the embedding table.')
group.add_argument('--megatron-path', type=str, default=None,
help='Base directory of deepspeed repository')
def _load_checkpoint(queue, args):
# Search in directory above this
sys.path.append(os.path.abspath(
os.path.join(os.path.dirname(__file__),
os.path.pardir)))
if args.megatron_path is not None:
sys.path.insert(0, args.megatron_path)
try:
from megatron.arguments import parse_args, validate_args
from megatron.global_vars import set_args, set_global_variables
from megatron.checkpointing import load_args_from_checkpoint, load_checkpoint
from megatron.model import ModelType, module
from megatron import mpu, fused_kernels
except ModuleNotFoundError:
print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.")
queue.put("exit")
exit(1)
# We want all arguments to come from us
sys.argv = ['script.py',
'--no-masked-softmax-fusion',
'--no-bias-gelu-fusion',
'--no-bias-dropout-fusion',
'--use-cpu-initialization',
'--micro-batch-size', '1',
'--no-load-optim',
'--no-load-rng',
'--no-save-optim',
'--no-save-rng',
'--no-initialization',
'--load', args.load_dir
]
margs = parse_args()
margs = load_args_from_checkpoint(margs)
# Arguments do sanity checks on the world size, but we don't care,
# so trick it into thinking we are plenty of processes
margs.world_size = margs.tensor_model_parallel_size * margs.pipeline_model_parallel_size
margs = validate_args(margs)
def check_for_arg(arg_name):
if getattr(margs, arg_name, None) is None:
print(f"Checkpoint does not specify the argument {arg_name}. Exiting.")
print(f"Arguments: {margs}")
queue.put("exit")
exit(1)
check_for_arg('tensor_model_parallel_size')
check_for_arg('pipeline_model_parallel_size')
check_for_arg('num_layers')
check_for_arg('hidden_size')
check_for_arg('seq_length')
check_for_arg('num_attention_heads')
check_for_arg('max_position_embeddings')
check_for_arg('tokenizer_type')
check_for_arg('iteration')
check_for_arg('bert_binary_head')
check_for_arg('params_dtype')
# Determine how to make our models
if args.model_type == 'GPT':
from pretrain_gpt import model_provider
margs.model_type = ModelType.encoder_or_decoder
elif args.model_type == 'BERT':
from pretrain_bert import model_provider
margs.model_type = ModelType.encoder_or_decoder
else:
raise Exception(f'unrecognized model type: {args.model_type}')
# supress warning about torch.distributed not being initialized
module.MegatronModule.embedding_warning_printed = True
consumed_train_samples = None
consumed_valid_samples = None
def get_models(count, dtype, pre_process, post_process):
nonlocal consumed_train_samples
nonlocal consumed_valid_samples
models = []
for rank in range(count):
mpu.initialize.set_tensor_model_parallel_rank(rank)
model_ = [model_provider(pre_process, post_process).to(dtype)]
margs.consumed_train_samples = 0
margs.consumed_valid_samples = 0
load_checkpoint(model_, None, None)
assert(len(model_) == 1)
model_ = model_[0]
if consumed_train_samples is not None:
assert(margs.consumed_train_samples == consumed_train_samples)
else:
consumed_train_samples = margs.consumed_train_samples
if consumed_valid_samples is not None:
assert(margs.consumed_valid_samples == consumed_valid_samples)
else:
consumed_valid_samples = margs.consumed_valid_samples
models.append(model_)
return models
if margs.num_layers_per_virtual_pipeline_stage is not None:
print("Model with an interleaved pipeline schedule are not yet supported.")
queue.put("exit")
exit(1)
set_global_variables(margs)
mpu.initialize.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size)
mpu.initialize.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size)
fused_kernels.load(margs)
# Get true (non-padded) vocab size
if args.true_vocab_size is not None:
true_vocab_size = args.true_vocab_size
elif args.vocab_file is not None:
vocab = json.load(open(args.vocab_file))
true_vocab_size = len(vocab)
if args.true_vocab_size is not None and true_vocab_size != args.true_vocab_size:
print("Both --true-vocab-size and --vocab-file specified and the vocab size does not match, aborting.")
queue.put("exit")
exit(1)
else:
true_vocab_size = None
# short aliases
tp_size = margs.tensor_model_parallel_size
pp_size = margs.pipeline_model_parallel_size
# metadata
md = types.SimpleNamespace()
md.model_type = args.model_type
md.num_layers = margs.num_layers
md.hidden_size = margs.hidden_size
md.seq_length = margs.seq_length
md.num_attention_heads = margs.num_attention_heads
md.max_position_embeddings = margs.max_position_embeddings
md.tokenizer_type = margs.tokenizer_type
md.iteration = margs.iteration
md.params_dtype = margs.params_dtype
md.bert_binary_head = margs.bert_binary_head
md.previous_tensor_parallel_size = margs.tensor_model_parallel_size
md.previous_pipeline_parallel_size = margs.pipeline_model_parallel_size
md.true_vocab_size = true_vocab_size
md.make_vocab_size_divisible_by = margs.make_vocab_size_divisible_by
# Get first pipe stage
mpu.initialize.set_pipeline_model_parallel_rank(0)
post_process = pp_size == 1
models = get_models(tp_size, md.params_dtype, True, post_process)
md.consumed_train_samples = consumed_train_samples
md.consumed_valid_samples = consumed_valid_samples
queue.put(md)
def queue_put(name, msg):
print(f"sending {name}")
msg["name"] = name
queue.put(msg)
# Send embeddings
message = {
"position embeddings": models[0].language_model.embedding.position_embeddings.weight.data,
"word embeddings": torch.cat(
[models[tp_rank].language_model.embedding.word_embeddings.weight.data for tp_rank in range(tp_size)],
dim = 0)
}
queue_put("embeddings", message)
total_layer_num = 0
for pp_rank in range(pp_size):
if pp_rank > 0:
mpu.initialize.set_pipeline_model_parallel_rank(pp_rank)
post_process = pp_rank == pp_size - 1
models = get_models(tp_size, md.params_dtype, False, post_process)
for layer_num in range(len(models[0].language_model.encoder.layers)):
message = {}
# Get non-parallel tensors from tp_rank 0
layer = models[0].language_model.encoder.layers[layer_num]
message["input layernorm weight"] = layer.input_layernorm.weight.data
message["input layernorm bias"] = layer.input_layernorm.bias.data
message["dense bias"] = layer.self_attention.dense.bias.data
message["post layernorm weight"] = layer.post_attention_layernorm.weight.data
message["post layernorm bias"] = layer.post_attention_layernorm.bias.data
message["mlp l1 bias"] = layer.mlp.dense_4h_to_h.bias.data
# Grab all parallel tensors for this layer
qkv_weight = []
qkv_bias = []
dense_weight = []
mlp_l0_weight = []
mlp_l0_bias = []
mlp_l1_weight = []
for tp_rank, model in enumerate(models):
layer = model.language_model.encoder.layers[layer_num]
qkv_weight.append(layer.self_attention.query_key_value.weight.data)
qkv_bias.append(layer.self_attention.query_key_value.bias.data)
dense_weight.append(layer.self_attention.dense.weight.data)
mlp_l0_weight.append(layer.mlp.dense_h_to_4h.weight.data)
mlp_l0_bias.append(layer.mlp.dense_h_to_4h.bias.data)
mlp_l1_weight.append(layer.mlp.dense_4h_to_h.weight.data)
# concat them
message["qkv weight"] = torch.cat(qkv_weight, dim=0)
message["qkv bias"] = torch.cat(qkv_bias, dim=0)
message["dense weight"] = torch.cat(dense_weight, dim=1)
message["mlp l0 weight"] = torch.cat(mlp_l0_weight, dim=0)
message["mlp l0 bias"] = torch.cat(mlp_l0_bias, dim=0)
message["mlp l1 weight"] = torch.cat(mlp_l1_weight, dim=1)
queue_put(f"transformer layer {total_layer_num}", message)
total_layer_num = total_layer_num + 1
# Send final layernorm from tp_rank 0
message = {
"weight": models[0].language_model.encoder.final_layernorm.weight.data,
"bias": models[0].language_model.encoder.final_layernorm.bias.data
}
queue_put("final layernorm", message)
# Send BERT lm head and binary head if it exists
if md.model_type == 'BERT':
print("Sending LM Pooler")
message = {
"weight": models[0].language_model.pooler.dense.weight.data,
"bias": models[0].language_model.pooler.dense.bias.data
}
queue_put("pooler", message)
message = {
"dense weight": models[0].lm_head.dense.weight.data,
"dense bias": models[0].lm_head.dense.bias.data,
"layernorm weight": models[0].lm_head.layernorm.weight.data,
"layernorm bias": models[0].lm_head.layernorm.bias.data
}
queue_put("lm head", message)
if md.bert_binary_head:
print("Sending BERT Binary head")
queue.put("binary head")
message = {
"weight": models[0].binary_head.weight.data,
"bias": models[0].binary_head.bias.data
}
queue_put("binary head", message)
queue.put("done")
def load_checkpoint(queue, args):
try:
_load_checkpoint(queue, args)
except:
queue.put("exit")
raise
import argparse
from collections.abc import Mapping
import concurrent.futures
import os
import sys
import torch
def add_arguments(parser):
group = parser.add_argument_group(title='Megatron saver')
group.add_argument('--megatron-path', type=str, default=None,
help='Base directory of Megatron repository')
group.add_argument('--target-tensor-parallel-size', type=int,
help='Target tensor model parallel size, defaults to the tensor parallel size '
'in the input checkpoint if provided by the loader, otherwise to 1')
group.add_argument('--target-pipeline-parallel-size', type=int,
help='Target tensor model parallel size, default to the pipeline parall size '
'in the input checkpoint if provided by the loader, otherwise to 1')
def save_checkpoint(queue, args):
# Search in directory above this
sys.path.append(os.path.abspath(
os.path.join(os.path.dirname(__file__),
os.path.pardir)))
if args.megatron_path is not None:
sys.path.insert(0, args.megatron_path)
try:
from megatron.arguments import (parse_args, validate_args)
from megatron.checkpointing import save_checkpoint
from megatron.global_vars import set_global_variables, get_args
from megatron.model import ModelType
from megatron.tokenizer.tokenizer import _vocab_size_with_padding
from megatron import mpu, fused_kernels
except ModuleNotFoundError:
print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.")
exit(1)
def queue_get(name=None):
val = queue.get()
if val == "exit":
print("Loader exited, exiting saver")
exit(1)
if name is not None and args.checking and val["name"] != name:
val_name = val["name"]
print(f'Unexpected message. Expecting "{name}" but got "{val_name}". Exiting saver.')
exit(1)
if name is not None:
print(f"received {name}")
return val
def check_message(msg):
if not args.checking:
return
msg_name = msg.pop("name")
if len(msg.keys()) > 0:
print(f"Unexpected values in {msg_name}:")
for key in msg.keys():
print(f" {key}")
print(f"Exiting. If you want to ignore this, use the argument --no-checking.")
exit(1)
md = queue_get()
if args.target_tensor_parallel_size is None:
if hasattr(md, 'previous_tensor_parallel_size'):
args.target_tensor_parallel_size = md.previous_tensor_parallel_size
else:
print("loader did not provide a tensor parallel size and --target-tensor-parallel-size not provided on command line. "
"Default to 1.")
args.target_tensor_parallel_size = 1
if args.target_pipeline_parallel_size is None:
if hasattr(md, 'previous_pipeline_parallel_size'):
args.target_pipeline_parallel_size = md.previous_pipeline_parallel_size
else:
print("loader did not provide a pipeline parallel size and --target-pipeline-parallel-size not provided on command line. "
"Default to 1.")
args.target_pipeline_parallel_size = 1
# Arguments do sanity checks on the world size, but we don't care,
# so trick it into thinking we are plenty of processes
if args.target_tensor_parallel_size is not None and args.target_pipeline_parallel_size is not None:
os.environ["WORLD_SIZE"] = f'{args.target_tensor_parallel_size * args.target_pipeline_parallel_size}'
# We want all arguments to come from us
sys.argv = ['script.py',
'--num-layers', str(md.num_layers),
'--hidden-size', str(md.hidden_size),
'--seq-length', str(md.seq_length),
'--num-attention-heads', str(md.num_attention_heads),
'--max-position-embeddings', str(md.max_position_embeddings),
'--tokenizer-type', str(md.tokenizer_type),
'--tensor-model-parallel-size', str(args.target_tensor_parallel_size),
'--pipeline-model-parallel-size', str(args.target_pipeline_parallel_size),
'--no-masked-softmax-fusion',
'--no-bias-gelu-fusion',
'--no-bias-dropout-fusion',
'--use-cpu-initialization',
'--micro-batch-size', '1',
'--no-load-optim',
'--no-load-rng',
'--no-save-optim',
'--no-save-rng',
'--no-initialization',
'--save-interval', '1',
'--save', args.save_dir
]
if md.make_vocab_size_divisible_by is not None:
sys.argv.extend(['--make-vocab-size-divisible-by', str(md.make_vocab_size_divisible_by)])
if md.params_dtype == torch.float16:
sys.argv.append('--fp16')
elif md.params_dtype == torch.bfloat16:
sys.argv.append('--bf16')
if md.model_type == 'BERT' and not md.bert_binary_head:
sys.argv.append('--bert-no-binary-head')
margs = parse_args()
validate_args(margs)
set_global_variables(margs)
# margs = megatron args
margs = get_args()
if hasattr(md, 'consumed_train_samples'):
margs.consumed_train_samples = md.consumed_train_samples
margs.consumed_valid_samples = md.consumed_valid_samples
print(f"Setting consumed_train_samples to {margs.consumed_train_samples}"
f" and consumed_valid_samples to {margs.consumed_valid_samples}")
else:
print("consumed_train_samples not provided.")
# Determine how to make our models
if md.model_type == 'GPT':
from pretrain_gpt import model_provider
margs.model_type = ModelType.encoder_or_decoder
elif md.model_type == 'BERT':
from pretrain_bert import model_provider
margs.model_type = ModelType.encoder_or_decoder
else:
raise Exception(f'unrecognized model type: {args.model_type}')
def get_models(count, dtype, pre_process, post_process):
models = [model_provider(pre_process, post_process).to(dtype) for _ in range(count)]
return models
# fake initializing distributed
mpu.initialize.set_tensor_model_parallel_world_size(args.target_tensor_parallel_size)
mpu.initialize.set_pipeline_model_parallel_world_size(args.target_pipeline_parallel_size)
mpu.initialize.set_tensor_model_parallel_rank(0)
mpu.initialize.set_pipeline_model_parallel_rank(0)
fused_kernels.load(margs)
# Embeddings
#-----------
embeddings_msg = queue_get("embeddings")
pos_embed = embeddings_msg.pop("position embeddings")
orig_word_embed = embeddings_msg.pop("word embeddings")
check_message(embeddings_msg)
# Deal with padding
if md.true_vocab_size is not None:
# figure out what our padded vocab size is
orig_vocab_size = orig_word_embed.shape[0]
margs.padded_vocab_size = _vocab_size_with_padding(md.true_vocab_size, margs)
# Cut out extra padding we don't need
if orig_vocab_size > margs.padded_vocab_size:
full_word_embed = orig_word_embed[0:margs.padded_vocab_size,:]
# Expanding embedding to larger size by replicating final entry
elif orig_vocab_size < margs.padded_vocab_size:
padding_size = margs.padded_vocab_size - orig_vocab_size
full_word_embed = torch.cat((
orig_word_embed,
orig_word_embed[-1].unsqueeze(0).expand(padding_size, -1)))
# Same size!
else:
full_word_embed = orig_word_embed
else:
print("Original vocab size not specified, leaving embedding table as-is. "
"If you've changed the tensor parallel size this could cause problems.")
margs.padded_vocab_size = orig_word_embed.shape[0]
full_word_embed = orig_word_embed
# Split into new tensor model parallel sizes
out_word_embed = torch.chunk(full_word_embed, args.target_tensor_parallel_size, dim=0)
# Make models for first pipeline stage and fill in embeddings
mpu.initialize.set_pipeline_model_parallel_rank(0)
post_process = args.target_pipeline_parallel_size == 1
models = get_models(args.target_tensor_parallel_size, md.params_dtype, True, post_process)
for tp_rank, model in enumerate(models):
print(f"word embeddings shape {model.language_model.embedding.word_embeddings.weight.shape}")
model.language_model.embedding.word_embeddings.weight.data.copy_(out_word_embed[tp_rank])
model.language_model.embedding.position_embeddings.weight.data.copy_(pos_embed)
# Transformer layers
#-------------------
total_layer_num = 0
for pp_rank in range(args.target_pipeline_parallel_size):
# For later pipeline parallel ranks, make the new models
if pp_rank > 0:
mpu.initialize.set_pipeline_model_parallel_rank(pp_rank)
post_process = pp_rank == args.target_pipeline_parallel_size - 1
models = get_models(args.target_tensor_parallel_size, md.params_dtype, False, post_process)
for layer in range(len(models[0].language_model.encoder.layers)):
msg = queue_get(f"transformer layer {total_layer_num}")
# duplicated tensors
input_layernorm_weight = msg.pop("input layernorm weight")
input_layernorm_bias = msg.pop("input layernorm bias")
dense_bias = msg.pop("dense bias")
post_layernorm_weight = msg.pop("post layernorm weight")
post_layernorm_bias = msg.pop("post layernorm bias")
mlp_l1_bias = msg.pop("mlp l1 bias")
# Split up the parallel tensors
qkv_weight = torch.chunk(msg.pop("qkv weight"), args.target_tensor_parallel_size, dim=0)
qkv_bias = torch.chunk(msg.pop("qkv bias"), args.target_tensor_parallel_size, dim=0)
dense_weight = torch.chunk(msg.pop("dense weight"), args.target_tensor_parallel_size, dim=1)
mlp_l0_weight = torch.chunk(msg.pop("mlp l0 weight"), args.target_tensor_parallel_size, dim=0)
mlp_l0_bias = torch.chunk(msg.pop("mlp l0 bias"), args.target_tensor_parallel_size, dim=0)
mlp_l1_weight = torch.chunk(msg.pop("mlp l1 weight"), args.target_tensor_parallel_size, dim=1)
# Save them to the model
for tp_rank in range(args.target_tensor_parallel_size):
l = models[tp_rank].language_model.encoder.layers[layer]
l.input_layernorm.weight.data.copy_(input_layernorm_weight)
l.input_layernorm.bias.data.copy_(input_layernorm_bias)
l.self_attention.query_key_value.weight.data.copy_(qkv_weight[tp_rank])
l.self_attention.query_key_value.bias.data.copy_(qkv_bias[tp_rank])
l.self_attention.dense.weight.data.copy_(dense_weight[tp_rank])
l.self_attention.dense.bias.data.copy_(dense_bias)
l.post_attention_layernorm.weight.data.copy_(post_layernorm_weight)
l.post_attention_layernorm.bias.data.copy_(post_layernorm_bias)
l.mlp.dense_h_to_4h.weight.data.copy_(mlp_l0_weight[tp_rank])
l.mlp.dense_h_to_4h.bias.data.copy_(mlp_l0_bias[tp_rank])
l.mlp.dense_4h_to_h.weight.data.copy_(mlp_l1_weight[tp_rank])
l.mlp.dense_4h_to_h.bias.data.copy_(mlp_l1_bias)
total_layer_num = total_layer_num + 1
check_message(msg)
if post_process:
msg = queue_get("final layernorm")
final_layernorm_weight = msg.pop("weight")
final_layernorm_bias = msg.pop("bias")
for tp_rank in range(args.target_tensor_parallel_size):
models[tp_rank].language_model.encoder.final_layernorm.weight.data.copy_(final_layernorm_weight)
models[tp_rank].language_model.encoder.final_layernorm.bias.data.copy_(final_layernorm_bias)
if pp_rank != 0:
# Copy word embeddings to final pipeline rank
models[tp_rank].word_embeddings.weight.data.copy_(out_word_embed[tp_rank])
del final_layernorm_weight
del final_layernorm_bias
check_message(msg)
msg = queue_get()
if msg != "done" and msg["name"] == "pooler":
if not hasattr(models[0].language_model, 'pooler'):
print("ERROR: got a pooler, but model does not have one")
exit(1)
print("received pooler")
pooler_weight = msg.pop("weight")
pooler_bias = msg.pop("bias")
for tp_rank in range(args.target_tensor_parallel_size):
models[tp_rank].language_model.pooler.dense.weight.data.copy_(pooler_weight)
models[tp_rank].language_model.pooler.dense.bias.data.copy_(pooler_bias)
del pooler_weight
del pooler_bias
check_message(msg)
msg = queue_get()
if msg != "done" and msg["name"] == "lm head":
if not hasattr(models[0], 'lm_head'):
print("ERROR: got an lm head, but model does not have one")
exit(1)
print("received lm head")
lm_head_dense_weight = msg.pop("dense weight")
lm_head_dense_bias = msg.pop("dense bias")
lm_head_layernorm_weight = msg.pop("layernorm weight")
lm_head_layernorm_bias = msg.pop("layernorm bias")
for tp_rank in range(args.target_tensor_parallel_size):
models[tp_rank].lm_head.dense.weight.data.copy_(lm_head_dense_weight)
models[tp_rank].lm_head.dense.bias.data.copy_(lm_head_dense_bias)
models[tp_rank].lm_head.layernorm.weight.data.copy_(lm_head_layernorm_weight)
models[tp_rank].lm_head.layernorm.bias.data.copy_(lm_head_layernorm_bias)
check_message(msg)
msg = queue_get()
if msg != "done" and msg["name"] == "binary head":
if not hasattr(models[0], 'binary_head'):
print("ERROR: got a binary head, but model does not have one")
exit(1)
print("received binary head")
binary_head_weight = msg.pop("weight")
binary_head_bias = msg.pop("bias")
for tp_rank in range(args.target_tensor_parallel_size):
models[tp_rank].binary_head.weight.data.copy_(binary_head_weight)
models[tp_rank].binary_head.bias.data.copy_(binary_head_bias)
check_message(msg)
msg = queue_get()
if msg != "done":
print("ERROR: got some more data but was expecting to be done")
for tp_rank in range(args.target_tensor_parallel_size):
mpu.initialize.set_tensor_model_parallel_rank(tp_rank)
save_checkpoint(md.iteration, [models[tp_rank]], None, None)
print("Done!")
import argparse
import importlib
import torch.multiprocessing as mp
import os
import sys
# A loader is a python file with at least two functions
# - add_arguments - takes in a parser and adds any arguments needed
# - load_checkpoint - takes in the queue and parsed arguments
# A saver is similar but has save_checkpoint instead of
# load_checkpoint
# The loader and saver process are each given a queue, the loader
# should load the checkpoint and send the weights in messages in the
# following order, the saver should receive them in this order and
# save the checkpoints. A message consists of a python dictionary with
# a "name" for error checking and an entry for each tensor as
# indicated below. Note that the weight sent over the queue are the
# full model weights, nothing split.
# If the loader ever sends "exit" to the queue, that means something
# went wrong and it is exiting.
# - Metadata Namespace with the following attributes:
# model_type - GPT, BERT, T5, etc. (Part of protocol to allow this to be deduced later instead of given on command line)
# num_layers - Number of transformer layers
# hidden_size
# seq_length
# num_attention_heads
# max_position_embeddings
# tokenizer_type
# iteration
# params_dtype
# bert_binary_head - Used only if model_type is BERT
# previous_tensor_parallel_size - Optional
# previous_pipeline_parallel_size - Optional
# true_vocab_size
# make_vocab_size_divisble_by
# consumed_train_samples
# consumed_valid_samples
# messages
# {
# "name": "embeddings"
# "position embeddings"
# "word embeddings"
# }
# (for each transformer layer):
# {
# "name": "transformer layer N"
# "input layernorm weight"
# "input layernorm bias"
# "qkv weight"
# "qkv bias"
# "dense weight"
# "dense bias"
# "post layernorm weight"
# "post layernorm bias"
# "mlp l0 weight"
# "mlp l0 bias"
# "mlp l1 weight"
# "mlp l1 bias"
# }
# {
# "name": "final layer norm"
# "weight"
# "bias"
# }
# if present (i.e. for BERT):
# {
# "name": "pooler"
# "weight"
# "bias"
# }
# {
# "name": "lm head"
# "dense weight"
# "dense bias"
# "layernorm weight"
# "layernorm bias"
# }
# {
# "name": "binary head"
# "weight"
# "bias"
# }
# - "done"
def load_plugin(plugin_type, name):
module_name = f"checkpoint_{plugin_type}_{name}"
try:
plugin = importlib.import_module(module_name)
except ModuleNotFoundError:
module_name = name
try:
plugin = importlib.import_module(module_name)
except ModuleNotFoundError:
sys.exit(f"Unable to load {plugin_type} plugin {name}. Exiting.")
if not hasattr(plugin, 'add_arguments'):
sys.exit(f"{module_name} module is not a plugin. Exiting.")
print(f"Loaded {module_name} as the {plugin_type}.")
return plugin
def main():
import argparse
parser = argparse.ArgumentParser(description="Megatron Checkpoint Utility Arguments",
allow_abbrev=False, conflict_handler='resolve')
parser.add_argument('--model-type', type=str, required=True,
choices=['GPT', 'BERT'],
help='Type of the model')
parser.add_argument('--loader', type=str, default='megatron',
help='Module name to load checkpoint, should be on python path')
parser.add_argument('--saver', type=str, default='megatron',
help='Module name to save checkpoint, shdoul be on python path')
parser.add_argument('--load-dir', type=str, required=True,
help='Directory to load model checkpoint from')
parser.add_argument('--save-dir', type=str, required=True,
help='Directory to save model checkpoint to')
parser.add_argument('--max-queue-size', type=int, default=50,
help='Maximum number of tensors in the queue')
parser.add_argument('--no-checking', action='store_false',
help='Do not perform checking on the name and ordering of weights',
dest='checking')
known_args, _ = parser.parse_known_args()
loader = load_plugin('loader', known_args.loader)
saver = load_plugin('saver', known_args.saver)
loader.add_arguments(parser)
saver.add_arguments(parser)
args = parser.parse_args()
queue = mp.Queue(maxsize=args.max_queue_size)
print("Starting saver...")
saver_proc = mp.Process(target=saver.save_checkpoint, args=(queue, args))
saver_proc.start()
print("Starting loader...")
loader.load_checkpoint(queue, args)
print("Waiting for saver to complete...")
saver_proc.join()
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment