Commit 52a5f2f2 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Intra-layer MP -> Tensor MP, Inter-layer MP -> Pipeline MP

parent 7abd3e90
......@@ -32,7 +32,7 @@ from megatron.data.realm_dataset_utils import get_ict_batch
def pretrain_ict_model_provider():
args = get_args()
assert args.inter_layer_model_parallel_size == 1, 'inter_layer_model_parallel_size must be 1!'
assert args.pipeline_model_parallel_size == 1, 'pipeline_model_parallel_size must be 1!'
return general_ict_model_provider(False, False)
......@@ -89,7 +89,7 @@ def forward_step(data_iterator, model, input_tensor):
# Forward model.
query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask)
local_batch_size = query_logits.shape[0]
global_batch_size = dist.get_world_size() * local_batch_size # recall we assert that intra_layer_model_parallel_size == 1
global_batch_size = dist.get_world_size() * local_batch_size # recall we assert that tensor_model_parallel_size == 1
all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
all_block_logits = AllgatherFromDataParallelRegion.apply(block_logits)
......
......@@ -188,18 +188,18 @@ def main():
# Args
args = _parse_args(extra_args_provider=get_mp_merge_args)
model_type = args.model_type
orig_intra_layer_model_parallel_size = args.intra_layer_model_parallel_size
args.intra_layer_model_parallel_size = 1
orig_tensor_model_parallel_size = args.tensor_model_parallel_size
args.tensor_model_parallel_size = 1
tokenizer = rebuild_tokenizer(args)
print('\n merging model parallel partitions ...')
print(' > number of partitions: {}'.format(orig_intra_layer_model_parallel_size))
print(' > number of partitions: {}'.format(orig_tensor_model_parallel_size))
print(' > checkpoint path: {}'.format(args.load))
print(' > model parameters:')
print(' number of tokens ................ {} '.format(
tokenizer.vocab_size))
print(' number of layers ................ {}'.format(args.num_layers))
print(' hidden sise ..................... {}'.format(args.hidden_size))
print(' hidden size ..................... {}'.format(args.hidden_size))
print(' number of attention heads ....... {}'.format(
args.num_attention_heads))
print(' maximum position embeddings ..... {}'.format(
......@@ -207,18 +207,18 @@ def main():
# Full model.
print('> building the full model ...')
mpu.initialize.set_intra_layer_model_parallel_world_size(1)
mpu.initialize.set_intra_layer_model_parallel_rank(0)
mpu.initialize.set_tensor_model_parallel_world_size(1)
mpu.initialize.set_tensor_model_parallel_rank(0)
merged_model = get_model(model_type)
# Build and load partitions.
partitions = []
iteration = 0
args.intra_layer_model_parallel_size = orig_intra_layer_model_parallel_size
args.tensor_model_parallel_size = orig_tensor_model_parallel_size
tokenizer = rebuild_tokenizer(args)
mpu.initialize.set_intra_layer_model_parallel_world_size(args.intra_layer_model_parallel_size)
for rank in range(args.intra_layer_model_parallel_size):
mpu.initialize.set_intra_layer_model_parallel_rank(rank)
mpu.initialize.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
for rank in range(args.tensor_model_parallel_size):
mpu.initialize.set_tensor_model_parallel_rank(rank)
checkpoint_name, iteration = get_parallel_checkpoint_name(args.load)
print('> loading {} ...'.format(checkpoint_name))
model_ = get_model(model_type)
......@@ -248,7 +248,7 @@ def main():
rank, partition_param.dtype, list(partition_param.size())))
# For the non-parallel parameters, simply copy the rank 0 values.
if not hasattr(merged_param, 'intra_layer_model_parallel'):
if not hasattr(merged_param, 'tensor_model_parallel'):
print(' none-parallel parameter, simple copy from rank 0')
with torch.no_grad():
merged_param.data.copy_(partitions_param[0].data)
......@@ -267,8 +267,8 @@ def main():
# Save the model.
args.intra_layer_model_parallel_size = 1
mpu.initialize.set_intra_layer_model_parallel_rank(0)
args.tensor_model_parallel_size = 1
mpu.initialize.set_tensor_model_parallel_rank(0)
sd = {}
sd['model'] = merged_model.state_dict_for_save_checkpoint()
sd['iteration'] = iteration
......
......@@ -136,7 +136,7 @@ def get_args():
# some default/dummy values for the tokenizer
args.rank = 0
args.make_vocab_size_divisible_by = 128
args.intra_layer_model_parallel_size = 1
args.tensor_model_parallel_size = 1
return args
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment