Commit 36c2674c authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

Merge branch 'ckpt_merge' into 'main'

Teach merge_mp_partitions how to write out a pipelined model.

See merge request ADLR/megatron-lm!218
parents de722164 98a5b9a0
...@@ -370,10 +370,11 @@ python tools/create_doc_index.py \ ...@@ -370,10 +370,11 @@ python tools/create_doc_index.py \
We provide several command line arguments, detailed in the scripts listed below, to handle various zero-shot and fine-tuned downstream tasks. However, you can also finetune your model from a pretrained checkpoint on other corpora as desired. To do so, simply add the `--finetune` flag and adjust the input files and training parameters within the original training script. The iteration count will be reset to zero, and the optimizer and internal state will be reinitialized. If the fine-tuning is interrupted for any reason, be sure to remove the `--finetune` flag before continuing, otherwise the training will start again from the beginning. We provide several command line arguments, detailed in the scripts listed below, to handle various zero-shot and fine-tuned downstream tasks. However, you can also finetune your model from a pretrained checkpoint on other corpora as desired. To do so, simply add the `--finetune` flag and adjust the input files and training parameters within the original training script. The iteration count will be reset to zero, and the optimizer and internal state will be reinitialized. If the fine-tuning is interrupted for any reason, be sure to remove the `--finetune` flag before continuing, otherwise the training will start again from the beginning.
Because evaluation requires substantially less memory than training, it may be advantageous to merge a model trained in parallel for use on a single GPU in downstream tasks. The following script accomplishes this. Currently only tensor model parallelism is supported (not pipeline model parallelism). Because evaluation requires substantially less memory than training, it may be advantageous to merge a model trained in parallel for use on a single GPU in downstream tasks. The following script accomplishes this. Currently only tensor model parallelism is supported on input and pipeline model parallelsim on the output. This example reads in a model with 2-way tensor model parallelism and writes out a model with 2-way pipeline model parallelism.
<pre> <pre>
TENSOR_MODEL_PARALLEL_SIZE=2 TENSOR_MODEL_PARALLEL_SIZE=2
TARGET_PIPELINE_MODEL_PARALLEL_SIZE=2
VOCAB_FILE=bert-vocab.txt VOCAB_FILE=bert-vocab.txt
CHECKPOINT_PATH=checkpoints/bert_345m CHECKPOINT_PATH=checkpoints/bert_345m
...@@ -381,6 +382,8 @@ CHECKPOINT_PATH=checkpoints/bert_345m ...@@ -381,6 +382,8 @@ CHECKPOINT_PATH=checkpoints/bert_345m
WORLD_SIZE=$TENSOR_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \ WORLD_SIZE=$TENSOR_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \
--model-type BERT \ --model-type BERT \
--tensor-model-parallel-size $TENSOR_MODEL_PARALLEL_SIZE \ --tensor-model-parallel-size $TENSOR_MODEL_PARALLEL_SIZE \
--pipeline-model-parallel-size 1 \
--target-pipeline-model-parallel-size $TARGET_PIPELINE_MODEL_PARALLEL_SIZE \
--tokenizer-type BertWordPieceLowerCase \ --tokenizer-type BertWordPieceLowerCase \
--vocab-file $VOCAB_FILE \ --vocab-file $VOCAB_FILE \
--num-layers 24 \ --num-layers 24 \
......
...@@ -98,11 +98,16 @@ class MegatronModule(torch.nn.Module): ...@@ -98,11 +98,16 @@ class MegatronModule(torch.nn.Module):
# Ensure that first and last stages have the same initial parameter # Ensure that first and last stages have the same initial parameter
# values. # values.
if torch.distributed.is_initialized():
if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage(): if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage():
torch.distributed.all_reduce(self.word_embeddings_weight().data, torch.distributed.all_reduce(self.word_embeddings_weight().data,
group=mpu.get_embedding_group()) group=mpu.get_embedding_group())
else:
print("WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong.")
def conversion_helper(val, conversion): def conversion_helper(val, conversion):
"""Apply conversion to val. Recursively apply conversion if `val` """Apply conversion to val. Recursively apply conversion if `val`
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
"""Merge model parallel partitions.""" """Merge model parallel partitions."""
import os import os
import re
import sys import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir))) os.path.pardir)))
...@@ -181,6 +182,8 @@ def get_mp_merge_args(parser): ...@@ -181,6 +182,8 @@ def get_mp_merge_args(parser):
group.add_argument('--model-type', type=str, required=True, group.add_argument('--model-type', type=str, required=True,
choices=['BERT', 'GPT', 'RACE', 'MNLI', 'QQP'], choices=['BERT', 'GPT', 'RACE', 'MNLI', 'QQP'],
help='Type of the mdoel.') help='Type of the mdoel.')
group.add_argument('--target-pipeline-model-parallel-size', type=int, default=1,
help='Degree of pipeline model parallelism in output model.')
return parser return parser
...@@ -284,14 +287,55 @@ def main(): ...@@ -284,14 +287,55 @@ def main():
except StopIteration: except StopIteration:
break break
# Save the model. partitions = []
args.tensor_model_parallel_size = 1 args.tensor_model_parallel_size = 1
args.pipeline_model_parallel_size = 1 args.pipeline_model_parallel_size = args.target_pipeline_model_parallel_size
# And now one last time so proper arguments are set in saved checkpoint
assert args.num_layers % args.pipeline_model_parallel_size == 0, \
'num_layers must be divisible by target pipeline model parallel size'
layers_per_part = args.num_layers // args.pipeline_model_parallel_size
tokenizer = rebuild_tokenizer(args) tokenizer = rebuild_tokenizer(args)
mpu.initialize.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
mpu.initialize.set_tensor_model_parallel_rank(0) mpu.initialize.set_tensor_model_parallel_rank(0)
print('> saving merged model') mpu.initialize.set_pipeline_model_parallel_world_size(args.pipeline_model_parallel_size)
save_checkpoint(iteration, merged_model, None, None)
# regex to parse out layer number from param name
layer_re = re.compile('layers\.([0-9]+)')
if args.pipeline_model_parallel_size > 1:
merged_params = {}
for name, merged_param in merged_model.named_parameters():
merged_params[name] = merged_param
for rank in range(args.pipeline_model_parallel_size):
mpu.initialize.set_pipeline_model_parallel_rank(rank)
model = get_model(model_type)
def update_layer_num(m):
# TODO! This assumes no interleaved pipeline execution
layer = int(m.group(1))
layer += rank * layers_per_part
return f'layers.{layer}'
for dst_name, partition_param in model.named_parameters():
if dst_name == "word_embeddings.weight":
# See comment in MegatronModule.initialize_word_embeddings()
src_name = "language_model.embedding.word_embeddings.weight"
else:
# Translate destination layer number (0-N for each partition)
# to source layer number (single-model layer number)
src_name = re.sub(layer_re, update_layer_num, dst_name)
print(f" > copying {src_name} to {dst_name} in rank {rank}'s model")
partition_param.data.copy_(merged_params[src_name].data)
partitions.append(model)
else:
partitions = [merged_model]
for rank, model in enumerate(partitions):
mpu.initialize.set_pipeline_model_parallel_rank(rank)
print(f"> saving rank {rank}'s model")
save_checkpoint(iteration, model, None, None)
print('done :-)') print('done :-)')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment