Commit 7db9c1e5 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'ckpt_merge' into 'main'

Fixing merge_mp_partitions

See merge request ADLR/megatron-lm!213
parents 6afaffa5 7be2648a
...@@ -370,8 +370,7 @@ python tools/create_doc_index.py \ ...@@ -370,8 +370,7 @@ python tools/create_doc_index.py \
We provide several command line arguments, detailed in the scripts listed below, to handle various zero-shot and fine-tuned downstream tasks. However, you can also finetune your model from a pretrained checkpoint on other corpora as desired. To do so, simply add the `--finetune` flag and adjust the input files and training parameters within the original training script. The iteration count will be reset to zero, and the optimizer and internal state will be reinitialized. If the fine-tuning is interrupted for any reason, be sure to remove the `--finetune` flag before continuing, otherwise the training will start again from the beginning. We provide several command line arguments, detailed in the scripts listed below, to handle various zero-shot and fine-tuned downstream tasks. However, you can also finetune your model from a pretrained checkpoint on other corpora as desired. To do so, simply add the `--finetune` flag and adjust the input files and training parameters within the original training script. The iteration count will be reset to zero, and the optimizer and internal state will be reinitialized. If the fine-tuning is interrupted for any reason, be sure to remove the `--finetune` flag before continuing, otherwise the training will start again from the beginning.
<!-- Because evaluation requires substantially less memory than training, it may be advantageous to merge a model trained in parallel for use on a single GPU in downstream tasks. The following script accomplishes this. Currently only tensor model parallelism is supported (not pipeline model parallelism).
Because evaluation requires substantially less memory than training, it may be advantageous to merge a model trained in parallel for use on a single GPU in downstream tasks. The following script accomplishes this.
<pre> <pre>
TENSOR_MODEL_PARALLEL_SIZE=2 TENSOR_MODEL_PARALLEL_SIZE=2
...@@ -390,9 +389,10 @@ WORLD_SIZE=$TENSOR_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \ ...@@ -390,9 +389,10 @@ WORLD_SIZE=$TENSOR_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \
--seq-length 512 \ --seq-length 512 \
--max-position-embeddings 512 \ --max-position-embeddings 512 \
--load $CHECKPOINT_PATH --load $CHECKPOINT_PATH
--save $CHECKPOINT_PATH/merged
</pre> </pre>
-->
Several downstream tasks are described for both GPT and BERT models below. They can be run in distributed and model parallel modes with the same changes used in the training scripts. Several downstream tasks are described for both GPT and BERT models below. They can be run in distributed and model parallel modes with the same changes used in the training scripts.
## GPT Text Generation ## GPT Text Generation
......
...@@ -91,6 +91,20 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -91,6 +91,20 @@ def parse_args(extra_args_provider=None, defaults={},
'longer valid, use --tensor-model-parallel-size instead' 'longer valid, use --tensor-model-parallel-size instead'
del args.model_parallel_size del args.model_parallel_size
# Set input defaults.
for key in defaults:
# For default to be valid, it should not be provided in the
# arguments that are passed to the program. We check this by
# ensuring the arg is set to None.
if getattr(args, key) is not None:
if args.rank == 0:
print('WARNING: overriding default arguments for {key}:{v} \
with {key}:{v2}'.format(key=key, v=defaults[key],
v2=getattr(args, key)),
flush=True)
else:
setattr(args, key, defaults[key])
# Batch size. # Batch size.
assert args.micro_batch_size is not None assert args.micro_batch_size is not None
assert args.micro_batch_size > 0 assert args.micro_batch_size > 0
...@@ -113,20 +127,6 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -113,20 +127,6 @@ def parse_args(extra_args_provider=None, defaults={},
args.consumed_train_samples = 0 args.consumed_train_samples = 0
args.consumed_valid_samples = 0 args.consumed_valid_samples = 0
# Set input defaults.
for key in defaults:
# For default to be valid, it should not be provided in the
# arguments that are passed to the program. We check this by
# ensuring the arg is set to None.
if getattr(args, key) is not None:
if args.rank == 0:
print('WARNING: overriding default arguments for {key}:{v} \
with {key}:{v2}'.format(key=key, v=defaults[key],
v2=getattr(args, key)),
flush=True)
else:
setattr(args, key, defaults[key])
# Iteration-based training. # Iteration-based training.
if args.train_iters: if args.train_iters:
# If we use iteration-based training, make sure the # If we use iteration-based training, make sure the
...@@ -435,9 +435,9 @@ def _add_checkpointing_args(parser): ...@@ -435,9 +435,9 @@ def _add_checkpointing_args(parser):
help='Do not save current rng state.') help='Do not save current rng state.')
group.add_argument('--load', type=str, default=None, group.add_argument('--load', type=str, default=None,
help='Directory containing a model checkpoint.') help='Directory containing a model checkpoint.')
group.add_argument('--no-load-optim', action='store_true', group.add_argument('--no-load-optim', action='store_true', default=None,
help='Do not load optimizer when loading checkpoint.') help='Do not load optimizer when loading checkpoint.')
group.add_argument('--no-load-rng', action='store_true', group.add_argument('--no-load-rng', action='store_true', default=None,
help='Do not load rng state when loading checkpoint.') help='Do not load rng state when loading checkpoint.')
group.add_argument('--finetune', action='store_true', group.add_argument('--finetune', action='store_true',
help='Load model for finetuning. Do not load optimizer ' help='Load model for finetuning. Do not load optimizer '
...@@ -506,7 +506,7 @@ def _add_distributed_args(parser): ...@@ -506,7 +506,7 @@ def _add_distributed_args(parser):
' and returns function to complete it instead.' ' and returns function to complete it instead.'
'Also turns on --use-cpu-initialization flag.' 'Also turns on --use-cpu-initialization flag.'
'This is for external DDP manager.' ) 'This is for external DDP manager.' )
group.add_argument('--use-cpu-initialization', action='store_true', group.add_argument('--use-cpu-initialization', action='store_true', default=None,
help='If set, affine parallel weights initialization uses CPU' ) help='If set, affine parallel weights initialization uses CPU' )
return parser return parser
......
...@@ -31,8 +31,9 @@ _CHECKPOINT_VERSION = None ...@@ -31,8 +31,9 @@ _CHECKPOINT_VERSION = None
def set_checkpoint_version(value): def set_checkpoint_version(value):
global _CHECKPOINT_VERSION global _CHECKPOINT_VERSION
assert _CHECKPOINT_VERSION is None, \ if _CHECKPOINT_VERSION is not None:
"checkpoint version already set" assert _CHECKPOINT_VERSION == value, \
"checkpoint versions do not match"
_CHECKPOINT_VERSION = value _CHECKPOINT_VERSION = value
def get_checkpoint_version(): def get_checkpoint_version():
...@@ -111,11 +112,10 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -111,11 +112,10 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
if isinstance(model, torchDDP): if isinstance(model, torchDDP):
model = model.module model = model.module
if torch.distributed.get_rank() == 0: print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
print('saving checkpoint at iteration {:7d} to {}'.format( iteration, args.save))
iteration, args.save), flush=True)
if mpu.get_data_parallel_rank() == 0: if not torch.distributed.is_initialized() or mpu.get_data_parallel_rank() == 0:
# Arguments, iteration, and model. # Arguments, iteration, and model.
state_dict = {} state_dict = {}
...@@ -146,17 +146,21 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -146,17 +146,21 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
torch.save(state_dict, checkpoint_name) torch.save(state_dict, checkpoint_name)
# Wait so everyone is done (necessary) # Wait so everyone is done (necessary)
torch.distributed.barrier() if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0: torch.distributed.barrier()
print(' successfully saved checkpoint at iteration {:7d} to {}'.format(
iteration, args.save), flush=True) print_rank_0(' successfully saved checkpoint at iteration {:7d} to {}'.format(
iteration, args.save))
# And update the latest iteration # And update the latest iteration
if torch.distributed.get_rank() == 0: if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
tracker_filename = get_checkpoint_tracker_filename(args.save) tracker_filename = get_checkpoint_tracker_filename(args.save)
with open(tracker_filename, 'w') as f: with open(tracker_filename, 'w') as f:
f.write(str(iteration)) f.write(str(iteration))
# Wait so everyone is done (not necessary) # Wait so everyone is done (not necessary)
torch.distributed.barrier() if torch.distributed.is_initialized():
torch.distributed.barrier()
def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
...@@ -197,9 +201,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): ...@@ -197,9 +201,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
# Checkpoint. # Checkpoint.
checkpoint_name = get_checkpoint_name(load_dir, iteration, release) checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
if torch.distributed.get_rank() == 0: print_rank_0(f' loading checkpoint from {args.load} at iteration {iteration}')
print(' loading checkpoint from {} at iteration {}'.format(
args.load, iteration), flush=True)
# Load the checkpoint. # Load the checkpoint.
try: try:
...@@ -284,10 +286,12 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): ...@@ -284,10 +286,12 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
'exiting ...'.format(checkpoint_name)) 'exiting ...'.format(checkpoint_name))
sys.exit() sys.exit()
torch.distributed.barrier() # Some utilities want to load a checkpoint without distributed being initialized
if torch.distributed.get_rank() == 0: if torch.distributed.is_initialized():
print(' successfully loaded checkpoint from {} at iteration {}'.format( torch.distributed.barrier()
args.load, iteration), flush=True)
print_rank_0(f' successfully loaded checkpoint from {args.load} '
f'at iteration {iteration}')
return iteration return iteration
......
...@@ -78,9 +78,7 @@ class BertLMHead(MegatronModule): ...@@ -78,9 +78,7 @@ class BertLMHead(MegatronModule):
args = get_args() args = get_args()
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size)) self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
self.bias.tensor_model_parallel = True mpu.set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
self.bias.partition_dim = 0
self.bias.stride = 1
self.parallel_output = parallel_output self.parallel_output = parallel_output
self.dense = get_linear_layer(hidden_size, hidden_size, init_method) self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
......
...@@ -60,6 +60,13 @@ class MegatronModule(torch.nn.Module): ...@@ -60,6 +60,13 @@ class MegatronModule(torch.nn.Module):
if not self.share_word_embeddings: if not self.share_word_embeddings:
raise Exception('initialize_word_embeddings() was called but ' raise Exception('initialize_word_embeddings() was called but '
'share_word_embeddings is false') 'share_word_embeddings is false')
# This function just initializes the word embeddings in the final stage
# when we are using pipeline parallelism. If we aren't using pipeline
# parallelism there is nothing to do.
if args.pipeline_model_parallel_size == 1:
return
# Parameters are shared between the word embeddings layer, and the # Parameters are shared between the word embeddings layer, and the
# heads at the end of the model. In a pipelined setup with more than # heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different # one stage, the initial embedding layer and the head are on different
...@@ -73,16 +80,16 @@ class MegatronModule(torch.nn.Module): ...@@ -73,16 +80,16 @@ class MegatronModule(torch.nn.Module):
# the two word_embeddings layers to ensure that every applied weight # the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages. # update is the same on both stages.
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
if not mpu.is_pipeline_first_stage(): assert not mpu.is_pipeline_first_stage()
self._word_embeddings_for_head_key = 'word_embeddings_for_head' self._word_embeddings_for_head_key = 'word_embeddings_for_head'
# If first and last stages are different, set word_embeddings # set word_embeddings weights to 0 here, then copy first
# weights to 0 here, then copy first stage's weights using # stage's weights using all_reduce below.
# all_reduce below. self.word_embeddings = mpu.VocabParallelEmbedding(
self.word_embeddings = mpu.VocabParallelEmbedding( args.padded_vocab_size, args.hidden_size,
args.padded_vocab_size, args.hidden_size, init_method=init_method_normal(args.init_method_std))
init_method=init_method_normal(args.init_method_std)) self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.shared = True
self.word_embeddings.weight.shared = True
# Ensure that first and last stages have the same initial parameter # Ensure that first and last stages have the same initial parameter
# values. # values.
if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage(): if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage():
......
...@@ -44,7 +44,8 @@ from .initialize import model_parallel_is_initialized ...@@ -44,7 +44,8 @@ from .initialize import model_parallel_is_initialized
from .layers import ColumnParallelLinear from .layers import ColumnParallelLinear
from .layers import RowParallelLinear from .layers import RowParallelLinear
from .layers import VocabParallelEmbedding from .layers import VocabParallelEmbedding
from .layers import (set_defaults_if_not_set_tensor_model_parallel_attributes, from .layers import (set_tensor_model_parallel_attributes,
set_defaults_if_not_set_tensor_model_parallel_attributes,
copy_tensor_model_parallel_attributes) copy_tensor_model_parallel_attributes)
from .mappings import copy_to_tensor_model_parallel_region from .mappings import copy_to_tensor_model_parallel_region
......
...@@ -109,7 +109,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size, ...@@ -109,7 +109,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size,
per_partition_per_stride_size = divide(per_partition_size, stride) per_partition_per_stride_size = divide(per_partition_size, stride)
weight_list = torch.split(master_weight, per_partition_per_stride_size, weight_list = torch.split(master_weight, per_partition_per_stride_size,
dim=partition_dim) dim=partition_dim)
rank = get_model_parallel_rank() rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
my_weight_list = weight_list[rank::world_size] my_weight_list = weight_list[rank::world_size]
...@@ -260,9 +260,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -260,9 +260,7 @@ class ColumnParallelLinear(torch.nn.Module):
self.output_size_per_partition, self.output_size_per_partition,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=args.params_dtype)) dtype=args.params_dtype))
self.bias.tensor_model_parallel = True set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
self.bias.partition_dim = 0
self.bias.stride = stride
# Always initialize bias to zero. # Always initialize bias to zero.
with torch.no_grad(): with torch.no_grad():
self.bias.zero_() self.bias.zero_()
......
...@@ -23,11 +23,13 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ...@@ -23,11 +23,13 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
import torch import torch
from megatron import mpu from megatron import mpu
from megatron.checkpointing import load_checkpoint, save_checkpoint
from megatron.checkpointing import ensure_directory_exists from megatron.checkpointing import ensure_directory_exists
from megatron.checkpointing import get_checkpoint_name from megatron.checkpointing import get_checkpoint_name
from megatron.checkpointing import get_checkpoint_version
from megatron.checkpointing import get_checkpoint_tracker_filename from megatron.checkpointing import get_checkpoint_tracker_filename
from megatron.global_vars import set_global_variables, get_args
from megatron.global_vars import rebuild_tokenizer from megatron.global_vars import rebuild_tokenizer
from megatron.global_vars import _parse_args
def split_into_partitions(tensor, num_partitions, partition_dim, stride): def split_into_partitions(tensor, num_partitions, partition_dim, stride):
...@@ -185,8 +187,23 @@ def get_mp_merge_args(parser): ...@@ -185,8 +187,23 @@ def get_mp_merge_args(parser):
def main(): def main():
# Arguments do sanity checks on the world size, but we don't care,
# so trick it into thinking we are plenty of processes
os.environ["WORLD_SIZE"] = f'{2**31}'
# Args # Args
args = _parse_args(extra_args_provider=get_mp_merge_args) set_global_variables(extra_args_provider=get_mp_merge_args,
args_defaults = {'use_cpu_initialization': True,
'micro_batch_size': 1,
'no_load_optim': True,
'no_load_rng': True,
'save_interval': 1})
args = get_args()
if args.pipeline_model_parallel_size > 1:
print("Checkpoints with pipeline model parallelism are not currently supported.")
exit()
model_type = args.model_type model_type = args.model_type
orig_tensor_model_parallel_size = args.tensor_model_parallel_size orig_tensor_model_parallel_size = args.tensor_model_parallel_size
args.tensor_model_parallel_size = 1 args.tensor_model_parallel_size = 1
...@@ -209,6 +226,8 @@ def main(): ...@@ -209,6 +226,8 @@ def main():
print('> building the full model ...') print('> building the full model ...')
mpu.initialize.set_tensor_model_parallel_world_size(1) mpu.initialize.set_tensor_model_parallel_world_size(1)
mpu.initialize.set_tensor_model_parallel_rank(0) mpu.initialize.set_tensor_model_parallel_rank(0)
mpu.initialize.set_pipeline_model_parallel_world_size(1)
mpu.initialize.set_pipeline_model_parallel_rank(0)
merged_model = get_model(model_type) merged_model = get_model(model_type)
# Build and load partitions. # Build and load partitions.
...@@ -220,13 +239,16 @@ def main(): ...@@ -220,13 +239,16 @@ def main():
for rank in range(args.tensor_model_parallel_size): for rank in range(args.tensor_model_parallel_size):
mpu.initialize.set_tensor_model_parallel_rank(rank) mpu.initialize.set_tensor_model_parallel_rank(rank)
checkpoint_name, iteration = get_parallel_checkpoint_name(args.load) checkpoint_name, iteration = get_parallel_checkpoint_name(args.load)
print('> loading {} ...'.format(checkpoint_name))
model_ = get_model(model_type) model_ = get_model(model_type)
sd = torch.load(checkpoint_name, map_location='cpu') print(f'> loading {checkpoint_name} ...')
model_.load_state_dict(sd['model']) load_checkpoint(model_, None, None)
print(f'> checkpoint version {get_checkpoint_version()}')
if get_checkpoint_version() < 2.0:
# Need to deal with the qkv matrix order of old versions
print("Checkpoints less than version 2.0 are not currently supported.")
exit()
partitions.append(model_) partitions.append(model_)
# Parameter generators so we can loop through them semiltaneouly. # Parameter generators so we can loop through them semiltaneouly.
merged_params_gen = merged_model.named_parameters() merged_params_gen = merged_model.named_parameters()
partitions_params_gen = [partition.named_parameters() partitions_params_gen = [partition.named_parameters()
...@@ -254,29 +276,26 @@ def main(): ...@@ -254,29 +276,26 @@ def main():
merged_param.data.copy_(partitions_param[0].data) merged_param.data.copy_(partitions_param[0].data)
# For parallel parameters, merge the values # For parallel parameters, merge the values
else: else:
print(' parallel parameter merge with stride {} along ' dim = merged_param.partition_dim
'dimention {}'.format(merged_param.stride, stride = merged_param.partition_stride
merged_param.partition_dim)) print(f' parallel parameter merge with stride {stride} along '
f'dimention {dim}')
merge_partitions(merged_param, merge_partitions(merged_param,
partitions_param, partitions_param,
merged_param.partition_dim, dim,
merged_param.stride) stride)
except StopIteration: except StopIteration:
break break
# Save the model. # Save the model.
args.tensor_model_parallel_size = 1 args.tensor_model_parallel_size = 1
args.pipeline_model_parallel_size = 1
# And now one last time so proper arguments are set in saved checkpoint
tokenizer = rebuild_tokenizer(args)
mpu.initialize.set_tensor_model_parallel_rank(0) mpu.initialize.set_tensor_model_parallel_rank(0)
sd = {} print('> saving merged model')
sd['model'] = merged_model.state_dict_for_save_checkpoint() save_checkpoint(iteration, merged_model, None, None)
sd['iteration'] = iteration
merged_path = os.path.join(args.load, 'merged')
checkpoint_name = get_checkpoint_name(merged_path, iteration)
ensure_directory_exists(checkpoint_name)
print('> saving merged model to {}'.format(checkpoint_name))
torch.save(sd, checkpoint_name)
print('done :-)') print('done :-)')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment