Commit 7abd3e90 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Pipeline parallelism implementation with periodic full-pipeline syncs

Also includes following changes for inter-layer model-parallel implementation:
- Refactoring of model implementations
- Training loop changes to support inter-layer communication using `ring_exchange`
- New groups for inter-layer communication
- Checkpoint changes
- Command line arguments
parent 28cd66e1
......@@ -27,6 +27,7 @@ from megatron import get_timers
from megatron import get_tensorboard_writer
from megatron import mpu
from megatron import print_rank_0
from megatron import print_rank_last
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
from megatron.fp16 import FP16_Module
......@@ -123,8 +124,10 @@ def get_model(model_provider_func):
# Print number of parameters.
if mpu.get_data_parallel_rank() == 0:
print(' > number of parameters on model parallel rank {}: {}'.format(
mpu.get_model_parallel_rank(),
print(' > number of parameters on (intra-layer, inter-layer) '
'model parallel rank ({}, {}): {}'.format(
mpu.get_intra_layer_model_parallel_rank(),
mpu.get_inter_layer_model_parallel_rank(),
sum([p.nelement() for p in model.parameters()])), flush=True)
# GPU allocation.
......@@ -135,6 +138,9 @@ def get_model(model_provider_func):
model = FP16_Module(model)
# Wrap model for distributed training."""
if args.use_pipelining:
assert args.DDP_impl == 'local'
if args.DDP_impl == 'torch':
i = torch.cuda.current_device()
model = torchDDP(model, device_ids=[i], output_device=i,
......@@ -160,8 +166,8 @@ def get_optimizer(model):
# Add model parallel attribute if it is not set.
for param_group in param_groups:
for param in param_group['params']:
if not hasattr(param, 'model_parallel'):
param.model_parallel = False
if not hasattr(param, 'intra_layer_model_parallel'):
param.intra_layer_model_parallel = False
# Use Adam.
optimizer = Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay,
......@@ -231,27 +237,144 @@ def setup_model_and_optimizer(model_provider_func):
return model, optimizer, lr_scheduler
def backward_step(optimizer, model, loss):
def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward):
"""Communicate tensors between stages using torch.distributed.ring_exchange(.) API."""
args = get_args()
# Create placeholder tensors for receive in forward and backward directions
# if needed.
tensor_recv_prev = None
tensor_recv_next = None
tensor_shape = (args.batch_size, args.seq_length, args.hidden_size)
if recv_forward:
tensor_recv_prev = torch.empty(tensor_shape,
requires_grad=True,
dtype=args.params_dtype).cuda()
if recv_backward:
tensor_recv_next = torch.empty(tensor_shape,
requires_grad=True,
dtype=args.params_dtype).cuda()
# Send tensors in both the forward and backward directions as appropriate.
torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
tensor_recv_prev=tensor_recv_prev,
tensor_send_next=tensor_send_next,
tensor_recv_next=tensor_recv_next,
group=mpu.get_inter_layer_model_parallel_group())
return tensor_recv_prev, tensor_recv_next
def backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad):
"""Backward step."""
args = get_args()
timers = get_timers()
# Retain the grad on the input_tensor.
if input_tensor is not None:
input_tensor.retain_grad()
# Backward pass.
timers('backward-backward').start()
if args.fp16:
optimizer.backward(output_tensor, update_master_grads=False,
output_tensor_grad=output_tensor_grad)
else:
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
timers('backward-backward').stop()
# Collect the grad of the input_tensor.
input_tensor_grad = None
if input_tensor is not None:
input_tensor_grad = input_tensor.grad
return input_tensor_grad
def train_step(forward_step_func, data_iterator,
model, optimizer, lr_scheduler):
"""Single training step."""
args = get_args()
timers = get_timers()
# Set grad to zero.
if args.fp16:
optimizer.zero_grad(set_grads_to_None=True)
optimizer.backward(loss, update_master_grads=False)
else:
optimizer.zero_grad()
loss.backward()
timers('backward-backward').stop()
# Compute number of microbatches in a minibatch.
num_microbatches_to_pipeline = args.inter_layer_model_parallel_size \
if args.use_pipelining else 1
input_tensors = []
output_tensors = []
losses_reduced = []
# Run forward pass for all microbatches in minibatch.
for i in range(num_microbatches_to_pipeline):
if not mpu.is_inter_layer_first_stage():
input_tensor, _ = communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=True,
recv_backward=False)
else:
input_tensor = None
# Forward model for one step.
timers('forward').start()
output_tensor = forward_step_func(data_iterator, model, input_tensor)
timers('forward').stop()
if mpu.is_inter_layer_last_stage():
loss, loss_reduced = output_tensor
output_tensor = loss
losses_reduced.append(loss_reduced)
else:
communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_forward=False,
recv_backward=False)
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
# Run backward pass for all microbatches in minibatch.
for i in range(num_microbatches_to_pipeline):
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
if mpu.is_inter_layer_last_stage():
output_grad_tensor = None
else:
_, output_grad_tensor = communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=False,
recv_backward=True)
# Backward pass for one step.
# TODO: This timer is a bit redundant now with backward-backward.
timers('backward').start()
input_grad_tensor = \
backward_step(optimizer, model, input_tensor, output_tensor, output_grad_tensor)
timers('backward').stop()
if not mpu.is_inter_layer_first_stage():
communicate(
tensor_send_next=None,
tensor_send_prev=input_grad_tensor,
recv_forward=False,
recv_backward=False)
# All-reduce if needed.
if args.DDP_impl == 'local':
timers('backward-allreduce').start()
timers('allreduce').start()
model.allreduce_params(reduce_after=False,
fp32_allreduce=args.fp32_allreduce)
timers('backward-allreduce').stop()
timers('allreduce').stop()
# Update master gradients.
timers('backward-master-grad').start()
......@@ -259,32 +382,33 @@ def backward_step(optimizer, model, loss):
optimizer.update_master_grads()
timers('backward-master-grad').stop()
# All-reduce across first and last stages.
if (mpu.is_inter_layer_first_stage() or mpu.is_inter_layer_last_stage()) and \
args.inter_layer_model_parallel_size > 1:
unwrapped_model = model
while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16_Module)):
unwrapped_model = unwrapped_model.module
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
torch.distributed.all_reduce(word_embeddings_weight.grad,
group=mpu.get_embedding_group())
# Clipping gradients helps prevent the exploding gradient.
timers('backward-clip-grad').start()
if args.clip_grad > 0:
if args.clip_grad > 0.:
if not args.fp16:
mpu.clip_grad_norm(model.parameters(), args.clip_grad)
named_parameters = model.named_parameters()
parameters = []
parameter_names = []
for parameter_name, parameter in model.named_parameters():
parameters.append(parameter)
parameter_names.append(parameter_name)
mpu.clip_grad_norm(parameters, args.clip_grad,
parameter_names=parameter_names)
else:
optimizer.clip_master_grads(args.clip_grad)
timers('backward-clip-grad').stop()
def train_step(forward_step_func, data_iterator,
model, optimizer, lr_scheduler):
"""Single training step."""
args = get_args()
timers = get_timers()
# Forward model for one step.
timers('forward').start()
loss, loss_reduced = forward_step_func(data_iterator, model)
timers('forward').stop()
# Calculate gradients, reduce across processes, and clip.
timers('backward').start()
backward_step(optimizer, model, loss)
timers('backward').stop()
# Update parameters.
timers('optimizer').start()
optimizer.step()
......@@ -297,7 +421,15 @@ def train_step(forward_step_func, data_iterator,
else:
skipped_iter = 1
return loss_reduced, skipped_iter
if mpu.is_inter_layer_last_stage():
# Average loss across microbatches.
loss_reduced = {}
for key in losses_reduced[0]:
losses_reduced_for_key = [x[key] for x in losses_reduced]
loss_reduced[key] = sum(losses_reduced_for_key) / \
len(losses_reduced_for_key)
return loss_reduced, skipped_iter
return {}, skipped_iter
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
......@@ -382,7 +514,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
total_loss_dict[got_nan_key])
total_loss_dict[skipped_iters_key] = 0
total_loss_dict[got_nan_key] = 0
print_rank_0(log_string)
print_rank_last(log_string)
if report_memory_flag:
report_memory('after {} iterations'.format(iteration))
report_memory_flag = False
......@@ -471,12 +603,32 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
if verbose and iteration % args.log_interval == 0:
print_rank_0('Evaluating iter {}/{}'.format(iteration,
args.eval_iters))
if not mpu.is_inter_layer_first_stage():
input_tensor, _ = communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=True,
recv_backward=False)
else:
input_tensor = None
# Forward evaluation.
_, loss_dict = forward_step_func(data_iterator, model)
# Reduce across processes.
for key in loss_dict:
total_loss_dict[key] = total_loss_dict.get(key, 0.) + \
loss_dict[key]
output_tensor = forward_step_func(data_iterator, model, input_tensor)
if mpu.is_inter_layer_last_stage():
_, loss_dict = output_tensor
# Reduce across processes.
for key in loss_dict:
total_loss_dict[key] = total_loss_dict.get(key, 0.) + \
loss_dict[key]
else:
communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_forward=False,
recv_backward=False)
# Move model back to the train mode.
model.train()
......@@ -505,9 +657,9 @@ def evaluate_and_print_results(prefix, forward_step_func,
writer.add_scalar('{} ppl'.format(key), ppl, iteration)
length = len(string) + 1
print_rank_0('-' * length)
print_rank_0(string)
print_rank_0('-' * length)
print_rank_last('-' * length)
print_rank_last(string)
print_rank_last('-' * length)
def build_train_valid_test_data_iterators(
......@@ -519,7 +671,7 @@ def build_train_valid_test_data_iterators(
print_rank_0('> building train, validation, and test datasets ...')
# Data loader only on rank 0 of each model parallel group.
if mpu.get_model_parallel_rank() == 0:
if mpu.get_intra_layer_model_parallel_rank() == 0:
# Rank, size, and global batch size.
data_parallel_size = mpu.get_data_parallel_world_size()
global_batch_size = args.batch_size * data_parallel_size
......@@ -557,8 +709,8 @@ def build_train_valid_test_data_iterators(
# Broadcast num tokens.
torch.distributed.broadcast(flags,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
mpu.get_intra_layer_model_parallel_src_rank(),
group=mpu.get_intra_layer_model_parallel_group())
args.do_train = flags[0].item()
args.do_valid = flags[1].item()
args.do_test = flags[2].item()
......
......@@ -28,14 +28,16 @@ from megatron.data.samplers import DistributedBatchSampler
from megatron.fp16 import FP16_Optimizer
def reduce_losses(losses):
def average_losses_across_data_parallel_group(losses):
"""Reduce a tensor of losses across all GPUs."""
reduced_losses = torch.cat(
averaged_losses = torch.cat(
[loss.clone().detach().view(1) for loss in losses])
torch.distributed.all_reduce(reduced_losses)
reduced_losses = reduced_losses / torch.distributed.get_world_size()
torch.distributed.all_reduce(averaged_losses,
group=mpu.get_data_parallel_group())
averaged_losses = averaged_losses / \
torch.distributed.get_world_size(group=mpu.get_data_parallel_group())
return reduced_losses
return averaged_losses
def report_memory(name):
......@@ -56,7 +58,7 @@ def print_params_min_max_norm(optimizer, iteration):
"""Print min, max, and norm of all parameters."""
index = 0
rank = torch.distributed.get_rank()
string = 'iteration, rank, index, model-parallel,min, max, norm\n'
string = 'iteration, rank, index, intra-layer-model-parallel, min, max, norm\n'
optimizer_ = optimizer
if isinstance(optimizer, FP16_Optimizer):
optimizer_ = optimizer.optimizer
......@@ -67,7 +69,7 @@ def print_params_min_max_norm(optimizer, iteration):
max_ = param.data.max()
norm = param.data.norm()
string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format(
iteration, rank, index, int(param.model_parallel))
iteration, rank, index, int(param.intra_layer_model_parallel))
string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm)
print(string, flush=True)
......
......@@ -23,9 +23,9 @@ from megatron import print_rank_0
from megatron import get_timers
from megatron import mpu
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import BertModel
from megatron.model import BertModel, BertModelFirstStage, BertModelIntermediateStage, BertModelLastStage
from megatron.training import pretrain
from megatron.utils import reduce_losses
from megatron.utils import average_losses_across_data_parallel_group
def model_provider():
......@@ -33,10 +33,25 @@ def model_provider():
print_rank_0('building BERT model ...')
model = BertModel(
num_tokentypes=2,
add_binary_head=True,
parallel_output=True)
args = get_args()
if args.inter_layer_model_parallel_size > 1:
# Determine model based on position of stage in pipeline.
if mpu.is_inter_layer_first_stage():
model = BertModelFirstStage(
num_tokentypes=2)
elif mpu.is_inter_layer_last_stage():
model = BertModelLastStage(
num_tokentypes=2,
add_binary_head=True,
parallel_output=True)
else:
model = BertModelIntermediateStage(
num_tokentypes=2)
else:
model = BertModel(
num_tokentypes=2,
add_binary_head=True,
parallel_output=True)
return model
......@@ -66,7 +81,7 @@ def get_batch(data_iterator):
return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
def forward_step(data_iterator, model):
def forward_step(data_iterator, model, input_tensor):
"""Forward step."""
args = get_args()
timers = get_timers()
......@@ -77,23 +92,40 @@ def forward_step(data_iterator, model):
= get_batch(data_iterator)
timers('batch generator').stop()
# Forward model. lm_labels
lm_loss_, sop_logits = model(tokens, padding_mask,
tokentype_ids=types,
lm_labels=lm_labels)
# Forward pass through the model.
if mpu.is_inter_layer_first_stage():
assert input_tensor is None
if mpu.is_inter_layer_last_stage():
output_tensor = model(tokens, padding_mask, tokentype_ids=types,
lm_labels=lm_labels)
else:
output_tensor = model(tokens, padding_mask, tokentype_ids=types)
elif mpu.is_inter_layer_last_stage():
assert input_tensor is not None
output_tensor = model(input_tensor, padding_mask, lm_labels=lm_labels)
else:
assert input_tensor is not None
output_tensor = model(input_tensor, padding_mask)
if mpu.is_inter_layer_last_stage():
lm_loss_, sop_logits = output_tensor
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
sentence_order.view(-1),
ignore_index=-1)
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
sentence_order.view(-1),
ignore_index=-1)
sop_loss = sop_loss.float()
lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
lm_loss_ = lm_loss_.float()
loss_mask = loss_mask.float()
lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
loss = lm_loss + sop_loss
loss = lm_loss + sop_loss
reduced_losses = reduce_losses([lm_loss, sop_loss])
averaged_losses = average_losses_across_data_parallel_group([lm_loss, sop_loss])
return loss, {'lm loss': reduced_losses[0], 'sop loss': reduced_losses[1]}
return loss, {'lm loss': averaged_losses[0], 'sop loss': averaged_losses[1]}
return output_tensor
def train_valid_test_datasets_provider(train_val_test_num_samples):
......
......@@ -23,16 +23,28 @@ from megatron import get_timers
from megatron import get_tokenizer
from megatron import mpu
from megatron.data.gpt2_dataset import build_train_valid_test_datasets
from megatron.model import GPT2Model
from megatron.model import GPT2Model, GPT2ModelFirstStage, GPT2ModelIntermediateStage, GPT2ModelLastStage
from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import reduce_losses
from megatron.utils import average_losses_across_data_parallel_group
def model_provider():
"""Build the model."""
print_rank_0('building GPT2 model ...')
model = GPT2Model(num_tokentypes=0, parallel_output=True)
args = get_args()
if args.inter_layer_model_parallel_size > 1:
# Determine model based on position of stage in pipeline.
if mpu.is_inter_layer_first_stage():
model = GPT2ModelFirstStage(num_tokentypes=0)
elif mpu.is_inter_layer_last_stage():
model = GPT2ModelLastStage(
num_tokentypes=0, parallel_output=True)
else:
model = GPT2ModelIntermediateStage(
num_tokentypes=0)
else:
model = GPT2Model(num_tokentypes=0, parallel_output=True)
return model
......@@ -69,7 +81,7 @@ def get_batch(data_iterator):
return tokens, labels, loss_mask, attention_mask, position_ids
def forward_step(data_iterator, model):
def forward_step(data_iterator, model, input_tensor):
"""Forward step."""
args = get_args()
timers = get_timers()
......@@ -79,15 +91,32 @@ def forward_step(data_iterator, model):
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator)
timers('batch generator').stop()
# Forward model.
losses = model(tokens, position_ids, attention_mask, labels=labels)
loss_mask = loss_mask.view(-1)
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging.
reduced_loss = reduce_losses([loss])
# Forward pass through the model.
if mpu.is_inter_layer_first_stage():
assert input_tensor is None
if mpu.is_inter_layer_last_stage():
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)
else:
output_tensor = model(tokens, position_ids, attention_mask)
elif mpu.is_inter_layer_last_stage():
assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask, labels=labels)
else:
assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask)
if mpu.is_inter_layer_last_stage():
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': reduced_loss[0]}
return loss, {'lm loss': averaged_loss[0]}
return output_tensor
def train_valid_test_datasets_provider(train_val_test_num_samples):
......
......@@ -25,12 +25,14 @@ from megatron import get_timers
from megatron import mpu
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.training import pretrain
from megatron.utils import reduce_losses
from megatron.utils import average_losses_across_data_parallel_group
from megatron.model.realm_model import general_ict_model_provider
from megatron.data.realm_dataset_utils import get_ict_batch
def pretrain_ict_model_provider():
args = get_args()
assert args.inter_layer_model_parallel_size == 1, 'inter_layer_model_parallel_size must be 1!'
return general_ict_model_provider(False, False)
......@@ -72,7 +74,7 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function):
return output
def forward_step(data_iterator, model):
def forward_step(data_iterator, model, input_tensor):
"""Forward step."""
args = get_args()
timers = get_timers()
......@@ -87,7 +89,7 @@ def forward_step(data_iterator, model):
# Forward model.
query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask)
local_batch_size = query_logits.shape[0]
global_batch_size = dist.get_world_size() * local_batch_size # recall we assert that model_parallel_size == 1
global_batch_size = dist.get_world_size() * local_batch_size # recall we assert that intra_layer_model_parallel_size == 1
all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
all_block_logits = AllgatherFromDataParallelRegion.apply(block_logits)
......@@ -102,11 +104,12 @@ def forward_step(data_iterator, model):
topk_accs = [topk_accuracy(int(k)) for k in args.report_topk_accuracies]
retrieval_loss = torch.nn.CrossEntropyLoss()(retrieval_scores, torch.arange(global_batch_size).long().cuda())
reduced_losses = reduce_losses([retrieval_loss, *topk_accs])
retrieval_loss = retrieval_loss.float()
averaged_losses = average_losses_across_data_parallel_group([retrieval_loss, *topk_accs])
# create stats_dict with retrieval loss and all specified top-k accuracies
topk_acc_dict = {'top{}_acc'.format(k): v for k, v in zip(args.report_topk_accuracies, reduced_losses[1:])}
stats_dict = dict(retrieval_loss=reduced_losses[0], **topk_acc_dict)
topk_acc_dict = {'top{}_acc'.format(k): v for k, v in zip(args.report_topk_accuracies, averaged_losses[1:])}
stats_dict = dict(retrieval_loss=averaged_losses[0], **topk_acc_dict)
return retrieval_loss, stats_dict
......
......@@ -28,7 +28,7 @@ from megatron.training import setup_model_and_optimizer
from megatron.training import train_step
from megatron.training import training_log
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import reduce_losses
from megatron.utils import average_losses_across_data_parallel_group
def process_batch(batch):
......@@ -66,9 +66,9 @@ def _cross_entropy_forward_step(batch, model):
loss = loss_func(logits.contiguous().float(), labels)
# Reduce loss for logging.
reduced_loss = reduce_losses([loss])
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': reduced_loss[0]}
return loss, {'lm loss': averaged_loss[0]}
def build_data_loader(dataset, batch_size, num_workers, drop_last):
......
......@@ -188,12 +188,12 @@ def main():
# Args
args = _parse_args(extra_args_provider=get_mp_merge_args)
model_type = args.model_type
orig_model_parallel_size = args.model_parallel_size
args.model_parallel_size = 1
orig_intra_layer_model_parallel_size = args.intra_layer_model_parallel_size
args.intra_layer_model_parallel_size = 1
tokenizer = rebuild_tokenizer(args)
print('\n merging model parallel partitions ...')
print(' > number of partitions: {}'.format(orig_model_parallel_size))
print(' > number of partitions: {}'.format(orig_intra_layer_model_parallel_size))
print(' > checkpoint path: {}'.format(args.load))
print(' > model parameters:')
print(' number of tokens ................ {} '.format(
......@@ -207,18 +207,18 @@ def main():
# Full model.
print('> building the full model ...')
mpu.initialize.set_model_parallel_world_size(1)
mpu.initialize.set_model_parallel_rank(0)
mpu.initialize.set_intra_layer_model_parallel_world_size(1)
mpu.initialize.set_intra_layer_model_parallel_rank(0)
merged_model = get_model(model_type)
# Build and load partitions.
partitions = []
iteration = 0
args.model_parallel_size = orig_model_parallel_size
args.intra_layer_model_parallel_size = orig_intra_layer_model_parallel_size
tokenizer = rebuild_tokenizer(args)
mpu.initialize.set_model_parallel_world_size(args.model_parallel_size)
for rank in range(args.model_parallel_size):
mpu.initialize.set_model_parallel_rank(rank)
mpu.initialize.set_intra_layer_model_parallel_world_size(args.intra_layer_model_parallel_size)
for rank in range(args.intra_layer_model_parallel_size):
mpu.initialize.set_intra_layer_model_parallel_rank(rank)
checkpoint_name, iteration = get_parallel_checkpoint_name(args.load)
print('> loading {} ...'.format(checkpoint_name))
model_ = get_model(model_type)
......@@ -248,7 +248,7 @@ def main():
rank, partition_param.dtype, list(partition_param.size())))
# For the non-parallel parameters, simply copy the rank 0 values.
if not hasattr(merged_param, 'model_parallel'):
if not hasattr(merged_param, 'intra_layer_model_parallel'):
print(' none-parallel parameter, simple copy from rank 0')
with torch.no_grad():
merged_param.data.copy_(partitions_param[0].data)
......@@ -267,8 +267,8 @@ def main():
# Save the model.
args.model_parallel_size = 1
mpu.initialize.set_model_parallel_rank(0)
args.intra_layer_model_parallel_size = 1
mpu.initialize.set_intra_layer_model_parallel_rank(0)
sd = {}
sd['model'] = merged_model.state_dict_for_save_checkpoint()
sd['iteration'] = iteration
......
......@@ -136,7 +136,7 @@ def get_args():
# some default/dummy values for the tokenizer
args.rank = 0
args.make_vocab_size_divisible_by = 128
args.model_parallel_size = 1
args.intra_layer_model_parallel_size = 1
return args
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment