Commit 523ec9cc authored by wangsen's avatar wangsen
Browse files

all

parents
Pipeline #1668 failed with stages
in 0 seconds
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import operator
from functools import reduce
from typing import Callable, List, Optional, Tuple, Union
import torch
from megatron import core
from megatron.core import ModelParallelConfig
from megatron.core.parallel_state import (
get_pipeline_model_parallel_group,
get_pipeline_model_parallel_next_rank,
get_pipeline_model_parallel_prev_rank,
get_pipeline_model_parallel_rank,
get_pipeline_model_parallel_world_size,
)
# Types
Shape = Union[List[int], torch.Size]
def _communicate_shapes(tensor_send_next, tensor_send_prev, recv_prev, recv_next, config):
"""Communicate tensor shapes between stages. Used to communicate
tensor shapes before the actual tensor communication happens.
This is required when the sequence lengths across micro batches
are not uniform.
Args:
tensor_send_next: tensor to send to next rank (no tensor sent if
set to None).
tensor_send_prev: tensor to send to prev rank (no tensor sent if
set to None).
recv_prev: boolean for whether tensor should be received from
previous rank.
recv_next: boolean for whether tensor should be received from
next rank.
Returns:
(recv_prev_shape, recv_next_shape)
"""
recv_prev_shape_tensor = None
recv_next_shape_tensor = None
send_prev_shape_tensor = None
send_next_shape_tensor = None
if recv_prev:
recv_prev_shape_tensor = torch.empty(
(3), device=torch.cuda.current_device(), dtype=torch.int64
)
if recv_next:
recv_next_shape_tensor = torch.empty(
(3), device=torch.cuda.current_device(), dtype=torch.int64
)
if tensor_send_prev is not None:
send_prev_shape_tensor = torch.tensor(
tensor_send_prev.size(), device=torch.cuda.current_device(), dtype=torch.int64
)
if tensor_send_next is not None:
send_next_shape_tensor = torch.tensor(
tensor_send_next.size(), device=torch.cuda.current_device(), dtype=torch.int64
)
if config.use_ring_exchange_p2p:
torch.distributed.ring_exchange(
tensor_send_prev=send_prev_shape_tensor,
tensor_recv_prev=recv_prev_shape_tensor,
tensor_send_next=send_next_shape_tensor,
tensor_recv_next=recv_next_shape_tensor,
group=get_pipeline_model_parallel_group(),
)
else:
ops = []
if send_prev_shape_tensor is not None:
send_prev_op = torch.distributed.P2POp(
torch.distributed.isend,
send_prev_shape_tensor,
get_pipeline_model_parallel_prev_rank(),
)
ops.append(send_prev_op)
if recv_prev_shape_tensor is not None:
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv,
recv_prev_shape_tensor,
get_pipeline_model_parallel_prev_rank(),
)
ops.append(recv_prev_op)
if send_next_shape_tensor is not None:
send_next_op = torch.distributed.P2POp(
torch.distributed.isend,
send_next_shape_tensor,
get_pipeline_model_parallel_next_rank(),
)
ops.append(send_next_op)
if recv_next_shape_tensor is not None:
recv_next_op = torch.distributed.P2POp(
torch.distributed.irecv,
recv_next_shape_tensor,
get_pipeline_model_parallel_next_rank(),
)
ops.append(recv_next_op)
if len(ops) > 0:
reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
# should take this out once the bug with batch_isend_irecv is resolved.
torch.cuda.synchronize()
recv_prev_shape = [0, 0, 0]
if recv_prev_shape_tensor is not None:
recv_prev_shape = recv_prev_shape_tensor.tolist()
recv_next_shape = [0, 0, 0]
if recv_next_shape_tensor is not None:
recv_next_shape = recv_next_shape_tensor.tolist()
return recv_prev_shape, recv_next_shape
def _batched_p2p_ops(
*,
tensor_send_prev: Optional[torch.Tensor],
tensor_recv_prev: Optional[torch.Tensor],
tensor_send_next: Optional[torch.Tensor],
tensor_recv_next: Optional[torch.Tensor],
group: torch.distributed.ProcessGroup
):
ops = []
if tensor_send_prev is not None:
send_prev_op = torch.distributed.P2POp(
torch.distributed.isend,
tensor_send_prev,
get_pipeline_model_parallel_prev_rank(),
group,
)
ops.append(send_prev_op)
if tensor_recv_prev is not None:
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv,
tensor_recv_prev,
get_pipeline_model_parallel_prev_rank(),
group,
)
ops.append(recv_prev_op)
if tensor_send_next is not None:
send_next_op = torch.distributed.P2POp(
torch.distributed.isend,
tensor_send_next,
get_pipeline_model_parallel_next_rank(),
group,
)
ops.append(send_next_op)
if tensor_recv_next is not None:
recv_next_op = torch.distributed.P2POp(
torch.distributed.irecv,
tensor_recv_next,
get_pipeline_model_parallel_next_rank(),
group,
)
ops.append(recv_next_op)
if len(ops) > 0:
reqs = torch.distributed.batch_isend_irecv(ops)
else:
reqs = []
return reqs
def _p2p_ops(
*,
tensor_send_prev: Optional[torch.Tensor],
tensor_recv_prev: Optional[torch.Tensor],
tensor_send_next: Optional[torch.Tensor],
tensor_recv_next: Optional[torch.Tensor],
group: torch.distributed.ProcessGroup
):
reqs = []
rank = get_pipeline_model_parallel_rank()
even_send_odd_recv_group = group
if get_pipeline_model_parallel_world_size() == 2:
# Use the global process group for one of the two p2p communications
# to allow the overlap of the independent communications.
# Using the global process group is compatible because the pipeline-parallel
# communications set the source and destination by global rank.
even_recv_odd_send_group = torch.distributed.group.WORLD
else:
even_recv_odd_send_group = group
if get_pipeline_model_parallel_rank() % 2 == 0:
if tensor_send_next is not None:
send_next_req = torch.distributed.isend(
tensor=tensor_send_next,
dst=get_pipeline_model_parallel_next_rank(),
group=even_send_odd_recv_group,
)
reqs.append(send_next_req)
if tensor_recv_prev is not None:
recv_prev_req = torch.distributed.irecv(
tensor=tensor_recv_prev,
src=get_pipeline_model_parallel_prev_rank(),
group=even_recv_odd_send_group,
)
reqs.append(recv_prev_req)
if tensor_send_prev is not None:
send_prev_req = torch.distributed.isend(
tensor=tensor_send_prev,
dst=get_pipeline_model_parallel_prev_rank(),
group=even_send_odd_recv_group,
)
reqs.append(send_prev_req)
if tensor_recv_next is not None:
recv_next_req = torch.distributed.irecv(
tensor=tensor_recv_next,
src=get_pipeline_model_parallel_next_rank(),
group=even_recv_odd_send_group,
)
reqs.append(recv_next_req)
else:
if tensor_recv_prev is not None:
recv_prev_req = torch.distributed.irecv(
tensor=tensor_recv_prev,
src=get_pipeline_model_parallel_prev_rank(),
group=even_send_odd_recv_group,
)
reqs.append(recv_prev_req)
if tensor_send_next is not None:
send_next_req = torch.distributed.isend(
tensor=tensor_send_next,
dst=get_pipeline_model_parallel_next_rank(),
group=even_recv_odd_send_group,
)
reqs.append(send_next_req)
if tensor_recv_next is not None:
recv_next_req = torch.distributed.irecv(
tensor=tensor_recv_next,
src=get_pipeline_model_parallel_next_rank(),
group=even_send_odd_recv_group,
)
reqs.append(recv_next_req)
if tensor_send_prev is not None:
send_prev_req = torch.distributed.isend(
tensor=tensor_send_prev,
dst=get_pipeline_model_parallel_prev_rank(),
group=even_recv_odd_send_group,
)
reqs.append(send_prev_req)
return reqs
def _communicate(
*,
tensor_send_next: Optional[torch.Tensor],
tensor_send_prev: Optional[torch.Tensor],
recv_prev: bool,
recv_next: bool,
tensor_shape: Shape,
config: ModelParallelConfig,
wait_on_reqs: bool = True
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py.
Args:
tensor_send_next (torch.Tensor, optional):
Tensor to send to next rank (no tensor sent if None)
tensor_send_prev (torch.Tensor, optional):
Tensor to send to prev rank (no tensor sent if None)
recv_prev (boolean, required):
whether tensor should be received from previous rank.
recv_next (boolean, required):
whether tensor should be received from next rank.
tensor_shape (List[int] or torch.Size, required):
shape of tensor to receive (this method assumes that all
tensors sent and received in a single function call are
the same shape).
wait_on_reqs (boolean, optional, default=False):
For non-batched p2p communication, wait on each request
before returning.
Returns:
tuple containing
- tensor_recv_prev: torch.Tensor if recv_prev is True, None otherwise.
- tensor_recv_next: torch.Tensor if recv_next is True, None otherwise.
"""
# Create placeholder tensors for receive in forward and backward directions
# if needed.
tensor_recv_prev = None
tensor_recv_next = None
if not config.variable_seq_lengths:
recv_prev_shape = tensor_shape
recv_next_shape = tensor_shape
else:
recv_prev_shape, recv_next_shape = _communicate_shapes(
tensor_send_next, tensor_send_prev, recv_prev, recv_next, config
)
if recv_prev:
if config.pipeline_dtype is None:
raise RuntimeError("pipeline_dtype must be provided if recv_prev is True")
if tensor_shape is None:
raise RuntimeError(
"tensor_shape must be specified if recv_prev is True. "
"Common tensor_shape is (seq_length, micro_batch_size, hidden_size)"
)
tensor_recv_prev = torch.empty(
recv_prev_shape,
requires_grad=True,
device=torch.cuda.current_device(),
dtype=config.pipeline_dtype,
)
if recv_next:
if config.pipeline_dtype is None:
raise RuntimeError("dtype must be provided if recv_next is True")
if tensor_shape is None:
raise RuntimeError(
"tensor_shape must be specified if recv_next is True. "
"Common tensor_shape is (seq_length, micro_batch_size, hidden_size)"
)
tensor_recv_next = torch.empty(
recv_next_shape,
requires_grad=True,
device=torch.cuda.current_device(),
dtype=config.pipeline_dtype,
)
# Send tensors in both the forward and backward directions as appropriate.
if config.use_ring_exchange_p2p:
def _ring_exchange_wrapper(**kwargs):
torch.distributed.ring_exchange(**kwargs)
return []
p2p_func = _ring_exchange_wrapper
elif config.batch_p2p_comm:
assert wait_on_reqs
p2p_func = _batched_p2p_ops
else:
p2p_func = _p2p_ops
reqs = p2p_func(
tensor_send_prev=tensor_send_prev,
tensor_recv_prev=tensor_recv_prev,
tensor_send_next=tensor_send_next,
tensor_recv_next=tensor_recv_next,
group=get_pipeline_model_parallel_group(),
)
if wait_on_reqs and len(reqs) > 0:
for req in reqs:
req.wait()
reqs = None
if config.batch_p2p_comm and config.batch_p2p_sync:
# To protect against race condition when using batch_isend_irecv().
# User should assert that we have a modern enough PyTorch to not need this
torch.cuda.synchronize()
return tensor_recv_prev, tensor_recv_next, reqs
def recv_forward(tensor_shape: Shape, config: ModelParallelConfig) -> torch.Tensor:
""" Receive tensor from previous rank in pipeline (forward receive).
See _communicate for argument details.
"""
if core.parallel_state.is_pipeline_first_stage():
input_tensor = None
else:
if config.timers is not None:
config.timers('forward-recv', log_level=2).start()
input_tensor, _, _ = _communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=True,
recv_next=False,
tensor_shape=tensor_shape,
config=config,
)
if config.timers is not None:
config.timers('forward-recv').stop()
return input_tensor
def recv_backward(tensor_shape: Shape, config: ModelParallelConfig) -> torch.Tensor:
"""Receive tensor from next rank in pipeline (backward receive).
See _communicate for argument details.
"""
if core.parallel_state.is_pipeline_last_stage():
output_tensor_grad = None
else:
if config.timers is not None:
config.timers('backward-recv', log_level=2).start()
_, output_tensor_grad, _ = _communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=False,
recv_next=True,
tensor_shape=tensor_shape,
config=config,
)
if config.timers is not None:
config.timers('backward-recv').stop()
return output_tensor_grad
def send_forward(output_tensor: torch.Tensor, config: ModelParallelConfig) -> None:
"""Send tensor to next rank in pipeline (forward send).
See _communicate for argument details.
"""
if not core.parallel_state.is_pipeline_last_stage():
if config.timers is not None:
config.timers('forward-send', log_level=2).start()
_communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=False,
tensor_shape=None,
config=config,
)
if config.timers is not None:
config.timers('forward-send').stop()
def send_backward(input_tensor_grad: torch.Tensor, config: ModelParallelConfig) -> None:
"""Send tensor to previous rank in pipeline (backward send).
See _communicate for argument details.
"""
if not core.parallel_state.is_pipeline_first_stage():
if config.timers is not None:
config.timers('backward-send', log_level=2).start()
_communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=False,
tensor_shape=None,
config=config,
)
if config.timers is not None:
config.timers('backward-send').stop()
def send_forward_recv_backward(
output_tensor: torch.Tensor, tensor_shape: Shape, config: ModelParallelConfig
) -> torch.Tensor:
"""Batched send and recv with next rank in pipeline.
See _communicate for argument details.
"""
if core.parallel_state.is_pipeline_last_stage():
output_tensor_grad = None
else:
if config.timers is not None:
config.timers('forward-send-backward-recv', log_level=2).start()
_, output_tensor_grad, _ = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=True,
tensor_shape=tensor_shape,
config=config,
)
if config.timers is not None:
config.timers('forward-send-backward-recv').stop()
return output_tensor_grad
def send_backward_recv_forward(
input_tensor_grad: torch.Tensor, tensor_shape: Shape, config: ModelParallelConfig
) -> torch.Tensor:
"""Batched send and recv with previous rank in pipeline.
See _communicate for argument details.
"""
if core.parallel_state.is_pipeline_first_stage():
input_tensor = None
else:
if config.timers is not None:
config.timers('backward-send-forward-recv', log_level=2).start()
input_tensor, _, _ = _communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=True,
recv_next=False,
tensor_shape=tensor_shape,
config=config,
)
if config.timers is not None:
config.timers('backward-send-forward-recv').stop()
return input_tensor
def send_forward_recv_forward(
output_tensor: torch.Tensor,
recv_prev: bool,
tensor_shape: Shape,
config: ModelParallelConfig,
overlap_p2p_comm: bool = False,
) -> torch.Tensor:
"""Batched recv from previous rank and send to next rank in pipeline.
See _communicate for argument details.
"""
if config.timers is not None:
config.timers('forward-send-forward-recv', log_level=2).start()
input_tensor, _, wait_handles = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=recv_prev,
recv_next=False,
tensor_shape=tensor_shape,
wait_on_reqs=(not overlap_p2p_comm),
config=config,
)
if config.timers is not None:
config.timers('forward-send-forward-recv').stop()
if overlap_p2p_comm:
return input_tensor, wait_handles
return input_tensor
def send_backward_recv_backward(
input_tensor_grad: torch.Tensor,
recv_next: bool,
tensor_shape: Shape,
config: ModelParallelConfig,
overlap_p2p_comm: bool = False,
) -> torch.Tensor:
"""Batched recv from next rank and send to previous rank in pipeline.
See _communicate for argument details.
"""
if config.timers is not None:
config.timers('backward-send-backward-recv', log_level=2).start()
_, output_tensor_grad, wait_handles = _communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=recv_next,
tensor_shape=tensor_shape,
wait_on_reqs=(not overlap_p2p_comm),
config=config,
)
if config.timers is not None:
config.timers('backward-send-backward-recv').stop()
if overlap_p2p_comm:
return output_tensor_grad, wait_handles
return output_tensor_grad
def send_forward_backward_recv_forward_backward(
output_tensor: torch.Tensor,
input_tensor_grad: torch.Tensor,
recv_prev: bool,
recv_next: bool,
tensor_shape: Shape,
config: ModelParallelConfig,
) -> torch.Tensor:
"""Batched send and recv with previous and next ranks in pipeline.
See _communicate for argument details.
"""
if config.timers is not None:
config.timers('forward-backward-send-forward-backward-recv', log_level=2).start()
input_tensor, output_tensor_grad, _ = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next,
tensor_shape=tensor_shape,
config=config,
)
if config.timers is not None:
config.timers('forward-backward-send-forward-backward-recv').stop()
return input_tensor, output_tensor_grad
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import contextlib
from typing import Callable, Iterator, List, Optional, Union
import torch
from torch.autograd.variable import Variable
from megatron.core import parallel_state
from megatron.core.enums import ModelType
from megatron.core.pipeline_parallel import p2p_communication
from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler
from megatron.core.utils import get_attr_wrapped_model, get_model_config, get_model_type
# Types
Shape = Union[List[int], torch.Size]
def get_forward_backward_func():
"""Retrieves the appropriate forward_backward function given the
configuration of parallel_state.
Returns a function that will perform all of the forward and
backward passes of the model given the pipeline model parallel
world size and virtual pipeline model parallel world size in the
global parallel_state.
Note that if using sequence parallelism, the sequence length component of
the tensor shape is updated to original_sequence_length /
tensor_model_parallel_world_size.
The function returned takes the following arguments:
forward_step_func (required): A function that takes a data
iterator and a model as its arguments and return the model's
forward output and the loss function. The loss function should
take one torch.Tensor and return a torch.Tensor of loss and a
dictionary of string -> torch.Tensor.
A third argument, checkpoint_activations_microbatch, indicates
that the activations for this microbatch should be
checkpointed. A None value for this argument indicates that
the default from the configuration should be used. This is
used when the
num_microbatches_with_partial_activation_checkpoints is used.
For example:
def loss_func(loss_mask, output_tensor):
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
def forward_step(data_iterator, model):
data, loss_mask = next(data_iterator)
output = model(data)
return output, partial(loss_func, loss_mask)
forward_backward_func(forward_step_func=forward_step, ...)
data_iterator (required): an iterator over the data, will be
passed as is to forward_step_func. Expected to be a list of
iterators in the case of interleaved pipeline parallelism.
model (required): the actual model. Expected to be a list of modules in the case of interleaved
pipeline parallelism. Must be a (potentially wrapped) megatron.core.models.MegatronModule.
num_microbatches (int, required):
The number of microbatches to go through
seq_length (int, required): Sequence length of the current global batch. If this is a dual-stack
transformer, this is the encoder's sequence length. This is ignored if variable_seq_lengths
in the config is True. Otherwise, each microbatch in the current global batch size must use
this sequence length.
micro_batch_size (int, required): The number of sequences in a microbatch.
decoder_seq_length (int, optional): The sequence length for the decoder in a dual-stack
transformer. This is ignored for a single-stack transformer.
forward_only (optional, default = False): Perform only the forward step
collect_non_loss_data (optional, bool, default=False): TODO
first_val_step (bool, optional): Is the first step of the validation phase. Used by
Transformer Engine modules to only update their fp8 weights only on the first validation step.
"""
pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
if pipeline_model_parallel_size > 1:
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving
else:
forward_backward_func = forward_backward_pipelining_without_interleaving
else:
forward_backward_func = forward_backward_no_pipelining
return forward_backward_func
def deallocate_output_tensor(out, deallocate_pipeline_outputs=False):
'''Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
This method should be called right after the output tensor has been
sent to the next pipeline stage. At this point, the output tensor is
only useful for its '.grad_fn' field, and not its '.data'.
'''
if (out is None) or (not deallocate_pipeline_outputs):
return
assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__
assert out._base is None, "counter-productive to free a view of another tensor."
out.data = torch.empty((1,), device=out.device, dtype=out.dtype,)
def custom_backward(output, grad_output):
'''Directly call C++ autograd engine.
To make the 'deallocate_output_tensor' (above) optimization work, the C++
autograd engine must be called directly, bypassing Pytorch's
torch.autograd.backward. Pytorch's 'backward' checks that the output and
grad have the same shape, while C++'s 'backward' does not.
'''
assert output.numel() == 1, "output should be pseudo-'freed' in schedule, to optimize memory"
assert isinstance(output, torch.Tensor), "output == '%s'." % type(output).__name__
assert isinstance(grad_output, (torch.Tensor, type(None))), (
"grad_output == '%s'." % type(grad_output).__name__
)
# Handle scalar output
if grad_output is None:
assert output.numel() == 1, "implicit grad requires scalar output."
grad_output = torch.ones_like(output, memory_format=torch.preserve_format,)
# Call c++ engine [ see torch/csrc/autograd/python_engine.cpp ]
Variable._execution_engine.run_backward(
tensors=(output,),
grad_tensors=(grad_output,),
keep_graph=False,
create_graph=False,
inputs=tuple(),
allow_unreachable=True,
accumulate_grad=True,
)
def set_current_microbatch(model, microbatch_id):
decoder_exists = True
decoder = None
try:
decoder = get_attr_wrapped_model(model, "decoder")
except RuntimeError:
decoder_exists = False
if decoder_exists and decoder is not None:
decoder.current_microbatch = microbatch_id
def forward_step(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data=False,
checkpoint_activations_microbatch=None,
is_first_microbatch=False,
current_microbatch=None,
):
"""Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise
passed-in input_tensor is used.
Returns output tensor."""
if config.timers is not None:
config.timers('forward-compute', log_level=2).start()
if is_first_microbatch and hasattr(model, 'set_is_first_microbatch'):
model.set_is_first_microbatch()
if current_microbatch is not None:
set_current_microbatch(model, current_microbatch)
unwrap_output_tensor = False
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
unwrap_output_tensor = True
set_input_tensor = get_attr_wrapped_model(model, "set_input_tensor")
set_input_tensor(input_tensor)
if config.enable_autocast:
context_manager = torch.autocast("cuda", dtype=config.autocast_dtype)
else:
context_manager = contextlib.nullcontext()
with context_manager:
if checkpoint_activations_microbatch is None:
output_tensor, loss_func = forward_step_func(data_iterator, model)
else:
output_tensor, loss_func = forward_step_func(
data_iterator, model, checkpoint_activations_microbatch
)
num_tokens = torch.tensor(0, dtype=torch.int)
if parallel_state.is_pipeline_last_stage():
if not collect_non_loss_data:
outputs = loss_func(output_tensor)
if len(outputs) == 3:
output_tensor, num_tokens, loss_reduced = outputs
if not config.calculate_per_token_loss:
output_tensor /= num_tokens
output_tensor /= num_microbatches
else:
# preserve legacy loss averaging behavior (ie, over the number of microbatches)
assert len(outputs) == 2
output_tensor, loss_reduced = outputs
output_tensor /= num_microbatches
forward_data_store.append(loss_reduced)
else:
data = loss_func(output_tensor, non_loss_data=True)
forward_data_store.append(data)
if config.timers is not None:
config.timers('forward-compute').stop()
# Set the loss scale for the auxiliary loss of the MoE layer.
# Since we use a trick to do backward on the auxiliary loss, we need to set the scale explicitly.
if hasattr(config, 'num_moe_experts') and config.num_moe_experts is not None:
# Calculate the loss scale based on the grad_scale_func if available, else default to 1.
loss_scale = (
config.grad_scale_func(torch.ones(1, device=output_tensor.device))
if config.grad_scale_func is not None
else torch.tensor(1.0)
)
# Set the loss scale
MoEAuxLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)
# If T5 model (or other model with encoder and decoder)
# and in decoder stack, then send encoder_hidden_state
# downstream as well.
model_type = get_model_type(model)
if (
parallel_state.is_pipeline_stage_after_split()
and model_type == ModelType.encoder_and_decoder
):
return [output_tensor, input_tensor[-1]], num_tokens
if unwrap_output_tensor:
return output_tensor, num_tokens
return [output_tensor], num_tokens
def backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config):
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
with respect to stage's output tensor.
Returns gradient of loss with respect to input tensor (None if first
stage)."""
# NOTE: This code currently can handle at most one skip connection. It
# needs to be modified slightly to support arbitrary numbers of skip
# connections.
if config.timers is not None:
config.timers('backward-compute', log_level=2).start()
# Retain the grad on the input_tensor.
unwrap_input_tensor_grad = False
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
unwrap_input_tensor_grad = True
for x in input_tensor:
if x is not None:
x.retain_grad()
if not isinstance(output_tensor, list):
output_tensor = [output_tensor]
if not isinstance(output_tensor_grad, list):
output_tensor_grad = [output_tensor_grad]
# Backward pass.
if output_tensor_grad[0] is None and config.grad_scale_func is not None:
output_tensor[0] = config.grad_scale_func(output_tensor[0])
if config.deallocate_pipeline_outputs:
custom_backward(output_tensor[0], output_tensor_grad[0])
else:
torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0])
# Collect the grad of the input_tensor.
input_tensor_grad = [None]
if input_tensor is not None:
input_tensor_grad = []
for x in input_tensor:
if x is None:
input_tensor_grad.append(None)
else:
input_tensor_grad.append(x.grad)
# Handle single skip connection if it exists (encoder_hidden_state in
# model with encoder and decoder).
if (
parallel_state.get_pipeline_model_parallel_world_size() > 1
and parallel_state.is_pipeline_stage_after_split()
and model_type == ModelType.encoder_and_decoder
):
if output_tensor_grad[1] is not None:
input_tensor_grad[-1].add_(output_tensor_grad[1])
if unwrap_input_tensor_grad:
input_tensor_grad = input_tensor_grad[0]
if config.timers is not None:
config.timers('backward-compute').stop()
return input_tensor_grad
def check_first_val_step(first_val_step, forward_only, cond):
if (first_val_step is not None) and forward_only:
return first_val_step and cond
else:
return cond
def forward_backward_no_pipelining(
*,
forward_step_func,
data_iterator: Union[Iterator, List[Iterator]],
model: Union[torch.nn.Module, List[torch.nn.Module]],
num_microbatches: int,
seq_length: int, # unused
micro_batch_size: int, # unused
decoder_seq_length: int = None, # unused
forward_only: bool = False,
collect_non_loss_data: bool = False,
first_val_step: bool = None,
):
"""Run forward and backward passes with no pipeline parallelism
(no inter-stage communication).
Returns dictionary with losses.
See get_forward_backward_func() for argument details
"""
if isinstance(model, list):
assert len(model) == 1, "non-pipeline-parallel schedule does not support model chunking"
model = model[0]
if isinstance(data_iterator, list):
assert (
len(data_iterator) == 1
), "non-pipeline-parallel schedule does not support model chunking"
data_iterator = data_iterator[0]
config = get_model_config(model)
if config.timers is not None:
config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time)
no_sync_func = config.no_sync_func
if no_sync_func is None:
no_sync_func = contextlib.nullcontext
model_type = get_model_type(model)
forward_data_store = []
input_tensor, output_tensor_grad = None, None
total_num_tokens = torch.tensor(0, dtype=torch.int).cuda()
with no_sync_func():
for i in range(num_microbatches - 1):
output_tensor, num_tokens = forward_step(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
is_first_microbatch=check_first_val_step(first_val_step, forward_only, i == 0),
current_microbatch=i,
)
total_num_tokens += num_tokens.item()
if not forward_only:
backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config)
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
output_tensor, num_tokens = forward_step(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
is_first_microbatch=check_first_val_step(
first_val_step, forward_only, num_microbatches == 1
),
current_microbatch=num_microbatches - 1,
)
total_num_tokens += num_tokens.item()
if not forward_only:
backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config)
if config.finalize_model_grads_func is not None and not forward_only:
# Finalize model grads (perform full grad all-reduce / reduce-scatter for
# data parallelism and layernorm all-reduce for sequence parallelism).
config.finalize_model_grads_func(
[model], total_num_tokens if config.calculate_per_token_loss else None
)
if config.timers is not None:
config.timers('forward-backward').stop()
return forward_data_store
def forward_backward_pipelining_with_interleaving(
*,
forward_step_func,
data_iterator: Union[Iterator, List[Iterator]],
model: Union[torch.nn.Module, List[torch.nn.Module]],
num_microbatches: int,
seq_length: int,
micro_batch_size: int,
decoder_seq_length: int = None,
forward_only: bool = False,
collect_non_loss_data: bool = False,
first_val_step: bool = None,
):
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise."""
assert isinstance(model, list), "interleaved pipeline parallelism expected model chunking"
assert all(isinstance(chunk, torch.nn.Module) for chunk in model), "invalid model chunking"
assert isinstance(
data_iterator, list
), "interleaved pipeline parallelism expected each model chunk to have a data iterator"
config = get_model_config(model[0])
if config.overlap_p2p_comm and config.batch_p2p_comm:
raise ValueError("Can not use both overlap_p2p_comm and batch_p2p_comm")
if config.timers is not None:
config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time)
# Disable async grad reductions
no_sync_func = config.no_sync_func
if isinstance(no_sync_func, list):
def multi_no_sync():
stack = contextlib.ExitStack()
for model_chunk_no_sync_func in config.no_sync_func:
stack.enter_context(model_chunk_no_sync_func())
return stack
no_sync_func = multi_no_sync
if no_sync_func is None:
no_sync_func = contextlib.nullcontext
no_sync_context = None
if config.grad_sync_func is not None and not isinstance(config.grad_sync_func, list):
config.grad_sync_func = [config.grad_sync_func for _ in model]
if config.param_sync_func is not None and not isinstance(config.param_sync_func, list):
config.param_sync_func = [config.param_sync_func for _ in model]
def disable_grad_sync():
"""Disable asynchronous grad reductions"""
nonlocal no_sync_context
if no_sync_context is None:
no_sync_context = no_sync_func()
no_sync_context.__enter__()
def enable_grad_sync():
"""Enable asynchronous grad reductions"""
nonlocal no_sync_context
if no_sync_context is not None:
no_sync_context.__exit__(None, None, None)
no_sync_context = None
disable_grad_sync()
# Model chunk IDs with synchronized grads
synchronized_model_chunks = set()
input_tensors = [[] for _ in range(len(model))]
output_tensors = [[] for _ in range(len(model))]
total_num_tokens = torch.tensor(0, dtype=torch.int).cuda()
forward_data_store = []
if not forward_only:
output_tensor_grads = [[] for _ in range(len(model))]
pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size()
pipeline_parallel_rank = parallel_state.get_pipeline_model_parallel_rank()
if num_microbatches % pipeline_parallel_size != 0:
msg = f'number of microbatches ({num_microbatches}) is not divisible by '
msg += f'pipeline-model-parallel-size ({pipeline_parallel_size}) '
msg += 'when using interleaved schedule'
raise RuntimeError(msg)
model_type = get_model_type(model[0])
if model_type == ModelType.encoder_and_decoder:
raise RuntimeError("Interleaving is not supported with an encoder and decoder model.")
if decoder_seq_length is not None and decoder_seq_length != seq_length:
raise RuntimeError(
"Interleaving is not supported with a different decoder sequence length."
)
tensor_shape = [seq_length, micro_batch_size, config.hidden_size]
tensor_shape[0] = tensor_shape[0] // parallel_state.get_context_parallel_world_size()
if config.sequence_parallel:
tensor_shape[0] = tensor_shape[0] // parallel_state.get_tensor_model_parallel_world_size()
# Compute number of warmup and remaining microbatches.
num_model_chunks = len(model)
total_num_microbatches = num_microbatches * num_model_chunks
all_warmup_microbatches = False
if forward_only:
num_warmup_microbatches = total_num_microbatches
else:
# Run all forward passes and then all backward passes if number of
# microbatches is just the number of pipeline stages.
# Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
# all workers, followed by more microbatches after depending on
# stage ID (more forward passes for earlier stages, later stages can
# immediately start with 1F1B).
if num_microbatches == pipeline_parallel_size:
num_warmup_microbatches = total_num_microbatches
all_warmup_microbatches = True
else:
num_warmup_microbatches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size
num_warmup_microbatches = min(num_warmup_microbatches, total_num_microbatches)
num_microbatches_remaining = total_num_microbatches - num_warmup_microbatches
# Checkpoint the activations of partial Transformer layers in a number of micro-batches
# within the maximum outstanding micro-batch backpropagations.
# Micro-batches with the ids less than 'num_microbatches_with_partial_activation_checkpoints'
# checkpoint partial Transformer layers (or skip checkpointing) and
# the rest of micro-batches within a window of micro-batches checkpoint
# all Transformer layers. The window of micro-batches is set by the maximum
# outstanding backpropagations and becomes smaller at later pipeline stages.
# Please refer the appendix C in https://arxiv.org/pdf/2205.05198.pdf
max_outstanding_backprops = None
if config.num_microbatches_with_partial_activation_checkpoints is not None:
max_outstanding_backprops = num_warmup_microbatches + 1
# Synchronize params for first two model chunks
if config.param_sync_func is not None:
config.param_sync_func[0](model[0].parameters())
config.param_sync_func[1](model[1].parameters())
def get_model_chunk_id(microbatch_id, forward):
"""Helper method to get the model chunk ID given the iteration number."""
microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
if not forward:
model_chunk_id = num_model_chunks - model_chunk_id - 1
return model_chunk_id
def get_microbatch_id_in_model_chunk(iteration_id, forward):
"""Helper method to get the microbatch_id within model chunk given the iteration number."""
assert forward
iteration_group_id = iteration_id // (pipeline_parallel_size * num_model_chunks)
microbatch_id_in_model_chunk = (iteration_group_id * pipeline_parallel_size) + (
iteration_id % pipeline_parallel_size
)
return microbatch_id_in_model_chunk
def is_first_microbatch_for_model_chunk(microbatch_id: int) -> bool:
"""Check if an iteration is the first for a model chunk."""
microbatch_group_size = pipeline_parallel_size * num_model_chunks
num_microbatch_groups = total_num_microbatches // microbatch_group_size
microbatch_group_id = microbatch_id // microbatch_group_size
microbatch_id_in_group = microbatch_id % microbatch_group_size
if microbatch_group_id == 0:
return microbatch_id_in_group % pipeline_parallel_size == 0
else:
return False
def is_last_microbatch_for_model_chunk(microbatch_id: int) -> bool:
"""Check if an iteration is the last for a model chunk."""
microbatch_group_size = pipeline_parallel_size * num_model_chunks
num_microbatch_groups = total_num_microbatches // microbatch_group_size
microbatch_group_id = microbatch_id // microbatch_group_size
microbatch_id_in_group = microbatch_id % microbatch_group_size
if microbatch_group_id == num_microbatch_groups - 1:
return microbatch_id_in_group % pipeline_parallel_size == pipeline_parallel_size - 1
else:
return False
def forward_step_helper(microbatch_id, current_microbatch, checkpoint_activations_microbatch):
"""Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
# launch param synchronization for next model chunk
# Note: Asynchronous communication tends to slow down compute.
# To reduce idling from mismatched microbatch times, we launch
# asynchronous communication at the same time across the
# pipeline-parallel group.
if config.param_sync_func is not None:
param_sync_microbatch_id = microbatch_id + pipeline_parallel_rank
if (
param_sync_microbatch_id < total_num_microbatches
and is_first_microbatch_for_model_chunk(param_sync_microbatch_id)
):
param_sync_chunk_id = get_model_chunk_id(param_sync_microbatch_id, forward=True) + 1
if 1 < param_sync_chunk_id < num_model_chunks:
config.param_sync_func[param_sync_chunk_id](
model[param_sync_chunk_id].parameters()
)
# forward step
if parallel_state.is_pipeline_first_stage():
if len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id]):
input_tensors[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id][-1]
output_tensor, num_tokens = forward_step(
forward_step_func,
data_iterator[model_chunk_id],
model[model_chunk_id],
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
check_first_val_step(
first_val_step, forward_only, is_first_microbatch_for_model_chunk(microbatch_id),
),
current_microbatch=current_microbatch,
)
output_tensors[model_chunk_id].append(output_tensor)
nonlocal total_num_tokens
total_num_tokens += num_tokens.item()
# if forward-only, no need to save tensors for a backward pass
if forward_only:
input_tensors[model_chunk_id].pop()
output_tensors[model_chunk_id].pop()
return output_tensor
def backward_step_helper(microbatch_id):
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
# launch grad synchronization (default)
if config.grad_sync_func is None and is_last_microbatch_for_model_chunk(microbatch_id):
enable_grad_sync()
synchronized_model_chunks.add(model_chunk_id)
if parallel_state.is_pipeline_last_stage():
if len(output_tensor_grads[model_chunk_id]) == 0:
output_tensor_grads[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id].pop(0)
output_tensor = output_tensors[model_chunk_id].pop(0)
output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
input_tensor_grad = backward_step(
input_tensor, output_tensor, output_tensor_grad, model_type, config
)
# launch grad synchronization (custom grad sync)
# Note: Asynchronous communication tends to slow down compute.
# To reduce idling from mismatched microbatch times, we launch
# asynchronous communication at the same time across the
# pipeline-parallel group.
if config.grad_sync_func is not None:
grad_sync_microbatch_id = microbatch_id - pipeline_parallel_rank
if grad_sync_microbatch_id >= 0 and is_last_microbatch_for_model_chunk(
grad_sync_microbatch_id
):
grad_sync_chunk_id = get_model_chunk_id(grad_sync_microbatch_id, forward=False)
enable_grad_sync()
config.grad_sync_func[grad_sync_chunk_id](model[grad_sync_chunk_id].parameters())
synchronized_model_chunks.add(grad_sync_chunk_id)
disable_grad_sync()
return input_tensor_grad
# Run warmup forward passes.
parallel_state.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append(p2p_communication.recv_forward(tensor_shape, config))
fwd_wait_handles = None
bwd_wait_handles = None
for k in range(num_warmup_microbatches):
if fwd_wait_handles is not None:
for req in fwd_wait_handles:
req.wait()
cur_model_chunk_id = get_model_chunk_id(k, forward=True)
# Decide to checkpoint all layers' activations of the current micro-batch
if max_outstanding_backprops is not None:
checkpoint_activations_microbatch = (
k % max_outstanding_backprops
>= config.num_microbatches_with_partial_activation_checkpoints
)
else:
checkpoint_activations_microbatch = None
current_microbatch = get_microbatch_id_in_model_chunk(k, forward=True)
output_tensor = forward_step_helper(
k, current_microbatch, checkpoint_activations_microbatch
)
# Determine if tensor should be received from previous stage.
next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True)
recv_prev = True
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
if next_forward_model_chunk_id == 0:
recv_prev = False
if k == (total_num_microbatches - 1):
recv_prev = False
# Don't send tensor downstream if on last stage.
if parallel_state.is_pipeline_last_stage():
output_tensor = None
# Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration).
if not config.overlap_p2p_comm:
if (
k == (num_warmup_microbatches - 1)
and not forward_only
and not all_warmup_microbatches
):
input_tensor_grad = None
recv_next = True
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
recv_next = False
(
input_tensor,
output_tensor_grad,
) = p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor,
input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next,
tensor_shape=tensor_shape,
config=config,
)
output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
else:
input_tensor = p2p_communication.send_forward_recv_forward(
output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape, config=config
)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
else:
input_tensor, fwd_wait_handles = p2p_communication.send_forward_recv_forward(
output_tensor,
recv_prev=recv_prev,
tensor_shape=tensor_shape,
config=config,
overlap_p2p_comm=True,
)
if (
k == (num_warmup_microbatches - 1)
and not forward_only
and not all_warmup_microbatches
):
input_tensor_grad = None
recv_next = True
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
recv_next = False
(
output_tensor_grad,
bwd_wait_handles,
) = p2p_communication.send_backward_recv_backward(
input_tensor_grad,
recv_next=recv_next,
tensor_shape=tensor_shape,
config=config,
overlap_p2p_comm=True,
)
output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
# Run 1F1B in steady state.
for k in range(num_microbatches_remaining):
# Forward pass.
forward_k = k + num_warmup_microbatches
# Decide to checkpoint all layers' activations of the current micro-batch
if max_outstanding_backprops is not None:
checkpoint_activations_microbatch = (
forward_k % max_outstanding_backprops
>= config.num_microbatches_with_partial_activation_checkpoints
)
else:
checkpoint_activations_microbatch = None
cur_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
current_microbatch = get_microbatch_id_in_model_chunk(forward_k, forward=True)
if config.overlap_p2p_comm:
if fwd_wait_handles is not None:
for req in fwd_wait_handles:
req.wait()
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
output_tensor = forward_step_helper(
forward_k, current_microbatch, checkpoint_activations_microbatch
)
# Determine if current stage has anything to send in either direction,
# otherwise set tensor to None.
forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
# Last virtual stage no activation tensor to send
if parallel_state.is_pipeline_last_stage():
output_tensor = None
# Determine if peers are sending, and where in data structure to put
# received tensors.
recv_prev = True
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
# First stage is ahead of last stage by (pipeline_parallel_size - 1).
next_forward_model_chunk_id = get_model_chunk_id(
forward_k - (pipeline_parallel_size - 1), forward=True
)
if next_forward_model_chunk_id == (num_model_chunks - 1):
recv_prev = False
next_forward_model_chunk_id += 1
else:
next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True)
# If last iteration, don't receive; we already received one extra
# before the start of the for loop.
if k == (num_microbatches_remaining - 1):
recv_prev = False
# Send activation tensor to the next stage and receive activation tensor from the
# previous stage
input_tensor, fwd_wait_handles = p2p_communication.send_forward_recv_forward(
output_tensor,
recv_prev=recv_prev,
tensor_shape=tensor_shape,
config=config,
overlap_p2p_comm=True,
)
# assert fwd_wait_handles is not None
if bwd_wait_handles is not None:
for req in bwd_wait_handles:
req.wait()
# Backward pass.
backward_k = k
input_tensor_grad = backward_step_helper(backward_k)
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
# First virtual stage no activation gradient tensor to send
if parallel_state.is_pipeline_first_stage():
input_tensor_grad = None
# Determine if the current virtual stage has an activation gradient tensor to receive
recv_next = True
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
# Last stage is ahead of first stage by (pipeline_parallel_size - 1).
next_backward_model_chunk_id = get_model_chunk_id(
backward_k - (pipeline_parallel_size - 1), forward=False
)
if next_backward_model_chunk_id == 0:
recv_next = False
next_backward_model_chunk_id -= 1
else:
next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False)
output_tensor_grad, bwd_wait_handles = p2p_communication.send_backward_recv_backward(
input_tensor_grad,
recv_next=recv_next,
tensor_shape=tensor_shape,
config=config,
overlap_p2p_comm=True,
)
else: # no p2p overlap
output_tensor = forward_step_helper(
forward_k, current_microbatch, checkpoint_activations_microbatch
)
# Backward pass.
backward_k = k
input_tensor_grad = backward_step_helper(backward_k)
# Send output_tensor and input_tensor_grad, receive input_tensor
# and output_tensor_grad.
# Determine if current stage has anything to send in either direction,
# otherwise set tensor to None.
forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id)
if parallel_state.is_pipeline_last_stage():
output_tensor = None
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id)
if parallel_state.is_pipeline_first_stage():
input_tensor_grad = None
# Determine if peers are sending, and where in data structure to put
# received tensors.
recv_prev = True
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
# First stage is ahead of last stage by (pipeline_parallel_size - 1).
next_forward_model_chunk_id = get_model_chunk_id(
forward_k - (pipeline_parallel_size - 1), forward=True
)
if next_forward_model_chunk_id == (num_model_chunks - 1):
recv_prev = False
next_forward_model_chunk_id += 1
else:
next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True)
recv_next = True
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
# Last stage is ahead of first stage by (pipeline_parallel_size - 1).
next_backward_model_chunk_id = get_model_chunk_id(
backward_k - (pipeline_parallel_size - 1), forward=False
)
if next_backward_model_chunk_id == 0:
recv_next = False
next_backward_model_chunk_id -= 1
else:
next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False)
# If last iteration, don't receive; we already received one extra
# before the start of the for loop.
if k == (num_microbatches_remaining - 1):
recv_prev = False
# Communicate tensors.
(
input_tensor,
output_tensor_grad,
) = p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor,
input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next,
tensor_shape=tensor_shape,
config=config,
)
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
if recv_prev:
input_tensors[next_forward_model_chunk_id].append(input_tensor)
if recv_next:
output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad)
deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs)
# Run cooldown backward passes (flush out pipeline).
if not forward_only:
if config.overlap_p2p_comm and bwd_wait_handles is not None:
for wait_handle in bwd_wait_handles:
wait_handle.wait()
if all_warmup_microbatches:
output_tensor_grads[num_model_chunks - 1].append(
p2p_communication.recv_backward(tensor_shape, config=config)
)
for k in range(num_microbatches_remaining, total_num_microbatches):
input_tensor_grad = backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False)
recv_next = True
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
if next_backward_model_chunk_id == (num_model_chunks - 1):
recv_next = False
if k == (total_num_microbatches - 1):
recv_next = False
output_tensor_grads[next_backward_model_chunk_id].append(
p2p_communication.send_backward_recv_backward(
input_tensor_grad, recv_next=recv_next, tensor_shape=tensor_shape, config=config
)
)
# Launch any remaining grad reductions.
enable_grad_sync()
if config.grad_sync_func is not None:
for model_chunk_id in range(num_model_chunks):
if model_chunk_id not in synchronized_model_chunks:
config.grad_sync_func[model_chunk_id](model[model_chunk_id].parameters())
synchronized_model_chunks.add(model_chunk_id)
if config.finalize_model_grads_func is not None and not forward_only:
# Finalize model grads (perform full grad all-reduce / reduce-scatter for
# data parallelism, layernorm all-reduce for sequence parallelism, and
# embedding all-reduce for pipeline parallelism).
config.finalize_model_grads_func(
model, total_num_tokens if config.calculate_per_token_loss else None
)
if config.timers is not None:
config.timers('forward-backward').stop()
return forward_data_store
def get_tensor_shapes(
*,
rank: int,
model_type: ModelType,
seq_length: int,
micro_batch_size: int,
decoder_seq_length: int,
config,
):
# Determine right tensor sizes (based on position of rank with respect to split
# rank) and model size.
# Send two tensors if model is T5 and rank is in decoder stage:
# first tensor is decoder (pre-transpose),
# second tensor is encoder (post-transpose).
# If model is T5 and rank is at the boundary:
# send one tensor (post-transpose from encoder).
# Otherwise, send one tensor (pre-transpose).
tensor_shapes = []
seq_length = seq_length // parallel_state.get_context_parallel_world_size()
if model_type == ModelType.encoder_and_decoder:
decoder_seq_length = decoder_seq_length // parallel_state.get_context_parallel_world_size()
if config.sequence_parallel:
seq_length = seq_length // parallel_state.get_tensor_model_parallel_world_size()
if model_type == ModelType.encoder_and_decoder:
decoder_seq_length = (
decoder_seq_length // parallel_state.get_tensor_model_parallel_world_size()
)
if model_type == ModelType.encoder_and_decoder:
if parallel_state.is_pipeline_stage_before_split(rank):
tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size))
else:
tensor_shapes.append((decoder_seq_length, micro_batch_size, config.hidden_size))
tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size))
else:
tensor_shapes.append((seq_length, micro_batch_size, config.hidden_size))
return tensor_shapes
def recv_forward(tensor_shapes, config):
input_tensors = []
for tensor_shape in tensor_shapes:
if tensor_shape is None:
input_tensors.append(None)
else:
input_tensors.append(p2p_communication.recv_forward(tensor_shape, config))
return input_tensors
def recv_backward(tensor_shapes, config):
output_tensor_grads = []
for tensor_shape in tensor_shapes:
if tensor_shape is None:
output_tensor_grads.append(None)
else:
output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape, config))
return output_tensor_grads
def send_forward(output_tensors, tensor_shapes, config):
if not isinstance(output_tensors, list):
output_tensors = [output_tensors]
for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
if tensor_shape is None:
continue
p2p_communication.send_forward(output_tensor, config)
def send_backward(input_tensor_grads, tensor_shapes, config):
if not isinstance(input_tensor_grads, list):
input_tensor_grads = [input_tensor_grads]
for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
if tensor_shape is None:
continue
p2p_communication.send_backward(input_tensor_grad, config)
def send_forward_recv_backward(output_tensors, tensor_shapes, config):
if not isinstance(output_tensors, list):
output_tensors = [output_tensors]
output_tensor_grads = []
for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
if tensor_shape is None:
output_tensor_grads.append(None)
continue
output_tensor_grad = p2p_communication.send_forward_recv_backward(
output_tensor, tensor_shape, config
)
output_tensor_grads.append(output_tensor_grad)
return output_tensor_grads
def send_backward_recv_forward(input_tensor_grads, tensor_shapes, config):
if not isinstance(input_tensor_grads, list):
input_tensor_grads = [input_tensor_grads]
input_tensors = []
for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
if tensor_shape is None:
input_tensors.append(None)
continue
input_tensor = p2p_communication.send_backward_recv_forward(
input_tensor_grad, tensor_shape, config
)
input_tensors.append(input_tensor)
return input_tensors
def forward_backward_pipelining_without_interleaving(
*,
forward_step_func,
data_iterator: Union[Iterator, List[Iterator]],
model: Union[torch.nn.Module, List[torch.nn.Module]],
num_microbatches: int,
seq_length: int,
micro_batch_size: int,
decoder_seq_length: int = None,
forward_only: bool = False,
collect_non_loss_data: bool = False,
first_val_step: bool = None,
):
"""Run non-interleaved 1F1B schedule, with communication between pipeline
stages.
Returns dictionary with losses if the last stage, empty dict otherwise."""
if isinstance(model, list):
assert (
len(model) == 1
), "non-interleaved pipeline parallelism does not support model chunking"
model = model[0]
if isinstance(data_iterator, list):
assert (
len(data_iterator) == 1
), "non-pipeline-parallel schedule does not support model chunking"
data_iterator = data_iterator[0]
config = get_model_config(model)
if config.overlap_p2p_comm:
raise ValueError(
"Non-interleaved pipeline parallelism does not support overlapping p2p communication"
)
if config.timers is not None:
config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time)
# Disable async grad reductions
no_sync_func = config.no_sync_func
if no_sync_func is None:
no_sync_func = contextlib.nullcontext
no_sync_context = None
def disable_grad_sync():
"""Disable asynchronous grad reductions"""
nonlocal no_sync_context
if no_sync_context is None:
no_sync_context = no_sync_func()
no_sync_context.__enter__()
def enable_grad_sync():
"""Enable asynchronous grad reductions"""
nonlocal no_sync_context
if no_sync_context is not None:
no_sync_context.__exit__(None, None, None)
no_sync_context = None
disable_grad_sync()
# Compute number of warmup microbatches.
num_warmup_microbatches = (
parallel_state.get_pipeline_model_parallel_world_size()
- parallel_state.get_pipeline_model_parallel_rank()
- 1
)
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
num_microbatches_remaining = num_microbatches - num_warmup_microbatches
# Checkpoint the activations of partial Transformer layers in a number of micro-batches
# within the maximum outstanding micro-batch backpropagations.
# Micro-batches with the ids less than 'num_microbatches_with_partial_activation_checkpoints'
# checkpoint partial Transformer layers (or skip checkpointing) and
# the rest of micro-batches within a window of micro-batches checkpoint
# all Transformer layers. The window of micro-batches is set by the maximum
# outstanding backpropagations and becomes smaller at later pipeline stages.
# Please refer the appendix C in https://arxiv.org/pdf/2205.05198.pdf
max_outstanding_backprops = None
if config.num_microbatches_with_partial_activation_checkpoints is not None:
max_outstanding_backprops = num_warmup_microbatches + 1
model_type = get_model_type(model)
rank = parallel_state.get_pipeline_model_parallel_rank()
recv_tensor_shapes = get_tensor_shapes(
rank=rank - 1,
model_type=model_type,
seq_length=seq_length,
micro_batch_size=micro_batch_size,
decoder_seq_length=decoder_seq_length,
config=config,
)
send_tensor_shapes = get_tensor_shapes(
rank=rank,
model_type=model_type,
seq_length=seq_length,
micro_batch_size=micro_batch_size,
decoder_seq_length=decoder_seq_length,
config=config,
)
# Input, output tensors only need to be saved when doing backward passes
input_tensors = None
output_tensors = None
total_num_tokens = torch.tensor(0, dtype=torch.int).cuda()
if not forward_only:
input_tensors = []
output_tensors = []
forward_data_store = []
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
# Decide to checkpoint all layers' activations of the current micro-batch
if max_outstanding_backprops is not None:
checkpoint_activations_microbatch = (
i % max_outstanding_backprops
>= config.num_microbatches_with_partial_activation_checkpoints
)
else:
checkpoint_activations_microbatch = None
input_tensor = recv_forward(recv_tensor_shapes, config)
output_tensor, num_tokens = forward_step(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
check_first_val_step(first_val_step, forward_only, i == 0),
current_microbatch=i,
)
send_forward(output_tensor, send_tensor_shapes, config)
total_num_tokens += num_tokens.item()
if not forward_only:
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
deallocate_output_tensor(output_tensor[0], config.deallocate_pipeline_outputs)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if num_microbatches_remaining > 0:
input_tensor = recv_forward(recv_tensor_shapes, config)
# Run 1F1B in steady state.
for i in range(num_microbatches_remaining):
last_iteration = i == (num_microbatches_remaining - 1)
# Decide to checkpoint all layers' activations of the current micro-batch
if max_outstanding_backprops is not None:
checkpoint_activations_microbatch = (
(i + num_warmup_microbatches) % max_outstanding_backprops
) >= config.num_microbatches_with_partial_activation_checkpoints
else:
checkpoint_activations_microbatch = None
output_tensor, num_tokens = forward_step(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
collect_non_loss_data,
checkpoint_activations_microbatch,
check_first_val_step(
first_val_step, forward_only, (i == 0) and (num_warmup_microbatches == 0)
),
current_microbatch=i + num_warmup_microbatches,
)
total_num_tokens += num_tokens.item()
if forward_only:
send_forward(output_tensor, send_tensor_shapes, config)
if not last_iteration:
input_tensor = recv_forward(recv_tensor_shapes, config)
else:
output_tensor_grad = send_forward_recv_backward(
output_tensor, send_tensor_shapes, config
)
# Add input_tensor and output_tensor to end of list.
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
deallocate_output_tensor(output_tensor[0], config.deallocate_pipeline_outputs)
# Pop input_tensor and output_tensor from the start of the list for
# the backward pass.
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
# Enable grad sync for the last microbatch in the batch if the full
# backward pass completes in the 1F1B stage.
if num_warmup_microbatches == 0 and last_iteration:
if config.grad_sync_func is None or rank == 0:
enable_grad_sync()
input_tensor_grad = backward_step(
input_tensor, output_tensor, output_tensor_grad, model_type, config
)
if last_iteration:
input_tensor = None
send_backward(input_tensor_grad, recv_tensor_shapes, config)
else:
input_tensor = send_backward_recv_forward(
input_tensor_grad, recv_tensor_shapes, config
)
# Run cooldown backward passes.
if not forward_only:
for i in range(num_warmup_microbatches):
# Enable async grad reduction in the last backward pass
# Note: If grad sync function is provided, only enable
# async grad reduction in first pipeline stage. Other
# pipeline stages do grad reduction during pipeline
# bubble.
if i == num_warmup_microbatches - 1:
if config.grad_sync_func is None or rank == 0:
enable_grad_sync()
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
output_tensor_grad = recv_backward(send_tensor_shapes, config)
input_tensor_grad = backward_step(
input_tensor, output_tensor, output_tensor_grad, model_type, config
)
send_backward(input_tensor_grad, recv_tensor_shapes, config)
# Launch any remaining grad reductions.
if no_sync_context is not None:
enable_grad_sync()
if config.grad_sync_func is not None:
config.grad_sync_func(model.parameters())
if config.finalize_model_grads_func is not None and not forward_only:
# Finalize model grads (perform full grad all-reduce / reduce-scatter for
# data parallelism, layernorm all-reduce for sequence parallelism, and
# embedding all-reduce for pipeline parallelism).
config.finalize_model_grads_func(
[model], total_num_tokens if config.calculate_per_token_loss else None
)
if config.timers is not None:
config.timers('forward-backward').stop()
return forward_data_store
torch
\ No newline at end of file
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Some of this code was adopted from https://github.com/state-spaces/mamba/
# This source code is licensed under the Apache license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass
from functools import partial
from typing import Union
from torch import Tensor, nn
from megatron.core import parallel_state
from megatron.core.ssm.mamba_hybrid_layer_allocation import Symbols as LayerSymbols
from megatron.core.ssm.mamba_hybrid_layer_allocation import allocate_layers
from megatron.core.tensor_parallel import get_cuda_rng_tracker
from megatron.core.transformer.custom_layers.transformer_engine import TENorm
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import make_viewless_tensor
def create_mamba_block(
config, mamba_layer_spec, residual_in_fp32=False, layer_idx=None,
):
block = build_module(
mamba_layer_spec, config, residual_in_fp32=residual_in_fp32, layer_idx=layer_idx,
)
block.layer_idx = layer_idx
return block
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def _init_weights(
module,
n_layer,
initializer_range=0.02, # Now only used for embedding layer.
rescale_prenorm_residual=True,
n_residuals_per_layer=1, # Change to 2 if we have MLP
):
with get_cuda_rng_tracker().fork():
if isinstance(module, nn.Linear):
if not getattr(module.weight, "_no_reinit", False):
nn.init.normal_(module.weight, std=initializer_range)
if module.bias is not None:
if not getattr(module.bias, "_no_reinit", False):
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=initializer_range)
for name, p in module.named_parameters():
if name in ["in_proj.weight", "x_proj.weight", "conv1d.weight", "out_proj.weight"]:
nn.init.kaiming_uniform(p, a=math.sqrt(5))
if rescale_prenorm_residual:
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for name, p in module.named_parameters():
if name in ["out_proj.weight", "fc2.weight"]:
# Special Scaled Initialization
nn.init.normal_(
p,
mean=0.0,
std=initializer_range / math.sqrt(n_residuals_per_layer * n_layer),
)
@dataclass
class MambaStackSubmodules:
mamba_layer: Union[ModuleSpec, type] = IdentityOp
attention_layer: Union[ModuleSpec, type] = IdentityOp
mlp_layer: Union[ModuleSpec, type] = IdentityOp
class MambaStack(MegatronModule):
def __init__(
self,
config: TransformerConfig,
submodules: MambaStackSubmodules,
residual_in_fp32=False,
pre_process: bool = True,
hybrid_attention_ratio: float = 0.0,
hybrid_mlp_ratio: float = 0.0,
hybrid_override_pattern: str = None,
post_layer_norm: bool = True,
post_process: bool = True,
device=None,
dtype=None,
) -> None:
super().__init__(config=config)
self.residual_in_fp32 = residual_in_fp32
self.pre_process = pre_process
self.post_layer_norm = post_layer_norm
self.post_process = post_process
# Required for pipeline parallel schedules
self.input_tensor = None
self.hybrid_attention_ratio = hybrid_attention_ratio
self.hybrid_mlp_ratio = hybrid_mlp_ratio
self.hybrid_override_pattern = hybrid_override_pattern
layer_type_list = allocate_layers(
self.config.num_layers,
self.hybrid_attention_ratio,
self.hybrid_mlp_ratio,
self.hybrid_override_pattern,
)
pp_layer_offset = 0
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
pp_layer_offset, layer_type_list = self._select_layers_for_pipeline_parallel(
layer_type_list
)
self.layers = nn.ModuleList()
for i, layer_type in enumerate(layer_type_list):
if layer_type == LayerSymbols.MAMBA:
layer_idx = i + pp_layer_offset
block = create_mamba_block(
self.config,
submodules.mamba_layer,
residual_in_fp32=residual_in_fp32,
layer_idx=layer_idx,
)
elif layer_type == LayerSymbols.ATTENTION:
# Wondering if layer_number should be i+1. See TransformerBlock
# and TransformerLayer::sharded_state_dict
# Also, transformer layers apply their own pp_layer_offset
block = build_module(submodules.attention_layer, config=self.config, layer_number=i)
elif layer_type == LayerSymbols.MLP:
# Wondering if layer_number should be i+1. See TransformerBlock
# and TransformerLayer::sharded_state_dict
# Also, transformer layers apply their own pp_layer_offset
block = build_module(submodules.mlp_layer, config=self.config, layer_number=i)
else:
assert True, "unexpected layer_type"
self.layers.append(block)
# Required for activation recomputation
self.num_layers_per_pipeline_rank = len(self.layers)
if self.post_process and self.post_layer_norm:
# Final layer norm before output.
self.final_norm = TENorm(
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
self.apply(partial(_init_weights, n_layer=self.config.num_layers,))
def _select_layers_for_pipeline_parallel(self, layer_type_list):
pipeline_rank = parallel_state.get_pipeline_model_parallel_rank()
num_layers_per_pipeline_rank = (
self.config.num_layers // parallel_state.get_pipeline_model_parallel_world_size()
)
assert parallel_state.get_virtual_pipeline_model_parallel_world_size() is None, (
"The Mamba hybrid model does not currently support "
"virtual/interleaved pipeline parallelism"
)
offset = pipeline_rank * num_layers_per_pipeline_rank
selected_list = layer_type_list[offset : offset + num_layers_per_pipeline_rank]
return offset, selected_list
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
return {
i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype)
for i, layer in enumerate(self.layers)
}
def set_input_tensor(self, input_tensor: Tensor):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self.input_tensor = input_tensor
def forward(
self,
hidden_states: Tensor,
attention_mask: Tensor,
inference_params=None,
rotary_pos_emb: Tensor = None,
):
if not self.pre_process:
# See set_input_tensor()
hidden_states = self.input_tensor
if inference_params:
# NOTE(bnorick): match InferenceParams attributes for mamba_ssm.utils.generation.InferenceParams,
# this hack supports eval
inference_params.max_seqlen = inference_params.max_sequence_length
inference_params.seqlen_offset = inference_params.sequence_len_offset
for layer in self.layers:
hidden_states = layer(
hidden_states,
attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
)
# The attention layer (currently a simplified transformer layer)
# outputs a tuple of (hidden_states, context). Context is intended
# for cross-attention, and is not needed in our model.
if isinstance(hidden_states, tuple):
hidden_states = hidden_states[0]
# Final layer norm.
if self.post_process and self.post_layer_norm:
hidden_states = self.final_norm(hidden_states)
# Ensure that the tensor passed between pipeline parallel stages is
# viewless. See related notes in TransformerBlock and TransformerLayer
output = make_viewless_tensor(
inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True
)
return hidden_states
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import logging
if __name__ != "__main__":
from megatron.core.utils import log_single_rank
else:
from typing import Any
def log_single_rank(logger: logging.Logger, *args: Any, rank: int = 0, **kwargs: Any):
print(*args[1:], **kwargs)
logger = logging.getLogger(__name__)
class Symbols:
MAMBA = 'M'
ATTENTION = '*'
MLP = '-'
VALID = {MAMBA, ATTENTION, MLP}
def _allocate_auto(
total_layers_count: int, target_attention_ratio: float, target_mlp_ratio: float
) -> list:
# First, allocate attention (evenly spaced, starting and ending with mamba)
attention_layers_count: int = round(total_layers_count * target_attention_ratio)
mamba_layers_count: int = total_layers_count - attention_layers_count
mamba_sections_count: int = attention_layers_count + 1
mamba_section_length: float = mamba_layers_count / mamba_sections_count
layer_type_list = [Symbols.MAMBA] * total_layers_count
x: float = mamba_section_length
for l in range(total_layers_count):
if x < 0.5:
layer_type_list[l] = Symbols.ATTENTION
x += mamba_section_length
else:
x -= 1
# Next, allocate mlp
# (evenly distributed, but right-justified, not replacing attention)
mlp_layers_count: int = round(total_layers_count * target_mlp_ratio)
if mlp_layers_count > 0:
mamba_layers_count -= mlp_layers_count
mamba_to_mlp_ratio: float = mamba_layers_count / mlp_layers_count
x: float = mamba_to_mlp_ratio
for l in range(total_layers_count):
if layer_type_list[l] == Symbols.MAMBA:
if x < 0.5:
layer_type_list[l] = Symbols.MLP
x += mamba_to_mlp_ratio
else:
x -= 1
return layer_type_list
def _allocate_override(total_layers_count: int, override_pattern: str) -> list:
layer_type_list = list(override_pattern)
override_pattern_length = len(layer_type_list)
if override_pattern_length != total_layers_count:
raise ValueError(
"The hybrid override pattern is the wrong "
f"length: got {override_pattern_length}, expected "
f"{total_layers_count}"
)
for l in layer_type_list:
if l not in Symbols.VALID:
raise ValueError(f"In hybrid override pattern, '{l}' is not " f"one of {Symbols.VALID}")
return layer_type_list
def _layer_counts_match(a: list, b: list) -> bool:
for s in Symbols.VALID:
if a.count(s) != b.count(s):
return False
return True
def allocate_layers(
total_layers_count: int,
target_attention_ratio: float,
target_mlp_ratio: float,
override_pattern: str = None,
) -> list:
assert total_layers_count > 0
assert target_attention_ratio >= 0.0 and target_attention_ratio <= 1.0
assert target_mlp_ratio >= 0.0 and target_mlp_ratio <= 1.0
assert target_attention_ratio + target_mlp_ratio <= 1.0
# Note: target_mamba_ratio = 1.0 - target_attention_ratio - target_mlp_ratio
layer_type_list = _allocate_auto(total_layers_count, target_attention_ratio, target_mlp_ratio)
if override_pattern is not None:
layer_type_list_override = _allocate_override(total_layers_count, override_pattern)
log_single_rank(logger, logging.INFO, "Using hybrid override pattern")
if (target_attention_ratio > 0.0 or target_mlp_ratio > 0.0) and not _layer_counts_match(
layer_type_list_override, layer_type_list
):
raise ValueError(
"The number of each type of layer in the override "
"pattern must match the number in the overridden "
"pattern."
)
if layer_type_list_override == layer_type_list:
log_single_rank(
logger, logging.INFO, "The override pattern matches the overridden pattern"
)
else:
log_single_rank(logger, logging.INFO, "Warning: overriding pattern A with pattern B")
log_single_rank(logger, logging.INFO, f"A: {''.join(layer_type_list)}")
log_single_rank(logger, logging.INFO, f"B: {''.join(layer_type_list_override)}")
layer_type_list = layer_type_list_override
if target_attention_ratio > 0.0 or target_mlp_ratio > 0.0 or override_pattern is not None:
actual_attention_layers_count = layer_type_list.count(Symbols.ATTENTION)
actual_attention_ratio = actual_attention_layers_count / total_layers_count
actual_mlp_layers_count = layer_type_list.count(Symbols.MLP)
actual_mlp_ratio = actual_mlp_layers_count / total_layers_count
allocation_string = ''.join(layer_type_list)
log_single_rank(
logger,
logging.INFO,
f"Hybrid allocation ({Symbols.MAMBA} is mamba, "
f"{Symbols.ATTENTION} is attention, "
f"{Symbols.MLP} is mlp):",
)
log_single_rank(logger, logging.INFO, allocation_string)
log_single_rank(
logger,
logging.INFO,
f"{actual_attention_layers_count} attention layers in "
f"{total_layers_count} total layers.",
)
log_single_rank(
logger,
logging.INFO,
f"Target attention ratio: {target_attention_ratio:.2f}. "
f"Actual attention ratio: {actual_attention_ratio:.2f}.",
)
log_single_rank(
logger,
logging.INFO,
f"{actual_mlp_layers_count} mlp layers in " f"{total_layers_count} total layers.",
)
log_single_rank(
logger,
logging.INFO,
f"Target mlp ratio: {target_mlp_ratio:.2f}. "
f"Actual mlp ratio: {actual_mlp_ratio:.2f}.",
)
return layer_type_list
if __name__ == "__main__":
test_cases = [
# (10, 0.2, 0.0),
# (48, 0.0, 0.0), # will not print anything
# (48, 0.1, 0.0),
# 48, 0.3, 0.0),
# (48, 0.5, 0.0),
# (48, 0.6, 0.0),
# (48, 0.7, 0.0),
# (10, 0.0, 0.1),
# (10, 0.0, 0.3),
# (10, 0.0, 0.5),
# (10, 0.1, 0.1),
# (10, 0.2, 0.2),
# (10, 0.3, 0.3),
# (10, 0.5, 0.5),
# (48, 0.2, 0.3),
# (48, 0.5, 0.2),
# (48, 0.5, 0.2, "MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-"),
# (48, 0.25, 0.25, "MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-"),
# (48, 0.25, 0.25, "MM-*MM-*MM*-MM*-MM*-MM*-M*M-M*M-M*M-M*M-*MM-*MM-"),
# (48, 0.0, 0.2, "MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-"),
# (48, 0.2, 0.0, "MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-"),
# (48, 0.0, 0.0, "MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-"),
# (48, 0.5, 0.5),
# (10, 0.3, 0.2, "MMM*-*M*M-"),
# (10, 0.3, 0.2, "MM*M-*M*M-"),
(9, 0.0, 0.0, "M*-M*-M*-"),
(9, 0.0, 0.0, "MMMMMMMMM"),
]
for t in test_cases:
print("")
allocate_layers(*t)
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Some of this code was adopted from https://github.com/state-spaces/mamba/
# This source code is licensed under the Apache license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Union
import torch
from torch import Tensor
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_config import TransformerConfig
@dataclass
class MambaLayerSubmodules:
norm: Union[ModuleSpec, type] = IdentityOp
mixer: Union[ModuleSpec, type] = IdentityOp
class MambaLayer(MegatronModule):
def __init__(
self,
config: TransformerConfig,
submodules: MambaLayerSubmodules,
layer_idx=None,
residual_in_fp32=False,
):
"""
Top level Mamba Layer
"""
super().__init__(config)
self.config = config
self.residual_in_fp32 = residual_in_fp32
self.mixer = build_module(
submodules.mixer, self.config, self.config.hidden_size, layer_idx=layer_idx,
)
self.norm = build_module(submodules.norm, self.config, self.config.hidden_size)
def forward(
self,
hidden_states: Tensor,
attention_mask: Tensor, # Not used in MambaLayer
inference_params=None,
rotary_pos_emb: Tensor = None, # Not used in MambaLayer
):
residual = hidden_states
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)
hidden_states = self.mixer(hidden_states, inference_params=inference_params)
return hidden_states + residual
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype)
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Some of this code was adopted from https://github.com/state-spaces/mamba/
# This source code is licensed under the Apache license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from megatron.core.parallel_state import get_tensor_model_parallel_world_size
from megatron.core.tensor_parallel import (
ColumnParallelLinear,
RowParallelLinear,
copy_to_tensor_model_parallel_region,
gather_from_sequence_parallel_region,
get_cuda_rng_tracker,
reduce_from_tensor_model_parallel_region,
reduce_scatter_to_sequence_parallel_region,
)
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
try:
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
except ImportError:
selective_state_update = None
try:
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
except ImportError:
causal_conv1d_fn = None
causal_conv1d_update = None
try:
from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
except ImportError:
raise ImportError("mamba-ssm is required by the Mamba model but cannot be imported")
try:
from einops import rearrange, repeat
except ImportError:
raise ImportError("einops is required by the Mamba model but cannot be imported")
class Mamba(MegatronModule):
def __init__(
self,
config: TransformerConfig,
d_model,
d_state=128,
d_conv=4,
conv_init=None,
expand=2,
headdim=64,
ngroups=8,
A_init_range=(1, 16),
D_has_hdim=False,
rmsnorm=True,
norm_before_gate=False,
dt_min=0.001,
dt_max=0.1,
dt_init="random",
dt_scale=1.0,
dt_init_floor=1e-4,
bias=False,
conv_bias=True,
# Fused kernel and sharding options
chunk_size=128,
use_fast_path=True,
layer_idx=None,
):
super().__init__(config)
self.config = config
self.d_model = d_model
self.d_state = d_state
self.d_conv = d_conv
self.conv_init = conv_init
self.expand = expand
self.d_inner = int(self.expand * self.d_model)
self.headdim = headdim
self.ngroups = ngroups
assert self.d_inner % self.headdim == 0
self.nheads = self.d_inner // self.headdim
self.D_has_hdim = D_has_hdim
self.rmsnorm = rmsnorm
self.norm_before_gate = norm_before_gate
self.chunk_size = chunk_size
self.use_fast_path = use_fast_path
self.layer_idx = layer_idx
self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
assert self.d_inner % self.tensor_model_parallel_size == 0
assert self.ngroups % self.tensor_model_parallel_size == 0
assert self.nheads % self.tensor_model_parallel_size == 0
assert not bias
self.d_inner_local = self.d_inner // self.tensor_model_parallel_size
self.ngroups_local = self.ngroups // self.tensor_model_parallel_size
self.nheads_local = self.nheads // self.tensor_model_parallel_size
assert self.d_inner_local % self.ngroups_local == 0
# Assume sequence parallelism: input is already partitioned along the
# sequence dimension
self.in_proj = ColumnParallelLinear(
self.d_model,
self.d_inner * 2 + 2 * self.ngroups * self.d_state + self.nheads,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=bias,
)
conv_dim = self.d_inner_local + 2 * self.ngroups_local * self.d_state
with get_cuda_rng_tracker().fork():
self.conv1d = nn.Conv1d(
in_channels=conv_dim,
out_channels=conv_dim,
bias=conv_bias,
kernel_size=d_conv,
groups=conv_dim,
padding=d_conv - 1,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
setattr(self.conv1d.weight, 'tensor_model_parallel', True)
setattr(self.conv1d.bias, 'tensor_model_parallel', True)
if self.conv_init is not None:
nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
self.activation = "silu"
self.act = nn.SiLU()
with get_cuda_rng_tracker().fork():
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
dt = torch.exp(
torch.rand(
self.nheads_local, device=torch.cuda.current_device(), dtype=config.params_dtype
)
* (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
self.dt_bias = nn.Parameter(inv_dt)
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
self.dt_bias._no_reinit = True
# Just to be explicit. Without this we already don't put wd on dt_bias because of the check
# name.endswith("bias") in param_grouping.py
self.dt_bias._no_weight_decay = True
assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
A = torch.empty(
self.nheads_local, dtype=torch.float32, device=torch.cuda.current_device()
).uniform_(*A_init_range)
A_log = torch.log(A) # Keep A_log in fp32
self.A_log = nn.Parameter(A_log)
self.A_log._no_weight_decay = True
setattr(self.A_log, 'tensor_model_parallel', True)
# D "skip" parameter
self.D = nn.Parameter(
torch.ones(
self.d_inner_local if self.D_has_hdim else self.nheads_local,
device=torch.cuda.current_device(),
)
) # Keep in fp32
self.D._no_weight_decay = True
setattr(self.D, 'tensor_model_parallel', True)
if self.rmsnorm:
assert RMSNormGated is not None
self.norm = RMSNormGated(
self.d_inner_local,
eps=1e-5,
group_size=self.d_inner_local // self.ngroups_local,
norm_before_gate=False,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
# Assume sequence parallelism: input is partitioned along d_inner and
# output is partitioned along the sequence dimension
self.out_proj = RowParallelLinear(
self.d_inner,
self.d_model,
config=self.config,
init_method=self.config.output_layer_init_method,
bias=bias,
input_is_parallel=True,
skip_bias_add=False,
)
def forward(self, hidden_states, inference_params=None):
"""
hidden_states: (nL, B, D) / (L B D)
Returns: same shape as hidden_states
"""
_, batch, dim = hidden_states.shape
conv_state, ssm_state = None, None
if inference_params is not None:
assert not self.config.sequence_parallel
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
if inference_params.seqlen_offset > 0:
# The states are updated inplace
out, _, _ = self.step(hidden_states, conv_state, ssm_state)
return out
# (nheads_local)
A = -torch.exp(self.A_log.float())
# pl b d -> l b p(2d)
# TODO move transpose to GEMM
if self.config.sequence_parallel:
# gather data along sequenece dimension
hidden_states = gather_from_sequence_parallel_region(hidden_states)
else:
hidden_states = copy_to_tensor_model_parallel_region(hidden_states)
xz = hidden_states @ self.in_proj.weight.t()
z, xBC, dt = torch.split(
xz,
[
self.d_inner_local,
self.d_inner_local + 2 * self.ngroups_local * self.d_state,
self.nheads_local,
],
dim=-1,
)
# transpose: l b pd --> b pd l
xBC = rearrange(xBC, "l b d -> b d l")
xBC = xBC.contiguous()
# Compute short convolution
if conv_state is not None:
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
conv_state.copy_(F.pad(xBC, (self.d_conv - xBC.shape[-1], 0))) # Update state (B D W)
seqlen = xBC.size(2)
if causal_conv1d_fn is None:
xBC = self.act(self.conv1d(xBC)[..., :seqlen])
else:
assert self.activation in ["silu", "swish"]
xBC = causal_conv1d_fn(
x=xBC,
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
bias=self.conv1d.bias,
activation=self.activation,
)
# transpose b pd l --> l b pd
xBC = rearrange(xBC, "b d l -> l b d")
xBC = xBC.contiguous()
x, B, C = torch.split(
xBC,
[
self.d_inner_local,
self.ngroups_local * self.d_state,
self.ngroups_local * self.d_state,
],
dim=-1,
)
# TODO Vijay: fuse most of the transposes with the GEMMS
x = rearrange(x, "l b (h p) -> b l h p", p=self.headdim).contiguous()
dt = rearrange(dt, "l b d -> b l d").contiguous()
B = rearrange(B, "l b (g n) -> b l g n", n=self.d_state).contiguous()
C = rearrange(C, "l b (g n) -> b l g n", n=self.d_state).contiguous()
z = rearrange(z, "l b (h p) -> b l h p", p=self.headdim).contiguous()
y = mamba_chunk_scan_combined(
x,
dt,
A,
B,
C,
self.chunk_size,
D=rearrange(self.D.float(), "(h p) -> h p", p=self.headdim)
if self.D_has_hdim
else self.D,
z=z if not self.rmsnorm else None,
dt_bias=self.dt_bias.float(),
dt_softplus=True,
return_final_states=ssm_state is not None,
)
if ssm_state is not None:
y, last_state = y
ssm_state.copy_(last_state)
if self.rmsnorm:
y = rearrange(y, "b l h p -> b l (h p)").contiguous()
z = rearrange(z, "b l h p -> b l (h p)").contiguous()
y = self.norm(y, z)
y = rearrange(y, "b l d -> l b d").contiguous()
else:
y = rearrange(y, "b l h p -> l b (h p)").contiguous()
# l b pd --> pl b d
out_full = y @ self.out_proj.weight.t()
if self.config.sequence_parallel:
out = reduce_scatter_to_sequence_parallel_region(out_full)
else:
out = reduce_from_tensor_model_parallel_region(out_full)
return out
def step(self, hidden_states, conv_state, ssm_state):
# assert self.ngroups_local == 1, "Only support ngroups=1 for inference for now"
dtype = hidden_states.dtype
assert hidden_states.shape[0] == 1, "Only support decoding with 1 token at a time for now"
# l b d --> b d
hidden_states = hidden_states.squeeze(0)
# b d_model --> b p(2d)
xz = hidden_states @ self.in_proj.weight.t()
z, xBC, dt = torch.split(
xz,
[
self.d_inner_local,
self.d_inner_local + 2 * self.ngroups_local * self.d_state,
self.nheads_local,
],
dim=-1,
)
# Conv step
if causal_conv1d_update is None:
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
conv_state[:, :, -1] = xBC
xBC = torch.sum(
conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
) # (B D)
if self.conv1d.bias is not None:
xBC = xBC + self.conv1d.bias
xBC = self.act(xBC).to(dtype=dtype)
else:
xBC = causal_conv1d_update(
xBC,
conv_state,
rearrange(self.conv1d.weight, "d 1 w -> d w"),
self.conv1d.bias,
self.activation,
)
x, B, C = torch.split(
xBC,
[
self.d_inner_local,
self.ngroups_local * self.d_state,
self.ngroups_local * self.d_state,
],
dim=-1,
)
A = -torch.exp(self.A_log.float())
# SSM step
if selective_state_update is None:
if self.ngroups_local > 1:
B = rearrange(B, "b (g n) -> b g n", n=self.d_state)
C = rearrange(C, "b (g n) -> b g n", n=self.d_state)
B = repeat(B, "b g n -> b (g h) n", h=self.d_inner_local // self.ngroups_local)
C = repeat(C, "b g n -> b (g h) n", h=self.d_inner_local // self.ngroups_local)
dt = repeat(dt, "b h -> b (h p)", p=self.headdim)
dt_bias = repeat(self.dt_bias, "h -> (h p)", p=self.headdim)
A = repeat(A, "h -> (h p) n", p=self.headdim, n=self.d_state)
D = repeat(self.D, "h -> (h p)", p=self.headdim)
dt = F.softplus(dt + dt_bias.to(dtype=dt.dtype))
dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
dB_x = torch.einsum('bd,bdn,bd->bdn', dt, B, x)
ssm_state.copy_(
ssm_state * rearrange(dA, "b (h p) n -> b h p n", p=self.headdim)
+ rearrange(dB_x, "b (h p) n -> b h p n", p=self.headdim)
)
y = torch.einsum(
"bdn,bdn->bd",
rearrange(ssm_state.to(dtype), "b h p n -> b (h p) n", p=self.headdim),
C,
)
y = y + D.to(dtype) * x
if not self.rmsnorm:
y = y * self.act(z) # (B D)
else:
# Discretize A and B (b (g n))
dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) # (batch, nheads)
dA = torch.exp(dt * A)
x = rearrange(x, "b (h p) -> b h p", p=self.headdim)
dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x)
ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C)
y = y + rearrange(self.D.to(dtype), "h -> h 1") * x
y = rearrange(y, "b h p -> b (h p)")
if not self.rmsnorm:
y = y * self.act(z) # (B D)
else:
A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(dtype=torch.float32)
dt = repeat(dt, "b h -> b h p", p=self.headdim)
dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim)
D = repeat(self.D, "h -> h p", p=self.headdim)
B = rearrange(B, "b (g n) -> b g n", g=self.ngroups_local)
C = rearrange(C, "b (g n) -> b g n", g=self.ngroups_local)
x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim)
if not self.rmsnorm:
z = rearrange(z, "b (h p) -> b h p", p=self.headdim)
y = selective_state_update(
ssm_state,
x_reshaped,
dt,
A,
B,
C,
D,
z=z if not self.rmsnorm else None,
dt_bias=dt_bias,
dt_softplus=True,
)
y = rearrange(y, "b h p -> b (h p)")
if self.rmsnorm:
y = self.norm(y, z)
# b pd --> b d
out = y @ self.out_proj.weight.t()
out = reduce_from_tensor_model_parallel_region(out)
return out.unsqueeze(0), conv_state, ssm_state
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
device = self.out_proj.weight.device
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
conv_state = torch.zeros(
batch_size, self.conv1d.weight.shape[0], self.d_conv, device=device, dtype=conv_dtype
)
ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype
# ssm_dtype = torch.float32
ssm_state = torch.zeros(
batch_size,
self.nheads_local,
self.headdim,
self.d_state,
device=device,
dtype=ssm_dtype,
)
return conv_state, ssm_state
def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
assert self.layer_idx is not None
if self.layer_idx not in inference_params.key_value_memory_dict:
conv_state = torch.zeros(
batch_size,
self.conv1d.weight.shape[0],
self.d_conv,
device=self.conv1d.weight.device,
dtype=self.conv1d.weight.dtype,
)
ssm_state = torch.zeros(
batch_size,
self.nheads_local,
self.headdim,
self.d_state,
device=self.in_proj.weight.device,
dtype=self.in_proj.weight.dtype,
)
inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
else:
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
# TODO: What if batch size changes between generation, and we reuse the same states?
if initialize_states:
conv_state.zero_()
ssm_state.zero_()
return conv_state, ssm_state
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import os
import socket
from pathlib import Path
import torch
try:
from triton.runtime.cache import FileCacheManager
except ImportError:
raise ImportError("triton is required by the Mamba model but cannot be imported")
def get_rank():
return torch.distributed.get_rank()
def default_cache_dir():
return os.path.join(Path.home(), ".triton", "cache")
class ParallelFileCacheManager(FileCacheManager):
# See https://github.com/triton-lang/triton/blob/main/python/triton/runtime/cache.py
# When running Triton with multiple ranks, they each create their own cache manager. Their input
# keys to that class are mostly (but not entirely) the same across ranks, which leads many ranks
# to write to the same 'key' directories in the cache dir at the same time during compilation,
# leading to conflicts. This works around that by making each cache dir be rank specific by
# adding "rank_<host>_<pid>" to the cache directory.
def __init__(self, key):
self.key = key
self.lock_path = None
# create cache directory if it doesn't exist
self.cache_dir = os.environ.get('TRITON_CACHE_DIR', default_cache_dir())
self.cache_dir = os.path.join(
self.cache_dir, "rank_{}_{}".format(socket.gethostname(), os.getpid())
)
if self.cache_dir:
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
from .cross_entropy import vocab_parallel_cross_entropy
from .data import broadcast_data
from .layers import (
ColumnParallelLinear,
RowParallelLinear,
VocabParallelEmbedding,
copy_tensor_model_parallel_attributes,
linear_with_grad_accumulation_and_async_allreduce,
param_is_not_tensor_parallel_duplicate,
set_defaults_if_not_set_tensor_model_parallel_attributes,
set_tensor_model_parallel_attributes,
)
from .mappings import (
all_gather_last_dim_from_tensor_parallel_region,
all_to_all,
all_to_all_hp2sp,
all_to_all_sp2hp,
copy_to_tensor_model_parallel_region,
gather_from_sequence_parallel_region,
gather_from_sequence_parallel_region_to_moe,
gather_from_tensor_model_parallel_region,
reduce_from_tensor_model_parallel_region,
reduce_scatter_last_dim_to_tensor_parallel_region,
reduce_scatter_to_sequence_parallel_region,
reduce_scatter_to_sequence_parallel_region_from_moe,
scatter_to_sequence_parallel_region,
scatter_to_tensor_model_parallel_region,
)
from .random import (
checkpoint,
get_cuda_rng_tracker,
get_data_parallel_rng_tracker_name,
model_parallel_cuda_manual_seed,
)
from .utils import (
gather_split_1d_tensor,
split_tensor_along_last_dim,
split_tensor_into_1d_equal_chunks,
)
__all__ = [
# cross_entropy.py
"vocab_parallel_cross_entropy",
# data.py
"broadcast_data",
# layers.py
"ColumnParallelLinear",
"RowParallelLinear",
"VocabParallelEmbedding",
"set_tensor_model_parallel_attributes",
"set_defaults_if_not_set_tensor_model_parallel_attributes",
"copy_tensor_model_parallel_attributes",
"param_is_not_tensor_parallel_duplicate",
"linear_with_grad_accumulation_and_async_allreduce",
# mappings.py
"copy_to_tensor_model_parallel_region",
"gather_from_tensor_model_parallel_region",
"gather_from_sequence_parallel_region",
"reduce_from_tensor_model_parallel_region",
"reduce_scatter_to_sequence_parallel_region",
"scatter_to_tensor_model_parallel_region",
"scatter_to_sequence_parallel_region",
# random.py
"checkpoint",
"get_cuda_rng_tracker",
"model_parallel_cuda_manual_seed",
# utils.py
"split_tensor_along_last_dim",
"split_tensor_into_1d_equal_chunks",
"gather_split_1d_tensor",
"gather_from_sequence_parallel_region_to_moe",
"reduce_scatter_to_sequence_parallel_region_from_moe",
]
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from typing import Tuple
import torch
from megatron.core.parallel_state import (
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from .utils import VocabUtility
class VocabParallelCrossEntropy:
"""Computes the Cross Entropy Loss splitting the Vocab size across tensor parallel
ranks. This implementation is used in both fused and unfused cross entropy implementations
"""
@staticmethod
def calculate_logits_max(
vocab_parallel_logits: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
vocab_parallel_logits = vocab_parallel_logits.float()
# Maximum value along vocab dimension across all GPUs.
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
return vocab_parallel_logits, logits_max
@staticmethod
def calculate_predicted_logits(
vocab_parallel_logits: torch.Tensor, target: torch.Tensor, logits_max: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# In-place subtraction reduces memory pressure.
vocab_parallel_logits -= logits_max.unsqueeze(dim=-1)
# Get the partition's vocab indices
get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
partition_vocab_size = vocab_parallel_logits.size()[-1]
rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size()
vocab_start_index, vocab_end_index = get_vocab_range(partition_vocab_size, rank, world_size)
# Create a mask of valid vocab ids (1 means it needs to be masked).
target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
masked_target = target.clone() - vocab_start_index
masked_target[target_mask] = 0
# Get predicted-logits = logits[target].
# For Simplicity, we convert logits to a 2-D tensor with size
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
masked_target_1d = masked_target.view(-1)
arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device)
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
predicted_logits_1d = predicted_logits_1d.clone().contiguous()
predicted_logits = predicted_logits_1d.view_as(target)
predicted_logits[target_mask] = 0.0
exp_logits = vocab_parallel_logits
torch.exp(vocab_parallel_logits, out=exp_logits)
sum_exp_logits = exp_logits.sum(dim=-1)
return target_mask, masked_target_1d, predicted_logits, sum_exp_logits, exp_logits
@staticmethod
def calculate_cross_entropy_loss(
exp_logits: torch.Tensor, predicted_logits: torch.Tensor, sum_exp_logits: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# Loss = log(sum(exp(logits))) - predicted-logit.
loss = torch.log(sum_exp_logits) - predicted_logits
# Normalize and optionally smooth logits
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
return exp_logits, loss
@staticmethod
def prepare_gradient_calculation_operands(
softmax: torch.Tensor, target_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# All the inputs have softmax as thier gradient.
grad_input = softmax
# For simplicity, work with the 2D gradient.
partition_vocab_size = softmax.size()[-1]
grad_2d = grad_input.view(-1, partition_vocab_size)
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
softmax_update = 1.0 - target_mask.view(-1).float()
return grad_2d, arange_1d, softmax_update, grad_input
@staticmethod
def calculate_gradients(
grad_2d: torch.Tensor,
arange_1d: torch.Tensor,
masked_target_1d: torch.Tensor,
softmax_update: torch.Tensor,
grad_input: torch.Tensor,
grad_output: torch.Tensor,
) -> torch.Tensor:
grad_2d[arange_1d, masked_target_1d] -= softmax_update
# Finally elementwise multiplication with the output gradients.
grad_input.mul_(grad_output.unsqueeze(dim=-1))
return grad_input
class _VocabParallelCrossEntropy(torch.autograd.Function):
@staticmethod
def forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0):
vocab_parallel_logits, logits_max = VocabParallelCrossEntropy.calculate_logits_max(
vocab_parallel_logits
)
torch.distributed.all_reduce(
logits_max, op=torch.distributed.ReduceOp.MAX, group=get_tensor_model_parallel_group()
)
(
target_mask,
masked_target_1d,
predicted_logits,
sum_exp_logits,
exp_logits,
) = VocabParallelCrossEntropy.calculate_predicted_logits(
vocab_parallel_logits, target, logits_max
)
# All reduce is needed to get the chunks from other GPUs.
torch.distributed.all_reduce(
predicted_logits,
op=torch.distributed.ReduceOp.SUM,
group=get_tensor_model_parallel_group(),
)
torch.distributed.all_reduce(
sum_exp_logits,
op=torch.distributed.ReduceOp.SUM,
group=get_tensor_model_parallel_group(),
)
exp_logits, loss = VocabParallelCrossEntropy.calculate_cross_entropy_loss(
exp_logits, predicted_logits, sum_exp_logits
)
vocab_size = exp_logits.size(-1)
if label_smoothing > 0:
"""
We'd like to assign 1 / (K - 1) probability mass to every index that is not the ground truth.
= (1 - alpha) * y_gt + alpha * mean(y_{i for i != gt})
= (1 - alpha) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i
= ((K - 1) * (1 - alpha) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i
= (K * (1 - alpha) - 1) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i} y_i
= (1 - (alpha * K) / (K - 1)) * y_gt + ( (alpha * K) / (K - 1) ) * \sum_{i} y_i / K
From: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/losses/smoothed_cross_entropy.py
"""
assert 1.0 > label_smoothing > 0.0
smoothing = label_smoothing * vocab_size / (vocab_size - 1)
# Exp logits at this point are normalized probabilities. So we can just take the log to get log-probs.
log_probs = torch.log(exp_logits)
mean_log_probs = log_probs.mean(dim=-1)
loss = (1.0 - smoothing) * loss - smoothing * mean_log_probs
ctx.label_smoothing, ctx.vocab_size = label_smoothing, vocab_size
# Store softmax, target-mask and masked-target for backward pass.
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
return loss
@staticmethod
def backward(ctx, grad_output):
# Retreive tensors from the forward path.
softmax, target_mask, masked_target_1d = ctx.saved_tensors
label_smoothing, vocab_size = ctx.label_smoothing, ctx.vocab_size
(
grad_2d,
arange_1d,
softmax_update,
grad_input,
) = VocabParallelCrossEntropy.prepare_gradient_calculation_operands(softmax, target_mask)
if label_smoothing > 0:
smoothing = label_smoothing * vocab_size / (vocab_size - 1)
grad_2d[arange_1d, masked_target_1d] -= (1.0 - smoothing) * softmax_update
average_grad = 1 / vocab_size
grad_2d[arange_1d, :] -= smoothing * average_grad
# Finally elementwise multiplication with the output gradients.
grad_input.mul_(grad_output.unsqueeze(dim=-1))
else:
grad_input = VocabParallelCrossEntropy.calculate_gradients(
grad_2d, arange_1d, masked_target_1d, softmax_update, grad_input, grad_output
)
return grad_input, None, None
def vocab_parallel_cross_entropy(vocab_parallel_logits, target, label_smoothing=0.0):
"""
Performs cross entropy loss when logits are split across tensor parallel ranks
Args:
vocab_parallel_logits: logits split across tensor parallel ranks
dimension is [sequence_length, batch_size, vocab_size/num_parallel_ranks]
target: correct vocab ids of dimseion [sequence_length, micro_batch_size]
lobal_smoothing: smoothing factor, must be in range [0.0, 1.0)
default is no smoothing (=0.0)
"""
return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target, label_smoothing)
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import torch
from megatron.core.parallel_state import (
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_src_rank,
)
_MAX_DATA_DIM = 5
def _check_data_types(keys, data, target_dtype):
"""Check that all the keys have the same target data type."""
for key in keys:
assert data[key].dtype == target_dtype, (
'{} has data type {} which '
'is different than {}'.format(key, data[key].dtype, target_dtype)
)
def _build_key_size_numel_dictionaries(keys, data):
"""Build the size on rank 0 and broadcast."""
max_dim = _MAX_DATA_DIM
sizes = [0 for _ in range(max_dim) for _ in keys]
# Pack the sizes on rank zero.
if get_tensor_model_parallel_rank() == 0:
offset = 0
for key in keys:
assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM'
size = data[key].size()
for i, s in enumerate(size):
sizes[i + offset] = s
offset += max_dim
# Move to GPU and broadcast.
sizes_cuda = torch.tensor(sizes, dtype=torch.long, device='cuda')
torch.distributed.broadcast(
sizes_cuda, get_tensor_model_parallel_src_rank(), group=get_tensor_model_parallel_group()
)
# Move back to cpu and unpack.
sizes_cpu = sizes_cuda.cpu()
key_size = {}
key_numel = {}
total_numel = 0
offset = 0
for key in keys:
i = 0
size = []
numel = 1
while sizes_cpu[offset + i] > 0:
this_size = sizes_cpu[offset + i]
size.append(this_size)
numel *= this_size
i += 1
key_size[key] = size
key_numel[key] = numel
total_numel += numel
offset += max_dim
return key_size, key_numel, total_numel
def broadcast_data(keys, data, datatype):
"""Broadcast data from rank zero of each model parallel group to the
members of the same model parallel group.
Args:
keys: list of keys in the data disctionary to be broadcasted
data: data dictionary of string keys and cpu tensor values.
datatype: torch data type of all tensors in data associated
with keys.
"""
# Build (key, size) and (key, number of elements) dictionaries along
# with the total number of elements on all ranks.
key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, data)
# Pack on rank zero.
if get_tensor_model_parallel_rank() == 0:
# Check that all keys have the same data type.
_check_data_types(keys, data, datatype)
# Flatten the data associated with the keys
flatten_data = torch.cat([data[key].contiguous().view(-1) for key in keys], dim=0).cuda()
else:
flatten_data = torch.empty(total_numel, device=torch.cuda.current_device(), dtype=datatype)
# Broadcast
torch.distributed.broadcast(
flatten_data, get_tensor_model_parallel_src_rank(), group=get_tensor_model_parallel_group()
)
# Unpack
output = {}
offset = 0
for key in keys:
size = key_size[key]
numel = key_numel[key]
output[key] = flatten_data.narrow(0, offset, numel).view(size)
offset += numel
return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment