Commit ca8dd4ac authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'mp_merger' into 'master'

Model parallel merger

See merge request ADLR/megatron-lm!28
parents 5df85022 57c2060f
......@@ -47,7 +47,7 @@ def add_model_config_args(parser):
help='dropout probability for hidden state transformer')
group.add_argument('--max-position-embeddings', type=int, default=512,
help='maximum number of position embeddings to use')
group.add_argument('--vocab-size', type=int, default=30522,
group.add_argument('--vocab-size', type=int, default=None,
help='vocab size to use for non-character-level '
'tokenization. This value will only be used when '
'creating a tokenizer')
......
......@@ -83,6 +83,8 @@ class BertLMHead(MegatronModule):
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
self.bias.model_parallel = True
self.bias.partition_dim = 0
self.bias.stride = 1
self.parallel_output = parallel_output
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
......
......@@ -372,6 +372,7 @@ class ParallelTransformerLayer(MegatronModule):
def __init__(self, hyperparameters, attention_mask_func, layer_number):
super(ParallelTransformerLayer, self).__init__()
self.layer_number = layer_number
self.apply_residual_connection_post_layernorm \
= hyperparameters['apply_residual_connection_post_layernorm']
......
......@@ -26,6 +26,10 @@ _MODEL_PARALLEL_GROUP = None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None
# These values enable us to change the mpu sizes on the fly.
_MPU_WORLD_SIZE = None
_MPU_RANK = None
def initialize_model_parallel(model_parallel_size_):
"""
......@@ -99,13 +103,31 @@ def get_data_parallel_group():
return _DATA_PARALLEL_GROUP
def set_model_parallel_world_size(world_size):
"""Set the model parallel size"""
global _MPU_WORLD_SIZE
_MPU_WORLD_SIZE = world_size
def get_model_parallel_world_size():
"""Return world size for the model parallel group."""
global _MPU_WORLD_SIZE
if _MPU_WORLD_SIZE is not None:
return _MPU_WORLD_SIZE
return torch.distributed.get_world_size(group=get_model_parallel_group())
def set_model_parallel_rank(rank):
"""Set model parallel rank."""
global _MPU_RANK
_MPU_RANK = rank
def get_model_parallel_rank():
"""Return my rank for the model parallel group."""
global _MPU_RANK
if _MPU_RANK is not None:
return _MPU_RANK
return torch.distributed.get_rank(group=get_model_parallel_group())
......
......@@ -46,6 +46,11 @@ def _initialize_affine_weight(weight, output_size, input_size,
Build the master weight on all processes and scatter
the relevant chunk."""
weight.model_parallel = True
weight.partition_dim = partition_dim
weight.stride = stride
# If we only use 1 process for model parallelism, bypass scatter.
world_size = get_model_parallel_world_size()
if world_size == 1:
......@@ -108,7 +113,6 @@ class VocabParallelEmbedding(torch.nn.Module):
# Allocate weights.
self.weight = Parameter(torch.Tensor(self.num_embeddings_per_partition,
self.embedding_dim))
self.weight.model_parallel = True
# And initialize.
_initialize_affine_weight(
self.weight, self.num_embeddings, self.embedding_dim,
......@@ -165,7 +169,6 @@ class ParallelEmbedding(torch.nn.Module):
# Allocate weights.
self.weight = Parameter(torch.Tensor(self.num_embeddings,
self.embedding_dim_per_partition))
self.weight.model_parallel = True
# And initialize.
_initialize_affine_weight(
self.weight, self.num_embeddings, self.embedding_dim,
......@@ -220,10 +223,11 @@ class ColumnParallelLinear(torch.nn.Module):
# we allocate the transpose.
self.weight = Parameter(torch.Tensor(self.output_size_per_partition,
self.input_size))
self.weight.model_parallel = True
if bias:
self.bias = Parameter(torch.Tensor(self.output_size_per_partition))
self.bias.model_parallel = True
self.bias.partition_dim = 0
self.bias.stride = stride
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
......@@ -294,7 +298,6 @@ class RowParallelLinear(torch.nn.Module):
# we allocate the transpose.
self.weight = Parameter(torch.Tensor(self.output_size,
self.input_size_per_partition))
self.weight.model_parallel = True
if bias:
self.bias = Parameter(torch.Tensor(self.output_size))
# Always initialize bias to zero.
......
import os
import torch
from arguments import get_args
from megatron import mpu
from megatron.utils import ensure_directory_exists
from megatron.utils import get_checkpoint_name
from megatron.utils import get_checkpoint_tracker_filename
from megatron.utils import vocab_size_with_padding
def split_into_partitions(tensor, num_partitions, partition_dim, stride):
per_partition_size = mpu.utils.divide(tensor.size(partition_dim),
num_partitions)
per_partition_per_stride_size = mpu.utils.divide(per_partition_size, stride)
partitions_list = torch.split(tensor,
per_partition_per_stride_size,
dim=partition_dim)
partitions = []
for i in range(num_partitions):
partition = torch.cat(partitions_list[i::num_partitions],
dim=partition_dim)
partitions.append(partition)
return partitions
def merge_partitions(merged, partitions, partition_dim, stride):
# Number and size of each partition.
num_partitions = len(partitions)
per_partition_size = None
for partition in partitions:
if per_partition_size is None:
per_partition_size = partition.size(partition_dim)
else:
assert per_partition_size == partition.size(partition_dim)
def concat_partitions(partitions_):
with torch.no_grad():
if (per_partition_size * num_partitions) == merged.size(
partition_dim):
torch.cat(partitions_, dim=partition_dim, out=merged)
else:
print(' ***WARNING*** sizes do not match. Will cut '
'the merged partitions by {} along dimension {} '
'to reduce the size from {} to {} ...'.format(
(per_partition_size * num_partitions) - \
merged.size(partition_dim), partition_dim,
per_partition_size * num_partitions,
merged.size(partition_dim)))
merged_ = torch.cat(partitions_, dim=partition_dim)
merged_split = torch.split(merged_, merged.size(partition_dim),
dim=partition_dim)
merged_ = merged_split[0]
assert merged_.size(partition_dim) == merged.size(partition_dim)
merged.data.copy_(merged_.data)
# If stride is 1, then do simple concatination.
if stride == 1:
concat_partitions(partitions)
return
# For none unity strides, first split based on stride and then group.
per_partition_per_stride_size = mpu.utils.divide(per_partition_size, stride)
# Chunk and build a list.
chunks = None
for i, partition in enumerate(partitions):
chunk = torch.split(partition,
per_partition_per_stride_size,
dim=partition_dim)
if chunks is None:
chunks = [0]*(num_partitions*len(chunk))
chunks[i::num_partitions] = chunk
# Concatinate.
concat_partitions(chunks)
return
def get_model(model_type, args):
if model_type == 'BERT':
from pretrain_albert import model_provider
args.tokentype_size = 2
elif model_type == 'GPT':
from pretrain_gpt2 import model_provider
else:
raise Exception('unrecognized model type: {}'.format(model_type))
orig_vocab_size = args.vocab_size
args.vocab_size = vocab_size_with_padding(args.vocab_size, args)
model = model_provider(args)
model = model.half()
args.vocab_size = orig_vocab_size
return model
def get_parallel_checkpoint_name(path):
tracker_filename = get_checkpoint_tracker_filename(path)
iteration = 0
with open(tracker_filename, 'r') as f:
metastring = f.read().strip()
iteration = int(metastring)
assert iteration > 0
checkpoint_name = get_checkpoint_name(path, iteration)
return checkpoint_name, iteration
def test_split_merge():
print('testing split and merge ...')
#[QKV.ROW-COL]
tensor = torch.FloatTensor([[1.11, 1.12, 1.13, 1.14, 1.15],
[1.21, 1.22, 1.23, 1.24, 1.25],
[1.31, 1.32, 1.33, 1.34, 1.35],
[1.41, 1.42, 1.43, 1.44, 1.45],
[2.11, 2.12, 2.13, 2.14, 2.15],
[2.21, 2.22, 2.23, 2.24, 2.25],
[2.31, 2.32, 2.33, 2.34, 2.35],
[2.41, 2.42, 2.43, 2.44, 2.45],
[3.11, 3.12, 3.13, 3.14, 3.15],
[3.21, 3.22, 3.23, 3.24, 3.25],
[3.31, 3.32, 3.33, 3.34, 3.35],
[3.41, 3.42, 3.43, 3.44, 3.45]])
num_partitions = 2
partition_dim = 0
stride = 3
partitions = split_into_partitions(tensor, num_partitions,
partition_dim, stride)
merged = torch.zeros_like(tensor)
merge_partitions(merged, partitions, partition_dim, stride)
max_error = (merged - tensor).abs().max()
print(' > max error (should be zero): {}'.format(max_error))
def main(model_type):
# Args
args = get_args()
print('\n merging model parallel partitions ...')
assert args.vocab_size is not None
print(' > number of partitions: {}'.format(args.model_parallel_size))
print(' > checkpoint path: {}'.format(args.load))
print(' > model parameters:')
print(' number of tokens ................ {} '.format(args.vocab_size))
print(' number of layers ................ {}'.format(args.num_layers))
print(' hidden sise ..................... {}'.format(args.hidden_size))
print(' number of attention heads ....... {}'.format(
args.num_attention_heads))
print(' maximum position embeddings ..... {}'.format(
args.max_position_embeddings))
# Full model.
print('> building the full model ...')
mpu.initialize.set_model_parallel_world_size(1)
mpu.initialize.set_model_parallel_rank(0)
merged_model = get_model(model_type, args)
# Build and load partitions.
partitions = []
iteration = 0
mpu.initialize.set_model_parallel_world_size(args.model_parallel_size)
for rank in range(args.model_parallel_size):
mpu.initialize.set_model_parallel_rank(rank)
checkpoint_name, iteration = get_parallel_checkpoint_name(args.load)
print('> loading {} ...'.format(checkpoint_name))
model_ = get_model(model_type, args)
sd = torch.load(checkpoint_name, map_location='cpu')
model_.load_state_dict(sd['model'])
partitions.append(model_)
# Parameter generators so we can loop through them semiltaneouly.
merged_params_gen = merged_model.named_parameters()
partitions_params_gen = [partition.named_parameters()
for partition in partitions]
while True:
try:
# Get the params and check names.
name, merged_param = next(merged_params_gen)
print(' > working on {} ...'.format(name))
print(' merged type: {}, size: {}'.format(
merged_param.dtype, list(merged_param.size())))
partitions_param = []
for rank, partition_params_gen in enumerate(partitions_params_gen):
partition_name, partition_param = next(partition_params_gen)
assert partition_name == name
partitions_param.append(partition_param)
print(' partition {} type: {}, size: {}'.format(
rank, partition_param.dtype, list(partition_param.size())))
# For the non-parallel parameters, simply copy the rank 0 values.
if not hasattr(merged_param, 'model_parallel'):
print(' none-parallel parameter, simple copy from rank 0')
with torch.no_grad():
merged_param.data.copy_(partitions_param[0].data)
# For parallel parameters, merge the values
else:
print(' parallel parameter merge with stride {} along '
'dimention {}'.format(merged_param.stride,
merged_param.partition_dim))
merge_partitions(merged_param,
partitions_param,
merged_param.partition_dim,
merged_param.stride)
except StopIteration:
break
# Save the model.
mpu.initialize.set_model_parallel_rank(0)
sd = {}
sd['model'] = merged_model.state_dict_for_save_checkpoint()
sd['iteration'] = iteration
merged_path = os.path.join(args.load, 'merged')
checkpoint_name = get_checkpoint_name(merged_path, iteration)
ensure_directory_exists(checkpoint_name)
print('> saving merged model to {}'.format(checkpoint_name))
torch.save(sd, checkpoint_name)
print('done :-)')
if __name__ == '__main__':
main('BERT')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment