Commit 9939fb58 authored by rprenger's avatar rprenger
Browse files

Fixing small merge conflict

parents 811183f0 b6b7ba4d
......@@ -13,7 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from megatron.model.enums import AttnMaskType
......@@ -30,10 +32,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
import scaled_upper_triang_masked_softmax_cuda
scale_t = torch.tensor([scale])
softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(
inputs, scale_t[0]
)
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
......@@ -42,10 +44,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
import scaled_upper_triang_masked_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_upper_triang_masked_softmax_cuda.backward(
output_grads, softmax_results, scale_t[0]
)
return input_grads, None
......@@ -63,9 +65,7 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
scale_t = torch.tensor([scale])
softmax_results = scaled_masked_softmax_cuda.forward(
inputs, mask, scale_t[0]
)
softmax_results = scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
......@@ -81,16 +81,18 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
return input_grads, None, None
class FusedScaleMaskSoftmax(torch.nn.Module):
class FusedScaleMaskSoftmax(nn.Module):
"""
fused operation: scaling + mask + softmax
Arguments:
input_in_fp16: flag to indicate if input in fp16 data format.
input_in_bf16: flag to indicate if input in bf16 data format.
attn_mask_type: attention mask type (pad or causal)
scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
"""
def __init__(
......@@ -106,8 +108,9 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
super(FusedScaleMaskSoftmax, self).__init__()
self.input_in_fp16 = input_in_fp16
self.input_in_bf16 = input_in_bf16
assert not (self.input_in_fp16 and self.input_in_bf16),\
'both fp16 and bf16 flags cannot be active at the same time.'
assert not (
self.input_in_fp16 and self.input_in_bf16
), "both fp16 and bf16 flags cannot be active at the same time."
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
......@@ -118,47 +121,72 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
assert (
self.scale is None or softmax_in_fp32
), "softmax should be in fp32 when scaled"
def forward(self, input, mask):
# [b, np, sq, sk]
assert input.dim() == 4
data_size = input.size()
query_seq_len = data_size[-2]
key_seq_len = data_size[-1]
attn_batch_size = data_size[0] * data_size[1]
# constraints on various tensor dimensions to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint = key_seq_len > 16 and key_seq_len <= 2048 and \
query_seq_len % 4 == 0 and attn_batch_size % 4 == 0
# invoke custom kernel
if self.input_in_float16 and mask is not None and \
custom_kernel_constraint and self.scaled_masked_softmax_fusion:
scale = self.scale if self.scale is not None else 1.0
if self.attn_mask_type == AttnMaskType.causal:
assert query_seq_len == key_seq_len, \
"causal mask is only for self attention"
input = input.view(-1, query_seq_len, key_seq_len)
probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
probs = probs.view(*data_size)
else:
assert self.attn_mask_type == AttnMaskType.padding
probs = ScaledMaskedSoftmax.apply(input, mask, scale)
if self.is_kernel_available(mask, *input.size()):
return self.forward_fused_softmax(input, mask)
else:
if self.input_in_float16 and self.softmax_in_fp32:
input = input.float()
return self.forward_torch_softmax(input, mask)
def is_kernel_available(self, mask, b, np, sq, sk):
attn_batches = b * np
if (
self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_float16 # input must be fp16
and mask is not None # mask tensor must not be None
and 16 < sk <= 2048 # sk must be 16 ~ 2048
and sq % 4 == 0 # sq must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if 0 <= sk <= 2048:
batch_per_block = self.get_batch_per_block(sq, sk, b, np)
if self.attn_mask_type == AttnMaskType.causal:
if attn_batches % batch_per_block == 0:
return True
else:
if sq % batch_per_block == 0:
return True
return False
if self.scale is not None:
input = input * self.scale
mask_output = self.mask_func(input, mask) if mask is not None else input
probs = torch.nn.Softmax(dim=-1)(mask_output)
def forward_fused_softmax(self, input, mask):
b, np, sq, sk = input.size()
scale = self.scale if self.scale is not None else 1.0
if self.input_in_float16 and self.softmax_in_fp32:
if self.input_in_fp16:
probs = probs.half()
else:
probs = probs.bfloat16()
if self.attn_mask_type == AttnMaskType.causal:
assert sq == sk, "causal mask is only for self attention"
# input is 3D tensor (attn_batches, sq, sk)
input = input.view(-1, sq, sk)
probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
return probs.view(b, np, sq, sk)
else:
# input is 4D tensor (b, np, sq, sk)
return ScaledMaskedSoftmax.apply(input, mask, scale)
def forward_torch_softmax(self, input, mask):
if self.input_in_float16 and self.softmax_in_fp32:
input = input.float()
if self.scale is not None:
input = input * self.scale
mask_output = self.mask_func(input, mask) if mask is not None else input
probs = torch.nn.Softmax(dim=-1)(mask_output)
if self.input_in_float16 and self.softmax_in_fp32:
if self.input_in_fp16:
probs = probs.half()
else:
probs = probs.bfloat16()
return probs
@staticmethod
def get_batch_per_block(sq, sk, b, np):
import scaled_masked_softmax_cuda
return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np)
......@@ -53,8 +53,7 @@ class ParallelMLP(MegatronModule):
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension. At the end, dropout is also
applied.
state back into h hidden dimension.
"""
def __init__(self, init_method, output_layer_init_method):
......@@ -84,7 +83,6 @@ class ParallelMLP(MegatronModule):
init_method=output_layer_init_method,
skip_bias_add=True)
def forward(self, hidden_states):
# [s, b, 4hp]
......@@ -544,8 +542,8 @@ class ParallelTransformer(MegatronModule):
self.input_tensor = None
# Store activation checkpoiting flag.
self.checkpoint_activations = args.checkpoint_activations
self.checkpoint_num_layers = args.checkpoint_num_layers
self.activations_checkpoint_method = args.activations_checkpoint_method
self.activations_checkpoint_num_layers = args.activations_checkpoint_num_layers
# Number of layers.
assert args.num_layers % mpu.get_pipeline_model_parallel_world_size() == 0, \
......@@ -611,12 +609,31 @@ class ParallelTransformer(MegatronModule):
# Make sure memory is freed.
mpu.reset_checkpointed_activations_memory_buffer()
l = 0
while l < self.num_layers:
hidden_states = mpu.checkpoint(
custom(l, l + self.checkpoint_num_layers),
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.checkpoint_num_layers
if self.activations_checkpoint_method == 'uniform':
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
l = 0
while l < self.num_layers:
hidden_states = mpu.checkpoint(
custom(l, l + self.activations_checkpoint_num_layers),
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.activations_checkpoint_num_layers
elif self.activations_checkpoint_method == 'block':
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
for l in range(self.num_layers):
if l < self.activations_checkpoint_num_layers:
hidden_states = mpu.checkpoint(
custom(l, l + 1),
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
hidden_states = custom(l, l + 1)(
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
raise ValueError("Invalid activation checkpoint method.")
return hidden_states
......@@ -639,7 +656,7 @@ class ParallelTransformer(MegatronModule):
'for not None values in layer_past, ' \
'expected get_key_value to be set'
if get_key_value:
assert not self.checkpoint_activations, \
assert self.activations_checkpoint_method is None, \
'get_key_value does not work with ' \
'activation checkpointing'
......@@ -658,7 +675,7 @@ class ParallelTransformer(MegatronModule):
if encoder_output is not None:
encoder_output = encoder_output.transpose(0, 1).contiguous()
if self.checkpoint_activations:
if self.activations_checkpoint_method is not None:
hidden_states = self._checkpointed_forward(hidden_states,
attention_mask,
encoder_output,
......
......@@ -356,9 +356,13 @@ def get_data_parallel_rank():
def destroy_model_parallel():
"""Set the groups to none."""
global _MODEL_PARALLEL_GROUP
_MODEL_PARALLEL_GROUP = None
global _TENSOR_MODEL_PARALLEL_GROUP
_TENSOR_MODEL_PARALLEL_GROUP = None
global _PIPELINE_MODEL_PARALLEL_GROUP
_PIPELINE_MODEL_PARALLEL_GROUP = None
global _DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP = None
global _EMBEDDING_GROUP
_EMBEDDING_GROUP = None
......@@ -256,7 +256,7 @@ class ColumnParallelLinear(torch.nn.Module):
device=torch.cuda.current_device(), dtype=args.params_dtype))
_initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=0, stride=stride)
if bias:
if args.use_cpu_initialization:
self.bias = Parameter(torch.empty(
......@@ -286,7 +286,7 @@ class ColumnParallelLinear(torch.nn.Module):
# All-gather across the partitions.
output = gather_from_tensor_model_parallel_region(output_parallel)
else:
output = output_parallel
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
......@@ -316,8 +316,8 @@ class RowParallelLinear(torch.nn.Module):
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
skip_bias_add: This was added to enable performance optimations where bias
can be fused with other elementwise operations. we skip
skip_bias_add: This was added to enable performance optimization where bias
can be fused with other elementwise operations. We skip
adding bias but instead return it.
"""
......
......@@ -20,7 +20,7 @@ from .utils import split_tensor_along_last_dim
def _reduce(input_):
"""All-reduce the the input tensor across model parallel group."""
"""All-reduce the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
if get_tensor_model_parallel_world_size()==1:
......
......@@ -47,9 +47,18 @@ def init_checkpointed_activations_memory_buffer():
per_layer = args.micro_batch_size * args.max_position_embeddings * \
args.hidden_size // args.tensor_model_parallel_size
assert args.num_layers % args.checkpoint_num_layers == 0, \
'number of layers is not divisible by checkpoint-num-layers'
num_checkpointer_layers = args.num_layers // args.checkpoint_num_layers
num_layers = args.num_layers // mpu.get_pipeline_model_parallel_world_size()
if args.virtual_pipeline_model_parallel_size is not None:
num_layers = num_layers // args.virtual_pipeline_model_parallel_size
if args.activations_checkpoint_method == 'uniform':
assert num_layers % args.activations_checkpoint_num_layers == 0, \
'total number of layers is not divisible by checkpoint-chunk_size'
num_checkpointer_layers = args.num_layers // args.activations_checkpoint_num_layers
elif args.activations_checkpoint_method == 'block':
assert args.activations_checkpoint_num_layers <= num_layers, \
'total number of layers is fewer than the number of layers to checkpoint'
num_checkpointer_layers = args.activations_checkpoint_num_layers
numel = per_layer * num_checkpointer_layers
dtype = torch.half
if not args.fp16:
......
......@@ -100,10 +100,12 @@ def get_megatron_optimizer(model):
args.clip_grad,
args.log_num_zeros_in_grad,
params_have_main_grad,
args.use_contiguous_buffers_in_local_ddp,
args.bf16,
grad_scaler)
# FP32.
return FP32Optimizer(optimizer, args.clip_grad,
args.log_num_zeros_in_grad,
params_have_main_grad)
params_have_main_grad,
args.use_contiguous_buffers_in_local_ddp)
......@@ -68,7 +68,9 @@ class MegatronOptimizer(ABC):
def __init__(self, optimizer, clip_grad,
log_num_zeros_in_grad,
params_have_main_grad):
params_have_main_grad,
use_contiguous_buffers_in_local_ddp):
"""Input optimizer is the base optimizer for example Adam."""
self.optimizer = optimizer
assert self.optimizer, 'no optimizer is provided.'
......@@ -76,7 +78,11 @@ class MegatronOptimizer(ABC):
self.clip_grad = clip_grad
self.log_num_zeros_in_grad = log_num_zeros_in_grad
self.params_have_main_grad = params_have_main_grad
self.use_contiguous_buffers_in_local_ddp = use_contiguous_buffers_in_local_ddp
if self.use_contiguous_buffers_in_local_ddp:
assert self.params_have_main_grad, \
"use of contiguous buffer requires that params have main grad"
def get_parameters(self):
params = []
......@@ -187,11 +193,12 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
"""
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, bf16, grad_scaler):
params_have_main_grad, use_contiguous_buffers_in_local_ddp,
bf16, grad_scaler):
super(Float16OptimizerWithFloat16Params, self).__init__(
optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad)
params_have_main_grad, use_contiguous_buffers_in_local_ddp)
self.bf16 = bf16
self.grad_scaler = grad_scaler
......@@ -282,9 +289,14 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
def zero_grad(self, set_to_none=True):
"""We only need to zero the model related parameters, i.e.,
float16_groups & fp32_from_fp32_groups."""
float16_groups & fp32_from_fp32_groups. We additionally zero
fp32_from_float16_groups as a memory optimization to reduce
fragmentation; in the case of set_to_none==True, the space
used by this field can be safely deallocated at this point."""
for group in self.float16_groups:
_zero_grad_group_helper(group, set_to_none)
for group in self.fp32_from_float16_groups:
_zero_grad_group_helper(group, set_to_none)
for group in self.fp32_from_fp32_groups:
_zero_grad_group_helper(group, set_to_none)
......@@ -305,12 +317,26 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
else:
if model_param.grad is not None:
main_param.grad = model_param.grad.float()
# Safe to deallocate model's grad/main_grad after copying.
# (If using contiguous buffers, main_grad's memory should
# persist and therefore should not be deallocated.)
model_param.grad = None
if self.params_have_main_grad and \
not self.use_contiguous_buffers_in_local_ddp:
model_param.main_grad = None
# For fp32 grads, we need to reset the grads to main grad.
if self.params_have_main_grad:
for model_group in self.fp32_from_fp32_groups:
for model_param in model_group:
model_param.grad = model_param.main_grad
# Safe to de-reference model's main_grad after copying.
# (If using contiguous buffers, main_grad's memory should
# persist and therefore should not be deallocated.)
if not self.use_contiguous_buffers_in_local_ddp:
model_param.main_grad = None
def _unscale_main_grads_and_check_for_nan(self):
main_grads = []
......@@ -464,11 +490,12 @@ class FP32Optimizer(MegatronOptimizer):
def __init__(self, optimizer, clip_grad,
log_num_zeros_in_grad,
params_have_main_grad):
params_have_main_grad,
use_contiguous_buffers_in_local_ddp):
super(FP32Optimizer, self).__init__(
optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad)
params_have_main_grad, use_contiguous_buffers_in_local_ddp)
self._scale = torch.cuda.FloatTensor([1.0])
......@@ -495,6 +522,12 @@ class FP32Optimizer(MegatronOptimizer):
for param in param_group['params']:
param.grad = param.main_grad
# Safe to de-reference model's main_grad after copying.
# (If using contiguous buffers, main_grad's memory should
# persist and therefore should not be deallocated.)
if not self.use_contiguous_buffers_in_local_ddp:
param.main_grad = None
# Clip gradients.
grad_norm = None
if self.clip_grad > 0.0:
......
......@@ -22,7 +22,9 @@ from megatron import mpu
def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
use_ring_exchange=False):
use_ring_exchange=False, tensor_shape=None,
override_scatter_gather_tensors_in_pipeline=False,
dtype_=None):
"""Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py.
......@@ -37,7 +39,14 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
next rank.
use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
API should be used.
tensor_shape: optional, use when the input sequence contains less
tokens than the default sequence length
override_scatter_gather_tensors_in_pipeline: optional, this is used
when tensor_shape is
provided to overwide
scatter gather tensors
dtype_: optional, this is used when tensor_shape is provied and what
is the type of tensor_shape
Returns:
(tensor_recv_prev, tensor_recv_next)
"""
......@@ -47,8 +56,10 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# if needed.
tensor_recv_prev = None
tensor_recv_next = None
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
if args.scatter_gather_tensors_in_pipeline:
if tensor_shape is None:
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
if not override_scatter_gather_tensors_in_pipeline and \
args.scatter_gather_tensors_in_pipeline:
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) // \
mpu.get_tensor_model_parallel_world_size()
else:
......@@ -56,19 +67,26 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
dtype = args.params_dtype
if args.fp32_residual_connection:
dtype = torch.float
requires_grad = True
if dtype_ is not None:
dtype = dtype_
requires_grad = False
if recv_prev:
tensor_recv_prev = torch.empty(tensor_chunk_shape,
requires_grad=True,
requires_grad=requires_grad,
device=torch.cuda.current_device(),
dtype=dtype)
if recv_next:
tensor_recv_next = torch.empty(tensor_chunk_shape,
requires_grad=True,
requires_grad=requires_grad,
device=torch.cuda.current_device(),
dtype=dtype)
# Split tensor into smaller chunks if using scatter-gather optimization.
if args.scatter_gather_tensors_in_pipeline:
if not override_scatter_gather_tensors_in_pipeline and \
args.scatter_gather_tensors_in_pipeline:
if tensor_send_next is not None:
tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next)
......@@ -112,7 +130,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
torch.cuda.synchronize()
# If using scatter-gather optimization, gather smaller chunks.
if args.scatter_gather_tensors_in_pipeline:
if not override_scatter_gather_tensors_in_pipeline and \
args.scatter_gather_tensors_in_pipeline:
if recv_prev:
tensor_recv_prev = mpu.gather_split_1d_tensor(
tensor_recv_prev).view(tensor_shape).requires_grad_()
......@@ -124,8 +143,11 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
return tensor_recv_prev, tensor_recv_next
def recv_forward(timers=None):
def recv_forward(tensor_shape=None,
override_scatter_gather_tensors_in_pipeline=False,
dtype_=None, timers=None):
"""Receive tensor from previous rank in pipeline (forward receive)."""
if mpu.is_pipeline_first_stage():
input_tensor = None
else:
......@@ -135,7 +157,11 @@ def recv_forward(timers=None):
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=True,
recv_next=False)
recv_next=False,
tensor_shape=tensor_shape,
override_scatter_gather_tensors_in_pipeline=\
override_scatter_gather_tensors_in_pipeline,
dtype_=dtype_)
if timers is not None:
timers('forward-recv').stop()
return input_tensor
......@@ -158,8 +184,11 @@ def recv_backward(timers=None):
return output_tensor_grad
def send_forward(output_tensor, timers=None):
def send_forward(output_tensor, timers=None,
override_scatter_gather_tensors_in_pipeline=False,
dtype_=None):
"""Send tensor to next rank in pipeline (forward send)."""
if not mpu.is_pipeline_last_stage():
if timers is not None:
timers('forward-send').start()
......@@ -167,7 +196,10 @@ def send_forward(output_tensor, timers=None):
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=False)
recv_next=False,
override_scatter_gather_tensors_in_pipeline=\
override_scatter_gather_tensors_in_pipeline,
dtype_=dtype_)
if timers is not None:
timers('forward-send').stop()
......
......@@ -31,6 +31,9 @@ def get_forward_backward_func():
if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving
assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \
'number of microbatches is not divisible by pipeline-parallel ' \
'size when using interleaved schedule'
else:
forward_backward_func = forward_backward_pipelining_without_interleaving
else:
......@@ -191,6 +194,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
# forward step
if mpu.is_pipeline_first_stage():
if len(input_tensors[model_chunk_id]) == \
len(output_tensors[model_chunk_id]):
......@@ -202,6 +206,11 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
input_tensor, losses_reduced)
output_tensors[model_chunk_id].append(output_tensor)
# if forward-only, no need to save tensors for a backward pass
if forward_only:
input_tensors[model_chunk_id].pop()
output_tensors[model_chunk_id].pop()
return output_tensor
def backward_step_helper(microbatch_id):
......@@ -228,7 +237,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
# Run warmup forward passes.
mpu.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append(
p2p_communication.recv_forward(timers))
p2p_communication.recv_forward(timers=timers))
for k in range(num_warmup_microbatches):
output_tensor = forward_step_helper(k)
......@@ -262,7 +271,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
else:
input_tensor = \
p2p_communication.send_forward_recv_forward(
output_tensor, recv_prev, timers)
output_tensor, recv_prev=recv_prev, timers=timers)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
# Run 1F1B in steady state.
......@@ -340,7 +349,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
if not forward_only:
if all_warmup_microbatches:
output_tensor_grads[num_model_chunks-1].append(
p2p_communication.recv_backward(timers))
p2p_communication.recv_backward(timers=timers))
for k in range(num_microbatches_remaining, num_microbatches):
input_tensor_grad = backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
......@@ -352,7 +361,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
recv_next = False
output_tensor_grads[next_backward_model_chunk_id].append(
p2p_communication.send_backward_recv_backward(
input_tensor_grad, recv_next, timers))
input_tensor_grad, recv_next=recv_next, timers=timers))
return losses_reduced
......@@ -380,25 +389,30 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
num_microbatches_remaining = \
num_microbatches - num_warmup_microbatches
input_tensors = []
output_tensors = []
# Input, output tensors only need to be saved when doing backward passes
input_tensors = None
output_tensors = None
if not forward_only:
input_tensors = []
output_tensors = []
losses_reduced = []
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
input_tensor = p2p_communication.recv_forward(timers)
input_tensor = p2p_communication.recv_forward(timers=timers)
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
p2p_communication.send_forward(output_tensor, timers)
p2p_communication.send_forward(output_tensor, timers=timers)
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
if not forward_only:
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if num_microbatches_remaining > 0:
input_tensor = p2p_communication.recv_forward(timers)
input_tensor = p2p_communication.recv_forward(timers=timers)
# Run 1F1B in steady state.
for i in range(num_microbatches_remaining):
......@@ -407,22 +421,24 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
if forward_only:
p2p_communication.send_forward(output_tensor, timers)
p2p_communication.send_forward(output_tensor, timers=timers)
if not last_iteration:
input_tensor = p2p_communication.recv_forward(timers=timers)
else:
output_tensor_grad = \
p2p_communication.send_forward_recv_backward(output_tensor,
timers)
timers=timers)
# Add input_tensor and output_tensor to end of list, then pop from the
# start of the list for backward pass.
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
# Add input_tensor and output_tensor to end of list.
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
if forward_only:
if not last_iteration:
input_tensor = p2p_communication.recv_forward(timers)
else:
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
# Pop input_tensor and output_tensor from the start of the list for
# the backward pass.
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor,
......@@ -430,11 +446,11 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
if last_iteration:
input_tensor = None
p2p_communication.send_backward(input_tensor_grad, timers)
p2p_communication.send_backward(input_tensor_grad, timers=timers)
else:
input_tensor = \
p2p_communication.send_backward_recv_forward(
input_tensor_grad, timers)
input_tensor_grad, timers=timers)
# Run cooldown backward passes.
if not forward_only:
......@@ -442,12 +458,12 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
output_tensor_grad = p2p_communication.recv_backward(timers)
output_tensor_grad = p2p_communication.recv_backward(timers=timers)
input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad)
p2p_communication.send_backward(input_tensor_grad, timers)
p2p_communication.send_backward(input_tensor_grad, timers=timers)
return losses_reduced
......@@ -121,15 +121,15 @@ def receive_generate_info():
"""
Needs to be synced up with send_generate_info
"""
input_info_tensor = torch.empty(4, dtype=torch.int64, device=torch.device("cuda"))
input_info_tensor = torch.empty(3, dtype=torch.int64, device=torch.cuda.current_device())
torch.distributed.broadcast(input_info_tensor, 0)
batch_size = input_info_tensor[0].item()
seq_len = input_info_tensor[1].item()
max_len = input_info_tensor[2].item()
all_probs = input_info_tensor[3].item()
context_length_tensor = torch.empty(batch_size, dtype=torch.int64, device=torch.device("cuda"))
context_tokens_tensor = torch.empty(batch_size, seq_len, dtype=torch.int64, device=torch.device("cuda"))
context_length_tensor = torch.empty(batch_size, dtype=torch.int64, device=torch.cuda.current_device())
context_tokens_tensor = torch.empty(batch_size, seq_len, dtype=torch.int64, device=torch.cuda.current_device())
# Send variables to all ranks
torch.distributed.broadcast(context_length_tensor, 0)
......@@ -175,6 +175,7 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len
if tokens is not None:
return tokens[:, :context_length], output_logits, full_logits
<<<<<<< HEAD
def generate(model, sentences=None, max_len=0, all_probs=False):
if torch.distributed.get_rank() == 0:
context_tokens_tensor, context_length_tensor = tokenize_batch(sentences)
......@@ -182,6 +183,13 @@ def generate(model, sentences=None, max_len=0, all_probs=False):
b = context_tokens_tensor.size(0)
start = time.time()
send_generate_info(context_tokens_tensor, context_length_tensor, max_len, all_probs)
=======
def generate(model, sentences=None, max_len=0):
model.eval()
if torch.distributed.get_rank() == 0:
context_tokens_tensor, context_length_tensor = tokenize_batch(sentences)
send_generate_info(context_tokens_tensor, context_length_tensor, max_len)
>>>>>>> server
else:
context_length_tensor, context_tokens_tensor, max_len, all_probs = receive_generate_info()
......@@ -198,6 +206,7 @@ def generate(model, sentences=None, max_len=0, all_probs=False):
decode_tokens = decode_tokens.cpu().numpy().tolist()
for decode_token in decode_tokens:
resp_sentences.append(tokenizer.detokenize(decode_token))
<<<<<<< HEAD
words = []
for token in decode_token:
word = tokenizer.tokenizer.decoder[token]
......@@ -212,6 +221,21 @@ def generate(model, sentences=None, max_len=0, all_probs=False):
end = time.time()
print(str(b)+","+str(c)+","+str(len(decode_tokens[0]))+","+str(end-start), flush=True)
return resp_sentences, resp_sentences_seg, output_logits, full_logits, decode_tokens
=======
return resp_sentences
>>>>>>> server
def generate_samples_eval(model, context, max_gen_length, eos_token_id):
"""
This function is here to provide an a matching API for a legacy task
This implementation hasn't been tested yet to make sure it matches
"""
assert False, "Implementation untested"
args = get_args()
args.eos_id = eos_token_id
raw_text_len = len(context)
resp_sentences = generate(model, [context], max_gen_length)
return resp_sentences[0][raw_text_len:]
def switch(val1, val2, boolean):
boolean = boolean.type_as(val1)
......
......@@ -47,9 +47,7 @@ from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import unwrap_model
from megatron.data.data_samplers import build_pretraining_data_loader
from megatron.utils import calc_params_l2_norm
from megatron.schedules import forward_backward_no_pipelining
from megatron.schedules import forward_backward_pipelining_without_interleaving
from megatron.schedules import forward_backward_pipelining_with_interleaving
from megatron.schedules import get_forward_backward_func
from megatron.utils import report_memory
......@@ -98,7 +96,7 @@ def pretrain(train_valid_test_dataset_provider,
# This will be closer to what scheduler will see (outside of
# image ... launches.
global _TRAIN_START_TIME
start_time_tensor = torch.cuda.FloatTensor([_TRAIN_START_TIME])
start_time_tensor = torch.cuda.DoubleTensor([_TRAIN_START_TIME])
torch.distributed.all_reduce(start_time_tensor,
op=torch.distributed.ReduceOp.MIN)
_TRAIN_START_TIME = start_time_tensor.item()
......@@ -255,7 +253,7 @@ def get_model(model_provider_func):
if args.DDP_impl == 'local':
model = [LocalDDP(model_module,
args.accumulate_allreduce_grads_in_fp32,
args.use_contiguous_buffers_in_ddp)
args.use_contiguous_buffers_in_local_ddp)
for model_module in model]
return model
......@@ -353,26 +351,20 @@ def train_step(forward_step_func, data_iterator,
timers = get_timers()
# Set grad to zero.
if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_ddp:
if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_local_ddp:
for partition in model:
partition.zero_grad_buffer()
else:
optimizer.zero_grad()
if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving
assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \
'number of microbatches is not divisible by pipeline-parallel ' \
'size when using interleaved schedule'
else:
forward_backward_func = forward_backward_pipelining_without_interleaving
else:
forward_backward_func = forward_backward_no_pipelining
optimizer.zero_grad()
forward_backward_func = get_forward_backward_func()
losses_reduced = forward_backward_func(
forward_step_func, data_iterator, model,
optimizer, timers, forward_only=False)
# Empty unused memory
if args.empty_unused_memory_level >= 1:
torch.cuda.empty_cache()
# All-reduce if needed.
if args.DDP_impl == 'local':
timers('backward-params-all-reduce').start()
......@@ -419,6 +411,10 @@ def train_step(forward_step_func, data_iterator,
else:
skipped_iter = 1
# Empty unused memory
if args.empty_unused_memory_level >= 2:
torch.cuda.empty_cache()
if mpu.is_pipeline_last_stage(ignore_virtual=True):
# Average loss across microbatches.
loss_reduced = {}
......@@ -531,11 +527,28 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
if args.log_timers_to_tensorboard:
timers.write(timers_to_log, writer, iteration,
normalizer=total_iterations)
if args.log_memory_to_tensorboard:
mem_stats = torch.cuda.memory_stats()
writer.add_scalar(
"mem-reserved-bytes",
mem_stats["reserved_bytes.all.current"],
iteration,
)
writer.add_scalar(
"mem-allocated-bytes",
mem_stats["allocated_bytes.all.current"],
iteration,
)
writer.add_scalar(
"mem-allocated-count",
mem_stats["allocation.all.current"],
iteration,
)
if iteration % args.log_interval == 0:
elapsed_time = timers('interval-time').elapsed()
elapsed_time_per_iteration = elapsed_time / total_iterations
if writer and torch.distributed.get_rank() == 0:
if writer:
if args.log_timers_to_tensorboard:
writer.add_scalar('iteration-time',
elapsed_time_per_iteration, iteration)
......@@ -705,17 +718,15 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
print_rank_0('Evaluating iter {}/{}'.format(iteration,
args.eval_iters))
if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving
else:
forward_backward_func = forward_backward_pipelining_without_interleaving
else:
forward_backward_func = forward_backward_no_pipelining
forward_backward_func = get_forward_backward_func()
loss_dicts = forward_backward_func(
forward_step_func, data_iterator, model, optimizer=None,
timers=None, forward_only=True)
# Empty unused memory
if args.empty_unused_memory_level >= 1:
torch.cuda.empty_cache()
if mpu.is_pipeline_last_stage(ignore_virtual=True):
# Reduce across processes.
for loss_dict in loss_dicts:
......@@ -748,7 +759,7 @@ def evaluate_and_print_results(prefix, forward_step_func,
string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
ppl = math.exp(min(20, total_loss_dict[key].item()))
string += '{} PPL: {:.6E} | '.format(key, ppl)
if writer and is_last_rank():
if writer:
writer.add_scalar('{} validation'.format(key),
total_loss_dict[key].item(),
iteration)
......@@ -787,10 +798,9 @@ def build_train_valid_test_data_iterators(
'only backward compatiblity support for iteration-based training'
args.consumed_train_samples = args.iteration * args.global_batch_size
if args.iteration > 0 and args.consumed_valid_samples == 0:
assert args.train_samples is None, \
'only backward compatiblity support for iteration-based training'
args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
args.eval_iters * args.global_batch_size
if args.train_samples is None:
args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
args.eval_iters * args.global_batch_size
# Data loader only on rank 0 of each model parallel group.
if mpu.get_tensor_model_parallel_rank() == 0:
......
#!/bin/bash
WORLD_SIZE=8
DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \
--nnodes 1 \
--node_rank 0 \
--master_addr localhost \
--master_port 6000"
CHECKPOINT="/home/universal-lm-data.cosmos549/chkpts/gpt2/8.3B_no_rng"
DATA_PATH="/home/universal-lm-data.cosmos549/scratch/mshoeybi/data/gpt2"
VOCAB_FILE="${DATA_PATH}/bpe/gpt2-vocab.json"
MERGE_FILE="${DATA_PATH}/bpe/gpt2-merges.txt"
python -m torch.distributed.launch $DISTRIBUTED_ARGS tools/run_api_server.py \
--tensor-model-parallel-size 8 \
--pipeline-model-parallel-size 1 \
--num-layers 72 \
--hidden-size 3072 \
--load $CHECKPOINT \
--num-attention-heads 24 \
--max-position-embeddings 1024 \
--tokenizer-type GPT2BPETokenizer \
--fp16 \
--micro-batch-size 1 \
--seq-length 1024 \
--out-seq-length 1024 \
--temperature 1.0 \
--vocab-file $VOCAB_FILE \
--merge-file $MERGE_FILE \
--top_p 0.9 \
--seed 42
## End-to-End Training of Neural Retrievers for Open-Domain Question Answering
Below we present the steps to run unsupervised and supervised trainining and evaluation of the retriever for [open domain question answering](https://arxiv.org/abs/2101.00408).
## Retriever Training
#### Unsupervised pretraining
1. Use `tools/preprocess_data.py` to preprocess the dataset for Inverse Cloze Task (ICT), which we call unsupervised pretraining. This script takes as input a corpus in loose JSON format and creates fixed-size blocks of text as the fundamental units of data. For a corpus like Wikipedia, this will mean multiple sentences per block and multiple blocks per document. Run [`tools/preprocess_data.py`](../../tools/preprocess_data.py) to construct one or more indexed datasets with the `--split-sentences` argument to make sentences the basic unit. We construct two datasets, one with the title of every document and another with the body.
<pre>
python tools/preprocess_data.py \
--input /path/to/corpus.json \
--json-keys text title \
--split-sentences \
--tokenizer-type BertWordPieceLowerCase \
--vocab-file /path/to/vocab.txt \
--output-prefix corpus_indexed \
--workers 10
</pre>
2. The [`examples/pretrain_ict.sh`](../../examples/pretrain_ict.sh) script runs a single GPU 217M parameter biencoder model for ICT retriever training. Single GPU training is primarily intended for debugging purposes, as the code is developed for distributed training. The script uses a pretrained BERT model and we use a total of batch size of 4096 for the ICT training.
3. Evaluate the pretrained ICT model using [`examples/evaluate_retriever_nq.sh`](../../examples/evaluate_retriever_nq.sh) for [Google's Natural Questions Open dataset](https://arxiv.org/pdf/1906.00300.pdf).
#### Supervised finetuning
1. Use the above pretrained ICT model to finetune using [Google's Natural Questions Open dataset](https://github.com/google-research/language/tree/master/language/orqa). The script [`examples/finetune_retriever_distributed.sh`](../../examples/finetune_retriever_distributed.sh) provides an example for how to perform the training. Our finetuning process includes retriever score scaling and longer training (80 epochs) on top [DPR training](https://arxiv.org/abs/2004.04906).
2. Evaluate the finetuned model using the same evaluation script as mentioned above for the unsupervised model.
More details on the retriever are available in [our paper](https://arxiv.org/abs/2101.00408).
## Reader Training
The reader component will be available soon.
......@@ -244,6 +244,9 @@ def normalize_question(question):
question = question[:-1]
return question
# The following class reads the datasets for training retriever as
# prepared by the DPR codebase (https://github.com/facebookresearch/DPR)
class NQSupervisedDataset(OpenRetrievalAbstractDataset):
def __init__(self, name, datapaths, tokenizer, max_seq_length, \
......
......@@ -33,6 +33,28 @@ from tasks.orqa.supervised.eval_utils import accuracy_func_provider
from tasks.orqa.supervised.eval_utils import process_batch, task_collate_fn
from tasks.orqa.evaluate_utils import ORQAEvaluator
# input_ is a 2D tensor
def check_and_append_tensor_for_gather(group, rank, world_size, input_):
# gather the size of the first dimension of the tensor from all ranks
current_length = input_.size()[0]
first_dim = torch.tensor([[current_length]],
device=torch.cuda.current_device())
input_list = [torch.empty_like(first_dim) for _ in range(world_size)]
input_list[rank].copy_(first_dim)
torch.distributed.all_gather(input_list, first_dim, group=group)
all_input_list = torch.cat(input_list, dim=0).contiguous()
max_length = torch.max(all_input_list)
# if the size are different than the max, extend the tensor
# accordingly
if max_length > current_length:
padding=tuple([0] * (input_.dim() * 2 - 1)) + \
tuple([max_length - current_length])
input_ = F.pad(input=input_, pad=padding)
return input_
def orqa(Dataset):
def cross_entropy_forward_step(batch, model):
......@@ -47,6 +69,8 @@ def orqa(Dataset):
except BaseException:
batch_ = batch
group, rank, world_size = get_group_world_size_rank()
query_tokens, query_mask, query_types, query_pad_mask, \
context_tokens, context_mask, context_types, context_pad_mask, \
neg_context_tokens, neg_context_mask, neg_context_types, \
......@@ -61,6 +85,14 @@ def orqa(Dataset):
query_list.append(tokenizer.decode(query_tokens[i].tolist()))
context_list.append(tokenizer.decode(context_tokens[i].tolist()))
if neg_context_tokens is not None:
neg_context_tokens = check_and_append_tensor_for_gather(group,
rank, world_size, neg_context_tokens)
neg_context_mask = check_and_append_tensor_for_gather(group,
rank, world_size, neg_context_mask)
neg_context_types = check_and_append_tensor_for_gather(group,
rank, world_size, neg_context_types)
if neg_context_tokens is not None:
context_tokens = torch.cat([context_tokens, neg_context_tokens])
context_mask = torch.cat([context_mask, neg_context_mask])
......@@ -70,7 +102,6 @@ def orqa(Dataset):
output_tensor = model(query_tokens, query_mask,
query_types, context_tokens,
context_mask, context_types)
return output_tensor, partial(cross_entropy_loss_func, query_tokens, context_tokens)
......
......@@ -20,7 +20,7 @@ python blacklist_urls.py <path to the dowloaded deduplicated URLs> <filename for
4. Merge the contents into one loose json file with 1 json per newline of the format `{'text': text, 'url': unique_url}`. It is important for the url to be unique.
# Prepare the data for GPT-2 training:
# Prepare the data for GPT training:
1. Perform ftfy, english detection and remove documents with less than 128 tokens. This step can be sharded and run on shards.
```
......@@ -50,7 +50,7 @@ shuf <cleaned deduped data file> -o train_data.json
To deduplicate the downstream tasks (e.g. lambada, squad) from the training dataset, we run the following command.
```
python filter_ngrams.py --tasks <name of he task, e.g. lambada, squad> --dedup-dataset <training dataset to deduplicate> <json key> --output <output training dataset>
python filter_ngrams.py --tasks <name of the task, e.g. lambada, squad> --dedup-dataset <training dataset to deduplicate> <json key> --output <output training dataset>
```
We use 13-grams by default for the deduplication. When we find a 13-gram match in a training document, we split the document into two pieces and remove the 13-gram along with 200 characters from the both side of the 13-gram. We also remove any splitted document with less than 200 characters or if a document got splitted more than 10 times. These parameters can be changed using corresponding arguments.
......
......@@ -26,7 +26,7 @@ from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron
from megatron.model import GPTModel
from megatron.training import get_model
from megatron.api_server import MegatronServer
from megatron.text_generation_server import MegatronServer
from megatron.text_generation_utils import generate
import torch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment