Commit 7abd3e90 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Pipeline parallelism implementation with periodic full-pipeline syncs

Also includes following changes for inter-layer model-parallel implementation:
- Refactoring of model implementations
- Training loop changes to support inter-layer communication using `ring_exchange`
- New groups for inter-layer communication
- Checkpoint changes
- Command line arguments
parent 28cd66e1
...@@ -27,6 +27,7 @@ from megatron import get_timers ...@@ -27,6 +27,7 @@ from megatron import get_timers
from megatron import get_tensorboard_writer from megatron import get_tensorboard_writer
from megatron import mpu from megatron import mpu
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import print_rank_last
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint from megatron.checkpointing import save_checkpoint
from megatron.fp16 import FP16_Module from megatron.fp16 import FP16_Module
...@@ -123,8 +124,10 @@ def get_model(model_provider_func): ...@@ -123,8 +124,10 @@ def get_model(model_provider_func):
# Print number of parameters. # Print number of parameters.
if mpu.get_data_parallel_rank() == 0: if mpu.get_data_parallel_rank() == 0:
print(' > number of parameters on model parallel rank {}: {}'.format( print(' > number of parameters on (intra-layer, inter-layer) '
mpu.get_model_parallel_rank(), 'model parallel rank ({}, {}): {}'.format(
mpu.get_intra_layer_model_parallel_rank(),
mpu.get_inter_layer_model_parallel_rank(),
sum([p.nelement() for p in model.parameters()])), flush=True) sum([p.nelement() for p in model.parameters()])), flush=True)
# GPU allocation. # GPU allocation.
...@@ -135,6 +138,9 @@ def get_model(model_provider_func): ...@@ -135,6 +138,9 @@ def get_model(model_provider_func):
model = FP16_Module(model) model = FP16_Module(model)
# Wrap model for distributed training.""" # Wrap model for distributed training."""
if args.use_pipelining:
assert args.DDP_impl == 'local'
if args.DDP_impl == 'torch': if args.DDP_impl == 'torch':
i = torch.cuda.current_device() i = torch.cuda.current_device()
model = torchDDP(model, device_ids=[i], output_device=i, model = torchDDP(model, device_ids=[i], output_device=i,
...@@ -160,8 +166,8 @@ def get_optimizer(model): ...@@ -160,8 +166,8 @@ def get_optimizer(model):
# Add model parallel attribute if it is not set. # Add model parallel attribute if it is not set.
for param_group in param_groups: for param_group in param_groups:
for param in param_group['params']: for param in param_group['params']:
if not hasattr(param, 'model_parallel'): if not hasattr(param, 'intra_layer_model_parallel'):
param.model_parallel = False param.intra_layer_model_parallel = False
# Use Adam. # Use Adam.
optimizer = Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay, optimizer = Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay,
...@@ -231,27 +237,144 @@ def setup_model_and_optimizer(model_provider_func): ...@@ -231,27 +237,144 @@ def setup_model_and_optimizer(model_provider_func):
return model, optimizer, lr_scheduler return model, optimizer, lr_scheduler
def backward_step(optimizer, model, loss): def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward):
"""Communicate tensors between stages using torch.distributed.ring_exchange(.) API."""
args = get_args()
# Create placeholder tensors for receive in forward and backward directions
# if needed.
tensor_recv_prev = None
tensor_recv_next = None
tensor_shape = (args.batch_size, args.seq_length, args.hidden_size)
if recv_forward:
tensor_recv_prev = torch.empty(tensor_shape,
requires_grad=True,
dtype=args.params_dtype).cuda()
if recv_backward:
tensor_recv_next = torch.empty(tensor_shape,
requires_grad=True,
dtype=args.params_dtype).cuda()
# Send tensors in both the forward and backward directions as appropriate.
torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
tensor_recv_prev=tensor_recv_prev,
tensor_send_next=tensor_send_next,
tensor_recv_next=tensor_recv_next,
group=mpu.get_inter_layer_model_parallel_group())
return tensor_recv_prev, tensor_recv_next
def backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad):
"""Backward step.""" """Backward step."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
# Retain the grad on the input_tensor.
if input_tensor is not None:
input_tensor.retain_grad()
# Backward pass. # Backward pass.
timers('backward-backward').start() timers('backward-backward').start()
if args.fp16:
optimizer.backward(output_tensor, update_master_grads=False,
output_tensor_grad=output_tensor_grad)
else:
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
timers('backward-backward').stop()
# Collect the grad of the input_tensor.
input_tensor_grad = None
if input_tensor is not None:
input_tensor_grad = input_tensor.grad
return input_tensor_grad
def train_step(forward_step_func, data_iterator,
model, optimizer, lr_scheduler):
"""Single training step."""
args = get_args()
timers = get_timers()
# Set grad to zero.
if args.fp16: if args.fp16:
optimizer.zero_grad(set_grads_to_None=True) optimizer.zero_grad(set_grads_to_None=True)
optimizer.backward(loss, update_master_grads=False)
else: else:
optimizer.zero_grad() optimizer.zero_grad()
loss.backward()
timers('backward-backward').stop() # Compute number of microbatches in a minibatch.
num_microbatches_to_pipeline = args.inter_layer_model_parallel_size \
if args.use_pipelining else 1
input_tensors = []
output_tensors = []
losses_reduced = []
# Run forward pass for all microbatches in minibatch.
for i in range(num_microbatches_to_pipeline):
if not mpu.is_inter_layer_first_stage():
input_tensor, _ = communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=True,
recv_backward=False)
else:
input_tensor = None
# Forward model for one step.
timers('forward').start()
output_tensor = forward_step_func(data_iterator, model, input_tensor)
timers('forward').stop()
if mpu.is_inter_layer_last_stage():
loss, loss_reduced = output_tensor
output_tensor = loss
losses_reduced.append(loss_reduced)
else:
communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_forward=False,
recv_backward=False)
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
# Run backward pass for all microbatches in minibatch.
for i in range(num_microbatches_to_pipeline):
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
if mpu.is_inter_layer_last_stage():
output_grad_tensor = None
else:
_, output_grad_tensor = communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=False,
recv_backward=True)
# Backward pass for one step.
# TODO: This timer is a bit redundant now with backward-backward.
timers('backward').start()
input_grad_tensor = \
backward_step(optimizer, model, input_tensor, output_tensor, output_grad_tensor)
timers('backward').stop()
if not mpu.is_inter_layer_first_stage():
communicate(
tensor_send_next=None,
tensor_send_prev=input_grad_tensor,
recv_forward=False,
recv_backward=False)
# All-reduce if needed. # All-reduce if needed.
if args.DDP_impl == 'local': if args.DDP_impl == 'local':
timers('backward-allreduce').start() timers('allreduce').start()
model.allreduce_params(reduce_after=False, model.allreduce_params(reduce_after=False,
fp32_allreduce=args.fp32_allreduce) fp32_allreduce=args.fp32_allreduce)
timers('backward-allreduce').stop() timers('allreduce').stop()
# Update master gradients. # Update master gradients.
timers('backward-master-grad').start() timers('backward-master-grad').start()
...@@ -259,32 +382,33 @@ def backward_step(optimizer, model, loss): ...@@ -259,32 +382,33 @@ def backward_step(optimizer, model, loss):
optimizer.update_master_grads() optimizer.update_master_grads()
timers('backward-master-grad').stop() timers('backward-master-grad').stop()
# All-reduce across first and last stages.
if (mpu.is_inter_layer_first_stage() or mpu.is_inter_layer_last_stage()) and \
args.inter_layer_model_parallel_size > 1:
unwrapped_model = model
while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16_Module)):
unwrapped_model = unwrapped_model.module
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
torch.distributed.all_reduce(word_embeddings_weight.grad,
group=mpu.get_embedding_group())
# Clipping gradients helps prevent the exploding gradient. # Clipping gradients helps prevent the exploding gradient.
timers('backward-clip-grad').start() timers('backward-clip-grad').start()
if args.clip_grad > 0: if args.clip_grad > 0.:
if not args.fp16: if not args.fp16:
mpu.clip_grad_norm(model.parameters(), args.clip_grad) named_parameters = model.named_parameters()
parameters = []
parameter_names = []
for parameter_name, parameter in model.named_parameters():
parameters.append(parameter)
parameter_names.append(parameter_name)
mpu.clip_grad_norm(parameters, args.clip_grad,
parameter_names=parameter_names)
else: else:
optimizer.clip_master_grads(args.clip_grad) optimizer.clip_master_grads(args.clip_grad)
timers('backward-clip-grad').stop() timers('backward-clip-grad').stop()
def train_step(forward_step_func, data_iterator,
model, optimizer, lr_scheduler):
"""Single training step."""
args = get_args()
timers = get_timers()
# Forward model for one step.
timers('forward').start()
loss, loss_reduced = forward_step_func(data_iterator, model)
timers('forward').stop()
# Calculate gradients, reduce across processes, and clip.
timers('backward').start()
backward_step(optimizer, model, loss)
timers('backward').stop()
# Update parameters. # Update parameters.
timers('optimizer').start() timers('optimizer').start()
optimizer.step() optimizer.step()
...@@ -297,7 +421,15 @@ def train_step(forward_step_func, data_iterator, ...@@ -297,7 +421,15 @@ def train_step(forward_step_func, data_iterator,
else: else:
skipped_iter = 1 skipped_iter = 1
return loss_reduced, skipped_iter if mpu.is_inter_layer_last_stage():
# Average loss across microbatches.
loss_reduced = {}
for key in losses_reduced[0]:
losses_reduced_for_key = [x[key] for x in losses_reduced]
loss_reduced[key] = sum(losses_reduced_for_key) / \
len(losses_reduced_for_key)
return loss_reduced, skipped_iter
return {}, skipped_iter
def training_log(loss_dict, total_loss_dict, learning_rate, iteration, def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
...@@ -382,7 +514,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -382,7 +514,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
total_loss_dict[got_nan_key]) total_loss_dict[got_nan_key])
total_loss_dict[skipped_iters_key] = 0 total_loss_dict[skipped_iters_key] = 0
total_loss_dict[got_nan_key] = 0 total_loss_dict[got_nan_key] = 0
print_rank_0(log_string) print_rank_last(log_string)
if report_memory_flag: if report_memory_flag:
report_memory('after {} iterations'.format(iteration)) report_memory('after {} iterations'.format(iteration))
report_memory_flag = False report_memory_flag = False
...@@ -471,12 +603,32 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False): ...@@ -471,12 +603,32 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
if verbose and iteration % args.log_interval == 0: if verbose and iteration % args.log_interval == 0:
print_rank_0('Evaluating iter {}/{}'.format(iteration, print_rank_0('Evaluating iter {}/{}'.format(iteration,
args.eval_iters)) args.eval_iters))
if not mpu.is_inter_layer_first_stage():
input_tensor, _ = communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=True,
recv_backward=False)
else:
input_tensor = None
# Forward evaluation. # Forward evaluation.
_, loss_dict = forward_step_func(data_iterator, model) output_tensor = forward_step_func(data_iterator, model, input_tensor)
# Reduce across processes.
for key in loss_dict: if mpu.is_inter_layer_last_stage():
total_loss_dict[key] = total_loss_dict.get(key, 0.) + \ _, loss_dict = output_tensor
loss_dict[key] # Reduce across processes.
for key in loss_dict:
total_loss_dict[key] = total_loss_dict.get(key, 0.) + \
loss_dict[key]
else:
communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_forward=False,
recv_backward=False)
# Move model back to the train mode. # Move model back to the train mode.
model.train() model.train()
...@@ -505,9 +657,9 @@ def evaluate_and_print_results(prefix, forward_step_func, ...@@ -505,9 +657,9 @@ def evaluate_and_print_results(prefix, forward_step_func,
writer.add_scalar('{} ppl'.format(key), ppl, iteration) writer.add_scalar('{} ppl'.format(key), ppl, iteration)
length = len(string) + 1 length = len(string) + 1
print_rank_0('-' * length) print_rank_last('-' * length)
print_rank_0(string) print_rank_last(string)
print_rank_0('-' * length) print_rank_last('-' * length)
def build_train_valid_test_data_iterators( def build_train_valid_test_data_iterators(
...@@ -519,7 +671,7 @@ def build_train_valid_test_data_iterators( ...@@ -519,7 +671,7 @@ def build_train_valid_test_data_iterators(
print_rank_0('> building train, validation, and test datasets ...') print_rank_0('> building train, validation, and test datasets ...')
# Data loader only on rank 0 of each model parallel group. # Data loader only on rank 0 of each model parallel group.
if mpu.get_model_parallel_rank() == 0: if mpu.get_intra_layer_model_parallel_rank() == 0:
# Rank, size, and global batch size. # Rank, size, and global batch size.
data_parallel_size = mpu.get_data_parallel_world_size() data_parallel_size = mpu.get_data_parallel_world_size()
global_batch_size = args.batch_size * data_parallel_size global_batch_size = args.batch_size * data_parallel_size
...@@ -557,8 +709,8 @@ def build_train_valid_test_data_iterators( ...@@ -557,8 +709,8 @@ def build_train_valid_test_data_iterators(
# Broadcast num tokens. # Broadcast num tokens.
torch.distributed.broadcast(flags, torch.distributed.broadcast(flags,
mpu.get_model_parallel_src_rank(), mpu.get_intra_layer_model_parallel_src_rank(),
group=mpu.get_model_parallel_group()) group=mpu.get_intra_layer_model_parallel_group())
args.do_train = flags[0].item() args.do_train = flags[0].item()
args.do_valid = flags[1].item() args.do_valid = flags[1].item()
args.do_test = flags[2].item() args.do_test = flags[2].item()
......
...@@ -28,14 +28,16 @@ from megatron.data.samplers import DistributedBatchSampler ...@@ -28,14 +28,16 @@ from megatron.data.samplers import DistributedBatchSampler
from megatron.fp16 import FP16_Optimizer from megatron.fp16 import FP16_Optimizer
def reduce_losses(losses): def average_losses_across_data_parallel_group(losses):
"""Reduce a tensor of losses across all GPUs.""" """Reduce a tensor of losses across all GPUs."""
reduced_losses = torch.cat( averaged_losses = torch.cat(
[loss.clone().detach().view(1) for loss in losses]) [loss.clone().detach().view(1) for loss in losses])
torch.distributed.all_reduce(reduced_losses) torch.distributed.all_reduce(averaged_losses,
reduced_losses = reduced_losses / torch.distributed.get_world_size() group=mpu.get_data_parallel_group())
averaged_losses = averaged_losses / \
torch.distributed.get_world_size(group=mpu.get_data_parallel_group())
return reduced_losses return averaged_losses
def report_memory(name): def report_memory(name):
...@@ -56,7 +58,7 @@ def print_params_min_max_norm(optimizer, iteration): ...@@ -56,7 +58,7 @@ def print_params_min_max_norm(optimizer, iteration):
"""Print min, max, and norm of all parameters.""" """Print min, max, and norm of all parameters."""
index = 0 index = 0
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
string = 'iteration, rank, index, model-parallel,min, max, norm\n' string = 'iteration, rank, index, intra-layer-model-parallel, min, max, norm\n'
optimizer_ = optimizer optimizer_ = optimizer
if isinstance(optimizer, FP16_Optimizer): if isinstance(optimizer, FP16_Optimizer):
optimizer_ = optimizer.optimizer optimizer_ = optimizer.optimizer
...@@ -67,7 +69,7 @@ def print_params_min_max_norm(optimizer, iteration): ...@@ -67,7 +69,7 @@ def print_params_min_max_norm(optimizer, iteration):
max_ = param.data.max() max_ = param.data.max()
norm = param.data.norm() norm = param.data.norm()
string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format( string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format(
iteration, rank, index, int(param.model_parallel)) iteration, rank, index, int(param.intra_layer_model_parallel))
string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm) string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm)
print(string, flush=True) print(string, flush=True)
......
...@@ -23,9 +23,9 @@ from megatron import print_rank_0 ...@@ -23,9 +23,9 @@ from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import BertModel from megatron.model import BertModel, BertModelFirstStage, BertModelIntermediateStage, BertModelLastStage
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import reduce_losses from megatron.utils import average_losses_across_data_parallel_group
def model_provider(): def model_provider():
...@@ -33,10 +33,25 @@ def model_provider(): ...@@ -33,10 +33,25 @@ def model_provider():
print_rank_0('building BERT model ...') print_rank_0('building BERT model ...')
model = BertModel( args = get_args()
num_tokentypes=2, if args.inter_layer_model_parallel_size > 1:
add_binary_head=True, # Determine model based on position of stage in pipeline.
parallel_output=True) if mpu.is_inter_layer_first_stage():
model = BertModelFirstStage(
num_tokentypes=2)
elif mpu.is_inter_layer_last_stage():
model = BertModelLastStage(
num_tokentypes=2,
add_binary_head=True,
parallel_output=True)
else:
model = BertModelIntermediateStage(
num_tokentypes=2)
else:
model = BertModel(
num_tokentypes=2,
add_binary_head=True,
parallel_output=True)
return model return model
...@@ -66,7 +81,7 @@ def get_batch(data_iterator): ...@@ -66,7 +81,7 @@ def get_batch(data_iterator):
return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
def forward_step(data_iterator, model): def forward_step(data_iterator, model, input_tensor):
"""Forward step.""" """Forward step."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
...@@ -77,23 +92,40 @@ def forward_step(data_iterator, model): ...@@ -77,23 +92,40 @@ def forward_step(data_iterator, model):
= get_batch(data_iterator) = get_batch(data_iterator)
timers('batch generator').stop() timers('batch generator').stop()
# Forward model. lm_labels # Forward pass through the model.
lm_loss_, sop_logits = model(tokens, padding_mask, if mpu.is_inter_layer_first_stage():
tokentype_ids=types, assert input_tensor is None
lm_labels=lm_labels) if mpu.is_inter_layer_last_stage():
output_tensor = model(tokens, padding_mask, tokentype_ids=types,
lm_labels=lm_labels)
else:
output_tensor = model(tokens, padding_mask, tokentype_ids=types)
elif mpu.is_inter_layer_last_stage():
assert input_tensor is not None
output_tensor = model(input_tensor, padding_mask, lm_labels=lm_labels)
else:
assert input_tensor is not None
output_tensor = model(input_tensor, padding_mask)
if mpu.is_inter_layer_last_stage():
lm_loss_, sop_logits = output_tensor
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
sentence_order.view(-1), sentence_order.view(-1),
ignore_index=-1) ignore_index=-1)
sop_loss = sop_loss.float()
lm_loss = torch.sum( lm_loss_ = lm_loss_.float()
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() loss_mask = loss_mask.float()
lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
loss = lm_loss + sop_loss loss = lm_loss + sop_loss
reduced_losses = reduce_losses([lm_loss, sop_loss]) averaged_losses = average_losses_across_data_parallel_group([lm_loss, sop_loss])
return loss, {'lm loss': reduced_losses[0], 'sop loss': reduced_losses[1]} return loss, {'lm loss': averaged_losses[0], 'sop loss': averaged_losses[1]}
return output_tensor
def train_valid_test_datasets_provider(train_val_test_num_samples): def train_valid_test_datasets_provider(train_val_test_num_samples):
......
...@@ -23,16 +23,28 @@ from megatron import get_timers ...@@ -23,16 +23,28 @@ from megatron import get_timers
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu from megatron import mpu
from megatron.data.gpt2_dataset import build_train_valid_test_datasets from megatron.data.gpt2_dataset import build_train_valid_test_datasets
from megatron.model import GPT2Model from megatron.model import GPT2Model, GPT2ModelFirstStage, GPT2ModelIntermediateStage, GPT2ModelLastStage
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import reduce_losses from megatron.utils import average_losses_across_data_parallel_group
def model_provider(): def model_provider():
"""Build the model.""" """Build the model."""
print_rank_0('building GPT2 model ...') print_rank_0('building GPT2 model ...')
model = GPT2Model(num_tokentypes=0, parallel_output=True) args = get_args()
if args.inter_layer_model_parallel_size > 1:
# Determine model based on position of stage in pipeline.
if mpu.is_inter_layer_first_stage():
model = GPT2ModelFirstStage(num_tokentypes=0)
elif mpu.is_inter_layer_last_stage():
model = GPT2ModelLastStage(
num_tokentypes=0, parallel_output=True)
else:
model = GPT2ModelIntermediateStage(
num_tokentypes=0)
else:
model = GPT2Model(num_tokentypes=0, parallel_output=True)
return model return model
...@@ -69,7 +81,7 @@ def get_batch(data_iterator): ...@@ -69,7 +81,7 @@ def get_batch(data_iterator):
return tokens, labels, loss_mask, attention_mask, position_ids return tokens, labels, loss_mask, attention_mask, position_ids
def forward_step(data_iterator, model): def forward_step(data_iterator, model, input_tensor):
"""Forward step.""" """Forward step."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
...@@ -79,15 +91,32 @@ def forward_step(data_iterator, model): ...@@ -79,15 +91,32 @@ def forward_step(data_iterator, model):
tokens, labels, loss_mask, attention_mask, position_ids = get_batch( tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator) data_iterator)
timers('batch generator').stop() timers('batch generator').stop()
# Forward model.
losses = model(tokens, position_ids, attention_mask, labels=labels)
loss_mask = loss_mask.view(-1)
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging. # Forward pass through the model.
reduced_loss = reduce_losses([loss]) if mpu.is_inter_layer_first_stage():
assert input_tensor is None
if mpu.is_inter_layer_last_stage():
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)
else:
output_tensor = model(tokens, position_ids, attention_mask)
elif mpu.is_inter_layer_last_stage():
assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask, labels=labels)
else:
assert input_tensor is not None
output_tensor = model(input_tensor, attention_mask)
if mpu.is_inter_layer_last_stage():
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': reduced_loss[0]} return loss, {'lm loss': averaged_loss[0]}
return output_tensor
def train_valid_test_datasets_provider(train_val_test_num_samples): def train_valid_test_datasets_provider(train_val_test_num_samples):
......
...@@ -25,12 +25,14 @@ from megatron import get_timers ...@@ -25,12 +25,14 @@ from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import reduce_losses from megatron.utils import average_losses_across_data_parallel_group
from megatron.model.realm_model import general_ict_model_provider from megatron.model.realm_model import general_ict_model_provider
from megatron.data.realm_dataset_utils import get_ict_batch from megatron.data.realm_dataset_utils import get_ict_batch
def pretrain_ict_model_provider(): def pretrain_ict_model_provider():
args = get_args()
assert args.inter_layer_model_parallel_size == 1, 'inter_layer_model_parallel_size must be 1!'
return general_ict_model_provider(False, False) return general_ict_model_provider(False, False)
...@@ -72,7 +74,7 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function): ...@@ -72,7 +74,7 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function):
return output return output
def forward_step(data_iterator, model): def forward_step(data_iterator, model, input_tensor):
"""Forward step.""" """Forward step."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
...@@ -87,7 +89,7 @@ def forward_step(data_iterator, model): ...@@ -87,7 +89,7 @@ def forward_step(data_iterator, model):
# Forward model. # Forward model.
query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask) query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask)
local_batch_size = query_logits.shape[0] local_batch_size = query_logits.shape[0]
global_batch_size = dist.get_world_size() * local_batch_size # recall we assert that model_parallel_size == 1 global_batch_size = dist.get_world_size() * local_batch_size # recall we assert that intra_layer_model_parallel_size == 1
all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits) all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
all_block_logits = AllgatherFromDataParallelRegion.apply(block_logits) all_block_logits = AllgatherFromDataParallelRegion.apply(block_logits)
...@@ -102,11 +104,12 @@ def forward_step(data_iterator, model): ...@@ -102,11 +104,12 @@ def forward_step(data_iterator, model):
topk_accs = [topk_accuracy(int(k)) for k in args.report_topk_accuracies] topk_accs = [topk_accuracy(int(k)) for k in args.report_topk_accuracies]
retrieval_loss = torch.nn.CrossEntropyLoss()(retrieval_scores, torch.arange(global_batch_size).long().cuda()) retrieval_loss = torch.nn.CrossEntropyLoss()(retrieval_scores, torch.arange(global_batch_size).long().cuda())
reduced_losses = reduce_losses([retrieval_loss, *topk_accs]) retrieval_loss = retrieval_loss.float()
averaged_losses = average_losses_across_data_parallel_group([retrieval_loss, *topk_accs])
# create stats_dict with retrieval loss and all specified top-k accuracies # create stats_dict with retrieval loss and all specified top-k accuracies
topk_acc_dict = {'top{}_acc'.format(k): v for k, v in zip(args.report_topk_accuracies, reduced_losses[1:])} topk_acc_dict = {'top{}_acc'.format(k): v for k, v in zip(args.report_topk_accuracies, averaged_losses[1:])}
stats_dict = dict(retrieval_loss=reduced_losses[0], **topk_acc_dict) stats_dict = dict(retrieval_loss=averaged_losses[0], **topk_acc_dict)
return retrieval_loss, stats_dict return retrieval_loss, stats_dict
......
...@@ -28,7 +28,7 @@ from megatron.training import setup_model_and_optimizer ...@@ -28,7 +28,7 @@ from megatron.training import setup_model_and_optimizer
from megatron.training import train_step from megatron.training import train_step
from megatron.training import training_log from megatron.training import training_log
from megatron.utils import check_adlr_autoresume_termination from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import reduce_losses from megatron.utils import average_losses_across_data_parallel_group
def process_batch(batch): def process_batch(batch):
...@@ -66,9 +66,9 @@ def _cross_entropy_forward_step(batch, model): ...@@ -66,9 +66,9 @@ def _cross_entropy_forward_step(batch, model):
loss = loss_func(logits.contiguous().float(), labels) loss = loss_func(logits.contiguous().float(), labels)
# Reduce loss for logging. # Reduce loss for logging.
reduced_loss = reduce_losses([loss]) averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': reduced_loss[0]} return loss, {'lm loss': averaged_loss[0]}
def build_data_loader(dataset, batch_size, num_workers, drop_last): def build_data_loader(dataset, batch_size, num_workers, drop_last):
......
...@@ -188,12 +188,12 @@ def main(): ...@@ -188,12 +188,12 @@ def main():
# Args # Args
args = _parse_args(extra_args_provider=get_mp_merge_args) args = _parse_args(extra_args_provider=get_mp_merge_args)
model_type = args.model_type model_type = args.model_type
orig_model_parallel_size = args.model_parallel_size orig_intra_layer_model_parallel_size = args.intra_layer_model_parallel_size
args.model_parallel_size = 1 args.intra_layer_model_parallel_size = 1
tokenizer = rebuild_tokenizer(args) tokenizer = rebuild_tokenizer(args)
print('\n merging model parallel partitions ...') print('\n merging model parallel partitions ...')
print(' > number of partitions: {}'.format(orig_model_parallel_size)) print(' > number of partitions: {}'.format(orig_intra_layer_model_parallel_size))
print(' > checkpoint path: {}'.format(args.load)) print(' > checkpoint path: {}'.format(args.load))
print(' > model parameters:') print(' > model parameters:')
print(' number of tokens ................ {} '.format( print(' number of tokens ................ {} '.format(
...@@ -207,18 +207,18 @@ def main(): ...@@ -207,18 +207,18 @@ def main():
# Full model. # Full model.
print('> building the full model ...') print('> building the full model ...')
mpu.initialize.set_model_parallel_world_size(1) mpu.initialize.set_intra_layer_model_parallel_world_size(1)
mpu.initialize.set_model_parallel_rank(0) mpu.initialize.set_intra_layer_model_parallel_rank(0)
merged_model = get_model(model_type) merged_model = get_model(model_type)
# Build and load partitions. # Build and load partitions.
partitions = [] partitions = []
iteration = 0 iteration = 0
args.model_parallel_size = orig_model_parallel_size args.intra_layer_model_parallel_size = orig_intra_layer_model_parallel_size
tokenizer = rebuild_tokenizer(args) tokenizer = rebuild_tokenizer(args)
mpu.initialize.set_model_parallel_world_size(args.model_parallel_size) mpu.initialize.set_intra_layer_model_parallel_world_size(args.intra_layer_model_parallel_size)
for rank in range(args.model_parallel_size): for rank in range(args.intra_layer_model_parallel_size):
mpu.initialize.set_model_parallel_rank(rank) mpu.initialize.set_intra_layer_model_parallel_rank(rank)
checkpoint_name, iteration = get_parallel_checkpoint_name(args.load) checkpoint_name, iteration = get_parallel_checkpoint_name(args.load)
print('> loading {} ...'.format(checkpoint_name)) print('> loading {} ...'.format(checkpoint_name))
model_ = get_model(model_type) model_ = get_model(model_type)
...@@ -248,7 +248,7 @@ def main(): ...@@ -248,7 +248,7 @@ def main():
rank, partition_param.dtype, list(partition_param.size()))) rank, partition_param.dtype, list(partition_param.size())))
# For the non-parallel parameters, simply copy the rank 0 values. # For the non-parallel parameters, simply copy the rank 0 values.
if not hasattr(merged_param, 'model_parallel'): if not hasattr(merged_param, 'intra_layer_model_parallel'):
print(' none-parallel parameter, simple copy from rank 0') print(' none-parallel parameter, simple copy from rank 0')
with torch.no_grad(): with torch.no_grad():
merged_param.data.copy_(partitions_param[0].data) merged_param.data.copy_(partitions_param[0].data)
...@@ -267,8 +267,8 @@ def main(): ...@@ -267,8 +267,8 @@ def main():
# Save the model. # Save the model.
args.model_parallel_size = 1 args.intra_layer_model_parallel_size = 1
mpu.initialize.set_model_parallel_rank(0) mpu.initialize.set_intra_layer_model_parallel_rank(0)
sd = {} sd = {}
sd['model'] = merged_model.state_dict_for_save_checkpoint() sd['model'] = merged_model.state_dict_for_save_checkpoint()
sd['iteration'] = iteration sd['iteration'] = iteration
......
...@@ -136,7 +136,7 @@ def get_args(): ...@@ -136,7 +136,7 @@ def get_args():
# some default/dummy values for the tokenizer # some default/dummy values for the tokenizer
args.rank = 0 args.rank = 0
args.make_vocab_size_divisible_by = 128 args.make_vocab_size_divisible_by = 128
args.model_parallel_size = 1 args.intra_layer_model_parallel_size = 1
return args return args
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment