Commit 67aa8619 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'main' into checkpoint_util

parents 03d09af0 f5345dfa
...@@ -27,6 +27,7 @@ from torch.nn.parameter import Parameter ...@@ -27,6 +27,7 @@ from torch.nn.parameter import Parameter
from .initialize import get_tensor_model_parallel_rank from .initialize import get_tensor_model_parallel_rank
from .initialize import get_tensor_model_parallel_world_size from .initialize import get_tensor_model_parallel_world_size
from .initialize import get_tensor_model_parallel_group
from .mappings import copy_to_tensor_model_parallel_region from .mappings import copy_to_tensor_model_parallel_region
from .mappings import gather_from_tensor_model_parallel_region from .mappings import gather_from_tensor_model_parallel_region
from .mappings import reduce_from_tensor_model_parallel_region from .mappings import reduce_from_tensor_model_parallel_region
...@@ -200,6 +201,37 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -200,6 +201,37 @@ class VocabParallelEmbedding(torch.nn.Module):
return output return output
class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function):
"""
Column-parallel linear layer execution with asynchronous all-reduce
execution in backprop.
"""
@staticmethod
def forward(ctx, input, weight, bias):
ctx.save_for_backward(input, weight)
ctx.use_bias = bias is not None
output = torch.matmul(input, weight.t())
if bias is not None:
output = output + bias
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
use_bias = ctx.use_bias
grad_input = grad_output.matmul(weight)
# Asyncronous all-reduce
handle = torch.distributed.all_reduce(
grad_input, group=get_tensor_model_parallel_group(), async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
grad_weight = grad_output.t().matmul(input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
handle.wait()
return grad_input, grad_weight, grad_bias
class ColumnParallelLinear(torch.nn.Module): class ColumnParallelLinear(torch.nn.Module):
"""Linear layer with column parallelism. """Linear layer with column parallelism.
...@@ -276,16 +308,30 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -276,16 +308,30 @@ class ColumnParallelLinear(torch.nn.Module):
self.bias.zero_() self.bias.zero_()
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
self.async_tensor_model_parallel_allreduce = (
not args.no_async_tensor_model_parallel_allreduce and
world_size > 1)
def forward(self, input_): def forward(self, input_):
# Set up backprop all-reduce.
input_parallel = copy_to_tensor_model_parallel_region(input_)
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
output_parallel = F.linear(input_parallel, self.weight, bias)
if self.async_tensor_model_parallel_allreduce:
input_shape = input_.shape
input_ = input_.view(input_shape[0] * input_shape[1],input_shape[2])
# Maxtrix multiply with asynchronouse all-reduce execution
output_parallel = ColumnParallelLinearWithAsyncAllreduce.apply(
input_, self.weight, bias)
output_parallel = output_parallel.view(
input_shape[0], input_shape[1], output_parallel.shape[1])
else:
# Set up backprop all-reduce.
input_parallel = copy_to_tensor_model_parallel_region(input_)
# Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight, bias)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
output = gather_from_tensor_model_parallel_region(output_parallel) output = gather_from_tensor_model_parallel_region(output_parallel)
......
...@@ -58,7 +58,8 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): ...@@ -58,7 +58,8 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
grad_not_none = param.grad is not None grad_not_none = param.grad is not None
is_not_shared = param_is_not_shared(param) is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
grad = param.grad.detach() if grad_not_none:
grad = param.grad.detach()
if grad_not_none: if grad_not_none:
# Make sure the grads are in fp32 # Make sure the grads are in fp32
assert param.grad.type() == 'torch.cuda.FloatTensor' assert param.grad.type() == 'torch.cuda.FloatTensor'
......
...@@ -179,7 +179,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): ...@@ -179,7 +179,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
a `main_grad` field. If this is set, we are assuming a `main_grad` field. If this is set, we are assuming
that the model parameters are store in the `main_grad` that the model parameters are store in the `main_grad`
field instead of the typical `grad` field. This happens field instead of the typical `grad` field. This happens
for the DDP cases where there is a contihuous buffer for the DDP cases where there is a continuous buffer
holding the gradients. For example for bfloat16, we want holding the gradients. For example for bfloat16, we want
to do gradient accumulation and all-reduces in float32 to do gradient accumulation and all-reduces in float32
and as a result we store those gradients in the main_grad. and as a result we store those gradients in the main_grad.
...@@ -312,7 +312,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): ...@@ -312,7 +312,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
for model_group, main_group in zip(self.float16_groups, for model_group, main_group in zip(self.float16_groups,
self.fp32_from_float16_groups): self.fp32_from_float16_groups):
for model_param, main_param in zip(model_group, main_group): for model_param, main_param in zip(model_group, main_group):
if self.params_have_main_grad: if self.params_have_main_grad and hasattr(model_param, 'main_grad'):
main_param.grad = model_param.main_grad.float() main_param.grad = model_param.main_grad.float()
else: else:
if model_param.grad is not None: if model_param.grad is not None:
......
...@@ -22,8 +22,8 @@ from megatron import mpu ...@@ -22,8 +22,8 @@ from megatron import mpu
def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
use_ring_exchange=False, tensor_shape=None, tensor_shape,
override_scatter_gather_tensors_in_pipeline=False, use_ring_exchange=False,
dtype_=None): dtype_=None):
"""Communicate tensors between stages. Used as helper method in other """Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py. communication methods that are used in megatron/schedules.py.
...@@ -37,16 +37,13 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -37,16 +37,13 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
previous rank. previous rank.
recv_next: boolean for whether tensor should be received from recv_next: boolean for whether tensor should be received from
next rank. next rank.
tensor_shape: shape of tensor to receive (this method assumes that all
tensors sent and received in a single function call are
the same shape).
use_ring_exchange: boolean for whether torch.distributed.ring_exchange() use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
API should be used. API should be used.
tensor_shape: optional, use when the input sequence contains less dtype_: optional, this is used when the tensor that needs to be
tokens than the default sequence length communicated is different from args.params_dtype.
override_scatter_gather_tensors_in_pipeline: optional, this is used
when tensor_shape is
provided to overwide
scatter gather tensors
dtype_: optional, this is used when tensor_shape is provied and what
is the type of tensor_shape
Returns: Returns:
(tensor_recv_prev, tensor_recv_next) (tensor_recv_prev, tensor_recv_next)
""" """
...@@ -56,12 +53,22 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -56,12 +53,22 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# if needed. # if needed.
tensor_recv_prev = None tensor_recv_prev = None
tensor_recv_next = None tensor_recv_next = None
# Some legacy inference code doesn't set the tensor shape, do so now
# for the normal values for gpt/bert. This could be removed if inference
# code is changed to provide tensor_shape.
if tensor_shape is None: if tensor_shape is None:
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size) tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
if not override_scatter_gather_tensors_in_pipeline and \
args.scatter_gather_tensors_in_pipeline: override_scatter_gather_tensors_in_pipeline = False
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) // \ if args.scatter_gather_tensors_in_pipeline:
mpu.get_tensor_model_parallel_world_size() tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1)
if tensor_chunk_shape % mpu.get_tensor_model_parallel_world_size() == 0:
tensor_chunk_shape = tensor_chunk_shape // \
mpu.get_tensor_model_parallel_world_size()
else:
tensor_chunk_shape = tensor_shape
override_scatter_gather_tensors_in_pipeline = True
else: else:
tensor_chunk_shape = tensor_shape tensor_chunk_shape = tensor_shape
dtype = args.params_dtype dtype = args.params_dtype
...@@ -143,9 +150,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -143,9 +150,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
return tensor_recv_prev, tensor_recv_next return tensor_recv_prev, tensor_recv_next
def recv_forward(tensor_shape=None, def recv_forward(tensor_shape=None, dtype_=None, timers=None):
override_scatter_gather_tensors_in_pipeline=False,
dtype_=None, timers=None):
"""Receive tensor from previous rank in pipeline (forward receive).""" """Receive tensor from previous rank in pipeline (forward receive)."""
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
...@@ -159,15 +164,13 @@ def recv_forward(tensor_shape=None, ...@@ -159,15 +164,13 @@ def recv_forward(tensor_shape=None,
recv_prev=True, recv_prev=True,
recv_next=False, recv_next=False,
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
override_scatter_gather_tensors_in_pipeline=\
override_scatter_gather_tensors_in_pipeline,
dtype_=dtype_) dtype_=dtype_)
if timers is not None: if timers is not None:
timers('forward-recv').stop() timers('forward-recv').stop()
return input_tensor return input_tensor
def recv_backward(timers=None): def recv_backward(tensor_shape=None, timers=None):
"""Receive tensor from next rank in pipeline (backward receive).""" """Receive tensor from next rank in pipeline (backward receive)."""
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
output_tensor_grad = None output_tensor_grad = None
...@@ -178,15 +181,14 @@ def recv_backward(timers=None): ...@@ -178,15 +181,14 @@ def recv_backward(timers=None):
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=None, tensor_send_prev=None,
recv_prev=False, recv_prev=False,
recv_next=True) recv_next=True,
tensor_shape=tensor_shape)
if timers is not None: if timers is not None:
timers('backward-recv').stop() timers('backward-recv').stop()
return output_tensor_grad return output_tensor_grad
def send_forward(output_tensor, timers=None, def send_forward(output_tensor, tensor_shape=None, dtype_=None, timers=None):
override_scatter_gather_tensors_in_pipeline=False,
dtype_=None):
"""Send tensor to next rank in pipeline (forward send).""" """Send tensor to next rank in pipeline (forward send)."""
if not mpu.is_pipeline_last_stage(): if not mpu.is_pipeline_last_stage():
...@@ -197,14 +199,13 @@ def send_forward(output_tensor, timers=None, ...@@ -197,14 +199,13 @@ def send_forward(output_tensor, timers=None,
tensor_send_prev=None, tensor_send_prev=None,
recv_prev=False, recv_prev=False,
recv_next=False, recv_next=False,
override_scatter_gather_tensors_in_pipeline=\ tensor_shape=tensor_shape,
override_scatter_gather_tensors_in_pipeline,
dtype_=dtype_) dtype_=dtype_)
if timers is not None: if timers is not None:
timers('forward-send').stop() timers('forward-send').stop()
def send_backward(input_tensor_grad, timers=None): def send_backward(input_tensor_grad, tensor_shape=None, timers=None):
"""Send tensor to previous rank in pipeline (backward send).""" """Send tensor to previous rank in pipeline (backward send)."""
if not mpu.is_pipeline_first_stage(): if not mpu.is_pipeline_first_stage():
if timers is not None: if timers is not None:
...@@ -213,12 +214,13 @@ def send_backward(input_tensor_grad, timers=None): ...@@ -213,12 +214,13 @@ def send_backward(input_tensor_grad, timers=None):
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=input_tensor_grad, tensor_send_prev=input_tensor_grad,
recv_prev=False, recv_prev=False,
recv_next=False) recv_next=False,
tensor_shape=tensor_shape)
if timers is not None: if timers is not None:
timers('backward-send').stop() timers('backward-send').stop()
def send_forward_recv_backward(output_tensor, timers=None): def send_forward_recv_backward(output_tensor, tensor_shape=None, timers=None):
"""Batched send and recv with next rank in pipeline.""" """Batched send and recv with next rank in pipeline."""
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
output_tensor_grad = None output_tensor_grad = None
...@@ -229,13 +231,14 @@ def send_forward_recv_backward(output_tensor, timers=None): ...@@ -229,13 +231,14 @@ def send_forward_recv_backward(output_tensor, timers=None):
tensor_send_next=output_tensor, tensor_send_next=output_tensor,
tensor_send_prev=None, tensor_send_prev=None,
recv_prev=False, recv_prev=False,
recv_next=True) recv_next=True,
tensor_shape=tensor_shape)
if timers is not None: if timers is not None:
timers('forward-send-backward-recv').stop() timers('forward-send-backward-recv').stop()
return output_tensor_grad return output_tensor_grad
def send_backward_recv_forward(input_tensor_grad, timers=None): def send_backward_recv_forward(input_tensor_grad, tensor_shape=None, timers=None):
"""Batched send and recv with previous rank in pipeline.""" """Batched send and recv with previous rank in pipeline."""
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
input_tensor = None input_tensor = None
...@@ -246,13 +249,14 @@ def send_backward_recv_forward(input_tensor_grad, timers=None): ...@@ -246,13 +249,14 @@ def send_backward_recv_forward(input_tensor_grad, timers=None):
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=input_tensor_grad, tensor_send_prev=input_tensor_grad,
recv_prev=True, recv_prev=True,
recv_next=False) recv_next=False,
tensor_shape=tensor_shape)
if timers is not None: if timers is not None:
timers('backward-send-forward-recv').stop() timers('backward-send-forward-recv').stop()
return input_tensor return input_tensor
def send_forward_recv_forward(output_tensor, recv_prev, timers=None): def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape=None, timers=None):
"""Batched recv from previous rank and send to next rank in pipeline.""" """Batched recv from previous rank and send to next rank in pipeline."""
if timers is not None: if timers is not None:
timers('forward-send-forward-recv').start() timers('forward-send-forward-recv').start()
...@@ -260,13 +264,14 @@ def send_forward_recv_forward(output_tensor, recv_prev, timers=None): ...@@ -260,13 +264,14 @@ def send_forward_recv_forward(output_tensor, recv_prev, timers=None):
tensor_send_next=output_tensor, tensor_send_next=output_tensor,
tensor_send_prev=None, tensor_send_prev=None,
recv_prev=recv_prev, recv_prev=recv_prev,
recv_next=False) recv_next=False,
tensor_shape=tensor_shape)
if timers is not None: if timers is not None:
timers('forward-send-forward-recv').stop() timers('forward-send-forward-recv').stop()
return input_tensor return input_tensor
def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None): def send_backward_recv_backward(input_tensor_grad, recv_next, tensor_shape=None, timers=None):
"""Batched recv from next rank and send to previous rank in pipeline.""" """Batched recv from next rank and send to previous rank in pipeline."""
if timers is not None: if timers is not None:
timers('backward-send-backward-recv').start() timers('backward-send-backward-recv').start()
...@@ -274,7 +279,8 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None): ...@@ -274,7 +279,8 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=input_tensor_grad, tensor_send_prev=input_tensor_grad,
recv_prev=False, recv_prev=False,
recv_next=recv_next) recv_next=recv_next,
tensor_shape=tensor_shape)
if timers is not None: if timers is not None:
timers('backward-send-backward-recv').stop() timers('backward-send-backward-recv').stop()
return output_tensor_grad return output_tensor_grad
...@@ -282,7 +288,7 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None): ...@@ -282,7 +288,7 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
def send_forward_backward_recv_forward_backward( def send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad, recv_prev, output_tensor, input_tensor_grad, recv_prev,
recv_next, timers=None): recv_next, tensor_shape=None, timers=None):
"""Batched send and recv with previous and next ranks in pipeline.""" """Batched send and recv with previous and next ranks in pipeline."""
if timers is not None: if timers is not None:
timers('forward-backward-send-forward-backward-recv').start() timers('forward-backward-send-forward-backward-recv').start()
...@@ -290,7 +296,8 @@ def send_forward_backward_recv_forward_backward( ...@@ -290,7 +296,8 @@ def send_forward_backward_recv_forward_backward(
tensor_send_next=output_tensor, tensor_send_next=output_tensor,
tensor_send_prev=input_tensor_grad, tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev, recv_prev=recv_prev,
recv_next=recv_next) recv_next=recv_next,
tensor_shape=tensor_shape)
if timers is not None: if timers is not None:
timers('forward-backward-send-forward-backward-recv').stop() timers('forward-backward-send-forward-backward-recv').stop()
return input_tensor, output_tensor_grad return input_tensor, output_tensor_grad
...@@ -25,6 +25,8 @@ from megatron import p2p_communication ...@@ -25,6 +25,8 @@ from megatron import p2p_communication
from megatron.utils import unwrap_model from megatron.utils import unwrap_model
from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module from megatron.model import Float16Module
from megatron.model import ModelType
def get_forward_backward_func(): def get_forward_backward_func():
args = get_args() args = get_args()
...@@ -48,11 +50,18 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r ...@@ -48,11 +50,18 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
passed-in input_tensor is used. passed-in input_tensor is used.
Returns output tensor.""" Returns output tensor."""
args = get_args()
timers = get_timers() timers = get_timers()
timers('forward-compute').start() timers('forward-compute').start()
unwrapped_model = unwrap_model( unwrapped_model = unwrap_model(
model, (torchDDP, LocalDDP, Float16Module)) model, (torchDDP, LocalDDP, Float16Module))
unwrap_output_tensor = False
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
unwrap_output_tensor = True
unwrapped_model.set_input_tensor(input_tensor) unwrapped_model.set_input_tensor(input_tensor)
output_tensor, loss_func = forward_step_func(data_iterator, model) output_tensor, loss_func = forward_step_func(data_iterator, model)
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
...@@ -62,7 +71,15 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r ...@@ -62,7 +71,15 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
losses_reduced.append(loss_reduced) losses_reduced.append(loss_reduced)
timers('forward-compute').stop() timers('forward-compute').stop()
return output_tensor # If T5 model (or other model with encoder and decoder)
# and in decoder stack, then send encoder_hidden_state
# downstream as well.
if mpu.is_pipeline_stage_after_split() and \
args.model_type == ModelType.encoder_and_decoder:
return [output_tensor, input_tensor[-1]]
if unwrap_output_tensor:
return output_tensor
return [output_tensor]
def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad): def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
...@@ -73,24 +90,53 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad): ...@@ -73,24 +90,53 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
Returns gradient of loss with respect to input tensor (None if first Returns gradient of loss with respect to input tensor (None if first
stage).""" stage)."""
# NOTE: This code currently can handle at most one skip connection. It
# needs to be modified slightly to support arbitrary numbers of skip
# connections.
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
timers('backward-compute').start() timers('backward-compute').start()
# Retain the grad on the input_tensor. # Retain the grad on the input_tensor.
if input_tensor is not None: unwrap_input_tensor_grad = False
input_tensor.retain_grad() if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
unwrap_input_tensor_grad = True
for x in input_tensor:
if x is not None:
x.retain_grad()
if not isinstance(output_tensor, list):
output_tensor = [output_tensor]
if not isinstance(output_tensor_grad, list):
output_tensor_grad = [output_tensor_grad]
# Backward pass. # Backward pass.
if output_tensor_grad is None: if output_tensor_grad[0] is None:
output_tensor = optimizer.scale_loss(output_tensor) output_tensor = optimizer.scale_loss(output_tensor[0])
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad) torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0])
# Collect the grad of the input_tensor. # Collect the grad of the input_tensor.
input_tensor_grad = None input_tensor_grad = [None]
if input_tensor is not None: if input_tensor is not None:
input_tensor_grad = input_tensor.grad input_tensor_grad = []
for x in input_tensor:
if x is None:
input_tensor_grad.append(None)
else:
input_tensor_grad.append(x.grad)
# Handle single skip connection if it exists (encoder_hidden_state in
# model with encoder and decoder).
if mpu.get_pipeline_model_parallel_world_size() > 1 and \
mpu.is_pipeline_stage_after_split() and \
args.model_type == ModelType.encoder_and_decoder:
if output_tensor_grad[1] is not None:
input_tensor_grad[-1].add_(output_tensor_grad[1])
if unwrap_input_tensor_grad:
input_tensor_grad = input_tensor_grad[0]
timers('backward-compute').stop() timers('backward-compute').stop()
...@@ -153,6 +199,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -153,6 +199,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size() pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size()
pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank() pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank()
args = get_args()
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
# Compute number of warmup and remaining microbatches. # Compute number of warmup and remaining microbatches.
num_model_chunks = len(model) num_model_chunks = len(model)
num_microbatches = get_num_microbatches() * num_model_chunks num_microbatches = get_num_microbatches() * num_model_chunks
...@@ -237,7 +286,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -237,7 +286,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
# Run warmup forward passes. # Run warmup forward passes.
mpu.set_virtual_pipeline_model_parallel_rank(0) mpu.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append( input_tensors[0].append(
p2p_communication.recv_forward(timers=timers)) p2p_communication.recv_forward(tensor_shape, timers=timers))
for k in range(num_warmup_microbatches): for k in range(num_warmup_microbatches):
output_tensor = forward_step_helper(k) output_tensor = forward_step_helper(k)
...@@ -266,12 +315,15 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -266,12 +315,15 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
p2p_communication.send_forward_backward_recv_forward_backward( p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad, output_tensor, input_tensor_grad,
recv_prev=recv_prev, recv_next=recv_next, recv_prev=recv_prev, recv_next=recv_next,
tensor_shape=tensor_shape,
timers=timers) timers=timers)
output_tensor_grads[num_model_chunks-1].append(output_tensor_grad) output_tensor_grads[num_model_chunks-1].append(output_tensor_grad)
else: else:
input_tensor = \ input_tensor = \
p2p_communication.send_forward_recv_forward( p2p_communication.send_forward_recv_forward(
output_tensor, recv_prev=recv_prev, timers=timers) output_tensor, recv_prev=recv_prev,
tensor_shape=tensor_shape,
timers=timers)
input_tensors[next_forward_model_chunk_id].append(input_tensor) input_tensors[next_forward_model_chunk_id].append(input_tensor)
# Run 1F1B in steady state. # Run 1F1B in steady state.
...@@ -335,7 +387,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -335,7 +387,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
p2p_communication.send_forward_backward_recv_forward_backward( p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad, output_tensor, input_tensor_grad,
recv_prev=recv_prev, recv_next=recv_next, recv_prev=recv_prev, recv_next=recv_next,
timers=timers) tensor_shape=tensor_shape, timers=timers)
# Put input_tensor and output_tensor_grad in data structures in the # Put input_tensor and output_tensor_grad in data structures in the
# right location. # right location.
...@@ -349,7 +401,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -349,7 +401,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
if not forward_only: if not forward_only:
if all_warmup_microbatches: if all_warmup_microbatches:
output_tensor_grads[num_model_chunks-1].append( output_tensor_grads[num_model_chunks-1].append(
p2p_communication.recv_backward(timers=timers)) p2p_communication.recv_backward(tensor_shape, timers=timers))
for k in range(num_microbatches_remaining, num_microbatches): for k in range(num_microbatches_remaining, num_microbatches):
input_tensor_grad = backward_step_helper(k) input_tensor_grad = backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False) next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
...@@ -361,11 +413,107 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -361,11 +413,107 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
recv_next = False recv_next = False
output_tensor_grads[next_backward_model_chunk_id].append( output_tensor_grads[next_backward_model_chunk_id].append(
p2p_communication.send_backward_recv_backward( p2p_communication.send_backward_recv_backward(
input_tensor_grad, recv_next=recv_next, timers=timers)) input_tensor_grad, recv_next=recv_next,
tensor_shape=tensor_shape,
timers=timers))
return losses_reduced return losses_reduced
def get_tensor_shapes(rank, model_type):
# Determine right tensor sizes (based on position of rank with respect to split
# rank) and model size.
# Send two tensors if model is T5 and rank is in decoder stage:
# first tensor is decoder (pre-transpose),
# second tensor is encoder (post-transpose).
# If model is T5 and rank is at the boundary:
# send one tensor (post-transpose from encoder).
# Otherwise, send one tensor (pre-transpose).
args = get_args()
tensor_shapes = []
if model_type == ModelType.encoder_and_decoder:
if mpu.is_pipeline_stage_before_split(rank):
# If next rank is after split, then need transpose for encoder_hidden_state.
if mpu.is_pipeline_stage_before_split(rank+1):
tensor_shapes.append((args.seq_length, args.micro_batch_size, args.hidden_size))
else:
tensor_shapes.append((args.micro_batch_size, args.seq_length, args.hidden_size))
else:
tensor_shapes.append((args.decoder_seq_length, args.micro_batch_size, args.hidden_size))
tensor_shapes.append((args.micro_batch_size, args.seq_length, args.hidden_size))
else:
tensor_shapes.append((args.seq_length, args.micro_batch_size, args.hidden_size))
return tensor_shapes
def recv_forward(tensor_shapes, timers):
input_tensors = []
for tensor_shape in tensor_shapes:
if tensor_shape is None:
input_tensors.append(None)
else:
input_tensors.append(p2p_communication.recv_forward(tensor_shape,
timers=timers))
return input_tensors
def recv_backward(tensor_shapes, timers):
output_tensor_grads = []
for tensor_shape in tensor_shapes:
if tensor_shape is None:
output_tensor_grads.append(None)
else:
output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape,
timers=timers))
return output_tensor_grads
def send_forward(output_tensors, tensor_shapes, timers):
if not isinstance(output_tensors, list):
output_tensors = [output_tensors]
for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
if tensor_shape is None:
continue
p2p_communication.send_forward(output_tensor, tensor_shape, timers=timers)
def send_backward(input_tensor_grads, tensor_shapes, timers):
if not isinstance(input_tensor_grads, list):
input_tensor_grads = [input_tensor_grads]
for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
if tensor_shape is None:
continue
p2p_communication.send_backward(input_tensor_grad, tensor_shape, timers=timers)
def send_forward_recv_backward(output_tensors, tensor_shapes, timers):
if not isinstance(output_tensors, list):
output_tensors = [output_tensors]
output_tensor_grads = []
for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
if tensor_shape is None:
output_tensor_grads.append(None)
continue
output_tensor_grad = p2p_communication.send_forward_recv_backward(
output_tensor, tensor_shape, timers=timers)
output_tensor_grads.append(output_tensor_grad)
return output_tensor_grads
def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers):
if not isinstance(input_tensor_grads, list):
input_tensor_grads = [input_tensor_grads]
input_tensors = []
for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
if tensor_shape is None:
input_tensors.append(None)
continue
input_tensor = p2p_communication.send_backward_recv_forward(
input_tensor_grad, tensor_shape, timers=timers)
input_tensors.append(input_tensor)
return input_tensors
def forward_backward_pipelining_without_interleaving(forward_step_func, data_iterator, def forward_backward_pipelining_without_interleaving(forward_step_func, data_iterator,
model, optimizer, timers, model, optimizer, timers,
forward_only): forward_only):
...@@ -389,6 +537,13 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -389,6 +537,13 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
num_microbatches_remaining = \ num_microbatches_remaining = \
num_microbatches - num_warmup_microbatches num_microbatches - num_warmup_microbatches
unwrapped_model = unwrap_model(
model, (torchDDP, LocalDDP, Float16Module))
model_type = unwrapped_model.model_type
rank = mpu.get_pipeline_model_parallel_rank()
recv_tensor_shapes = get_tensor_shapes(rank-1, model_type)
send_tensor_shapes = get_tensor_shapes(rank, model_type)
# Input, output tensors only need to be saved when doing backward passes # Input, output tensors only need to be saved when doing backward passes
input_tensors = None input_tensors = None
output_tensors = None output_tensors = None
...@@ -399,10 +554,10 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -399,10 +554,10 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
# Run warmup forward passes. # Run warmup forward passes.
for i in range(num_warmup_microbatches): for i in range(num_warmup_microbatches):
input_tensor = p2p_communication.recv_forward(timers=timers) input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
output_tensor = forward_step(forward_step_func, data_iterator, model, output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced) input_tensor, losses_reduced)
p2p_communication.send_forward(output_tensor, timers=timers) send_forward(output_tensor, send_tensor_shapes, timers=timers)
if not forward_only: if not forward_only:
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
...@@ -412,7 +567,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -412,7 +567,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
# If all microbatches are run in warmup / cooldown phase, then no need to # If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here. # receive this tensor here.
if num_microbatches_remaining > 0: if num_microbatches_remaining > 0:
input_tensor = p2p_communication.recv_forward(timers=timers) input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
# Run 1F1B in steady state. # Run 1F1B in steady state.
for i in range(num_microbatches_remaining): for i in range(num_microbatches_remaining):
...@@ -421,15 +576,16 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -421,15 +576,16 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
output_tensor = forward_step(forward_step_func, data_iterator, model, output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced) input_tensor, losses_reduced)
if forward_only: if forward_only:
p2p_communication.send_forward(output_tensor, timers=timers) send_forward(output_tensor, send_tensor_shapes, timers=timers)
if not last_iteration: if not last_iteration:
input_tensor = p2p_communication.recv_forward(timers=timers) input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
else: else:
output_tensor_grad = \ output_tensor_grad = \
p2p_communication.send_forward_recv_backward(output_tensor, send_forward_recv_backward(output_tensor,
timers=timers) send_tensor_shapes,
timers=timers)
# Add input_tensor and output_tensor to end of list. # Add input_tensor and output_tensor to end of list.
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
...@@ -446,11 +602,11 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -446,11 +602,11 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
if last_iteration: if last_iteration:
input_tensor = None input_tensor = None
p2p_communication.send_backward(input_tensor_grad, timers=timers) send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers)
else: else:
input_tensor = \ input_tensor = \
p2p_communication.send_backward_recv_forward( send_backward_recv_forward(
input_tensor_grad, timers=timers) input_tensor_grad, recv_tensor_shapes, timers=timers)
# Run cooldown backward passes. # Run cooldown backward passes.
if not forward_only: if not forward_only:
...@@ -458,12 +614,12 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -458,12 +614,12 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
input_tensor = input_tensors.pop(0) input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0) output_tensor = output_tensors.pop(0)
output_tensor_grad = p2p_communication.recv_backward(timers=timers) output_tensor_grad = recv_backward(send_tensor_shapes, timers=timers)
input_tensor_grad = \ input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor, backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad) output_tensor_grad)
p2p_communication.send_backward(input_tensor_grad, timers=timers) send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers)
return losses_reduced return losses_reduced
...@@ -13,18 +13,7 @@ ...@@ -13,18 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
MAJOR = 1
MINOR = 1.5
# Use the following formatting: (major, minor)
VERSION = (MAJOR, MINOR)
__version__ = '.'.join(map(str, VERSION))
__package_name__ = 'megatron-lm'
__contact_names__ = 'NVIDIA INC'
__url__ = 'https://github.com/NVIDIA/Megatron-LM'
__download_url__ = 'https://github.com/NVIDIA/Megatron-LM/releases'
__description__ = 'Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism.'
__license__ = 'See https://github.com/NVIDIA/Megatron-LM/blob/master/LICENSE'
__keywords__ = 'deep learning, Megatron, gpu, NLP, nvidia, pytorch, torch, language'
from .api import (
generate,
generate_and_post_process)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference API."""
import torch
from megatron import mpu
from .communication import broadcast_float_list
from .generation import (
generate_tokens_probs_and_return_on_first_stage,
score_and_return_on_first_stage)
from .tokenization import (
tokenize_prompts,
detokenize_generations)
def generate_and_post_process(model,
prompts=None,
tokens_to_generate=0,
return_output_log_probs=False,
top_k_sampling=0,
top_p_sampling=0.0,
temperature=1.0,
add_BOS=False,
use_eod_token_for_early_termination=True):
"""Run inference and post-process outputs, i.e., detokenize,
move to cpu and convert to list."""
# Main inference.
tokens, lengths, output_log_probs = generate(
model,
prompts=prompts,
tokens_to_generate=tokens_to_generate,
return_output_log_probs=return_output_log_probs,
top_k_sampling=top_k_sampling,
top_p_sampling=top_p_sampling,
temperature=temperature,
add_BOS=add_BOS,
use_eod_token_for_early_termination=use_eod_token_for_early_termination)
# Only post-process on first stage.
if mpu.is_pipeline_first_stage():
tokens, prompts_plus_generations, prompts_plus_generations_segments = \
detokenize_generations(tokens, lengths, True)
if return_output_log_probs:
output_log_probs = output_log_probs.cpu().numpy().tolist()
for i, (prob, seg) in enumerate(zip(output_log_probs, prompts_plus_generations_segments)):
output_log_probs[i] = prob[:len(seg)-1]
return prompts_plus_generations, prompts_plus_generations_segments, \
output_log_probs, tokens
return None
def generate(model,
prompts=None,
tokens_to_generate=0,
return_output_log_probs=False,
top_k_sampling=0,
top_p_sampling=0.0,
temperature=1.0,
add_BOS=False,
use_eod_token_for_early_termination=True):
"""Given prompts and input parameters, run inference and return:
tokens: prompts plus the generated tokens.
lengths: length of the prompt + generations. Note that we can
discard tokens in the tokens tensor that are after the
corresponding length.
output_log_probs: log probs of the tokens.
"""
# Make sure input params are avaialble to all ranks.
values = [tokens_to_generate,
return_output_log_probs,
top_k_sampling, top_p_sampling,
temperature, add_BOS, use_eod_token_for_early_termination]
values_float_tensor = broadcast_float_list(7, float_list=values)
tokens_to_generate = int(values_float_tensor[0].item())
return_output_log_probs = bool(values_float_tensor[1].item())
top_k_sampling = int(values_float_tensor[2].item())
top_p_sampling = values_float_tensor[3].item()
temperature = values_float_tensor[4].item()
add_BOS = bool(values_float_tensor[5].item())
use_eod_token_for_early_termination = bool(values_float_tensor[6].item())
# Tokenize prompts and get the batch.
# Note that these tensors are broadcaseted to all ranks.
if torch.distributed.get_rank() == 0:
assert prompts is not None
context_tokens_tensor, context_length_tensor = tokenize_prompts(
prompts=prompts, tokens_to_generate=tokens_to_generate, add_BOS=add_BOS)
if tokens_to_generate == 0:
return score_and_return_on_first_stage(
model, context_tokens_tensor, context_length_tensor)
# Main inference function.
# Note that the outputs are available on the first stage.
return generate_tokens_probs_and_return_on_first_stage(
model, context_tokens_tensor, context_length_tensor,
return_output_log_probs=return_output_log_probs,
top_k=top_k_sampling,
top_p=top_p_sampling,
temperature=temperature,
use_eod_token_for_early_termination=use_eod_token_for_early_termination)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Communications utilities."""
import torch
from megatron import mpu
# TODO: use functions from megatron/p2p
def recv_from_prev_pipeline_rank_(recv_buffer=None):
"""Receive from previous pipeline stage and update the
input buffer inplace."""
if not mpu.is_pipeline_first_stage():
assert recv_buffer is not None
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv, recv_buffer,
mpu.get_pipeline_model_parallel_prev_rank())
reqs = torch.distributed.batch_isend_irecv([recv_prev_op])
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
# TODO: use functions from megatron/p2p
def send_to_next_pipeline_rank(tensor=None):
"""Send output to the next pipeline stage."""
if not mpu.is_pipeline_last_stage():
assert tensor is not None
send_next_op = torch.distributed.P2POp(
torch.distributed.isend, tensor,
mpu.get_pipeline_model_parallel_next_rank())
reqs = torch.distributed.batch_isend_irecv([send_next_op])
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
def _is_cuda(tensor):
"""Check if a tensor is not none and is cuda."""
assert tensor is not None
assert tensor.is_cuda
def _is_cuda_contiguous(tensor):
"""Check if a tensor is not none, is cuda, and is contiguous."""
_is_cuda(tensor)
assert tensor.is_contiguous()
def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
"""Broadcast a tensor from last pipeline stage to all ranks."""
is_last_stage = mpu.is_pipeline_last_stage()
# If first stage and last state are the same, then there is no
# pipeline parallelism and no need to communicate.
if mpu.is_pipeline_first_stage() and is_last_stage:
return tensor
if is_last_stage:
_is_cuda_contiguous(tensor)
else:
tensor = torch.empty(size,
dtype=dtype,
device=torch.cuda.current_device())
# Get the group and corresponding source rank.
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_pipeline_model_parallel_group()
torch.distributed.broadcast(tensor, src, group)
return tensor
def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
"""Broadcast tensor values from last stage into the first stage."""
is_last_stage = mpu.is_pipeline_last_stage()
is_first_stage = mpu.is_pipeline_first_stage()
# If first stage and last state are the same, then there is no
# pipeline parallelism and no need to communicate.
if is_first_stage and is_last_stage:
return tensor
# Only first and last stage pipeline stages need to be involved.
if is_last_stage or is_first_stage:
if is_last_stage:
_is_cuda_contiguous(tensor)
else:
tensor = torch.empty(size,
dtype=dtype,
device=torch.cuda.current_device())
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
# Broadcast from last stage into the first stage.
torch.distributed.broadcast(tensor, src, group)
else:
tensor = None
return tensor
def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
"""Copy tensor values from last stage into the first stage.
Note that the input tensor is updated in place."""
is_last_stage = mpu.is_pipeline_last_stage()
is_first_stage = mpu.is_pipeline_first_stage()
# If first stage and last state are the same, then there is no
# pipeline parallelism and no need to communicate.
if is_first_stage and is_last_stage:
return
# Only first and last stage pipeline stages need to be involved.
if is_last_stage or is_first_stage:
_is_cuda(tensor)
is_contiguous = tensor.is_contiguous()
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
if is_contiguous:
tensor_ = tensor
else:
if is_last_stage:
tensor_ = tensor.contiguous()
else:
tensor_ = torch.empty(size,
dtype=dtype,
device=torch.cuda.current_device())
# Broadcast from last stage into the first stage.
torch.distributed.broadcast(tensor_, src, group)
# Update the first stage tensor
if is_first_stage and not is_contiguous:
tensor[...] = tensor_
def broadcast_tensor(size, dtype, tensor=None, rank=0):
""" Given size and type of a tensor on all ranks and the tensor value
only on a specific rank, broadcast from that rank to all other ranks.
"""
if torch.distributed.get_rank() == rank:
_is_cuda_contiguous(tensor)
else:
tensor = torch.empty(size,
dtype=dtype,
device=torch.cuda.current_device())
torch.distributed.broadcast(tensor, rank)
return tensor
def broadcast_list(size, dtype, list_values=None, rank=0):
"""Broadcast a list of values with a given type."""
tensor = None
if torch.distributed.get_rank() == rank:
tensor = torch.tensor(list_values, dtype=dtype,
device=torch.cuda.current_device())
return broadcast_tensor(size, dtype, tensor=tensor, rank=rank)
def broadcast_int_list(size, int_list=None, rank=0):
"""Broadcast a list of interger values."""
return broadcast_list(size, torch.int64, list_values=int_list, rank=rank)
def broadcast_float_list(size, float_list=None, rank=0):
"""Broadcast a list of float values."""
return broadcast_list(size, torch.float32, list_values=float_list,
rank=rank)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Forward step utilities."""
from collections.abc import Iterable
import torch
from megatron import (
get_args,
mpu)
from .communication import (
send_to_next_pipeline_rank,
recv_from_prev_pipeline_rank_)
class InferenceParams:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
def __init__(self, max_batch_size, max_sequence_len):
"""Note that offsets are set to zero and we always set the
flag to allocate memory. After the first call, make sure to
set this flag to False."""
self.max_sequence_len = max_sequence_len
self.max_batch_size = max_batch_size
self.sequence_len_offset = 0
self.batch_size_offset = 0
self.key_value_memory_dict = {}
class ForwardStep:
"""Forward step function with all the communications.
We use a class here to hide the inference parameters
from the outside caller."""
def __init__(self, model, max_batch_size, max_sequence_len):
"""Set values so we don't need to do it multiple times."""
# Make sure model is in eval mode.
assert not isinstance(model, Iterable), \
'interleaving schedule is not supported for inference'
model.eval()
self.model = model
# Initialize inference parameters.
self.inference_params = InferenceParams(max_batch_size,
max_sequence_len)
# Pipelining arguments.
args = get_args()
self.pipeline_size_larger_than_one = (
args.pipeline_model_parallel_size > 1)
# Threshold of pipelining.
self.pipelining_batch_x_seqlen = \
args.inference_batch_times_seqlen_threshold
def __call__(self, tokens, position_ids, attention_mask):
"""Invocation of the forward methods. Note that self.inference_params
is being modified by the forward step."""
# Pipelining case.
if self.pipeline_size_larger_than_one:
current_batch_x_seqlen = tokens.size(0) * tokens.size(1)
if current_batch_x_seqlen >= self.pipelining_batch_x_seqlen:
micro_batch_size = \
max(1, self.pipelining_batch_x_seqlen // tokens.size(1))
return _with_pipelining_forward_step(self.model,
tokens,
position_ids,
attention_mask,
self.inference_params,
micro_batch_size)
return _no_pipelining_forward_step(self.model,
tokens,
position_ids,
attention_mask,
self.inference_params)
def _get_recv_buffer_dtype(args):
"""Receive happens between the layers."""
if args.fp32_residual_connection:
return torch.float
return args.params_dtype
def _allocate_recv_buffer(batch_size, sequence_length):
"""Receive happens between the layers with size [s, b, h]."""
if mpu.is_pipeline_first_stage():
return None
args = get_args()
recv_size = (sequence_length, batch_size, args.hidden_size)
return torch.empty(recv_size,
dtype=_get_recv_buffer_dtype(args),
device=torch.cuda.current_device())
def _forward_step_helper(model, tokens, position_ids, attention_mask,
inference_params, recv_buffer=None):
"""Single forward step. Update the allocate memory flag so
only the first time the memory is allocated."""
batch_size = tokens.size(0)
sequence_length = tokens.size(1)
if recv_buffer is None:
recv_buffer = _allocate_recv_buffer(batch_size, sequence_length)
# Receive from previous stage.
recv_from_prev_pipeline_rank_(recv_buffer)
# Forward pass through the model.
model.set_input_tensor(recv_buffer)
output_tensor = model(tokens, position_ids, attention_mask,
inference_params=inference_params)
# Send output to the next stage.
send_to_next_pipeline_rank(output_tensor)
return output_tensor
def _no_pipelining_forward_step(model, tokens, position_ids, attention_mask,
inference_params, recv_buffer=None):
"""If recv_buffer is none, we will allocate one on the fly."""
# Run a simple forward pass.
output_tensor = _forward_step_helper(model, tokens, position_ids,
attention_mask, inference_params,
recv_buffer=recv_buffer)
# Update the sequence length offset.
inference_params.sequence_len_offset += tokens.size(1)
logits = None
if mpu.is_pipeline_last_stage():
logits = output_tensor
return logits
def _with_pipelining_forward_step(model, tokens, position_ids, attention_mask,
inference_params, micro_batch_size):
"""No interleaving is supported."""
sequence_length = tokens.size(1)
batch_size = tokens.size(0)
# Divide the batch dimension into micro batches.
num_micro_batches, last_chunk = divmod(batch_size,
micro_batch_size)
if last_chunk > 0:
num_micro_batches += 1
# Preallocate memory for output logits.
logits = None
if mpu.is_pipeline_last_stage():
args = get_args()
logits = torch.empty(
(batch_size, sequence_length, args.padded_vocab_size),
dtype=torch.float32, device=torch.cuda.current_device())
# Preallocate recv buffer.
recv_buffer = _allocate_recv_buffer(micro_batch_size, sequence_length)
for micro_batch_index in range(num_micro_batches):
# Slice among the batch dimenion.
start = micro_batch_index * micro_batch_size
end = min(start + micro_batch_size, batch_size)
this_micro_batch_size = end - start
tokens2use = tokens[start:end, ...]
position_ids2use = position_ids[start:end, ...]
# Run a simple forward pass.
if this_micro_batch_size != micro_batch_size:
recv_buffer = None
output = _forward_step_helper(model, tokens2use, position_ids2use,
attention_mask, inference_params,
recv_buffer=recv_buffer)
# Adjust the batch size offset to account for the micro-batch.
inference_params.batch_size_offset += this_micro_batch_size
# Copy logits.
if mpu.is_pipeline_last_stage():
logits[start:end, ...] = output
# Once we are done with all the micro-batches, we can
# adjust the sequence length offset.
inference_params.sequence_len_offset += sequence_length
# and reset the batch size offset
inference_params.batch_size_offset = 0
return logits
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generation utilities."""
import torch
import torch.nn.functional as F
from megatron import get_args, get_tokenizer, mpu
from megatron.utils import get_ltor_masks_and_position_ids
from .communication import (
copy_from_last_to_first_pipeline_stage,
broadcast_from_last_pipeline_stage,
broadcast_from_last_to_first_pipeline_stage)
from .forward_step import ForwardStep
from .sampling import sample
def score_and_return_on_first_stage(model, tokens, lengths):
"""Function for just scoring.
Arguments:
model: no interleaving is supported.
tokens: prompt tokens extended to be of size [b, max_prompt_length]
lengths: original prompt length, size: [b]
Note: Outside of model, other parameters only need to be available on
rank 0.
Outputs:
output_log_probs: log probability of the selected tokens. size: [b, s]
"""
args = get_args()
batch_size = tokens.size(0)
max_prompt_length = lengths.max().item()
assert max_prompt_length == tokens.size(1)
max_sequence_length = min(max_prompt_length, args.max_position_embeddings)
# forward step.
forward_step = ForwardStep(model, batch_size, max_sequence_length)
# ===================
# Pre-allocate memory
# ===================
# Log probability of the sequence (prompt + generated tokens).
output_log_probs = None
output_log_probs_size = (batch_size, max_sequence_length - 1)
if mpu.is_pipeline_last_stage():
output_log_probs = torch.empty(output_log_probs_size,
dtype=torch.float32,
device=torch.cuda.current_device())
# =============
# Run infernece
# =============
with torch.no_grad():
attention_mask, position_ids = _build_attention_mask_and_position_ids(tokens)
# logits will be meanigful only in the last pipeline stage.
logits = forward_step(tokens, position_ids, attention_mask)
if mpu.is_pipeline_last_stage():
# Always the last stage should have an output.
assert logits is not None
log_probs = F.log_softmax(logits, dim=2)
# Pick the tokens that we need to get the log
# probabilities for. Note that next input token is
# the token which we selected in the current logits,
# so shift by 1.
indices = torch.unsqueeze(tokens[:, 1:], 2)
output_log_probs = torch.gather(log_probs, 2, indices).squeeze(2)
# ======================================
# Broadcast to the first pipeline stage.
# ======================================
output_log_probs = broadcast_from_last_to_first_pipeline_stage(
output_log_probs_size, torch.float32, output_log_probs)
return tokens, lengths, output_log_probs
def generate_tokens_probs_and_return_on_first_stage(
model, tokens, lengths,
return_output_log_probs=False,
top_k=0, top_p=0.0,
temperature=1.0,
use_eod_token_for_early_termination=True):
"""Main token generation function.
Arguments:
model: no interleaving is supported.
tokens: prompt tokens extended to be of size [b, max-sequence-length]
lengths: original prompt length, size: [b]
return_output_log_probs: flag to calculate the log probability of
the generated tokens. Note that the log probability is the one
from the original logit.
top_k, top_p: top-k and top-p sampling parameters.
Note that top-k = 1 is gready. Also, these paramters are
exclusive meaning that:
if top-k > 0 then we expect top-p=0.
if top-p > 0 then we check for top-k=0.
temperature: sampling temperature.
use_eod_token_for_early_termination: if True, do early termination if
all the sequences have reached this token.
Note: Outside of model, other parameters only need to be available on
rank 0.
Outputs: Note that is size is adjusted to a lower value than
max-sequence-length if generation is terminated early.
tokens: prompt and generated tokens. size: [b, :]
generated_sequence_lengths: total length (including prompt) of
the generated sequence. size: [b]
output_log_probs: log probability of the selected tokens. size: [b, s]
"""
args = get_args()
tokenizer = get_tokenizer()
batch_size = tokens.size(0)
min_prompt_length = lengths.min().item()
max_sequence_length = tokens.size(1)
max_sequence_length = min(max_sequence_length, args.max_position_embeddings)
# forward step.
forward_step = ForwardStep(model, batch_size, max_sequence_length)
# Added termination_id to support the case that we want to terminate the
# generation once that id is generated.
if hasattr(args, 'eos_id'):
termination_id = args.eos_id
else:
termination_id = tokenizer.eod
# ===================
# Pre-allocate memory
# ===================
# Log probability of the sequence (prompt + generated tokens).
output_log_probs = None
output_log_probs_size = (batch_size, max_sequence_length - 1)
# Lengths of generated seuquence including including prompts.
generated_sequence_lengths = None
if mpu.is_pipeline_last_stage():
if return_output_log_probs:
output_log_probs = torch.empty(output_log_probs_size,
dtype=torch.float32,
device=torch.cuda.current_device())
generated_sequence_lengths = torch.ones(
batch_size, dtype=torch.int64,
device=torch.cuda.current_device()) * max_sequence_length
# Whether we have reached a termination id.
is_generation_done = torch.zeros(batch_size, dtype=torch.uint8,
device=torch.cuda.current_device())
# =============
# Run infernece
# =============
with torch.no_grad():
attention_mask, position_ids = _build_attention_mask_and_position_ids(
tokens)
prev_context_length = 0
for context_length in range(min_prompt_length, max_sequence_length):
# Pick the slice that we need to pass through the network.
tokens2use = tokens[:, prev_context_length:context_length]
positions2use = position_ids[:, prev_context_length:context_length]
attention_mask2use = attention_mask[
..., prev_context_length:context_length, :context_length]
# logits will be meanigful only in the last pipeline stage.
logits = forward_step(tokens2use, positions2use, attention_mask2use)
if mpu.is_pipeline_last_stage():
# Always the last stage should have an output.
assert logits is not None
# Sample.
last_token_logits = logits[:, -1, :]
new_sample = sample(last_token_logits,
top_k=top_k,
top_p=top_p,
temperature=temperature,
vocab_size=tokenizer.vocab_size)
# If a prompt length is smaller or equal th current context
# length, it means we have started generating tokens
started = lengths <= context_length
# Update the tokens.
tokens[started, context_length] = new_sample[started]
# Calculate the log probabilities.
if return_output_log_probs:
log_probs = F.log_softmax(logits, dim=2)
if return_output_log_probs:
# Pick the tokens that we need to get the log
# probabilities for. Note that next input token is
# the token which we selected in the current logits,
# so shift by 1.
indices = torch.unsqueeze(
tokens[
:,
(prev_context_length + 1):(context_length + 1)],
2)
output_log_probs[:,
prev_context_length:context_length] = \
torch.gather(log_probs, 2, indices).squeeze(2)
# Update the tokens on the first stage so the next input to
# the network is correct.
copy_from_last_to_first_pipeline_stage(batch_size, torch.int64,
tokens[:, context_length])
# Update the context length for the next token generation.
prev_context_length = context_length
# Check if all the sequences have hit the termination_id.
done = None
if mpu.is_pipeline_last_stage():
done_token = (new_sample == termination_id).byte() & \
started.byte()
just_finished = (done_token & ~is_generation_done).bool()
generated_sequence_lengths[just_finished.view(-1)] = \
context_length + 1
is_generation_done = is_generation_done | done_token
done = torch.all(is_generation_done)
done = broadcast_from_last_pipeline_stage(1, torch.uint8,
tensor=done)
if use_eod_token_for_early_termination and done:
break
# ===================================================
# Update the length of based on max generated length.
# ===================================================
tokens = tokens[:, :(context_length + 1)]
if mpu.is_pipeline_last_stage():
if return_output_log_probs:
output_log_probs = output_log_probs[:, :context_length]
# ======================================
# Broadcast to the first pipeline stage.
# ======================================
generated_sequence_lengths = broadcast_from_last_to_first_pipeline_stage(
batch_size, torch.int64, generated_sequence_lengths)
if return_output_log_probs:
output_log_probs_size = (batch_size, context_length)
output_log_probs = broadcast_from_last_to_first_pipeline_stage(
output_log_probs_size, torch.float32, output_log_probs)
return tokens, generated_sequence_lengths, output_log_probs
def _build_attention_mask_and_position_ids(tokens):
"""Build the attention mask and postition ids for the input tokens."""
# Since we are not interested in loss-mask and reset attention/position
# is also False, eod_token is not used so it is safe to set it to None.
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
data=tokens,
eod_token=None,
reset_position_ids=False,
reset_attention_mask=False,
eod_mask_loss=False)
return attention_mask, position_ids
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sampling utilities.
Part of this code is inspired by:
- https://github.com/ari-holtzman/degen/blob/master/gen.py
- https://huggingface.co/transformers/_modules/transformers/generation_logits_process.html
"""
import torch
def modify_logits_for_top_k_filtering(logits, top_k):
"""Set the logits for none top-k values to -inf."""
filter_ = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits.masked_fill_(filter_, float('-Inf'))
def modify_logits_for_top_p_filtering(logits, top_p):
"""Set the logits for none top-p values to -inf."""
# First sort and calculate cumulative sum of probabilities.
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Filteration based on the cumulative sum.
filter_ = cumulative_probs > top_p
# This shift by 1 is weird and I cannot justify it. This existed
# in the original implementation:
# https://github.com/ari-holtzman/degen/blob/master/gen.py
# and I guess it is needed so keeping it for now.
filter_[:, 1:] = filter_[:, :-1].clone()
# Make sure we at least have one token to select from.
filter_[..., 0] = 0
# Fill in the filtered part
filter_ = filter_.scatter(1, sorted_indices, filter_)
logits.masked_fill_(filter_, float('-Inf'))
def sample(logits, top_k=0, top_p=0.0, temperature=1.0, vocab_size=None):
""" Sample and generate a token.
Note: logits has the dimension [b, v] where b is the batch size
and v is the vocabulary size.
If vocab_size is provided, we will make sure the sample that is
generated is in [0, vocab-size). This will avoid out of vocabulary
generations due to padding.
"""
# Check logits for consistency.
assert logits.ndim == 2, 'expected the logits to be of [b, v] shape.'
assert logits.type() == 'torch.cuda.FloatTensor', \
'input logits should be floats.'
# Greedy is just simple argmax.
if top_k == 1:
assert top_p == 0.0, 'cannot set both greedy and top-p samplings.'
samples = torch.argmax(logits, dim=-1)
# Top-k or top-p sampling.
else:
# Clone so we do not modify the inputs,
logits = logits.clone()
# Apply temperature in place.
if temperature != 1.0:
logits.div_(temperature)
if top_k > 1:
assert top_p == 0.0, 'cannot set both top-k and top-p samplings.'
assert top_k <= logits.size(1), 'top-k is larger than logit size.'
if vocab_size:
assert top_k < vocab_size, 'top-k is larger than vocab size.'
modify_logits_for_top_k_filtering(logits, top_k)
elif top_p > 0.0:
assert top_p <= 1.0, 'top-p should be in (0, 1].'
modify_logits_for_top_p_filtering(logits, top_p)
# After filtering, we need to recalculate the distribution.
probs = logits.softmax(dim=-1)
samples = torch.multinomial(probs, num_samples=1).view(-1)
# If vocab size is provided, make sure the samples are in
# in the range [0, vocab-size).
if vocab_size:
samples = torch.clamp(samples, min=0, max=(vocab_size - 1))
return samples
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization utilities."""
import torch
from megatron import get_tokenizer
from .communication import broadcast_int_list, broadcast_tensor
def detokenize_generations(tokens_gpu_tensor,
lengths_gpu_tensor,
return_segments):
"""Detokenize the generated tokens."""
tokenizer = get_tokenizer()
prompts_plus_generations = []
if return_segments:
prompts_plus_generations_segments = []
tokens = tokens_gpu_tensor.cpu().numpy().tolist()
lengths = lengths_gpu_tensor.cpu().numpy().tolist()
for sequence_tokens, length in zip(tokens, lengths):
sequence_tokens = sequence_tokens[:length]
prompts_plus_generations.append(
tokenizer.detokenize(sequence_tokens))
if return_segments:
words = []
for token in sequence_tokens:
word = tokenizer.tokenizer.decoder[token]
word = bytearray(
[tokenizer.tokenizer.byte_decoder[c] for c in word]).decode(
'utf-8', errors='replace')
words.append(word)
prompts_plus_generations_segments.append(words)
if return_segments:
return tokens, prompts_plus_generations, \
prompts_plus_generations_segments
return tokens, prompts_plus_generations
def tokenize_prompts(prompts=None, tokens_to_generate=None,
add_BOS=None, rank=0):
"""Tokenize prompts and make them avaiable on all ranks."""
# On all ranks set to None so we can pass them to functions
sizes_list = None
prompts_tokens_cuda_long_tensor = None
prompts_length_cuda_long_tensor = None
# On the specified rank, build the above.
if torch.distributed.get_rank() == rank:
assert prompts is not None
assert tokens_to_generate is not None
# Tensor of tokens padded and their unpadded length.
prompts_tokens_cuda_long_tensor, prompts_length_cuda_long_tensor = \
_tokenize_prompts_and_batch(prompts, tokens_to_generate, add_BOS)
# We need the sizes of these tensors for the boradcast
sizes_list = [prompts_tokens_cuda_long_tensor.size(0), # Batch size
prompts_tokens_cuda_long_tensor.size(1)] # Sequence lenght
# First, broadcast the sizes.
sizes_tensor = broadcast_int_list(2, int_list=sizes_list, rank=rank)
# Now that we have the sizes, we can boradcast the tokens
# and length tensors.
sizes = sizes_tensor.tolist()
prompts_tokens_cuda_long_tensor = broadcast_tensor(
sizes, torch.int64, tensor=prompts_tokens_cuda_long_tensor, rank=rank)
prompts_length_cuda_long_tensor = broadcast_tensor(
sizes[0], torch.int64, tensor=prompts_length_cuda_long_tensor,
rank=rank)
return prompts_tokens_cuda_long_tensor, prompts_length_cuda_long_tensor
def _tokenize_prompts_and_batch(prompts, tokens_to_generate, add_BOS):
"""Given a set of prompts and number of tokens to generate:
- tokenize prompts
- set the sequence length to be the max of length of prompts
plus the number of tokens we would like to generate
- pad all the sequences to this length so we can convert them
into a 2D tensor.
"""
# Tokenize all the prompts.
tokenizer = get_tokenizer()
if add_BOS:
prompts_tokens = [[tokenizer.eod] + tokenizer.tokenize(prompt)
for prompt in prompts]
else:
prompts_tokens = [tokenizer.tokenize(prompt) for prompt in prompts]
# Now we have a list of list of tokens which each list has a different
# size. We want to extend this list to:
# - incorporate the tokens that need to be generated
# - make all the sequences equal length.
# Get the prompts length.
prompts_length = [len(prompt_tokens) for prompt_tokens in prompts_tokens]
# Get the max prompts length.
max_prompt_len = max(prompts_length)
# Number of tokens in the each sample of the batch.
samples_length = max_prompt_len + tokens_to_generate
# Now update the list of list to be of the same size: samples_length.
for prompt_tokens, prompt_length in zip(prompts_tokens, prompts_length):
padding_size = samples_length - prompt_length
prompt_tokens.extend([tokenizer.eod] * padding_size)
# Now we are in a structured format, we can convert to tensors.
prompts_tokens_tensor = torch.cuda.LongTensor(prompts_tokens)
prompts_length_tensor = torch.cuda.LongTensor(prompts_length)
return prompts_tokens_tensor, prompts_length_tensor
...@@ -12,55 +12,119 @@ ...@@ -12,55 +12,119 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import datetime
import torch import torch
import json
import threading
from flask import Flask, request, jsonify, current_app from flask import Flask, request, jsonify, current_app
from flask_restful import Resource, Api from flask_restful import Resource, Api
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron.text_generation import generate_and_post_process
from megatron.text_generation_utils import generate
GENERATE_NUM = 0 GENERATE_NUM = 0
lock = threading.Lock()
class MegatronGenerate(Resource): class MegatronGenerate(Resource):
def __init__(self, model): def __init__(self, model):
self.model = model self.model = model
@staticmethod @staticmethod
def send_do_generate(): def send_do_generate():
choice = torch.cuda.LongTensor([GENERATE_NUM]) choice = torch.cuda.LongTensor([GENERATE_NUM])
torch.distributed.broadcast(choice, torch.distributed.broadcast(choice, 0)
mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())
def put(self): def put(self):
args = get_args() args = get_args()
sentences = request.get_json()["sentences"] print("request IP: " + str(request.remote_addr))
if len(sentences) > 128: print(json.dumps(request.get_json()),flush=True)
return "Maximum number of sentences is 128", 400 print("current time: ", datetime.datetime.now())
max_len = 64 # Choosing hopefully sane default. Full sequence is slow if not "prompts" in request.get_json():
return "prompts argument required", 400
if "max_len" in request.get_json(): if "max_len" in request.get_json():
max_len = request.get_json()["max_len"] return "max_len is no longer used. Replace with tokens_to_generate", 400
if not isinstance(max_len, int):
return "max_len must be an integer greater than 0" if "sentences" in request.get_json():
if max_len < 1: return "sentences is no longer used. Replace with prompts", 400
return "max_len must be an integer greater than 0"
prompts = request.get_json()["prompts"]
if len(prompts) > 128:
return "Maximum number of prompts is 128", 400
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate tokens_to_generate = 64 # Choosing hopefully sane default. Full sequence is slow
resp_sentences = generate(self.model, sentences, max_len) if "tokens_to_generate" in request.get_json():
return jsonify({"sentences": resp_sentences}) tokens_to_generate = request.get_json()["tokens_to_generate"]
if not isinstance(tokens_to_generate, int):
return "tokens_to_generate must be an integer greater than 0"
if tokens_to_generate < 0:
return "tokens_to_generate must be an integer greater than or equal to 0"
logprobs = False
if "logprobs" in request.get_json():
logprobs = request.get_json()["logprobs"]
if not isinstance(logprobs, bool):
return "logprobs must be a boolean value"
if tokens_to_generate == 0 and not logprobs:
return "tokens_to_generate=0 implies logprobs should be True"
temperature = 1.0
if "temperature" in request.get_json():
temperature = request.get_json()["temperature"]
if not (type(temperature) == int or type(temperature) == float):
return "temperature must be a positive number less than or equal to 100.0"
if not (0.0 < temperature <= 100.0):
return "temperature must be a positive number less than or equal to 100.0"
top_k = 0.0
if "top_k" in request.get_json():
top_k = request.get_json()["top_k"]
if not (type(top_k) == int):
return "top_k must be an integer equal to or greater than 0 and less than or equal to 1000"
if not (0 <= top_k <= 1000):
return "top_k must be equal to or greater than 0 and less than or equal to 1000"
top_p = 0.0
if "top_p" in request.get_json():
top_p = request.get_json()["top_p"]
if not (type(top_p) == float):
return "top_p must be a positive float less than or equal to 1.0"
if top_p > 0.0 and top_k > 0.0:
return "cannot set both top-k and top-p samplings."
if not (0 <= top_p <= 1.0):
return "top_p must be less than or equal to 1.0"
add_BOS = False
if "add_BOS" in request.get_json():
add_BOS = request.get_json()["add_BOS"]
if not isinstance(add_BOS, bool):
return "add_BOS must be a boolean value"
def index(): with lock: # Need to get lock to keep multiple threads from hitting code
return current_app.send_static_file('index.html') MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
response, response_seg, response_logprobs, _ = \
generate_and_post_process(
self.model,
prompts=prompts,
tokens_to_generate=tokens_to_generate,
return_output_log_probs=logprobs,
top_k_sampling=top_k,
top_p_sampling=top_p,
temperature=temperature,
add_BOS=add_BOS,
use_eod_token_for_early_termination=True)
return jsonify({"text": response,
"segments": response_seg,
"logprobs": response_logprobs})
class MegatronServer(object): class MegatronServer(object):
def __init__(self, model): def __init__(self, model):
self.app = Flask(__name__) self.app = Flask(__name__, static_url_path='')
self.app.add_url_rule('/', 'index', index)
api = Api(self.app) api = Api(self.app)
api.add_resource(MegatronGenerate, '/generate', resource_class_args=[model]) api.add_resource(MegatronGenerate, '/api', resource_class_args=[model])
def run(self, url): def run(self, url):
self.app.run(url, threaded=False, debug=False) self.app.run(url, threaded=True, debug=False)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for generating text."""
import copy
import json
import os
import time
import torch
import torch.nn.functional as F
from megatron import get_args
from megatron import get_tokenizer
from megatron import mpu
from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model
from megatron.p2p_communication import recv_forward, send_forward
# These are needed to unwrap the model, would be nice to put these in megatron.utils if possible?
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
def get_batch(context_tokens):
"""Generate batch from context tokens."""
args = get_args()
tokenizer = get_tokenizer()
# Move to GPU.
tokens = context_tokens.contiguous().cuda()
# Get the attention mask and postition ids.
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
tokens,
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss)
return tokens, attention_mask, position_ids
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
""" This function has been mostly taken from huggingface conversational
ai code at
https://medium.com/huggingface/how-to-build-a-state-of-the-art-
conversational-ai-with-transfer-learning-2d818ac26313 """
if top_k > 0:
# Remove all tokens with a probability less than the
# last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
# Cconvert to 1D
sorted_logits, sorted_indices = torch.sort(
logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1),
dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token
# above the threshold
sorted_indices_to_remove[..., 1:] \
= sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
for i in range(sorted_indices.size(0)):
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
logits[i][indices_to_remove] = filter_value
return logits
def pad_batch(batch, pad_id, args):
context_lengths = []
for tokens in batch:
context_length = len(tokens)
if context_length < args.seq_length:
tokens.extend([pad_id] * (args.seq_length - context_length))
context_lengths.append(context_length)
return batch, context_lengths
def tokenize_batch(sentences):
args = get_args()
tokenizer = get_tokenizer()
context_tokens = [tokenizer.tokenize(s) for s in sentences]
context_tokens, context_lengths = pad_batch(context_tokens,
tokenizer.eod, args)
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
context_length_tensor = torch.cuda.LongTensor(context_lengths)
return context_tokens_tensor, context_length_tensor
def send_generate_info(context_tokens_tensor, context_length_tensor, max_len):
"""
Needs to be synced up with receive_generate_info
"""
# Send the sizes of the tensors
input_info = [context_tokens_tensor.size(0), context_tokens_tensor.size(1), max_len]
input_info_tensor = torch.cuda.LongTensor(input_info)
torch.distributed.broadcast(input_info_tensor, 0)
# Send variables to all ranks
torch.distributed.broadcast(context_length_tensor, 0)
torch.distributed.broadcast(context_tokens_tensor, 0)
def receive_generate_info():
"""
Needs to be synced up with send_generate_info
"""
input_info_tensor = torch.empty(3, dtype=torch.int64, device=torch.cuda.current_device())
torch.distributed.broadcast(input_info_tensor, 0)
batch_size = input_info_tensor[0].item()
seq_len = input_info_tensor[1].item()
max_len = input_info_tensor[2].item()
context_length_tensor = torch.empty(batch_size, dtype=torch.int64, device=torch.cuda.current_device())
context_tokens_tensor = torch.empty(batch_size, seq_len, dtype=torch.int64, device=torch.cuda.current_device())
# Send variables to all ranks
torch.distributed.broadcast(context_length_tensor, 0)
torch.distributed.broadcast(context_tokens_tensor, 0)
return context_length_tensor, context_tokens_tensor, max_len
def synced_generate(model, context_tokens_tensor, context_length_tensor, max_len):
context_length = context_length_tensor.min().item()
tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)
batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
context_length_tensor,
attention_mask, position_ids,
max_len)
for tokens, lengths in batch_token_iterator:
context_length += 1
if tokens is not None:
return tokens[:, :context_length]
def generate(model, sentences=None, max_len=0):
model.eval()
if torch.distributed.get_rank() == 0:
context_tokens_tensor, context_length_tensor = tokenize_batch(sentences)
send_generate_info(context_tokens_tensor, context_length_tensor, max_len)
else:
context_length_tensor, context_tokens_tensor, max_len = receive_generate_info()
decode_tokens = synced_generate(model, context_tokens_tensor, context_length_tensor, max_len)
if torch.distributed.get_rank() == 0:
args = get_args()
tokenizer = get_tokenizer()
resp_sentences = []
for i in range(decode_tokens.size(0)):
decode_token = decode_tokens[i,:].cpu().numpy().tolist()
resp_sentences.append(tokenizer.detokenize(decode_token))
return resp_sentences
def generate_samples_eval(model, context, max_gen_length, eos_token_id):
"""
This function is here to provide an a matching API for a legacy task
This implementation hasn't been tested yet to make sure it matches
"""
assert False, "Implementation untested"
args = get_args()
args.eos_id = eos_token_id
raw_text_len = len(context)
resp_sentences = generate(model, [context], max_gen_length)
return resp_sentences[0][raw_text_len:]
def switch(val1, val2, boolean):
boolean = boolean.type_as(val1)
return (1 - boolean) * val1 + boolean * val2
def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
layer_past=None, get_key_value=None,
forward_method_parallel_output=None):
# Hidden size changes when not using recompute, need to tell p2p_communicate
# functions the correct size
args = get_args()
orig_seq_length = args.seq_length
args.seq_length = tokens.shape[1]
args.micro_batch_size = tokens.shape[0]
input_tensor = recv_forward()
# Forward pass through the model.
unwrapped_model = unwrap_model(
model, (torchDDP, LocalDDP, Float16Module))
unwrapped_model.set_input_tensor(input_tensor)
output_tensor = model(tokens, position_ids, attention_mask,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value,
forward_method_parallel_output=forward_method_parallel_output)
if get_key_value:
output_tensor, layer_past = output_tensor
send_forward(output_tensor)
args.seq_length = orig_seq_length
if get_key_value:
return output_tensor, layer_past
return output_tensor
def sample_sequence_batch(model, context_tokens, context_lengths,
attention_mask, position_ids,
maxlen=None, type_ids=None):
args = get_args()
tokenizer = get_tokenizer()
model.eval()
with torch.no_grad():
context_length = context_lengths.min().item()
# added eos_id to support the function generate_samples_eval that passes
# eos_id as an argument and needs termination when that id id found.
if hasattr(args, 'eos_id'):
eos_id = args.eos_id
else:
eos_id = tokenizer.eod
counter = 0
org_context_length = context_length
layer_past = None
batch_size = context_tokens.size(0)
is_done = torch.zeros([batch_size]).byte().cuda()
tokens = context_tokens
if maxlen is None:
maxlen = args.seq_length - 1
maxlen = maxlen + org_context_length
if maxlen > (org_context_length + args.out_seq_length):
maxlen = org_context_length + args.out_seq_length
lengths = torch.ones([batch_size]).long().cuda() * maxlen
while context_length < maxlen:
types2use = None
if counter == 0:
tokens2use = tokens[:, :context_length]
positions2use = position_ids[:, :context_length]
if type_ids is not None:
types2use = type_ids[:, :context_length]
else:
tokens2use = tokens[:, context_length - 1].view(
batch_size, -1)
positions2use = position_ids[:, context_length - 1].view(
batch_size, -1)
if type_ids is not None:
types2use = type_ids[:, context_length - 1].view(
batch_size, -1)
output, layer_past = forward_step(model, tokens2use,
positions2use,
attention_mask,
layer_past=layer_past,
get_key_value=True,
tokentype_ids=types2use,
forward_method_parallel_output=False)
if mpu.is_pipeline_last_stage():
assert output is not None
logits = output[:, -1].view(batch_size, -1).contiguous()
if mpu.is_pipeline_last_stage():
if args.greedy:
prev = torch.argmax(logits, dim=-1).view(-1)
else:
logits = logits.float()
logits /= args.temperature
logits = top_k_logits(logits, top_k=args.top_k,
top_p=args.top_p)
log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1).view(-1)
started = context_lengths <= context_length
new_tokens = switch(
tokens[:, context_length].view(-1), prev, started)
tokens[:, context_length] = new_tokens
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
torch.distributed.broadcast(new_tokens, src, group)
done_token = (prev == eos_id).byte() & started.byte()
just_finished = (done_token & ~is_done).bool()
lengths[just_finished.view(-1)] = context_length
is_done = is_done | done_token
done = torch.all(is_done)
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_pipeline_model_parallel_group()
torch.distributed.broadcast(done, src, group)
yield tokens, lengths
else:
if mpu.is_pipeline_first_stage():
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
new_tokens = torch.empty_like(tokens[:, context_length])
torch.distributed.broadcast(new_tokens, src, group)
tokens[:, context_length] = new_tokens
yield tokens, None
else:
yield None, None
done = torch.cuda.ByteTensor([0])
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_pipeline_model_parallel_group()
torch.distributed.broadcast(done, src, group)
context_length += 1
counter += 1
if done:
break
...@@ -38,6 +38,7 @@ from megatron import print_rank_last ...@@ -38,6 +38,7 @@ from megatron import print_rank_last
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint from megatron.checkpointing import save_checkpoint
from megatron.model import Float16Module from megatron.model import Float16Module
from megatron.model import ModelType
from megatron.optimizer import get_megatron_optimizer from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard from megatron.initialize import write_args_to_tensorboard
...@@ -61,6 +62,7 @@ def print_datetime(string): ...@@ -61,6 +62,7 @@ def print_datetime(string):
def pretrain(train_valid_test_dataset_provider, def pretrain(train_valid_test_dataset_provider,
model_provider, model_provider,
model_type,
forward_step_func, forward_step_func,
extra_args_provider=None, extra_args_provider=None,
args_defaults={}): args_defaults={}):
...@@ -77,6 +79,7 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -77,6 +79,7 @@ def pretrain(train_valid_test_dataset_provider,
train/valid/test dataset and returns `train, valid, test` datasets. train/valid/test dataset and returns `train, valid, test` datasets.
model_provider: a function that returns a vanilla version of the model_provider: a function that returns a vanilla version of the
model. By vanilla we mean a simple model on cpu with no fp16 or ddp. model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
model_type: an enum that specifies the type of model being trained.
forward_step_func: a function that takes a `data iterator` and `model`, forward_step_func: a function that takes a `data iterator` and `model`,
and returns a `loss` scalar with a dictionary with key:values being and returns a `loss` scalar with a dictionary with key:values being
the info we would like to monitor during training, for example the info we would like to monitor during training, for example
...@@ -109,7 +112,8 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -109,7 +112,8 @@ def pretrain(train_valid_test_dataset_provider,
# Model, optimizer, and learning rate. # Model, optimizer, and learning rate.
timers('model-and-optimizer-setup').start() timers('model-and-optimizer-setup').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider) model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider,
model_type)
timers('model-and-optimizer-setup').stop() timers('model-and-optimizer-setup').stop()
print_datetime('after model, optimizer, and learning rate ' print_datetime('after model, optimizer, and learning rate '
'scheduler are built') 'scheduler are built')
...@@ -189,13 +193,16 @@ def update_train_iters(args): ...@@ -189,13 +193,16 @@ def update_train_iters(args):
print_rank_0('setting training iterations to {}'.format(args.train_iters)) print_rank_0('setting training iterations to {}'.format(args.train_iters))
def get_model(model_provider_func): def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True):
"""Build the model.""" """Build the model."""
args = get_args() args = get_args()
args.model_type = model_type
# Build model. # Build model.
if mpu.get_pipeline_model_parallel_world_size() > 1 and \ if mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.virtual_pipeline_model_parallel_size is not None: args.virtual_pipeline_model_parallel_size is not None:
assert model_type != ModelType.encoder_and_decoder, \
"Interleaved schedule not supported for model with both encoder and decoder"
model = [] model = []
for i in range(args.virtual_pipeline_model_parallel_size): for i in range(args.virtual_pipeline_model_parallel_size):
mpu.set_virtual_pipeline_model_parallel_rank(i) mpu.set_virtual_pipeline_model_parallel_rank(i)
...@@ -206,14 +213,36 @@ def get_model(model_provider_func): ...@@ -206,14 +213,36 @@ def get_model(model_provider_func):
pre_process=pre_process, pre_process=pre_process,
post_process=post_process post_process=post_process
) )
this_model.model_type = model_type
model.append(this_model) model.append(this_model)
else: else:
pre_process = mpu.is_pipeline_first_stage() pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage() post_process = mpu.is_pipeline_last_stage()
model = model_provider_func( add_encoder = True
pre_process=pre_process, add_decoder = True
post_process=post_process if model_type == ModelType.encoder_and_decoder:
) if mpu.get_pipeline_model_parallel_world_size() > 1:
assert args.pipeline_model_parallel_split_rank is not None, \
"Split rank needs to be specified for model with both encoder and decoder"
rank = mpu.get_pipeline_model_parallel_rank()
split_rank = args.pipeline_model_parallel_split_rank
world_size = mpu.get_pipeline_model_parallel_world_size()
pre_process = rank == 0 or rank == split_rank
post_process = (rank == (split_rank - 1)) or (
rank == (world_size - 1))
add_encoder = mpu.is_pipeline_stage_before_split()
add_decoder = mpu.is_pipeline_stage_after_split()
model = model_provider_func(
pre_process=pre_process,
post_process=post_process,
add_encoder=add_encoder,
add_decoder=add_decoder)
else:
model = model_provider_func(
pre_process=pre_process,
post_process=post_process
)
model.model_type = model_type
if not isinstance(model, list): if not isinstance(model, list):
model = [model] model = [model]
...@@ -243,22 +272,24 @@ def get_model(model_provider_func): ...@@ -243,22 +272,24 @@ def get_model(model_provider_func):
if args.fp16 or args.bf16: if args.fp16 or args.bf16:
model = [Float16Module(model_module, args) for model_module in model] model = [Float16Module(model_module, args) for model_module in model]
if args.DDP_impl == 'torch': if wrap_with_ddp:
i = torch.cuda.current_device() if args.DDP_impl == 'torch':
model = [torchDDP(model_module, device_ids=[i], output_device=i, i = torch.cuda.current_device()
process_group=mpu.get_data_parallel_group()) model = [torchDDP(model_module, device_ids=[i], output_device=i,
for model_module in model] process_group=mpu.get_data_parallel_group())
return model for model_module in model]
if args.DDP_impl == 'local': elif args.DDP_impl == 'local':
model = [LocalDDP(model_module, model = [LocalDDP(model_module,
args.accumulate_allreduce_grads_in_fp32, args.accumulate_allreduce_grads_in_fp32,
args.use_contiguous_buffers_in_local_ddp) args.use_contiguous_buffers_in_local_ddp)
for model_module in model] for model_module in model]
return model
else:
raise NotImplementedError('Unknown DDP implementation specified: '
'{}. Exiting.'.format(args.DDP_impl))
raise NotImplementedError('Unknown DDP implementation specified: {}. ' return model
'Exiting.'.format(args.DDP_impl))
def get_learning_rate_scheduler(optimizer): def get_learning_rate_scheduler(optimizer):
...@@ -304,11 +335,11 @@ def get_learning_rate_scheduler(optimizer): ...@@ -304,11 +335,11 @@ def get_learning_rate_scheduler(optimizer):
return lr_scheduler return lr_scheduler
def setup_model_and_optimizer(model_provider_func): def setup_model_and_optimizer(model_provider_func, model_type):
"""Setup model and optimizer.""" """Setup model and optimizer."""
args = get_args() args = get_args()
model = get_model(model_provider_func) model = get_model(model_provider_func, model_type)
unwrapped_model = unwrap_model(model, unwrapped_model = unwrap_model(model,
(torchDDP, LocalDDP, Float16Module)) (torchDDP, LocalDDP, Float16Module))
...@@ -377,13 +408,14 @@ def train_step(forward_step_func, data_iterator, ...@@ -377,13 +408,14 @@ def train_step(forward_step_func, data_iterator,
# This should only run for models that support pipelined model parallelism # This should only run for models that support pipelined model parallelism
# (BERT and GPT-2). # (BERT and GPT-2).
timers('backward-embedding-all-reduce').start() timers('backward-embedding-all-reduce').start()
if (mpu.is_pipeline_first_stage(ignore_virtual=True) or if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
mpu.is_pipeline_last_stage(ignore_virtual=True)) and \
mpu.get_pipeline_model_parallel_world_size() > 1: mpu.get_pipeline_model_parallel_world_size() > 1:
if mpu.is_pipeline_first_stage(ignore_virtual=True): if mpu.is_pipeline_first_stage(ignore_virtual=True):
unwrapped_model = model[0] unwrapped_model = model[0]
elif mpu.is_pipeline_last_stage(ignore_virtual=True): elif mpu.is_pipeline_last_stage(ignore_virtual=True):
unwrapped_model = model[-1] unwrapped_model = model[-1]
else: # We do not support the interleaved schedule for T5 yet.
unwrapped_model = model[0]
unwrapped_model = unwrap_model( unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module)) unwrapped_model, (torchDDP, LocalDDP, Float16Module))
......
...@@ -25,7 +25,7 @@ from megatron import print_rank_0 ...@@ -25,7 +25,7 @@ from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import BertModel from megatron.model import BertModel, ModelType
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
...@@ -143,5 +143,6 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ...@@ -143,5 +143,6 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if __name__ == "__main__": if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider, forward_step, pretrain(train_valid_test_datasets_provider, model_provider,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) ModelType.encoder_or_decoder,
forward_step, args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
...@@ -23,7 +23,7 @@ from megatron import get_timers ...@@ -23,7 +23,7 @@ from megatron import get_timers
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu from megatron import mpu
from megatron.data.gpt_dataset import build_train_valid_test_datasets from megatron.data.gpt_dataset import build_train_valid_test_datasets
from megatron.model import GPTModel from megatron.model import GPTModel, ModelType
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
...@@ -121,5 +121,6 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ...@@ -121,5 +121,6 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if __name__ == "__main__": if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider, forward_step, pretrain(train_valid_test_datasets_provider, model_provider,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}) ModelType.encoder_or_decoder,
forward_step, args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
...@@ -28,6 +28,7 @@ from megatron import get_timers ...@@ -28,6 +28,7 @@ from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron.data.biencoder_dataset_utils import get_ict_batch from megatron.data.biencoder_dataset_utils import get_ict_batch
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import ModelType
from megatron.model.biencoder_model import biencoder_model_provider from megatron.model.biencoder_model import biencoder_model_provider
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
...@@ -174,5 +175,6 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ...@@ -174,5 +175,6 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if __name__ == "__main__": if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, pretrain(train_valid_test_datasets_provider,
pretrain_ict_model_provider, pretrain_ict_model_provider,
ModelType.encoder_or_decoder,
forward_step, forward_step,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
...@@ -26,18 +26,58 @@ from megatron import ( ...@@ -26,18 +26,58 @@ from megatron import (
print_rank_0 print_rank_0
) )
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import T5Model from megatron.model import T5Model, ModelType
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
def model_provider(pre_process=True, post_process=True): """
Pipeline parallelism for T5
===========================
T5 is a model architecture with both encoder and decoder blocks.
Consequently, pipeline parallelism is implemented slightly differently
compared to architectures like GPT and BERT.
In particular, when pipeline_model_parallel_world_size > 1, each stage
either executes an encoder block or a decoder block. The
--pipeline-model-parallel-split-rank argument controls the rank at which
the split happens: all ranks lower than this argument execute the
encoder block, and all ranks equal to or higher than this argument value
execute the decoder block.
In the encoder section of the model, only one tensor is sent downstream:
the intermediate encoder_hidden_state. In the decoder section of the
model, two tensors are sent downstream in the forward pass: the fully
computed encoder_hidden_state, and the intermediate decoder_hidden_state.
In particular, these are the shapes of the tensors sent between
different workers:
If rank is in decoder section:
intermediate decoder_hidden_state (pre-transpose),
complete encoder_hidden_state (post-transpose).
If rank is at boundary between encoder and decoder sections:
complete encoder_hidden_state (post-transpose).
If rank is in encoder section:
intermediate encoder_hidden_state (pre-transpose).
Additionally, we have code in the backward_step function in schedules.py
to accumulate the encoder_hidden_state gradient across skip connections
(encoder_hidden_state fed in as input to each layer in the decoder).
"""
def model_provider(pre_process=True, post_process=True,
add_encoder=True, add_decoder=True):
"""Build the model.""" """Build the model."""
assert pre_process and post_process, "T5 doesn't yet support pipelining"
print_rank_0('building T5 model ...') print_rank_0('building T5 model ...')
model = T5Model(num_tokentypes=0, model = T5Model(num_tokentypes=0,
parallel_output=True) parallel_output=True,
pre_process=pre_process,
post_process=post_process,
add_encoder=add_encoder,
add_decoder=add_decoder)
return model return model
...@@ -70,9 +110,7 @@ def get_batch(data_iterator): ...@@ -70,9 +110,7 @@ def get_batch(data_iterator):
def loss_func(loss_mask, output_tensor): def loss_func(loss_mask, output_tensor):
lm_loss_, _ = output_tensor lm_loss_ = output_tensor.float()
lm_loss_ = lm_loss_.float()
lm_loss = torch.sum( lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
...@@ -130,5 +168,5 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ...@@ -130,5 +168,5 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if __name__ == "__main__": if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider, forward_step, pretrain(train_valid_test_datasets_provider, model_provider, ModelType.encoder_and_decoder,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) forward_step, args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
...@@ -20,6 +20,7 @@ import torch.nn.functional as F ...@@ -20,6 +20,7 @@ import torch.nn.functional as F
from functools import partial from functools import partial
from megatron import get_args, get_timers, mpu, print_rank_0 from megatron import get_args, get_timers, mpu, print_rank_0
from megatron.data.vit_dataset import build_train_valid_datasets from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model import ModelType
from megatron.model.vit_model import VitModel from megatron.model.vit_model import VitModel
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
...@@ -92,6 +93,7 @@ if __name__ == "__main__": ...@@ -92,6 +93,7 @@ if __name__ == "__main__":
pretrain( pretrain(
train_valid_test_datasets_provider, train_valid_test_datasets_provider,
model_provider, model_provider,
ModelType.encoder_or_decoder,
forward_step, forward_step,
args_defaults={'dataloader_type': 'cyclic'} args_defaults={'dataloader_type': 'cyclic'}
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment