Commit 6fd0b406 authored by zihanl's avatar zihanl
Browse files

merge with main branch

parents 492fdf83 60750922
...@@ -58,6 +58,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): ...@@ -58,6 +58,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
grad_not_none = param.grad is not None grad_not_none = param.grad is not None
is_not_shared = param_is_not_shared(param) is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
if grad_not_none:
grad = param.grad.detach() grad = param.grad.detach()
if grad_not_none: if grad_not_none:
# Make sure the grads are in fp32 # Make sure the grads are in fp32
......
...@@ -68,7 +68,9 @@ class MegatronOptimizer(ABC): ...@@ -68,7 +68,9 @@ class MegatronOptimizer(ABC):
def __init__(self, optimizer, clip_grad, def __init__(self, optimizer, clip_grad,
log_num_zeros_in_grad, log_num_zeros_in_grad,
params_have_main_grad): params_have_main_grad,
use_contiguous_buffers_in_local_ddp):
"""Input optimizer is the base optimizer for example Adam.""" """Input optimizer is the base optimizer for example Adam."""
self.optimizer = optimizer self.optimizer = optimizer
assert self.optimizer, 'no optimizer is provided.' assert self.optimizer, 'no optimizer is provided.'
...@@ -76,7 +78,11 @@ class MegatronOptimizer(ABC): ...@@ -76,7 +78,11 @@ class MegatronOptimizer(ABC):
self.clip_grad = clip_grad self.clip_grad = clip_grad
self.log_num_zeros_in_grad = log_num_zeros_in_grad self.log_num_zeros_in_grad = log_num_zeros_in_grad
self.params_have_main_grad = params_have_main_grad self.params_have_main_grad = params_have_main_grad
self.use_contiguous_buffers_in_local_ddp = use_contiguous_buffers_in_local_ddp
if self.use_contiguous_buffers_in_local_ddp:
assert self.params_have_main_grad, \
"use of contiguous buffer requires that params have main grad"
def get_parameters(self): def get_parameters(self):
params = [] params = []
...@@ -173,7 +179,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): ...@@ -173,7 +179,7 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
a `main_grad` field. If this is set, we are assuming a `main_grad` field. If this is set, we are assuming
that the model parameters are store in the `main_grad` that the model parameters are store in the `main_grad`
field instead of the typical `grad` field. This happens field instead of the typical `grad` field. This happens
for the DDP cases where there is a contihuous buffer for the DDP cases where there is a continuous buffer
holding the gradients. For example for bfloat16, we want holding the gradients. For example for bfloat16, we want
to do gradient accumulation and all-reduces in float32 to do gradient accumulation and all-reduces in float32
and as a result we store those gradients in the main_grad. and as a result we store those gradients in the main_grad.
...@@ -187,11 +193,12 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): ...@@ -187,11 +193,12 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
""" """
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, bf16, grad_scaler): params_have_main_grad, use_contiguous_buffers_in_local_ddp,
bf16, grad_scaler):
super(Float16OptimizerWithFloat16Params, self).__init__( super(Float16OptimizerWithFloat16Params, self).__init__(
optimizer, clip_grad, log_num_zeros_in_grad, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad) params_have_main_grad, use_contiguous_buffers_in_local_ddp)
self.bf16 = bf16 self.bf16 = bf16
self.grad_scaler = grad_scaler self.grad_scaler = grad_scaler
...@@ -282,9 +289,14 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): ...@@ -282,9 +289,14 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
def zero_grad(self, set_to_none=True): def zero_grad(self, set_to_none=True):
"""We only need to zero the model related parameters, i.e., """We only need to zero the model related parameters, i.e.,
float16_groups & fp32_from_fp32_groups.""" float16_groups & fp32_from_fp32_groups. We additionally zero
fp32_from_float16_groups as a memory optimization to reduce
fragmentation; in the case of set_to_none==True, the space
used by this field can be safely deallocated at this point."""
for group in self.float16_groups: for group in self.float16_groups:
_zero_grad_group_helper(group, set_to_none) _zero_grad_group_helper(group, set_to_none)
for group in self.fp32_from_float16_groups:
_zero_grad_group_helper(group, set_to_none)
for group in self.fp32_from_fp32_groups: for group in self.fp32_from_fp32_groups:
_zero_grad_group_helper(group, set_to_none) _zero_grad_group_helper(group, set_to_none)
...@@ -300,17 +312,31 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): ...@@ -300,17 +312,31 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
for model_group, main_group in zip(self.float16_groups, for model_group, main_group in zip(self.float16_groups,
self.fp32_from_float16_groups): self.fp32_from_float16_groups):
for model_param, main_param in zip(model_group, main_group): for model_param, main_param in zip(model_group, main_group):
if self.params_have_main_grad: if self.params_have_main_grad and hasattr(model_param, 'main_grad'):
main_param.grad = model_param.main_grad.float() main_param.grad = model_param.main_grad.float()
else: else:
if model_param.grad is not None: if model_param.grad is not None:
main_param.grad = model_param.grad.float() main_param.grad = model_param.grad.float()
# Safe to deallocate model's grad/main_grad after copying.
# (If using contiguous buffers, main_grad's memory should
# persist and therefore should not be deallocated.)
model_param.grad = None
if self.params_have_main_grad and \
not self.use_contiguous_buffers_in_local_ddp:
model_param.main_grad = None
# For fp32 grads, we need to reset the grads to main grad. # For fp32 grads, we need to reset the grads to main grad.
if self.params_have_main_grad: if self.params_have_main_grad:
for model_group in self.fp32_from_fp32_groups: for model_group in self.fp32_from_fp32_groups:
for model_param in model_group: for model_param in model_group:
model_param.grad = model_param.main_grad model_param.grad = model_param.main_grad
# Safe to de-reference model's main_grad after copying.
# (If using contiguous buffers, main_grad's memory should
# persist and therefore should not be deallocated.)
if not self.use_contiguous_buffers_in_local_ddp:
model_param.main_grad = None
def _unscale_main_grads_and_check_for_nan(self): def _unscale_main_grads_and_check_for_nan(self):
main_grads = [] main_grads = []
...@@ -464,11 +490,12 @@ class FP32Optimizer(MegatronOptimizer): ...@@ -464,11 +490,12 @@ class FP32Optimizer(MegatronOptimizer):
def __init__(self, optimizer, clip_grad, def __init__(self, optimizer, clip_grad,
log_num_zeros_in_grad, log_num_zeros_in_grad,
params_have_main_grad): params_have_main_grad,
use_contiguous_buffers_in_local_ddp):
super(FP32Optimizer, self).__init__( super(FP32Optimizer, self).__init__(
optimizer, clip_grad, log_num_zeros_in_grad, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad) params_have_main_grad, use_contiguous_buffers_in_local_ddp)
self._scale = torch.cuda.FloatTensor([1.0]) self._scale = torch.cuda.FloatTensor([1.0])
...@@ -495,6 +522,12 @@ class FP32Optimizer(MegatronOptimizer): ...@@ -495,6 +522,12 @@ class FP32Optimizer(MegatronOptimizer):
for param in param_group['params']: for param in param_group['params']:
param.grad = param.main_grad param.grad = param.main_grad
# Safe to de-reference model's main_grad after copying.
# (If using contiguous buffers, main_grad's memory should
# persist and therefore should not be deallocated.)
if not self.use_contiguous_buffers_in_local_ddp:
param.main_grad = None
# Clip gradients. # Clip gradients.
grad_norm = None grad_norm = None
if self.clip_grad > 0.0: if self.clip_grad > 0.0:
......
...@@ -22,7 +22,9 @@ from megatron import mpu ...@@ -22,7 +22,9 @@ from megatron import mpu
def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
use_ring_exchange=False): tensor_shape,
use_ring_exchange=False,
dtype_=None):
"""Communicate tensors between stages. Used as helper method in other """Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py. communication methods that are used in megatron/schedules.py.
...@@ -35,9 +37,13 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -35,9 +37,13 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
previous rank. previous rank.
recv_next: boolean for whether tensor should be received from recv_next: boolean for whether tensor should be received from
next rank. next rank.
tensor_shape: shape of tensor to receive (this method assumes that all
tensors sent and received in a single function call are
the same shape).
use_ring_exchange: boolean for whether torch.distributed.ring_exchange() use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
API should be used. API should be used.
dtype_: optional, this is used when the tensor that needs to be
communicated is different from args.params_dtype.
Returns: Returns:
(tensor_recv_prev, tensor_recv_next) (tensor_recv_prev, tensor_recv_next)
""" """
...@@ -47,28 +53,47 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -47,28 +53,47 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# if needed. # if needed.
tensor_recv_prev = None tensor_recv_prev = None
tensor_recv_next = None tensor_recv_next = None
# Some legacy inference code doesn't set the tensor shape, do so now
# for the normal values for gpt/bert. This could be removed if inference
# code is changed to provide tensor_shape.
if tensor_shape is None:
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size) tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
override_scatter_gather_tensors_in_pipeline = False
if args.scatter_gather_tensors_in_pipeline: if args.scatter_gather_tensors_in_pipeline:
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) // \ tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1)
if tensor_chunk_shape % mpu.get_tensor_model_parallel_world_size() == 0:
tensor_chunk_shape = tensor_chunk_shape // \
mpu.get_tensor_model_parallel_world_size() mpu.get_tensor_model_parallel_world_size()
else:
tensor_chunk_shape = tensor_shape
override_scatter_gather_tensors_in_pipeline = True
else: else:
tensor_chunk_shape = tensor_shape tensor_chunk_shape = tensor_shape
dtype = args.params_dtype dtype = args.params_dtype
if args.fp32_residual_connection: if args.fp32_residual_connection:
dtype = torch.float dtype = torch.float
requires_grad = True
if dtype_ is not None:
dtype = dtype_
requires_grad = False
if recv_prev: if recv_prev:
tensor_recv_prev = torch.empty(tensor_chunk_shape, tensor_recv_prev = torch.empty(tensor_chunk_shape,
requires_grad=True, requires_grad=requires_grad,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=dtype) dtype=dtype)
if recv_next: if recv_next:
tensor_recv_next = torch.empty(tensor_chunk_shape, tensor_recv_next = torch.empty(tensor_chunk_shape,
requires_grad=True, requires_grad=requires_grad,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=dtype) dtype=dtype)
# Split tensor into smaller chunks if using scatter-gather optimization. # Split tensor into smaller chunks if using scatter-gather optimization.
if args.scatter_gather_tensors_in_pipeline: if not override_scatter_gather_tensors_in_pipeline and \
args.scatter_gather_tensors_in_pipeline:
if tensor_send_next is not None: if tensor_send_next is not None:
tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next) tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next)
...@@ -112,7 +137,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -112,7 +137,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
torch.cuda.synchronize() torch.cuda.synchronize()
# If using scatter-gather optimization, gather smaller chunks. # If using scatter-gather optimization, gather smaller chunks.
if args.scatter_gather_tensors_in_pipeline: if not override_scatter_gather_tensors_in_pipeline and \
args.scatter_gather_tensors_in_pipeline:
if recv_prev: if recv_prev:
tensor_recv_prev = mpu.gather_split_1d_tensor( tensor_recv_prev = mpu.gather_split_1d_tensor(
tensor_recv_prev).view(tensor_shape).requires_grad_() tensor_recv_prev).view(tensor_shape).requires_grad_()
...@@ -124,8 +150,9 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -124,8 +150,9 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
return tensor_recv_prev, tensor_recv_next return tensor_recv_prev, tensor_recv_next
def recv_forward(timers=None): def recv_forward(tensor_shape=None, dtype_=None, timers=None):
"""Receive tensor from previous rank in pipeline (forward receive).""" """Receive tensor from previous rank in pipeline (forward receive)."""
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
input_tensor = None input_tensor = None
else: else:
...@@ -135,13 +162,15 @@ def recv_forward(timers=None): ...@@ -135,13 +162,15 @@ def recv_forward(timers=None):
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=None, tensor_send_prev=None,
recv_prev=True, recv_prev=True,
recv_next=False) recv_next=False,
tensor_shape=tensor_shape,
dtype_=dtype_)
if timers is not None: if timers is not None:
timers('forward-recv').stop() timers('forward-recv').stop()
return input_tensor return input_tensor
def recv_backward(timers=None): def recv_backward(tensor_shape=None, timers=None):
"""Receive tensor from next rank in pipeline (backward receive).""" """Receive tensor from next rank in pipeline (backward receive)."""
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
output_tensor_grad = None output_tensor_grad = None
...@@ -152,14 +181,16 @@ def recv_backward(timers=None): ...@@ -152,14 +181,16 @@ def recv_backward(timers=None):
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=None, tensor_send_prev=None,
recv_prev=False, recv_prev=False,
recv_next=True) recv_next=True,
tensor_shape=tensor_shape)
if timers is not None: if timers is not None:
timers('backward-recv').stop() timers('backward-recv').stop()
return output_tensor_grad return output_tensor_grad
def send_forward(output_tensor, timers=None): def send_forward(output_tensor, tensor_shape=None, dtype_=None, timers=None):
"""Send tensor to next rank in pipeline (forward send).""" """Send tensor to next rank in pipeline (forward send)."""
if not mpu.is_pipeline_last_stage(): if not mpu.is_pipeline_last_stage():
if timers is not None: if timers is not None:
timers('forward-send').start() timers('forward-send').start()
...@@ -167,12 +198,14 @@ def send_forward(output_tensor, timers=None): ...@@ -167,12 +198,14 @@ def send_forward(output_tensor, timers=None):
tensor_send_next=output_tensor, tensor_send_next=output_tensor,
tensor_send_prev=None, tensor_send_prev=None,
recv_prev=False, recv_prev=False,
recv_next=False) recv_next=False,
tensor_shape=tensor_shape,
dtype_=dtype_)
if timers is not None: if timers is not None:
timers('forward-send').stop() timers('forward-send').stop()
def send_backward(input_tensor_grad, timers=None): def send_backward(input_tensor_grad, tensor_shape=None, timers=None):
"""Send tensor to previous rank in pipeline (backward send).""" """Send tensor to previous rank in pipeline (backward send)."""
if not mpu.is_pipeline_first_stage(): if not mpu.is_pipeline_first_stage():
if timers is not None: if timers is not None:
...@@ -181,12 +214,13 @@ def send_backward(input_tensor_grad, timers=None): ...@@ -181,12 +214,13 @@ def send_backward(input_tensor_grad, timers=None):
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=input_tensor_grad, tensor_send_prev=input_tensor_grad,
recv_prev=False, recv_prev=False,
recv_next=False) recv_next=False,
tensor_shape=tensor_shape)
if timers is not None: if timers is not None:
timers('backward-send').stop() timers('backward-send').stop()
def send_forward_recv_backward(output_tensor, timers=None): def send_forward_recv_backward(output_tensor, tensor_shape=None, timers=None):
"""Batched send and recv with next rank in pipeline.""" """Batched send and recv with next rank in pipeline."""
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
output_tensor_grad = None output_tensor_grad = None
...@@ -197,13 +231,14 @@ def send_forward_recv_backward(output_tensor, timers=None): ...@@ -197,13 +231,14 @@ def send_forward_recv_backward(output_tensor, timers=None):
tensor_send_next=output_tensor, tensor_send_next=output_tensor,
tensor_send_prev=None, tensor_send_prev=None,
recv_prev=False, recv_prev=False,
recv_next=True) recv_next=True,
tensor_shape=tensor_shape)
if timers is not None: if timers is not None:
timers('forward-send-backward-recv').stop() timers('forward-send-backward-recv').stop()
return output_tensor_grad return output_tensor_grad
def send_backward_recv_forward(input_tensor_grad, timers=None): def send_backward_recv_forward(input_tensor_grad, tensor_shape=None, timers=None):
"""Batched send and recv with previous rank in pipeline.""" """Batched send and recv with previous rank in pipeline."""
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
input_tensor = None input_tensor = None
...@@ -214,13 +249,14 @@ def send_backward_recv_forward(input_tensor_grad, timers=None): ...@@ -214,13 +249,14 @@ def send_backward_recv_forward(input_tensor_grad, timers=None):
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=input_tensor_grad, tensor_send_prev=input_tensor_grad,
recv_prev=True, recv_prev=True,
recv_next=False) recv_next=False,
tensor_shape=tensor_shape)
if timers is not None: if timers is not None:
timers('backward-send-forward-recv').stop() timers('backward-send-forward-recv').stop()
return input_tensor return input_tensor
def send_forward_recv_forward(output_tensor, recv_prev, timers=None): def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape=None, timers=None):
"""Batched recv from previous rank and send to next rank in pipeline.""" """Batched recv from previous rank and send to next rank in pipeline."""
if timers is not None: if timers is not None:
timers('forward-send-forward-recv').start() timers('forward-send-forward-recv').start()
...@@ -228,13 +264,14 @@ def send_forward_recv_forward(output_tensor, recv_prev, timers=None): ...@@ -228,13 +264,14 @@ def send_forward_recv_forward(output_tensor, recv_prev, timers=None):
tensor_send_next=output_tensor, tensor_send_next=output_tensor,
tensor_send_prev=None, tensor_send_prev=None,
recv_prev=recv_prev, recv_prev=recv_prev,
recv_next=False) recv_next=False,
tensor_shape=tensor_shape)
if timers is not None: if timers is not None:
timers('forward-send-forward-recv').stop() timers('forward-send-forward-recv').stop()
return input_tensor return input_tensor
def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None): def send_backward_recv_backward(input_tensor_grad, recv_next, tensor_shape=None, timers=None):
"""Batched recv from next rank and send to previous rank in pipeline.""" """Batched recv from next rank and send to previous rank in pipeline."""
if timers is not None: if timers is not None:
timers('backward-send-backward-recv').start() timers('backward-send-backward-recv').start()
...@@ -242,7 +279,8 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None): ...@@ -242,7 +279,8 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=input_tensor_grad, tensor_send_prev=input_tensor_grad,
recv_prev=False, recv_prev=False,
recv_next=recv_next) recv_next=recv_next,
tensor_shape=tensor_shape)
if timers is not None: if timers is not None:
timers('backward-send-backward-recv').stop() timers('backward-send-backward-recv').stop()
return output_tensor_grad return output_tensor_grad
...@@ -250,7 +288,7 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None): ...@@ -250,7 +288,7 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
def send_forward_backward_recv_forward_backward( def send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad, recv_prev, output_tensor, input_tensor_grad, recv_prev,
recv_next, timers=None): recv_next, tensor_shape=None, timers=None):
"""Batched send and recv with previous and next ranks in pipeline.""" """Batched send and recv with previous and next ranks in pipeline."""
if timers is not None: if timers is not None:
timers('forward-backward-send-forward-backward-recv').start() timers('forward-backward-send-forward-backward-recv').start()
...@@ -258,7 +296,8 @@ def send_forward_backward_recv_forward_backward( ...@@ -258,7 +296,8 @@ def send_forward_backward_recv_forward_backward(
tensor_send_next=output_tensor, tensor_send_next=output_tensor,
tensor_send_prev=input_tensor_grad, tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev, recv_prev=recv_prev,
recv_next=recv_next) recv_next=recv_next,
tensor_shape=tensor_shape)
if timers is not None: if timers is not None:
timers('forward-backward-send-forward-backward-recv').stop() timers('forward-backward-send-forward-backward-recv').stop()
return input_tensor, output_tensor_grad return input_tensor, output_tensor_grad
...@@ -25,12 +25,17 @@ from megatron import p2p_communication ...@@ -25,12 +25,17 @@ from megatron import p2p_communication
from megatron.utils import unwrap_model from megatron.utils import unwrap_model
from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module from megatron.model import Float16Module
from megatron.model import ModelType
def get_forward_backward_func(): def get_forward_backward_func():
args = get_args() args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1: if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None: if args.virtual_pipeline_model_parallel_size is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving forward_backward_func = forward_backward_pipelining_with_interleaving
assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \
'number of microbatches is not divisible by pipeline-parallel ' \
'size when using interleaved schedule'
else: else:
forward_backward_func = forward_backward_pipelining_without_interleaving forward_backward_func = forward_backward_pipelining_without_interleaving
else: else:
...@@ -45,11 +50,18 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r ...@@ -45,11 +50,18 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
passed-in input_tensor is used. passed-in input_tensor is used.
Returns output tensor.""" Returns output tensor."""
args = get_args()
timers = get_timers() timers = get_timers()
timers('forward-compute').start() timers('forward-compute').start()
unwrapped_model = unwrap_model( unwrapped_model = unwrap_model(
model, (torchDDP, LocalDDP, Float16Module)) model, (torchDDP, LocalDDP, Float16Module))
unwrap_output_tensor = False
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
unwrap_output_tensor = True
unwrapped_model.set_input_tensor(input_tensor) unwrapped_model.set_input_tensor(input_tensor)
output_tensor, loss_func = forward_step_func(data_iterator, model) output_tensor, loss_func = forward_step_func(data_iterator, model)
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
...@@ -59,7 +71,15 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r ...@@ -59,7 +71,15 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
losses_reduced.append(loss_reduced) losses_reduced.append(loss_reduced)
timers('forward-compute').stop() timers('forward-compute').stop()
# If T5 model (or other model with encoder and decoder)
# and in decoder stack, then send encoder_hidden_state
# downstream as well.
if mpu.is_pipeline_stage_after_split() and \
args.model_type == ModelType.encoder_and_decoder:
return [output_tensor, input_tensor[-1]]
if unwrap_output_tensor:
return output_tensor return output_tensor
return [output_tensor]
def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad): def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
...@@ -70,24 +90,53 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad): ...@@ -70,24 +90,53 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
Returns gradient of loss with respect to input tensor (None if first Returns gradient of loss with respect to input tensor (None if first
stage).""" stage)."""
# NOTE: This code currently can handle at most one skip connection. It
# needs to be modified slightly to support arbitrary numbers of skip
# connections.
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
timers('backward-compute').start() timers('backward-compute').start()
# Retain the grad on the input_tensor. # Retain the grad on the input_tensor.
if input_tensor is not None: unwrap_input_tensor_grad = False
input_tensor.retain_grad() if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
unwrap_input_tensor_grad = True
for x in input_tensor:
if x is not None:
x.retain_grad()
if not isinstance(output_tensor, list):
output_tensor = [output_tensor]
if not isinstance(output_tensor_grad, list):
output_tensor_grad = [output_tensor_grad]
# Backward pass. # Backward pass.
if output_tensor_grad is None: if output_tensor_grad[0] is None:
output_tensor = optimizer.scale_loss(output_tensor) output_tensor = optimizer.scale_loss(output_tensor[0])
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad) torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0])
# Collect the grad of the input_tensor. # Collect the grad of the input_tensor.
input_tensor_grad = None input_tensor_grad = [None]
if input_tensor is not None: if input_tensor is not None:
input_tensor_grad = input_tensor.grad input_tensor_grad = []
for x in input_tensor:
if x is None:
input_tensor_grad.append(None)
else:
input_tensor_grad.append(x.grad)
# Handle single skip connection if it exists (encoder_hidden_state in
# model with encoder and decoder).
if mpu.get_pipeline_model_parallel_world_size() > 1 and \
mpu.is_pipeline_stage_after_split() and \
args.model_type == ModelType.encoder_and_decoder:
if output_tensor_grad[1] is not None:
input_tensor_grad[-1].add_(output_tensor_grad[1])
if unwrap_input_tensor_grad:
input_tensor_grad = input_tensor_grad[0]
timers('backward-compute').stop() timers('backward-compute').stop()
...@@ -150,6 +199,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -150,6 +199,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size() pipeline_parallel_size = mpu.get_pipeline_model_parallel_world_size()
pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank() pipeline_parallel_rank = mpu.get_pipeline_model_parallel_rank()
args = get_args()
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
# Compute number of warmup and remaining microbatches. # Compute number of warmup and remaining microbatches.
num_model_chunks = len(model) num_model_chunks = len(model)
num_microbatches = get_num_microbatches() * num_model_chunks num_microbatches = get_num_microbatches() * num_model_chunks
...@@ -191,6 +243,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -191,6 +243,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
model_chunk_id = get_model_chunk_id(microbatch_id, forward=True) model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id) mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
# forward step
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
if len(input_tensors[model_chunk_id]) == \ if len(input_tensors[model_chunk_id]) == \
len(output_tensors[model_chunk_id]): len(output_tensors[model_chunk_id]):
...@@ -202,6 +255,11 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -202,6 +255,11 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
input_tensor, losses_reduced) input_tensor, losses_reduced)
output_tensors[model_chunk_id].append(output_tensor) output_tensors[model_chunk_id].append(output_tensor)
# if forward-only, no need to save tensors for a backward pass
if forward_only:
input_tensors[model_chunk_id].pop()
output_tensors[model_chunk_id].pop()
return output_tensor return output_tensor
def backward_step_helper(microbatch_id): def backward_step_helper(microbatch_id):
...@@ -228,7 +286,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -228,7 +286,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
# Run warmup forward passes. # Run warmup forward passes.
mpu.set_virtual_pipeline_model_parallel_rank(0) mpu.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append( input_tensors[0].append(
p2p_communication.recv_forward(timers)) p2p_communication.recv_forward(tensor_shape, timers=timers))
for k in range(num_warmup_microbatches): for k in range(num_warmup_microbatches):
output_tensor = forward_step_helper(k) output_tensor = forward_step_helper(k)
...@@ -257,12 +315,15 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -257,12 +315,15 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
p2p_communication.send_forward_backward_recv_forward_backward( p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad, output_tensor, input_tensor_grad,
recv_prev=recv_prev, recv_next=recv_next, recv_prev=recv_prev, recv_next=recv_next,
tensor_shape=tensor_shape,
timers=timers) timers=timers)
output_tensor_grads[num_model_chunks-1].append(output_tensor_grad) output_tensor_grads[num_model_chunks-1].append(output_tensor_grad)
else: else:
input_tensor = \ input_tensor = \
p2p_communication.send_forward_recv_forward( p2p_communication.send_forward_recv_forward(
output_tensor, recv_prev, timers) output_tensor, recv_prev=recv_prev,
tensor_shape=tensor_shape,
timers=timers)
input_tensors[next_forward_model_chunk_id].append(input_tensor) input_tensors[next_forward_model_chunk_id].append(input_tensor)
# Run 1F1B in steady state. # Run 1F1B in steady state.
...@@ -326,7 +387,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -326,7 +387,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
p2p_communication.send_forward_backward_recv_forward_backward( p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad, output_tensor, input_tensor_grad,
recv_prev=recv_prev, recv_next=recv_next, recv_prev=recv_prev, recv_next=recv_next,
timers=timers) tensor_shape=tensor_shape, timers=timers)
# Put input_tensor and output_tensor_grad in data structures in the # Put input_tensor and output_tensor_grad in data structures in the
# right location. # right location.
...@@ -340,7 +401,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -340,7 +401,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
if not forward_only: if not forward_only:
if all_warmup_microbatches: if all_warmup_microbatches:
output_tensor_grads[num_model_chunks-1].append( output_tensor_grads[num_model_chunks-1].append(
p2p_communication.recv_backward(timers)) p2p_communication.recv_backward(tensor_shape, timers=timers))
for k in range(num_microbatches_remaining, num_microbatches): for k in range(num_microbatches_remaining, num_microbatches):
input_tensor_grad = backward_step_helper(k) input_tensor_grad = backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False) next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
...@@ -352,11 +413,107 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -352,11 +413,107 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
recv_next = False recv_next = False
output_tensor_grads[next_backward_model_chunk_id].append( output_tensor_grads[next_backward_model_chunk_id].append(
p2p_communication.send_backward_recv_backward( p2p_communication.send_backward_recv_backward(
input_tensor_grad, recv_next, timers)) input_tensor_grad, recv_next=recv_next,
tensor_shape=tensor_shape,
timers=timers))
return losses_reduced return losses_reduced
def get_tensor_shapes(rank, model_type):
# Determine right tensor sizes (based on position of rank with respect to split
# rank) and model size.
# Send two tensors if model is T5 and rank is in decoder stage:
# first tensor is decoder (pre-transpose),
# second tensor is encoder (post-transpose).
# If model is T5 and rank is at the boundary:
# send one tensor (post-transpose from encoder).
# Otherwise, send one tensor (pre-transpose).
args = get_args()
tensor_shapes = []
if model_type == ModelType.encoder_and_decoder:
if mpu.is_pipeline_stage_before_split(rank):
# If next rank is after split, then need transpose for encoder_hidden_state.
if mpu.is_pipeline_stage_before_split(rank+1):
tensor_shapes.append((args.seq_length, args.micro_batch_size, args.hidden_size))
else:
tensor_shapes.append((args.micro_batch_size, args.seq_length, args.hidden_size))
else:
tensor_shapes.append((args.decoder_seq_length, args.micro_batch_size, args.hidden_size))
tensor_shapes.append((args.micro_batch_size, args.seq_length, args.hidden_size))
else:
tensor_shapes.append((args.seq_length, args.micro_batch_size, args.hidden_size))
return tensor_shapes
def recv_forward(tensor_shapes, timers):
input_tensors = []
for tensor_shape in tensor_shapes:
if tensor_shape is None:
input_tensors.append(None)
else:
input_tensors.append(p2p_communication.recv_forward(tensor_shape,
timers=timers))
return input_tensors
def recv_backward(tensor_shapes, timers):
output_tensor_grads = []
for tensor_shape in tensor_shapes:
if tensor_shape is None:
output_tensor_grads.append(None)
else:
output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape,
timers=timers))
return output_tensor_grads
def send_forward(output_tensors, tensor_shapes, timers):
if not isinstance(output_tensors, list):
output_tensors = [output_tensors]
for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
if tensor_shape is None:
continue
p2p_communication.send_forward(output_tensor, tensor_shape, timers=timers)
def send_backward(input_tensor_grads, tensor_shapes, timers):
if not isinstance(input_tensor_grads, list):
input_tensor_grads = [input_tensor_grads]
for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
if tensor_shape is None:
continue
p2p_communication.send_backward(input_tensor_grad, tensor_shape, timers=timers)
def send_forward_recv_backward(output_tensors, tensor_shapes, timers):
if not isinstance(output_tensors, list):
output_tensors = [output_tensors]
output_tensor_grads = []
for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
if tensor_shape is None:
output_tensor_grads.append(None)
continue
output_tensor_grad = p2p_communication.send_forward_recv_backward(
output_tensor, tensor_shape, timers=timers)
output_tensor_grads.append(output_tensor_grad)
return output_tensor_grads
def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers):
if not isinstance(input_tensor_grads, list):
input_tensor_grads = [input_tensor_grads]
input_tensors = []
for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
if tensor_shape is None:
input_tensors.append(None)
continue
input_tensor = p2p_communication.send_backward_recv_forward(
input_tensor_grad, tensor_shape, timers=timers)
input_tensors.append(input_tensor)
return input_tensors
def forward_backward_pipelining_without_interleaving(forward_step_func, data_iterator, def forward_backward_pipelining_without_interleaving(forward_step_func, data_iterator,
model, optimizer, timers, model, optimizer, timers,
forward_only): forward_only):
...@@ -380,17 +537,29 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -380,17 +537,29 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
num_microbatches_remaining = \ num_microbatches_remaining = \
num_microbatches - num_warmup_microbatches num_microbatches - num_warmup_microbatches
unwrapped_model = unwrap_model(
model, (torchDDP, LocalDDP, Float16Module))
model_type = unwrapped_model.model_type
rank = mpu.get_pipeline_model_parallel_rank()
recv_tensor_shapes = get_tensor_shapes(rank-1, model_type)
send_tensor_shapes = get_tensor_shapes(rank, model_type)
# Input, output tensors only need to be saved when doing backward passes
input_tensors = None
output_tensors = None
if not forward_only:
input_tensors = [] input_tensors = []
output_tensors = [] output_tensors = []
losses_reduced = [] losses_reduced = []
# Run warmup forward passes. # Run warmup forward passes.
for i in range(num_warmup_microbatches): for i in range(num_warmup_microbatches):
input_tensor = p2p_communication.recv_forward(timers) input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
output_tensor = forward_step(forward_step_func, data_iterator, model, output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced) input_tensor, losses_reduced)
p2p_communication.send_forward(output_tensor, timers) send_forward(output_tensor, send_tensor_shapes, timers=timers)
if not forward_only:
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
output_tensors.append(output_tensor) output_tensors.append(output_tensor)
...@@ -398,7 +567,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -398,7 +567,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
# If all microbatches are run in warmup / cooldown phase, then no need to # If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here. # receive this tensor here.
if num_microbatches_remaining > 0: if num_microbatches_remaining > 0:
input_tensor = p2p_communication.recv_forward(timers) input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
# Run 1F1B in steady state. # Run 1F1B in steady state.
for i in range(num_microbatches_remaining): for i in range(num_microbatches_remaining):
...@@ -407,22 +576,25 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -407,22 +576,25 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
output_tensor = forward_step(forward_step_func, data_iterator, model, output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced) input_tensor, losses_reduced)
if forward_only: if forward_only:
p2p_communication.send_forward(output_tensor, timers) send_forward(output_tensor, send_tensor_shapes, timers=timers)
if not last_iteration:
input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
else: else:
output_tensor_grad = \ output_tensor_grad = \
p2p_communication.send_forward_recv_backward(output_tensor, send_forward_recv_backward(output_tensor,
timers) send_tensor_shapes,
timers=timers)
# Add input_tensor and output_tensor to end of list, then pop from the # Add input_tensor and output_tensor to end of list.
# start of the list for backward pass.
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
output_tensors.append(output_tensor) output_tensors.append(output_tensor)
if forward_only: # Pop input_tensor and output_tensor from the start of the list for
if not last_iteration: # the backward pass.
input_tensor = p2p_communication.recv_forward(timers) input_tensor = input_tensors.pop(0)
else: output_tensor = output_tensors.pop(0)
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
input_tensor_grad = \ input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor, backward_step(optimizer, input_tensor, output_tensor,
...@@ -430,11 +602,11 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -430,11 +602,11 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
if last_iteration: if last_iteration:
input_tensor = None input_tensor = None
p2p_communication.send_backward(input_tensor_grad, timers) send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers)
else: else:
input_tensor = \ input_tensor = \
p2p_communication.send_backward_recv_forward( send_backward_recv_forward(
input_tensor_grad, timers) input_tensor_grad, recv_tensor_shapes, timers=timers)
# Run cooldown backward passes. # Run cooldown backward passes.
if not forward_only: if not forward_only:
...@@ -442,12 +614,12 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -442,12 +614,12 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
input_tensor = input_tensors.pop(0) input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0) output_tensor = output_tensors.pop(0)
output_tensor_grad = p2p_communication.recv_backward(timers) output_tensor_grad = recv_backward(send_tensor_shapes, timers=timers)
input_tensor_grad = \ input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor, backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad) output_tensor_grad)
p2p_communication.send_backward(input_tensor_grad, timers) send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers)
return losses_reduced return losses_reduced
...@@ -13,18 +13,7 @@ ...@@ -13,18 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
MAJOR = 1
MINOR = 1.5
# Use the following formatting: (major, minor)
VERSION = (MAJOR, MINOR)
__version__ = '.'.join(map(str, VERSION))
__package_name__ = 'megatron-lm'
__contact_names__ = 'NVIDIA INC'
__url__ = 'https://github.com/NVIDIA/Megatron-LM'
__download_url__ = 'https://github.com/NVIDIA/Megatron-LM/releases'
__description__ = 'Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism.'
__license__ = 'See https://github.com/NVIDIA/Megatron-LM/blob/master/LICENSE'
__keywords__ = 'deep learning, Megatron, gpu, NLP, nvidia, pytorch, torch, language'
from .api import (
generate,
generate_and_post_process)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference API."""
import torch
from megatron import mpu
from .communication import broadcast_float_list
from .generation import (
generate_tokens_probs_and_return_on_first_stage,
score_and_return_on_first_stage)
from .tokenization import (
tokenize_prompts,
detokenize_generations)
def generate_and_post_process(model,
prompts=None,
tokens_to_generate=0,
return_output_log_probs=False,
top_k_sampling=0,
top_p_sampling=0.0,
temperature=1.0,
add_BOS=False,
use_eod_token_for_early_termination=True):
"""Run inference and post-process outputs, i.e., detokenize,
move to cpu and convert to list."""
# Main inference.
tokens, lengths, output_log_probs = generate(
model,
prompts=prompts,
tokens_to_generate=tokens_to_generate,
return_output_log_probs=return_output_log_probs,
top_k_sampling=top_k_sampling,
top_p_sampling=top_p_sampling,
temperature=temperature,
add_BOS=add_BOS,
use_eod_token_for_early_termination=use_eod_token_for_early_termination)
# Only post-process on first stage.
if mpu.is_pipeline_first_stage():
tokens, prompts_plus_generations, prompts_plus_generations_segments = \
detokenize_generations(tokens, lengths, True)
if return_output_log_probs:
output_log_probs = output_log_probs.cpu().numpy().tolist()
for i, (prob, seg) in enumerate(zip(output_log_probs, prompts_plus_generations_segments)):
output_log_probs[i] = prob[:len(seg)-1]
return prompts_plus_generations, prompts_plus_generations_segments, \
output_log_probs, tokens
return None
def generate(model,
prompts=None,
tokens_to_generate=0,
return_output_log_probs=False,
top_k_sampling=0,
top_p_sampling=0.0,
temperature=1.0,
add_BOS=False,
use_eod_token_for_early_termination=True):
"""Given prompts and input parameters, run inference and return:
tokens: prompts plus the generated tokens.
lengths: length of the prompt + generations. Note that we can
discard tokens in the tokens tensor that are after the
corresponding length.
output_log_probs: log probs of the tokens.
"""
# Make sure input params are avaialble to all ranks.
values = [tokens_to_generate,
return_output_log_probs,
top_k_sampling, top_p_sampling,
temperature, add_BOS, use_eod_token_for_early_termination]
values_float_tensor = broadcast_float_list(7, float_list=values)
tokens_to_generate = int(values_float_tensor[0].item())
return_output_log_probs = bool(values_float_tensor[1].item())
top_k_sampling = int(values_float_tensor[2].item())
top_p_sampling = values_float_tensor[3].item()
temperature = values_float_tensor[4].item()
add_BOS = bool(values_float_tensor[5].item())
use_eod_token_for_early_termination = bool(values_float_tensor[6].item())
# Tokenize prompts and get the batch.
# Note that these tensors are broadcaseted to all ranks.
if torch.distributed.get_rank() == 0:
assert prompts is not None
context_tokens_tensor, context_length_tensor = tokenize_prompts(
prompts=prompts, tokens_to_generate=tokens_to_generate, add_BOS=add_BOS)
if tokens_to_generate == 0:
return score_and_return_on_first_stage(
model, context_tokens_tensor, context_length_tensor)
# Main inference function.
# Note that the outputs are available on the first stage.
return generate_tokens_probs_and_return_on_first_stage(
model, context_tokens_tensor, context_length_tensor,
return_output_log_probs=return_output_log_probs,
top_k=top_k_sampling,
top_p=top_p_sampling,
temperature=temperature,
use_eod_token_for_early_termination=use_eod_token_for_early_termination)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Communications utilities."""
import torch
from megatron import mpu
# TODO: use functions from megatron/p2p
def recv_from_prev_pipeline_rank_(recv_buffer=None):
"""Receive from previous pipeline stage and update the
input buffer inplace."""
if not mpu.is_pipeline_first_stage():
assert recv_buffer is not None
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv, recv_buffer,
mpu.get_pipeline_model_parallel_prev_rank())
reqs = torch.distributed.batch_isend_irecv([recv_prev_op])
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
# TODO: use functions from megatron/p2p
def send_to_next_pipeline_rank(tensor=None):
"""Send output to the next pipeline stage."""
if not mpu.is_pipeline_last_stage():
assert tensor is not None
send_next_op = torch.distributed.P2POp(
torch.distributed.isend, tensor,
mpu.get_pipeline_model_parallel_next_rank())
reqs = torch.distributed.batch_isend_irecv([send_next_op])
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
def _is_cuda(tensor):
"""Check if a tensor is not none and is cuda."""
assert tensor is not None
assert tensor.is_cuda
def _is_cuda_contiguous(tensor):
"""Check if a tensor is not none, is cuda, and is contiguous."""
_is_cuda(tensor)
assert tensor.is_contiguous()
def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
"""Broadcast a tensor from last pipeline stage to all ranks."""
is_last_stage = mpu.is_pipeline_last_stage()
# If first stage and last state are the same, then there is no
# pipeline parallelism and no need to communicate.
if mpu.is_pipeline_first_stage() and is_last_stage:
return tensor
if is_last_stage:
_is_cuda_contiguous(tensor)
else:
tensor = torch.empty(size,
dtype=dtype,
device=torch.cuda.current_device())
# Get the group and corresponding source rank.
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_pipeline_model_parallel_group()
torch.distributed.broadcast(tensor, src, group)
return tensor
def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
"""Broadcast tensor values from last stage into the first stage."""
is_last_stage = mpu.is_pipeline_last_stage()
is_first_stage = mpu.is_pipeline_first_stage()
# If first stage and last state are the same, then there is no
# pipeline parallelism and no need to communicate.
if is_first_stage and is_last_stage:
return tensor
# Only first and last stage pipeline stages need to be involved.
if is_last_stage or is_first_stage:
if is_last_stage:
_is_cuda_contiguous(tensor)
else:
tensor = torch.empty(size,
dtype=dtype,
device=torch.cuda.current_device())
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
# Broadcast from last stage into the first stage.
torch.distributed.broadcast(tensor, src, group)
else:
tensor = None
return tensor
def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
"""Copy tensor values from last stage into the first stage.
Note that the input tensor is updated in place."""
is_last_stage = mpu.is_pipeline_last_stage()
is_first_stage = mpu.is_pipeline_first_stage()
# If first stage and last state are the same, then there is no
# pipeline parallelism and no need to communicate.
if is_first_stage and is_last_stage:
return
# Only first and last stage pipeline stages need to be involved.
if is_last_stage or is_first_stage:
_is_cuda(tensor)
is_contiguous = tensor.is_contiguous()
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
if is_contiguous:
tensor_ = tensor
else:
if is_last_stage:
tensor_ = tensor.contiguous()
else:
tensor_ = torch.empty(size,
dtype=dtype,
device=torch.cuda.current_device())
# Broadcast from last stage into the first stage.
torch.distributed.broadcast(tensor_, src, group)
# Update the first stage tensor
if is_first_stage and not is_contiguous:
tensor[...] = tensor_
def broadcast_tensor(size, dtype, tensor=None, rank=0):
""" Given size and type of a tensor on all ranks and the tensor value
only on a specific rank, broadcast from that rank to all other ranks.
"""
if torch.distributed.get_rank() == rank:
_is_cuda_contiguous(tensor)
else:
tensor = torch.empty(size,
dtype=dtype,
device=torch.cuda.current_device())
torch.distributed.broadcast(tensor, rank)
return tensor
def broadcast_list(size, dtype, list_values=None, rank=0):
"""Broadcast a list of values with a given type."""
tensor = None
if torch.distributed.get_rank() == rank:
tensor = torch.tensor(list_values, dtype=dtype,
device=torch.cuda.current_device())
return broadcast_tensor(size, dtype, tensor=tensor, rank=rank)
def broadcast_int_list(size, int_list=None, rank=0):
"""Broadcast a list of interger values."""
return broadcast_list(size, torch.int64, list_values=int_list, rank=rank)
def broadcast_float_list(size, float_list=None, rank=0):
"""Broadcast a list of float values."""
return broadcast_list(size, torch.float32, list_values=float_list,
rank=rank)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Forward step utilities."""
from collections.abc import Iterable
import torch
from megatron import (
get_args,
mpu)
from .communication import (
send_to_next_pipeline_rank,
recv_from_prev_pipeline_rank_)
class InferenceParams:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
def __init__(self, max_batch_size, max_sequence_len):
"""Note that offsets are set to zero and we always set the
flag to allocate memory. After the first call, make sure to
set this flag to False."""
self.max_sequence_len = max_sequence_len
self.max_batch_size = max_batch_size
self.sequence_len_offset = 0
self.batch_size_offset = 0
self.key_value_memory_dict = {}
class ForwardStep:
"""Forward step function with all the communications.
We use a class here to hide the inference parameters
from the outside caller."""
def __init__(self, model, max_batch_size, max_sequence_len):
"""Set values so we don't need to do it multiple times."""
# Make sure model is in eval mode.
assert not isinstance(model, Iterable), \
'interleaving schedule is not supported for inference'
model.eval()
self.model = model
# Initialize inference parameters.
self.inference_params = InferenceParams(max_batch_size,
max_sequence_len)
# Pipelining arguments.
args = get_args()
self.pipeline_size_larger_than_one = (
args.pipeline_model_parallel_size > 1)
# Threshold of pipelining.
self.pipelining_batch_x_seqlen = \
args.inference_batch_times_seqlen_threshold
def __call__(self, tokens, position_ids, attention_mask):
"""Invocation of the forward methods. Note that self.inference_params
is being modified by the forward step."""
# Pipelining case.
if self.pipeline_size_larger_than_one:
current_batch_x_seqlen = tokens.size(0) * tokens.size(1)
if current_batch_x_seqlen >= self.pipelining_batch_x_seqlen:
micro_batch_size = \
max(1, self.pipelining_batch_x_seqlen // tokens.size(1))
return _with_pipelining_forward_step(self.model,
tokens,
position_ids,
attention_mask,
self.inference_params,
micro_batch_size)
return _no_pipelining_forward_step(self.model,
tokens,
position_ids,
attention_mask,
self.inference_params)
def _get_recv_buffer_dtype(args):
"""Receive happens between the layers."""
if args.fp32_residual_connection:
return torch.float
return args.params_dtype
def _allocate_recv_buffer(batch_size, sequence_length):
"""Receive happens between the layers with size [s, b, h]."""
if mpu.is_pipeline_first_stage():
return None
args = get_args()
recv_size = (sequence_length, batch_size, args.hidden_size)
return torch.empty(recv_size,
dtype=_get_recv_buffer_dtype(args),
device=torch.cuda.current_device())
def _forward_step_helper(model, tokens, position_ids, attention_mask,
inference_params, recv_buffer=None):
"""Single forward step. Update the allocate memory flag so
only the first time the memory is allocated."""
batch_size = tokens.size(0)
sequence_length = tokens.size(1)
if recv_buffer is None:
recv_buffer = _allocate_recv_buffer(batch_size, sequence_length)
# Receive from previous stage.
recv_from_prev_pipeline_rank_(recv_buffer)
# Forward pass through the model.
model.set_input_tensor(recv_buffer)
output_tensor = model(tokens, position_ids, attention_mask,
inference_params=inference_params)
# Send output to the next stage.
send_to_next_pipeline_rank(output_tensor)
return output_tensor
def _no_pipelining_forward_step(model, tokens, position_ids, attention_mask,
inference_params, recv_buffer=None):
"""If recv_buffer is none, we will allocate one on the fly."""
# Run a simple forward pass.
output_tensor = _forward_step_helper(model, tokens, position_ids,
attention_mask, inference_params,
recv_buffer=recv_buffer)
# Update the sequence length offset.
inference_params.sequence_len_offset += tokens.size(1)
logits = None
if mpu.is_pipeline_last_stage():
logits = output_tensor
return logits
def _with_pipelining_forward_step(model, tokens, position_ids, attention_mask,
inference_params, micro_batch_size):
"""No interleaving is supported."""
sequence_length = tokens.size(1)
batch_size = tokens.size(0)
# Divide the batch dimension into micro batches.
num_micro_batches, last_chunk = divmod(batch_size,
micro_batch_size)
if last_chunk > 0:
num_micro_batches += 1
# Preallocate memory for output logits.
logits = None
if mpu.is_pipeline_last_stage():
args = get_args()
logits = torch.empty(
(batch_size, sequence_length, args.padded_vocab_size),
dtype=torch.float32, device=torch.cuda.current_device())
# Preallocate recv buffer.
recv_buffer = _allocate_recv_buffer(micro_batch_size, sequence_length)
for micro_batch_index in range(num_micro_batches):
# Slice among the batch dimenion.
start = micro_batch_index * micro_batch_size
end = min(start + micro_batch_size, batch_size)
this_micro_batch_size = end - start
tokens2use = tokens[start:end, ...]
position_ids2use = position_ids[start:end, ...]
# Run a simple forward pass.
if this_micro_batch_size != micro_batch_size:
recv_buffer = None
output = _forward_step_helper(model, tokens2use, position_ids2use,
attention_mask, inference_params,
recv_buffer=recv_buffer)
# Adjust the batch size offset to account for the micro-batch.
inference_params.batch_size_offset += this_micro_batch_size
# Copy logits.
if mpu.is_pipeline_last_stage():
logits[start:end, ...] = output
# Once we are done with all the micro-batches, we can
# adjust the sequence length offset.
inference_params.sequence_len_offset += sequence_length
# and reset the batch size offset
inference_params.batch_size_offset = 0
return logits
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generation utilities."""
import torch
import torch.nn.functional as F
from megatron import get_args, get_tokenizer, mpu
from megatron.utils import get_ltor_masks_and_position_ids
from .communication import (
copy_from_last_to_first_pipeline_stage,
broadcast_from_last_pipeline_stage,
broadcast_from_last_to_first_pipeline_stage)
from .forward_step import ForwardStep
from .sampling import sample
def score_and_return_on_first_stage(model, tokens, lengths):
"""Function for just scoring.
Arguments:
model: no interleaving is supported.
tokens: prompt tokens extended to be of size [b, max_prompt_length]
lengths: original prompt length, size: [b]
Note: Outside of model, other parameters only need to be available on
rank 0.
Outputs:
output_log_probs: log probability of the selected tokens. size: [b, s]
"""
args = get_args()
batch_size = tokens.size(0)
max_prompt_length = lengths.max().item()
assert max_prompt_length == tokens.size(1)
max_sequence_length = min(max_prompt_length, args.max_position_embeddings)
# forward step.
forward_step = ForwardStep(model, batch_size, max_sequence_length)
# ===================
# Pre-allocate memory
# ===================
# Log probability of the sequence (prompt + generated tokens).
output_log_probs = None
output_log_probs_size = (batch_size, max_sequence_length - 1)
if mpu.is_pipeline_last_stage():
output_log_probs = torch.empty(output_log_probs_size,
dtype=torch.float32,
device=torch.cuda.current_device())
# =============
# Run infernece
# =============
with torch.no_grad():
attention_mask, position_ids = _build_attention_mask_and_position_ids(tokens)
# logits will be meanigful only in the last pipeline stage.
logits = forward_step(tokens, position_ids, attention_mask)
if mpu.is_pipeline_last_stage():
# Always the last stage should have an output.
assert logits is not None
log_probs = F.log_softmax(logits, dim=2)
# Pick the tokens that we need to get the log
# probabilities for. Note that next input token is
# the token which we selected in the current logits,
# so shift by 1.
indices = torch.unsqueeze(tokens[:, 1:], 2)
output_log_probs = torch.gather(log_probs, 2, indices).squeeze(2)
# ======================================
# Broadcast to the first pipeline stage.
# ======================================
output_log_probs = broadcast_from_last_to_first_pipeline_stage(
output_log_probs_size, torch.float32, output_log_probs)
return tokens, lengths, output_log_probs
def generate_tokens_probs_and_return_on_first_stage(
model, tokens, lengths,
return_output_log_probs=False,
top_k=0, top_p=0.0,
temperature=1.0,
use_eod_token_for_early_termination=True):
"""Main token generation function.
Arguments:
model: no interleaving is supported.
tokens: prompt tokens extended to be of size [b, max-sequence-length]
lengths: original prompt length, size: [b]
return_output_log_probs: flag to calculate the log probability of
the generated tokens. Note that the log probability is the one
from the original logit.
top_k, top_p: top-k and top-p sampling parameters.
Note that top-k = 1 is gready. Also, these paramters are
exclusive meaning that:
if top-k > 0 then we expect top-p=0.
if top-p > 0 then we check for top-k=0.
temperature: sampling temperature.
use_eod_token_for_early_termination: if True, do early termination if
all the sequences have reached this token.
Note: Outside of model, other parameters only need to be available on
rank 0.
Outputs: Note that is size is adjusted to a lower value than
max-sequence-length if generation is terminated early.
tokens: prompt and generated tokens. size: [b, :]
generated_sequence_lengths: total length (including prompt) of
the generated sequence. size: [b]
output_log_probs: log probability of the selected tokens. size: [b, s]
"""
args = get_args()
tokenizer = get_tokenizer()
batch_size = tokens.size(0)
min_prompt_length = lengths.min().item()
max_sequence_length = tokens.size(1)
max_sequence_length = min(max_sequence_length, args.max_position_embeddings)
# forward step.
forward_step = ForwardStep(model, batch_size, max_sequence_length)
# Added termination_id to support the case that we want to terminate the
# generation once that id is generated.
if hasattr(args, 'eos_id'):
termination_id = args.eos_id
else:
termination_id = tokenizer.eod
# ===================
# Pre-allocate memory
# ===================
# Log probability of the sequence (prompt + generated tokens).
output_log_probs = None
output_log_probs_size = (batch_size, max_sequence_length - 1)
# Lengths of generated seuquence including including prompts.
generated_sequence_lengths = None
if mpu.is_pipeline_last_stage():
if return_output_log_probs:
output_log_probs = torch.empty(output_log_probs_size,
dtype=torch.float32,
device=torch.cuda.current_device())
generated_sequence_lengths = torch.ones(
batch_size, dtype=torch.int64,
device=torch.cuda.current_device()) * max_sequence_length
# Whether we have reached a termination id.
is_generation_done = torch.zeros(batch_size, dtype=torch.uint8,
device=torch.cuda.current_device())
# =============
# Run infernece
# =============
with torch.no_grad():
attention_mask, position_ids = _build_attention_mask_and_position_ids(
tokens)
prev_context_length = 0
for context_length in range(min_prompt_length, max_sequence_length):
# Pick the slice that we need to pass through the network.
tokens2use = tokens[:, prev_context_length:context_length]
positions2use = position_ids[:, prev_context_length:context_length]
attention_mask2use = attention_mask[
..., prev_context_length:context_length, :context_length]
# logits will be meanigful only in the last pipeline stage.
logits = forward_step(tokens2use, positions2use, attention_mask2use)
if mpu.is_pipeline_last_stage():
# Always the last stage should have an output.
assert logits is not None
# Sample.
last_token_logits = logits[:, -1, :]
new_sample = sample(last_token_logits,
top_k=top_k,
top_p=top_p,
temperature=temperature,
vocab_size=tokenizer.vocab_size)
# If a prompt length is smaller or equal th current context
# length, it means we have started generating tokens
started = lengths <= context_length
# Update the tokens.
tokens[started, context_length] = new_sample[started]
# Calculate the log probabilities.
if return_output_log_probs:
log_probs = F.log_softmax(logits, dim=2)
if return_output_log_probs:
# Pick the tokens that we need to get the log
# probabilities for. Note that next input token is
# the token which we selected in the current logits,
# so shift by 1.
indices = torch.unsqueeze(
tokens[
:,
(prev_context_length + 1):(context_length + 1)],
2)
output_log_probs[:,
prev_context_length:context_length] = \
torch.gather(log_probs, 2, indices).squeeze(2)
# Update the tokens on the first stage so the next input to
# the network is correct.
copy_from_last_to_first_pipeline_stage(batch_size, torch.int64,
tokens[:, context_length])
# Update the context length for the next token generation.
prev_context_length = context_length
# Check if all the sequences have hit the termination_id.
done = None
if mpu.is_pipeline_last_stage():
done_token = (new_sample == termination_id).byte() & \
started.byte()
just_finished = (done_token & ~is_generation_done).bool()
generated_sequence_lengths[just_finished.view(-1)] = \
context_length + 1
is_generation_done = is_generation_done | done_token
done = torch.all(is_generation_done)
done = broadcast_from_last_pipeline_stage(1, torch.uint8,
tensor=done)
if use_eod_token_for_early_termination and done:
break
# ===================================================
# Update the length of based on max generated length.
# ===================================================
tokens = tokens[:, :(context_length + 1)]
if mpu.is_pipeline_last_stage():
if return_output_log_probs:
output_log_probs = output_log_probs[:, :context_length]
# ======================================
# Broadcast to the first pipeline stage.
# ======================================
generated_sequence_lengths = broadcast_from_last_to_first_pipeline_stage(
batch_size, torch.int64, generated_sequence_lengths)
if return_output_log_probs:
output_log_probs_size = (batch_size, context_length)
output_log_probs = broadcast_from_last_to_first_pipeline_stage(
output_log_probs_size, torch.float32, output_log_probs)
return tokens, generated_sequence_lengths, output_log_probs
def _build_attention_mask_and_position_ids(tokens):
"""Build the attention mask and postition ids for the input tokens."""
# Since we are not interested in loss-mask and reset attention/position
# is also False, eod_token is not used so it is safe to set it to None.
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
data=tokens,
eod_token=None,
reset_position_ids=False,
reset_attention_mask=False,
eod_mask_loss=False)
return attention_mask, position_ids
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -266,7 +266,7 @@ class _GPT2BPETokenizer(AbstractTokenizer): ...@@ -266,7 +266,7 @@ class _GPT2BPETokenizer(AbstractTokenizer):
super().__init__(name) super().__init__(name)
if special_tokens is not None: if special_tokens is not None:
# for controllable dialog, special_tokens: "[SEP],[CTRL],[PAD]" # special_tokens: "[SEP],[PAD]"
special_tokens = special_tokens.split(",") special_tokens = special_tokens.split(",")
else: else:
special_tokens = [] special_tokens = []
...@@ -277,8 +277,6 @@ class _GPT2BPETokenizer(AbstractTokenizer): ...@@ -277,8 +277,6 @@ class _GPT2BPETokenizer(AbstractTokenizer):
if special_tokens is not None and len(special_tokens) > 0: if special_tokens is not None and len(special_tokens) > 0:
if "[SEP]" in special_tokens: if "[SEP]" in special_tokens:
self.sep_id = self.tokenizer.special_tokens['[SEP]'] self.sep_id = self.tokenizer.special_tokens['[SEP]']
if "[CTRL]" in special_tokens:
self.ctrl_id = self.tokenizer.special_tokens['[CTRL]']
if "[PAD]" in special_tokens: if "[PAD]" in special_tokens:
self.pad_id = self.tokenizer.special_tokens['[PAD]'] self.pad_id = self.tokenizer.special_tokens['[PAD]']
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment