Commit 4f4715db authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 't5_pipeline_parallelism' into t5_pipeline_parallelism_grad_norm_fix

parents 1f4966dc 1c3a1e6a
......@@ -176,14 +176,10 @@ def _initialize_distributed():
args.local_rank = device
torch.cuda.set_device(device)
# Call the init process
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', 'localhost')
master_port = os.getenv('MASTER_PORT', '6000')
init_method += master_ip + ':' + master_port
torch.distributed.init_process_group(
backend=args.distributed_backend,
world_size=args.world_size, rank=args.rank,
init_method=init_method)
world_size=args.world_size, rank=args.rank)
# Set the tensor model-parallel, pipeline model-parallel, and
# data-parallel communicators.
......
......@@ -53,8 +53,7 @@ class ParallelMLP(MegatronModule):
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension. At the end, dropout is also
applied.
state back into h hidden dimension.
"""
def __init__(self, init_method, output_layer_init_method):
......@@ -84,7 +83,6 @@ class ParallelMLP(MegatronModule):
init_method=output_layer_init_method,
skip_bias_add=True)
def forward(self, hidden_states):
# [s, b, 4hp]
......
......@@ -457,9 +457,13 @@ def get_data_parallel_rank():
def destroy_model_parallel():
"""Set the groups to none."""
global _MODEL_PARALLEL_GROUP
_MODEL_PARALLEL_GROUP = None
global _TENSOR_MODEL_PARALLEL_GROUP
_TENSOR_MODEL_PARALLEL_GROUP = None
global _PIPELINE_MODEL_PARALLEL_GROUP
_PIPELINE_MODEL_PARALLEL_GROUP = None
global _DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP = None
global _EMBEDDING_GROUP
_EMBEDDING_GROUP = None
......@@ -256,7 +256,7 @@ class ColumnParallelLinear(torch.nn.Module):
device=torch.cuda.current_device(), dtype=args.params_dtype))
_initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=0, stride=stride)
if bias:
if args.use_cpu_initialization:
self.bias = Parameter(torch.empty(
......@@ -286,7 +286,7 @@ class ColumnParallelLinear(torch.nn.Module):
# All-gather across the partitions.
output = gather_from_tensor_model_parallel_region(output_parallel)
else:
output = output_parallel
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
......@@ -316,8 +316,8 @@ class RowParallelLinear(torch.nn.Module):
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
skip_bias_add: This was added to enable performance optimations where bias
can be fused with other elementwise operations. we skip
skip_bias_add: This was added to enable performance optimization where bias
can be fused with other elementwise operations. We skip
adding bias but instead return it.
"""
......
......@@ -20,7 +20,7 @@ from .utils import split_tensor_along_last_dim
def _reduce(input_):
"""All-reduce the the input tensor across model parallel group."""
"""All-reduce the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
if get_tensor_model_parallel_world_size()==1:
......
......@@ -100,10 +100,12 @@ def get_megatron_optimizer(model):
args.clip_grad,
args.log_num_zeros_in_grad,
params_have_main_grad,
args.use_contiguous_buffers_in_ddp,
args.bf16,
grad_scaler)
# FP32.
return FP32Optimizer(optimizer, args.clip_grad,
args.log_num_zeros_in_grad,
params_have_main_grad)
params_have_main_grad,
args.use_contiguous_buffers_in_ddp)
......@@ -68,7 +68,9 @@ class MegatronOptimizer(ABC):
def __init__(self, optimizer, clip_grad,
log_num_zeros_in_grad,
params_have_main_grad):
params_have_main_grad,
use_contiguous_buffers_in_ddp):
"""Input optimizer is the base optimizer for example Adam."""
self.optimizer = optimizer
assert self.optimizer, 'no optimizer is provided.'
......@@ -76,7 +78,11 @@ class MegatronOptimizer(ABC):
self.clip_grad = clip_grad
self.log_num_zeros_in_grad = log_num_zeros_in_grad
self.params_have_main_grad = params_have_main_grad
self.use_contiguous_buffers_in_ddp = use_contiguous_buffers_in_ddp
if self.use_contiguous_buffers_in_ddp:
assert self.params_have_main_grad, \
"use of contiguous buffer requires that params have main grad"
def get_parameters(self):
params = []
......@@ -187,11 +193,12 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
"""
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, bf16, grad_scaler):
params_have_main_grad, use_contiguous_buffers_in_ddp,
bf16, grad_scaler):
super(Float16OptimizerWithFloat16Params, self).__init__(
optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad)
params_have_main_grad, use_contiguous_buffers_in_ddp)
self.bf16 = bf16
self.grad_scaler = grad_scaler
......@@ -310,12 +317,26 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
else:
if model_param.grad is not None:
main_param.grad = model_param.grad.float()
# Safe to deallocate model's grad/main_grad after copying.
# (If using contiguous buffers, main_grad's memory should
# persist and therefore should not be deallocated.)
model_param.grad = None
if self.params_have_main_grad and \
not self.use_contiguous_buffers_in_ddp:
model_param.main_grad = None
# For fp32 grads, we need to reset the grads to main grad.
if self.params_have_main_grad:
for model_group in self.fp32_from_fp32_groups:
for model_param in model_group:
model_param.grad = model_param.main_grad
# Safe to de-reference model's main_grad after copying.
# (If using contiguous buffers, main_grad's memory should
# persist and therefore should not be deallocated.)
if not self.use_contiguous_buffers_in_ddp:
model_param.main_grad = None
def _unscale_main_grads_and_check_for_nan(self):
main_grads = []
......@@ -469,11 +490,12 @@ class FP32Optimizer(MegatronOptimizer):
def __init__(self, optimizer, clip_grad,
log_num_zeros_in_grad,
params_have_main_grad):
params_have_main_grad,
use_contiguous_buffers_in_ddp):
super(FP32Optimizer, self).__init__(
optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad)
params_have_main_grad, use_contiguous_buffers_in_ddp)
self._scale = torch.cuda.FloatTensor([1.0])
......@@ -500,6 +522,12 @@ class FP32Optimizer(MegatronOptimizer):
for param in param_group['params']:
param.grad = param.main_grad
# Safe to de-reference model's main_grad after copying.
# (If using contiguous buffers, main_grad's memory should
# persist and therefore should not be deallocated.)
if not self.use_contiguous_buffers_in_ddp:
param.main_grad = None
# Clip gradients.
grad_norm = None
if self.clip_grad > 0.0:
......
......@@ -243,6 +243,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
# forward step
if mpu.is_pipeline_first_stage():
if len(input_tensors[model_chunk_id]) == \
len(output_tensors[model_chunk_id]):
......@@ -254,6 +255,11 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
input_tensor, losses_reduced)
output_tensors[model_chunk_id].append(output_tensor)
# if forward-only, no need to save tensors for a backward pass
if forward_only:
input_tensors[model_chunk_id].pop()
output_tensors[model_chunk_id].pop()
return output_tensor
def backward_step_helper(microbatch_id):
......@@ -538,8 +544,12 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
recv_tensor_shapes = get_tensor_shapes(rank-1, model_type)
send_tensor_shapes = get_tensor_shapes(rank, model_type)
input_tensors = []
output_tensors = []
# Input, output tensors only need to be saved when doing backward passes
input_tensors = None
output_tensors = None
if not forward_only:
input_tensors = []
output_tensors = []
losses_reduced = []
# Run warmup forward passes.
......@@ -549,8 +559,9 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
input_tensor, losses_reduced)
send_forward(output_tensor, send_tensor_shapes, timers=timers)
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
if not forward_only:
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
......@@ -566,21 +577,24 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
input_tensor, losses_reduced)
if forward_only:
send_forward(output_tensor, send_tensor_shapes, timers=timers)
if not last_iteration:
input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
else:
output_tensor_grad = \
send_forward_recv_backward(output_tensor,
send_tensor_shapes,
timers=timers)
# Add input_tensor and output_tensor to end of list, then pop from the
# start of the list for backward pass.
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
if forward_only:
if not last_iteration:
input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
else:
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
# Add input_tensor and output_tensor to end of list.
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
# Pop input_tensor and output_tensor from the start of the list for
# the backward pass.
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor,
......
......@@ -99,7 +99,7 @@ def pretrain(train_valid_test_dataset_provider,
# This will be closer to what scheduler will see (outside of
# image ... launches.
global _TRAIN_START_TIME
start_time_tensor = torch.cuda.FloatTensor([_TRAIN_START_TIME])
start_time_tensor = torch.cuda.DoubleTensor([_TRAIN_START_TIME])
torch.distributed.all_reduce(start_time_tensor,
op=torch.distributed.ReduceOp.MIN)
_TRAIN_START_TIME = start_time_tensor.item()
......@@ -391,6 +391,10 @@ def train_step(forward_step_func, data_iterator,
forward_step_func, data_iterator, model,
optimizer, timers, forward_only=False)
# Empty unused memory
if args.empty_unused_memory_each_iter >= 1:
torch.cuda.empty_cache()
# All-reduce if needed.
if args.DDP_impl == 'local':
timers('backward-params-all-reduce').start()
......@@ -438,6 +442,10 @@ def train_step(forward_step_func, data_iterator,
else:
skipped_iter = 1
# Empty unused memory
if args.empty_unused_memory_each_iter >= 2:
torch.cuda.empty_cache()
if mpu.is_pipeline_last_stage(ignore_virtual=True):
# Average loss across microbatches.
loss_reduced = {}
......@@ -571,7 +579,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
if iteration % args.log_interval == 0:
elapsed_time = timers('interval-time').elapsed()
elapsed_time_per_iteration = elapsed_time / total_iterations
if writer and torch.distributed.get_rank() == 0:
if writer:
if args.log_timers_to_tensorboard:
writer.add_scalar('iteration-time',
elapsed_time_per_iteration, iteration)
......@@ -746,6 +754,10 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
forward_step_func, data_iterator, model, optimizer=None,
timers=None, forward_only=True)
# Empty unused memory
if args.empty_unused_memory_each_iter >= 1:
torch.cuda.empty_cache()
if mpu.is_pipeline_last_stage(ignore_virtual=True):
# Reduce across processes.
for loss_dict in loss_dicts:
......@@ -778,7 +790,7 @@ def evaluate_and_print_results(prefix, forward_step_func,
string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
ppl = math.exp(min(20, total_loss_dict[key].item()))
string += '{} PPL: {:.6E} | '.format(key, ppl)
if writer and is_last_rank():
if writer:
writer.add_scalar('{} validation'.format(key),
total_loss_dict[key].item(),
iteration)
......@@ -817,10 +829,9 @@ def build_train_valid_test_data_iterators(
'only backward compatiblity support for iteration-based training'
args.consumed_train_samples = args.iteration * args.global_batch_size
if args.iteration > 0 and args.consumed_valid_samples == 0:
assert args.train_samples is None, \
'only backward compatiblity support for iteration-based training'
args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
args.eval_iters * args.global_batch_size
if args.train_samples is None:
args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
args.eval_iters * args.global_batch_size
# Data loader only on rank 0 of each model parallel group.
if mpu.get_tensor_model_parallel_rank() == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment