Commit 52a5f2f2 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Intra-layer MP -> Tensor MP, Inter-layer MP -> Pipeline MP

parent 7abd3e90
...@@ -32,7 +32,7 @@ from megatron.data.realm_dataset_utils import get_ict_batch ...@@ -32,7 +32,7 @@ from megatron.data.realm_dataset_utils import get_ict_batch
def pretrain_ict_model_provider(): def pretrain_ict_model_provider():
args = get_args() args = get_args()
assert args.inter_layer_model_parallel_size == 1, 'inter_layer_model_parallel_size must be 1!' assert args.pipeline_model_parallel_size == 1, 'pipeline_model_parallel_size must be 1!'
return general_ict_model_provider(False, False) return general_ict_model_provider(False, False)
...@@ -89,7 +89,7 @@ def forward_step(data_iterator, model, input_tensor): ...@@ -89,7 +89,7 @@ def forward_step(data_iterator, model, input_tensor):
# Forward model. # Forward model.
query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask) query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask)
local_batch_size = query_logits.shape[0] local_batch_size = query_logits.shape[0]
global_batch_size = dist.get_world_size() * local_batch_size # recall we assert that intra_layer_model_parallel_size == 1 global_batch_size = dist.get_world_size() * local_batch_size # recall we assert that tensor_model_parallel_size == 1
all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits) all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
all_block_logits = AllgatherFromDataParallelRegion.apply(block_logits) all_block_logits = AllgatherFromDataParallelRegion.apply(block_logits)
......
...@@ -188,18 +188,18 @@ def main(): ...@@ -188,18 +188,18 @@ def main():
# Args # Args
args = _parse_args(extra_args_provider=get_mp_merge_args) args = _parse_args(extra_args_provider=get_mp_merge_args)
model_type = args.model_type model_type = args.model_type
orig_intra_layer_model_parallel_size = args.intra_layer_model_parallel_size orig_tensor_model_parallel_size = args.tensor_model_parallel_size
args.intra_layer_model_parallel_size = 1 args.tensor_model_parallel_size = 1
tokenizer = rebuild_tokenizer(args) tokenizer = rebuild_tokenizer(args)
print('\n merging model parallel partitions ...') print('\n merging model parallel partitions ...')
print(' > number of partitions: {}'.format(orig_intra_layer_model_parallel_size)) print(' > number of partitions: {}'.format(orig_tensor_model_parallel_size))
print(' > checkpoint path: {}'.format(args.load)) print(' > checkpoint path: {}'.format(args.load))
print(' > model parameters:') print(' > model parameters:')
print(' number of tokens ................ {} '.format( print(' number of tokens ................ {} '.format(
tokenizer.vocab_size)) tokenizer.vocab_size))
print(' number of layers ................ {}'.format(args.num_layers)) print(' number of layers ................ {}'.format(args.num_layers))
print(' hidden sise ..................... {}'.format(args.hidden_size)) print(' hidden size ..................... {}'.format(args.hidden_size))
print(' number of attention heads ....... {}'.format( print(' number of attention heads ....... {}'.format(
args.num_attention_heads)) args.num_attention_heads))
print(' maximum position embeddings ..... {}'.format( print(' maximum position embeddings ..... {}'.format(
...@@ -207,18 +207,18 @@ def main(): ...@@ -207,18 +207,18 @@ def main():
# Full model. # Full model.
print('> building the full model ...') print('> building the full model ...')
mpu.initialize.set_intra_layer_model_parallel_world_size(1) mpu.initialize.set_tensor_model_parallel_world_size(1)
mpu.initialize.set_intra_layer_model_parallel_rank(0) mpu.initialize.set_tensor_model_parallel_rank(0)
merged_model = get_model(model_type) merged_model = get_model(model_type)
# Build and load partitions. # Build and load partitions.
partitions = [] partitions = []
iteration = 0 iteration = 0
args.intra_layer_model_parallel_size = orig_intra_layer_model_parallel_size args.tensor_model_parallel_size = orig_tensor_model_parallel_size
tokenizer = rebuild_tokenizer(args) tokenizer = rebuild_tokenizer(args)
mpu.initialize.set_intra_layer_model_parallel_world_size(args.intra_layer_model_parallel_size) mpu.initialize.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
for rank in range(args.intra_layer_model_parallel_size): for rank in range(args.tensor_model_parallel_size):
mpu.initialize.set_intra_layer_model_parallel_rank(rank) mpu.initialize.set_tensor_model_parallel_rank(rank)
checkpoint_name, iteration = get_parallel_checkpoint_name(args.load) checkpoint_name, iteration = get_parallel_checkpoint_name(args.load)
print('> loading {} ...'.format(checkpoint_name)) print('> loading {} ...'.format(checkpoint_name))
model_ = get_model(model_type) model_ = get_model(model_type)
...@@ -248,7 +248,7 @@ def main(): ...@@ -248,7 +248,7 @@ def main():
rank, partition_param.dtype, list(partition_param.size()))) rank, partition_param.dtype, list(partition_param.size())))
# For the non-parallel parameters, simply copy the rank 0 values. # For the non-parallel parameters, simply copy the rank 0 values.
if not hasattr(merged_param, 'intra_layer_model_parallel'): if not hasattr(merged_param, 'tensor_model_parallel'):
print(' none-parallel parameter, simple copy from rank 0') print(' none-parallel parameter, simple copy from rank 0')
with torch.no_grad(): with torch.no_grad():
merged_param.data.copy_(partitions_param[0].data) merged_param.data.copy_(partitions_param[0].data)
...@@ -267,8 +267,8 @@ def main(): ...@@ -267,8 +267,8 @@ def main():
# Save the model. # Save the model.
args.intra_layer_model_parallel_size = 1 args.tensor_model_parallel_size = 1
mpu.initialize.set_intra_layer_model_parallel_rank(0) mpu.initialize.set_tensor_model_parallel_rank(0)
sd = {} sd = {}
sd['model'] = merged_model.state_dict_for_save_checkpoint() sd['model'] = merged_model.state_dict_for_save_checkpoint()
sd['iteration'] = iteration sd['iteration'] = iteration
......
...@@ -136,7 +136,7 @@ def get_args(): ...@@ -136,7 +136,7 @@ def get_args():
# some default/dummy values for the tokenizer # some default/dummy values for the tokenizer
args.rank = 0 args.rank = 0
args.make_vocab_size_divisible_by = 128 args.make_vocab_size_divisible_by = 128
args.intra_layer_model_parallel_size = 1 args.tensor_model_parallel_size = 1
return args return args
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment