Commit 7db9c1e5 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'ckpt_merge' into 'main'

Fixing merge_mp_partitions

See merge request ADLR/megatron-lm!213
parents 6afaffa5 7be2648a
......@@ -370,8 +370,7 @@ python tools/create_doc_index.py \
We provide several command line arguments, detailed in the scripts listed below, to handle various zero-shot and fine-tuned downstream tasks. However, you can also finetune your model from a pretrained checkpoint on other corpora as desired. To do so, simply add the `--finetune` flag and adjust the input files and training parameters within the original training script. The iteration count will be reset to zero, and the optimizer and internal state will be reinitialized. If the fine-tuning is interrupted for any reason, be sure to remove the `--finetune` flag before continuing, otherwise the training will start again from the beginning.
<!--
Because evaluation requires substantially less memory than training, it may be advantageous to merge a model trained in parallel for use on a single GPU in downstream tasks. The following script accomplishes this.
Because evaluation requires substantially less memory than training, it may be advantageous to merge a model trained in parallel for use on a single GPU in downstream tasks. The following script accomplishes this. Currently only tensor model parallelism is supported (not pipeline model parallelism).
<pre>
TENSOR_MODEL_PARALLEL_SIZE=2
......@@ -390,9 +389,10 @@ WORLD_SIZE=$TENSOR_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \
--seq-length 512 \
--max-position-embeddings 512 \
--load $CHECKPOINT_PATH
--save $CHECKPOINT_PATH/merged
</pre>
-->
Several downstream tasks are described for both GPT and BERT models below. They can be run in distributed and model parallel modes with the same changes used in the training scripts.
## GPT Text Generation
......
......@@ -91,6 +91,20 @@ def parse_args(extra_args_provider=None, defaults={},
'longer valid, use --tensor-model-parallel-size instead'
del args.model_parallel_size
# Set input defaults.
for key in defaults:
# For default to be valid, it should not be provided in the
# arguments that are passed to the program. We check this by
# ensuring the arg is set to None.
if getattr(args, key) is not None:
if args.rank == 0:
print('WARNING: overriding default arguments for {key}:{v} \
with {key}:{v2}'.format(key=key, v=defaults[key],
v2=getattr(args, key)),
flush=True)
else:
setattr(args, key, defaults[key])
# Batch size.
assert args.micro_batch_size is not None
assert args.micro_batch_size > 0
......@@ -113,20 +127,6 @@ def parse_args(extra_args_provider=None, defaults={},
args.consumed_train_samples = 0
args.consumed_valid_samples = 0
# Set input defaults.
for key in defaults:
# For default to be valid, it should not be provided in the
# arguments that are passed to the program. We check this by
# ensuring the arg is set to None.
if getattr(args, key) is not None:
if args.rank == 0:
print('WARNING: overriding default arguments for {key}:{v} \
with {key}:{v2}'.format(key=key, v=defaults[key],
v2=getattr(args, key)),
flush=True)
else:
setattr(args, key, defaults[key])
# Iteration-based training.
if args.train_iters:
# If we use iteration-based training, make sure the
......@@ -435,9 +435,9 @@ def _add_checkpointing_args(parser):
help='Do not save current rng state.')
group.add_argument('--load', type=str, default=None,
help='Directory containing a model checkpoint.')
group.add_argument('--no-load-optim', action='store_true',
group.add_argument('--no-load-optim', action='store_true', default=None,
help='Do not load optimizer when loading checkpoint.')
group.add_argument('--no-load-rng', action='store_true',
group.add_argument('--no-load-rng', action='store_true', default=None,
help='Do not load rng state when loading checkpoint.')
group.add_argument('--finetune', action='store_true',
help='Load model for finetuning. Do not load optimizer '
......@@ -506,7 +506,7 @@ def _add_distributed_args(parser):
' and returns function to complete it instead.'
'Also turns on --use-cpu-initialization flag.'
'This is for external DDP manager.' )
group.add_argument('--use-cpu-initialization', action='store_true',
group.add_argument('--use-cpu-initialization', action='store_true', default=None,
help='If set, affine parallel weights initialization uses CPU' )
return parser
......
......@@ -31,8 +31,9 @@ _CHECKPOINT_VERSION = None
def set_checkpoint_version(value):
global _CHECKPOINT_VERSION
assert _CHECKPOINT_VERSION is None, \
"checkpoint version already set"
if _CHECKPOINT_VERSION is not None:
assert _CHECKPOINT_VERSION == value, \
"checkpoint versions do not match"
_CHECKPOINT_VERSION = value
def get_checkpoint_version():
......@@ -111,11 +112,10 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
if isinstance(model, torchDDP):
model = model.module
if torch.distributed.get_rank() == 0:
print('saving checkpoint at iteration {:7d} to {}'.format(
iteration, args.save), flush=True)
print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
iteration, args.save))
if mpu.get_data_parallel_rank() == 0:
if not torch.distributed.is_initialized() or mpu.get_data_parallel_rank() == 0:
# Arguments, iteration, and model.
state_dict = {}
......@@ -146,17 +146,21 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
torch.save(state_dict, checkpoint_name)
# Wait so everyone is done (necessary)
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' successfully saved checkpoint at iteration {:7d} to {}'.format(
iteration, args.save), flush=True)
if torch.distributed.is_initialized():
torch.distributed.barrier()
print_rank_0(' successfully saved checkpoint at iteration {:7d} to {}'.format(
iteration, args.save))
# And update the latest iteration
if torch.distributed.get_rank() == 0:
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
tracker_filename = get_checkpoint_tracker_filename(args.save)
with open(tracker_filename, 'w') as f:
f.write(str(iteration))
# Wait so everyone is done (not necessary)
torch.distributed.barrier()
if torch.distributed.is_initialized():
torch.distributed.barrier()
def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
......@@ -197,9 +201,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
# Checkpoint.
checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
if torch.distributed.get_rank() == 0:
print(' loading checkpoint from {} at iteration {}'.format(
args.load, iteration), flush=True)
print_rank_0(f' loading checkpoint from {args.load} at iteration {iteration}')
# Load the checkpoint.
try:
......@@ -284,10 +286,12 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
'exiting ...'.format(checkpoint_name))
sys.exit()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' successfully loaded checkpoint from {} at iteration {}'.format(
args.load, iteration), flush=True)
# Some utilities want to load a checkpoint without distributed being initialized
if torch.distributed.is_initialized():
torch.distributed.barrier()
print_rank_0(f' successfully loaded checkpoint from {args.load} '
f'at iteration {iteration}')
return iteration
......
......@@ -78,9 +78,7 @@ class BertLMHead(MegatronModule):
args = get_args()
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
self.bias.tensor_model_parallel = True
self.bias.partition_dim = 0
self.bias.stride = 1
mpu.set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
self.parallel_output = parallel_output
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
......
......@@ -60,6 +60,13 @@ class MegatronModule(torch.nn.Module):
if not self.share_word_embeddings:
raise Exception('initialize_word_embeddings() was called but '
'share_word_embeddings is false')
# This function just initializes the word embeddings in the final stage
# when we are using pipeline parallelism. If we aren't using pipeline
# parallelism there is nothing to do.
if args.pipeline_model_parallel_size == 1:
return
# Parameters are shared between the word embeddings layer, and the
# heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different
......@@ -73,16 +80,16 @@ class MegatronModule(torch.nn.Module):
# the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages.
if mpu.is_pipeline_last_stage():
if not mpu.is_pipeline_first_stage():
self._word_embeddings_for_head_key = 'word_embeddings_for_head'
# If first and last stages are different, set word_embeddings
# weights to 0 here, then copy first stage's weights using
# all_reduce below.
self.word_embeddings = mpu.VocabParallelEmbedding(
args.padded_vocab_size, args.hidden_size,
init_method=init_method_normal(args.init_method_std))
self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True
assert not mpu.is_pipeline_first_stage()
self._word_embeddings_for_head_key = 'word_embeddings_for_head'
# set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below.
self.word_embeddings = mpu.VocabParallelEmbedding(
args.padded_vocab_size, args.hidden_size,
init_method=init_method_normal(args.init_method_std))
self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True
# Ensure that first and last stages have the same initial parameter
# values.
if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage():
......
......@@ -44,7 +44,8 @@ from .initialize import model_parallel_is_initialized
from .layers import ColumnParallelLinear
from .layers import RowParallelLinear
from .layers import VocabParallelEmbedding
from .layers import (set_defaults_if_not_set_tensor_model_parallel_attributes,
from .layers import (set_tensor_model_parallel_attributes,
set_defaults_if_not_set_tensor_model_parallel_attributes,
copy_tensor_model_parallel_attributes)
from .mappings import copy_to_tensor_model_parallel_region
......
......@@ -109,7 +109,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size,
per_partition_per_stride_size = divide(per_partition_size, stride)
weight_list = torch.split(master_weight, per_partition_per_stride_size,
dim=partition_dim)
rank = get_model_parallel_rank()
rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size()
my_weight_list = weight_list[rank::world_size]
......@@ -260,9 +260,7 @@ class ColumnParallelLinear(torch.nn.Module):
self.output_size_per_partition,
device=torch.cuda.current_device(),
dtype=args.params_dtype))
self.bias.tensor_model_parallel = True
self.bias.partition_dim = 0
self.bias.stride = stride
set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
......
......@@ -23,11 +23,13 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
import torch
from megatron import mpu
from megatron.checkpointing import load_checkpoint, save_checkpoint
from megatron.checkpointing import ensure_directory_exists
from megatron.checkpointing import get_checkpoint_name
from megatron.checkpointing import get_checkpoint_version
from megatron.checkpointing import get_checkpoint_tracker_filename
from megatron.global_vars import set_global_variables, get_args
from megatron.global_vars import rebuild_tokenizer
from megatron.global_vars import _parse_args
def split_into_partitions(tensor, num_partitions, partition_dim, stride):
......@@ -185,8 +187,23 @@ def get_mp_merge_args(parser):
def main():
# Arguments do sanity checks on the world size, but we don't care,
# so trick it into thinking we are plenty of processes
os.environ["WORLD_SIZE"] = f'{2**31}'
# Args
args = _parse_args(extra_args_provider=get_mp_merge_args)
set_global_variables(extra_args_provider=get_mp_merge_args,
args_defaults = {'use_cpu_initialization': True,
'micro_batch_size': 1,
'no_load_optim': True,
'no_load_rng': True,
'save_interval': 1})
args = get_args()
if args.pipeline_model_parallel_size > 1:
print("Checkpoints with pipeline model parallelism are not currently supported.")
exit()
model_type = args.model_type
orig_tensor_model_parallel_size = args.tensor_model_parallel_size
args.tensor_model_parallel_size = 1
......@@ -209,6 +226,8 @@ def main():
print('> building the full model ...')
mpu.initialize.set_tensor_model_parallel_world_size(1)
mpu.initialize.set_tensor_model_parallel_rank(0)
mpu.initialize.set_pipeline_model_parallel_world_size(1)
mpu.initialize.set_pipeline_model_parallel_rank(0)
merged_model = get_model(model_type)
# Build and load partitions.
......@@ -220,13 +239,16 @@ def main():
for rank in range(args.tensor_model_parallel_size):
mpu.initialize.set_tensor_model_parallel_rank(rank)
checkpoint_name, iteration = get_parallel_checkpoint_name(args.load)
print('> loading {} ...'.format(checkpoint_name))
model_ = get_model(model_type)
sd = torch.load(checkpoint_name, map_location='cpu')
model_.load_state_dict(sd['model'])
print(f'> loading {checkpoint_name} ...')
load_checkpoint(model_, None, None)
print(f'> checkpoint version {get_checkpoint_version()}')
if get_checkpoint_version() < 2.0:
# Need to deal with the qkv matrix order of old versions
print("Checkpoints less than version 2.0 are not currently supported.")
exit()
partitions.append(model_)
# Parameter generators so we can loop through them semiltaneouly.
merged_params_gen = merged_model.named_parameters()
partitions_params_gen = [partition.named_parameters()
......@@ -254,29 +276,26 @@ def main():
merged_param.data.copy_(partitions_param[0].data)
# For parallel parameters, merge the values
else:
print(' parallel parameter merge with stride {} along '
'dimention {}'.format(merged_param.stride,
merged_param.partition_dim))
dim = merged_param.partition_dim
stride = merged_param.partition_stride
print(f' parallel parameter merge with stride {stride} along '
f'dimention {dim}')
merge_partitions(merged_param,
partitions_param,
merged_param.partition_dim,
merged_param.stride)
dim,
stride)
except StopIteration:
break
# Save the model.
args.tensor_model_parallel_size = 1
args.pipeline_model_parallel_size = 1
# And now one last time so proper arguments are set in saved checkpoint
tokenizer = rebuild_tokenizer(args)
mpu.initialize.set_tensor_model_parallel_rank(0)
sd = {}
sd['model'] = merged_model.state_dict_for_save_checkpoint()
sd['iteration'] = iteration
merged_path = os.path.join(args.load, 'merged')
checkpoint_name = get_checkpoint_name(merged_path, iteration)
ensure_directory_exists(checkpoint_name)
print('> saving merged model to {}'.format(checkpoint_name))
torch.save(sd, checkpoint_name)
print('> saving merged model')
save_checkpoint(iteration, merged_model, None, None)
print('done :-)')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment