Commit 996ea169 authored by Przemek Tredak's avatar Przemek Tredak
Browse files

Inital code drop


Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
parents
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Methods needed for distributed training (DP/TP)."""
from typing import Union, Optional, Callable, Tuple
import torch
from torch.cuda import _lazy_call
from torch.utils.checkpoint import detach_variable
from .utils import safely_set_viewless_tensor_data
from .constants import dist_group_type
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
"tensor_model_parallel": False,
"partition_dim": -1,
"partition_stride": 1,
}
def _set_cuda_rng_state(new_state: torch.Tensor, device: Union[int, str] = -1) -> None:
"""Sets the random number generator state of the current GPU.
Arguments:
new_state (torch.ByteTensor): The desired state
This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
with a single change: the input state is not cloned. Cloning caused
major performance issues for +4 GPU cases.
"""
if device == -1:
device = torch.device("cuda")
elif isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device("cuda", device)
def cb() -> None:
idx = device.index
if idx is None:
idx = torch.cuda.current_device()
default_generator = torch.cuda.default_generators[idx]
default_generator.set_state(new_state)
_lazy_call(cb)
def set_tensor_model_parallel_attributes(
tensor: torch.Tensor, is_parallel: bool, dim: int, stride: int
) -> None:
"""set attributes needed for TP"""
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
assert not hasattr(tensor, attribute)
# Set the attributes.
setattr(tensor, "tensor_model_parallel", is_parallel)
setattr(tensor, "partition_dim", dim)
setattr(tensor, "partition_stride", stride)
def get_distributed_world_size(group: Optional[dist_group_type] = None) -> int:
"""Return world size for the distributed group."""
if group is None:
return 1
return torch.distributed.get_world_size(group=group)
def get_distributed_rank(group: Optional[dist_group_type] = None) -> int:
"""Return my rank for the distributed group."""
return torch.distributed.get_rank(group=group)
def initialize_affine_weight_gpu(
weight: torch.Tensor,
init_method: Callable,
get_rng_state_tracker: Callable,
partition_dim: int,
stride: int = 1,
) -> None:
"""Initialize affine weight for model parallel on GPU."""
set_tensor_model_parallel_attributes(
tensor=weight, is_parallel=True, dim=partition_dim, stride=stride
)
if get_rng_state_tracker is None:
init_method(weight)
return
with get_rng_state_tracker().fork():
init_method(weight)
def split_tensor_into_1d_equal_chunks(
tensor: torch.Tensor, tp_group: dist_group_type, new_buffer: bool = False
) -> torch.Tensor:
"""Break a tensor into equal 1D chunks."""
partition_size = torch.numel(tensor) // get_distributed_world_size(tp_group)
start_index = partition_size * get_distributed_rank(tp_group)
end_index = start_index + partition_size
if new_buffer:
data = torch.empty(
partition_size,
dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
data.copy_(tensor.view(-1)[start_index:end_index])
else:
data = tensor.view(-1)[start_index:end_index]
return data
def gather_split_1d_tensor(
tensor: torch.Tensor, tp_group: dist_group_type
) -> torch.Tensor:
"""Opposite of above function, gather values from model parallel ranks."""
numel_gathered = torch.numel(tensor) * get_distributed_world_size(tp_group)
gathered = torch.empty(
numel_gathered,
dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
torch.distributed._all_gather_base(gathered, tensor, group=tp_group)
return gathered
class CheckpointFunction(torch.autograd.Function):
"""This function is adapted from torch.utils.checkpoint with
two main changes:
1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`
2) the states in the model parallel tracker are also properly
tracked/set/reset.
"""
@staticmethod
def forward(
ctx,
run_function: Callable,
distribute_saved_activations: bool,
get_cuda_rng_tracker: Callable,
tp_group: dist_group_type,
*args: Tuple[torch.Tensor, ...],
) -> Tuple[torch.Tensor, ...]:
ctx.run_function = run_function
ctx.distribute_saved_activations = distribute_saved_activations
# Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state()
ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
with torch.no_grad():
outputs = run_function(*args)
# Divide hidden states across model parallel group and only keep
# the chunk corresponding to the current rank.
if distribute_saved_activations:
ctx.input_0_shape = args[0].data.shape
safely_set_viewless_tensor_data(
args[0],
split_tensor_into_1d_equal_chunks(
args[0].data, tp_group, new_buffer=True
),
)
# Store everything.
ctx.save_for_backward(*args)
ctx.get_cuda_rng_tracker = get_cuda_rng_tracker
ctx.tp_group = tp_group
return outputs
@staticmethod
def backward(
ctx, *args: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]:
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError(
"Checkpointing is not compatible with .grad(), "
"please use .backward() if possible"
)
inputs = ctx.saved_tensors
get_cuda_rng_tracker = ctx.get_cuda_rng_tracker
if ctx.distribute_saved_activations:
safely_set_viewless_tensor_data(
inputs[0],
gather_split_1d_tensor(inputs[0].data, ctx.tp_group).view(
ctx.input_0_shape
),
)
# Store the current states.
bwd_cpu_rng_state = torch.get_rng_state()
bwd_cuda_rng_state = torch.cuda.get_rng_state()
bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
# Set the states to what it used to be before the forward pass.
torch.set_rng_state(ctx.fwd_cpu_rng_state)
_set_cuda_rng_state(ctx.fwd_cuda_rng_state)
get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
# Compute the forward pass.
detached_inputs = detach_variable(inputs)
with torch.enable_grad():
outputs = ctx.run_function(*detached_inputs)
# Set the states back to what it was at the start of this function.
torch.set_rng_state(bwd_cpu_rng_state)
_set_cuda_rng_state(bwd_cuda_rng_state)
get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
torch.autograd.backward(outputs, args)
grads = tuple(
inp.grad if isinstance(inp, torch.Tensor) else inp
for inp in detached_inputs
)
return (None, None, None, None) + grads
def checkpoint(
function: Callable,
distribute_saved_activations: bool,
get_cuda_rng_tracker: Callable,
tp_group: dist_group_type,
*args: Tuple[torch.Tensor, ...],
) -> Tuple[torch.Tensor, ...]:
"""Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint."""
return CheckpointFunction.apply(
function, distribute_saved_activations, get_cuda_rng_tracker, tp_group, *args
)
def reduce_scatter_along_first_dim(
input_: torch.Tensor, tp_group: dist_group_type, async_op: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Reduce-scatter the input tensor across model parallel group."""
world_size = get_distributed_world_size(tp_group)
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_, None
dim_size = list(input_.size())
assert (
dim_size[0] % world_size == 0
), "First dimension of the tensor should be divisible by tensor parallel size"
dim_size[0] = dim_size[0] // world_size
output = torch.empty(
dim_size, dtype=input_.dtype, device=torch.cuda.current_device()
)
handle = torch.distributed._reduce_scatter_base(
output, input_.contiguous(), group=tp_group, async_op=async_op
)
return output, handle
def gather_along_first_dim(
input_: torch.Tensor, tp_group: dist_group_type, async_op: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Gather tensors and concatinate along the first dimension."""
world_size = get_distributed_world_size(tp_group)
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_, None
dim_size = list(input_.size())
dim_size[0] = dim_size[0] * world_size
output = torch.empty(
dim_size, dtype=input_.dtype, device=torch.cuda.current_device()
)
handle = torch.distributed._all_gather_base(
output, input_.contiguous(), group=tp_group, async_op=async_op
)
return output, handle
def gather_along_last_dim(
input_: torch.Tensor, tp_group: dist_group_type, async_op: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Gather tensors and concatinate along the last dimension."""
world_size = get_distributed_world_size(tp_group)
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_, None
dim_size = list(input_.size())
dim_size[-1] = dim_size[-1] * world_size
output = torch.empty(
dim_size, dtype=input_.dtype, device=torch.cuda.current_device()
)
handle = torch.distributed._all_gather_base(
output, input_.contiguous(), group=tp_group, async_op=async_op
)
return output, handle
def allreduce(
input_: torch.Tensor,
tp_group: Optional[dist_group_type] = None,
async_op: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""All-reduce the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
if get_distributed_world_size(tp_group) == 1:
return input_, None
# All-reduce.
handle = torch.distributed.all_reduce(input_, group=tp_group, async_op=async_op)
return input_, handle
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""FP8 utilies for TransformerEngine"""
from contextlib import contextmanager
from typing import Callable, List, Optional, Dict, Any, Tuple, Union
import torch
import transformer_engine_extensions as tex
from transformer_engine.common.recipe import DelayedScaling, Format
from .constants import dist_group_type
_FP8_ENABLED = False
_FP8_RECIPE = None
_FP8_DISTRIBUTED_GROUP = None
_IS_FIRST_FP8_MODULE = False
_FP8_AUTOCAST_COUNTER = 0
_FP8_CURRENT_CONTEXT_ID = 0
_FP8_AUTOCAST_DEPTH = 0
_global_fp8_buffer = {}
_amax_forward_global_reduce_func = None
_buffer_delete_key_fwd = None
_buffer_delete_key_bwd = None
def get_meta_tensor_key(forward: bool = True) -> str:
"""Returns scaling key in `fp8_meta`."""
if forward:
return "scaling_fwd"
return "scaling_bwd"
def get_buffer_position_key(forward: bool = True) -> str:
"""Returns module position key in `fp8_meta`."""
if forward:
return "global_fp8_buffer_pos_fwd"
return "global_fp8_buffer_pos_bwd"
def get_autocast_key(forward: bool = True) -> str:
"""Returns module position key in `fp8_meta`."""
if forward:
return "autocast_id_fwd"
return "autocast_id_bwd"
def get_global_fp8_buffer() -> Dict[str, List[torch.Tensor]]:
"""Returns global fp8 buffer."""
return _global_fp8_buffer
def set_global_fp8_buffer(buffer: Dict[str, List[torch.Tensor]]) -> None:
"""Sets global fp8 buffer."""
global _global_fp8_buffer
# Map all tensors back to GPU.
for k, v in buffer.items():
buffer[k] = [tensor.cuda() for tensor in v]
_global_fp8_buffer = buffer
def setup_amax_forward_global_reduce_func(f: Callable) -> None:
"""Sets up the function to call during autocast exit."""
global _amax_forward_global_reduce_func
_amax_forward_global_reduce_func = f
def get_amax_buffer_key(fp8_meta: Dict[str, Any], forward: bool = True) -> str:
"""Return a key in `_global_fp8_buffer` for the AMAX storage."""
if forward:
return f"FWD_AMAX_{fp8_meta['autocast_id_fwd']}"
return f"BWD_AMAX_{fp8_meta['autocast_id_bwd']}"
def add_amax_to_global_buffer(fp8_meta: Dict[str, Any], forward: bool = True) -> None:
"""Append 1D tensor `amax` to global buffer."""
global _global_fp8_buffer
buffer_key = get_amax_buffer_key(fp8_meta, forward=forward)
fp8_meta_tensor_key = get_meta_tensor_key(forward=forward)
buffer_position_key = get_buffer_position_key(forward=forward)
if buffer_key not in _global_fp8_buffer:
_global_fp8_buffer[buffer_key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]]
else:
_global_fp8_buffer[buffer_key].append(
fp8_meta[fp8_meta_tensor_key].amax_history[0]
)
if buffer_position_key not in fp8_meta:
fp8_meta[buffer_position_key] = len(_global_fp8_buffer[buffer_key]) - 1
def copy_amax_from_global_buffer(
fp8_meta: Dict[str, Any], forward: bool = True
) -> None:
"""Populate current amax with the correct location from buffer."""
fp8_meta_tensor_key = get_meta_tensor_key(forward=forward)
buffer_position_key = get_buffer_position_key(forward=forward)
if buffer_position_key not in fp8_meta:
return
amax_buffer_key = get_amax_buffer_key(fp8_meta, forward=forward)
fp8_meta[fp8_meta_tensor_key].amax_history[0] = _global_fp8_buffer[amax_buffer_key][
fp8_meta[buffer_position_key]
]
def set_amax_buffer_key_deletion(
fp8_meta: Dict[str, Any], forward: bool = True
) -> None:
"""Delete this amax key from global buffer during autocast end."""
if get_autocast_key(forward=forward) not in fp8_meta:
return
global _buffer_delete_key_fwd, _buffer_delete_key_bwd
if forward:
_buffer_delete_key_fwd = get_amax_buffer_key(fp8_meta, forward=forward)
else:
_buffer_delete_key_bwd = get_amax_buffer_key(fp8_meta, forward=forward)
def get_default_fp8_recipe() -> DelayedScaling:
"""FP8 recipe if not provided by user
Margin = 0, interval = 1, E4M3
"""
return DelayedScaling()
@contextmanager
def fp8_autocast(
enabled: bool = False,
fp8_recipe: Optional[DelayedScaling] = None,
fp8_group: Optional[dist_group_type] = None,
) -> None:
"""
Context manager for FP8 usage.
.. code-block:: python
with fp8_autocast(enabled=True):
out = model(inp)
.. note::
Support for FP8 in the Linear layer of Transformer Engine is currently limited to tensors
with shapes where both dimensions are divisible by 16. In terms of the input to the full
Transformer network, this typically requires padding sequence length to be multiple of 16.
Parameters
----------
enabled: bool, default = `False`
whether or not to enable fp8
fp8_recipe: recipe.DelayedScaling, default = `None`
recipe used for FP8 training.
fp8_group: torch._C._distributed_c10d.ProcessGroup, default = `None`
distributed group over which amaxes for the fp8 tensors
are reduced at the end of each training step.
"""
global _FP8_ENABLED, _FP8_RECIPE, _FP8_DISTRIBUTED_GROUP, _FP8_AUTOCAST_DEPTH
global _IS_FIRST_FP8_MODULE, _FP8_AUTOCAST_COUNTER
global _global_fp8_buffer, _buffer_delete_key_fwd
fp8_state = (_FP8_ENABLED, _FP8_RECIPE, _FP8_DISTRIBUTED_GROUP)
try:
_FP8_ENABLED = enabled
_FP8_RECIPE = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe
_FP8_DISTRIBUTED_GROUP = fp8_group
if _FP8_AUTOCAST_DEPTH == 0:
_IS_FIRST_FP8_MODULE = True
_FP8_AUTOCAST_COUNTER += 1
_FP8_AUTOCAST_DEPTH += 1
if enabled:
assert (
torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9
), "Device compute capability 9.x required for FP8 execution."
yield
finally:
_FP8_ENABLED, _FP8_RECIPE, _FP8_DISTRIBUTED_GROUP = fp8_state
_IS_FIRST_FP8_MODULE = False
_FP8_AUTOCAST_DEPTH -= 1
if _FP8_AUTOCAST_DEPTH == 0:
if callable(_amax_forward_global_reduce_func):
_amax_forward_global_reduce_func()
delete_key_from_amax_buffer(forward=True)
def get_fp8_context_id() -> int:
"""Returns an ID for the current FP8 context."""
return _FP8_CURRENT_CONTEXT_ID
def set_fp8_context_id(ctx_id: int) -> None:
"""Sets the current FP8 context."""
global _FP8_CURRENT_CONTEXT_ID
_FP8_CURRENT_CONTEXT_ID = ctx_id
def new_fp8_context_id() -> int:
"""Returns global autocast counter as a proxy to be used
as the autocast ID for FP8 modules.
"""
return _FP8_AUTOCAST_COUNTER
def is_fp8_enabled() -> bool:
"""Is FP8 enabled"""
return _FP8_ENABLED
def is_first_fp8_module():
"""Returns `True` only the first time when called multiple
times from within the same `fp8_autocast` context.
"""
global _IS_FIRST_FP8_MODULE
tmp = _IS_FIRST_FP8_MODULE
_IS_FIRST_FP8_MODULE = False
return tmp
def get_fp8_recipe() -> DelayedScaling:
"""Return the fp8 recipe"""
return _FP8_RECIPE
def get_fp8_group() -> Union[dist_group_type, None]:
"""Return the fp8 group for scale/amax comm"""
return _FP8_DISTRIBUTED_GROUP
def update_amax_history(amax_history: torch.Tensor) -> torch.Tensor:
"""Update amax history and set next amax to zero."""
amax_history = torch.roll(amax_history, -1, 0)
amax_history[0].fill_(0.0)
return amax_history
@torch.jit.script
def _default_get_amax(
amax_history: torch.Tensor,
amax_compute_algo: str,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Default function to obtain amax from history."""
if amax_compute_algo == "max":
amax = torch.max(amax_history, dim=0).values
else: # amax_compute_algo == "most_recent"
amax = amax_history[0]
amax_history = update_amax_history(amax_history)
return amax_history, amax
@torch.jit.script
def _default_sf_compute(
amax: torch.Tensor,
scale: torch.Tensor,
fp8_max: float,
margin: int,
) -> torch.Tensor:
"""Default function to convert amax to scaling factor."""
exp = torch.floor(torch.log2(fp8_max / amax)) - margin
sf = torch.round(torch.pow(2, torch.abs(exp)))
sf = torch.where(amax > 0.0, sf, scale)
sf = torch.where(torch.isfinite(amax), sf, scale)
sf = torch.where(exp < 0, 1 / sf, sf)
return sf
@torch.jit.script
def fused_amax_and_scale_update(
amax_history: torch.Tensor,
scale: torch.Tensor,
fp8_max: float,
margin: int,
amax_compute_algo: str,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Amax to scale conversion."""
# Get amax from history.
amax_history, amax = _default_get_amax(
amax_history,
amax_compute_algo,
)
# Calculate new scaling factor.
return amax_history, _default_sf_compute(
amax,
scale,
fp8_max,
margin,
)
def _compute_amax(
amax_history: torch.Tensor,
recipe: DelayedScaling,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Obtain the amax from the history."""
if callable(recipe.amax_compute_algo):
amax = recipe.amax_compute_algo(amax_history)
amax_history = update_amax_history(amax_history)
return amax_history, amax
return _default_get_amax(
amax_history,
recipe.amax_compute_algo,
)
def _compute_scaling_factor(
amax: torch.Tensor,
scale: torch.Tensor,
fp8_max: float,
recipe: DelayedScaling,
) -> torch.Tensor:
"""Convert amax to scaling factor."""
if recipe.scaling_factor_compute_algo is None:
return _default_sf_compute(
amax,
scale,
fp8_max,
recipe.margin,
)
return recipe.scaling_factor_compute_algo(amax, scale, fp8_max, recipe)
def amax_and_scale_update(
fp8_meta: Dict[str, Any],
fwd_update: bool,
) -> None:
"""Updates fp8 amaxes/scales for fwd | bwd."""
amax_compute = fp8_meta["recipe"].amax_compute_algo
sf_compute = fp8_meta["recipe"].scaling_factor_compute_algo
fp8_meta_tensor_key = "scaling_fwd" if fwd_update else "scaling_bwd"
fp8_max_key = "fp8_max_fwd" if fwd_update else "fp8_max_bwd"
if not callable(amax_compute) and sf_compute is None:
(
fp8_meta[fp8_meta_tensor_key].amax_history,
fp8_meta[fp8_meta_tensor_key].scale,
) = fused_amax_and_scale_update(
fp8_meta[fp8_meta_tensor_key].amax_history,
fp8_meta[fp8_meta_tensor_key].scale,
fp8_meta[fp8_max_key],
fp8_meta["recipe"].margin,
fp8_meta["recipe"].amax_compute_algo,
)
else:
fp8_meta[fp8_meta_tensor_key].amax_history, amax = _compute_amax(
fp8_meta[fp8_meta_tensor_key].amax_history,
fp8_meta["recipe"],
)
fp8_meta[fp8_meta_tensor_key].scale = _compute_scaling_factor(
amax,
fp8_meta[fp8_meta_tensor_key].scale,
fp8_meta[fp8_max_key],
fp8_meta["recipe"],
)
def get_fp8_te_dtype(
fp8_recipe: DelayedScaling, fprop_tensor: bool = True
) -> tex.DType:
"""Get fp8 data type according to recipe and tensor"""
if fp8_recipe.fp8_format == Format.E4M3 or (
fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor
):
return tex.DType.kFloat8E4M3
return tex.DType.kFloat8E5M2
def reduce_tensor_across_group_op_max(
tensor: torch.Tensor, group: dist_group_type
) -> None:
"""Reduce tensor across given group."""
if torch.distributed.is_initialized():
torch.distributed.all_reduce(
tensor,
op=torch.distributed.ReduceOp.MAX,
group=group,
async_op=False,
)
def global_amax_reduction(
fp8_meta: Dict[str, Any],
reduce_amax_across_tp_group: bool = False,
tp_group: Optional[dist_group_type] = None,
forward: bool = True,
) -> None:
"""Concatenate, reduce, and split amaxes in the global buffer."""
global _global_fp8_buffer
amax_buffer_key = get_amax_buffer_key(fp8_meta, forward=forward)
# Key already deleted.
if amax_buffer_key not in _global_fp8_buffer:
return
chunk_sizes = [x.numel() for x in _global_fp8_buffer[amax_buffer_key]]
contiguous_amax = torch.cat(_global_fp8_buffer[amax_buffer_key])
reduce_tensor_across_group_op_max(contiguous_amax, fp8_meta["fp8_group"])
if reduce_amax_across_tp_group:
reduce_tensor_across_group_op_max(contiguous_amax, tp_group)
_global_fp8_buffer[amax_buffer_key] = list(contiguous_amax.split(chunk_sizes))
def delete_key_from_amax_buffer(forward: bool = True) -> None:
"""Delete the key from global amax buffer."""
global _global_fp8_buffer, _buffer_delete_key_fwd, _buffer_delete_key_bwd
if forward:
if (
_buffer_delete_key_fwd is not None
and _buffer_delete_key_fwd in _global_fp8_buffer
):
del _global_fp8_buffer[_buffer_delete_key_fwd]
else:
if (
_buffer_delete_key_bwd is not None
and _buffer_delete_key_bwd in _global_fp8_buffer
):
del _global_fp8_buffer[_buffer_delete_key_bwd]
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""NVFuser functions and JIT utilities"""
from typing import Callable, Tuple
import torch
def set_jit_fusion_options() -> None:
"""Set PyTorch JIT layer fusion options."""
# flags required to enable jit fusion kernels
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10):
# nvfuser
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(True)
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(True)
torch._C._debug_set_autodiff_subgraph_inlining(False)
else:
# legacy pytorch fuser
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
@torch.jit.script
def bias_gelu_fused_(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
"""Bias-GeLU fused"""
x = inp + bias
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.jit.script
def bgrad_dgelu_fused_(
grad_output: torch.Tensor, inp: torch.Tensor, bias: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Bgrad-Dgelu fused"""
x = inp + bias
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * (
(1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)
) + 0.5 * (1 + tanh_out)
dgelu = ff * grad_output
bgrad = dgelu.sum(dim=0)
return bgrad, dgelu
def bias_gelu_fused(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
"""Disable native AMP for bias_gelu_fused_"""
with torch.cuda.amp.autocast(enabled=False):
return bias_gelu_fused_(inp, bias)
def bgrad_dgelu_fused(
grad_output: torch.Tensor, inp: torch.Tensor, bias: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Disable native AMP for `bgrad_dgelu_fused_`"""
with torch.cuda.amp.autocast(enabled=False):
return bgrad_dgelu_fused_(grad_output, inp, bias)
def bias_dropout_add(
x: torch.Tensor,
bias: torch.Tensor,
residual: torch.Tensor,
prob: float,
training: bool,
) -> torch.Tensor:
"""dropout(inp + bias) + residual"""
out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
out = residual + out
return out
def get_bias_dropout_add(training: bool) -> Callable:
"""bias_dropout_add based on training or not"""
def _bias_dropout_add(x, bias, residual, prob):
return bias_dropout_add(x, bias, residual, prob, training)
return _bias_dropout_add
@torch.jit.script
def bias_dropout_add_fused_train_(
x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float
) -> torch.Tensor:
"""Jit fused bias_dropout_add for training"""
return bias_dropout_add(x, bias, residual, prob, True)
def bias_dropout_add_fused_train(
x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float
) -> torch.Tensor:
"""Disable native AMP and enable grad for BDA"""
with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False):
return bias_dropout_add_fused_train_(x, bias, residual, prob)
@torch.jit.script
def bias_dropout_add_fused_inference_(
x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float
) -> torch.Tensor:
"""Jit fused bias_dropout_add for inference"""
return bias_dropout_add(x, bias, residual, prob, False)
def bias_dropout_add_fused_inference(
x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float
) -> torch.Tensor:
"""Disable native AMP for BDA"""
with torch.cuda.amp.autocast(enabled=False):
return bias_dropout_add_fused_inference_(x, bias, residual, prob)
def warmup_jit_bias_dropout_add(
hidden_size: int, dtype: torch.dtype, seq_length: int, micro_batch_size: int
) -> None:
"""Compilie BDA JIT function before the main training steps"""
# Warmup fused bias+dropout+add
inp = torch.rand(
(seq_length, micro_batch_size, hidden_size), dtype=dtype, device="cuda"
)
residual = torch.rand(
(seq_length, micro_batch_size, hidden_size), dtype=dtype, device="cuda"
)
bias = torch.rand((hidden_size), dtype=dtype, device="cuda")
dropout_rate = 0.1
# Warmup JIT fusions with the input grad_enable state of both forward
# prop and recomputation
for input_grad, bias_grad, residual_grad in zip(
[False, True], [True, True], [True, True]
):
inp.requires_grad = input_grad
bias.requires_grad = bias_grad
residual.requires_grad = residual_grad
for _ in range(5):
output = bias_dropout_add_fused_train(inp, bias, residual, dropout_rate)
del bias, inp, residual, output
torch.cuda.empty_cache()
def warmup_jit_bias_dropout_add_all_dtypes(
hidden_size: int, seq_length: int, micro_batch_size: int
) -> None:
"""Call `warmup_jit_bias_dropout_add` for all training dtypes"""
for dtype in [torch.float32, torch.bfloat16, torch.float16]:
warmup_jit_bias_dropout_add(hidden_size, dtype, seq_length, micro_batch_size)
def warmup_jit_bias_gelu(
ffn_hidden_size_per_partition: int,
dtype: torch.dtype,
seq_length: int,
micro_batch_size: int,
) -> None:
"""Compilie bias-gelu JIT function before the main training steps"""
# Warmup fused bias+gelu
bias = torch.rand(ffn_hidden_size_per_partition, dtype=dtype, device="cuda")
inp = torch.rand(
(seq_length, micro_batch_size, ffn_hidden_size_per_partition),
dtype=dtype,
device="cuda",
)
# Warmup JIT fusions with the input grad_enable state of both forward
# prop and recomputation
for bias_grad, input_grad in zip([True, True], [False, True]):
bias.requires_grad, inp.requires_grad = bias_grad, input_grad
for _ in range(5):
output = bias_gelu_fused(inp, bias)
del bias, inp, output
def warmup_jit_bias_gelu_all_dtypes(
ffn_hidden_size: int, seq_length: int, micro_batch_size: int
) -> None:
"""Call `warmup_jit_bias_gelu` for all training dtypes"""
for dtype in [torch.float32, torch.bfloat16, torch.float16]:
warmup_jit_bias_gelu(ffn_hidden_size, dtype, seq_length, micro_batch_size)
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Top level Transformer Engine PyTorch modules"""
import os
import warnings
from abc import ABC, abstractmethod
from typing import Union, Optional, Callable, Tuple, Dict, List, Any
from functools import partial
import torch
from torch.nn.parameter import Parameter
from torch.nn import init
import transformer_engine_extensions as tex
from .fp8 import (
is_fp8_enabled,
get_fp8_recipe,
get_fp8_group,
get_default_fp8_recipe,
get_fp8_te_dtype,
is_first_fp8_module,
new_fp8_context_id,
get_fp8_context_id,
set_fp8_context_id,
add_amax_to_global_buffer,
copy_amax_from_global_buffer,
global_amax_reduction,
setup_amax_forward_global_reduce_func,
amax_and_scale_update,
get_global_fp8_buffer,
set_global_fp8_buffer,
set_amax_buffer_key_deletion,
delete_key_from_amax_buffer,
)
from .jit import (
bias_gelu_fused,
bgrad_dgelu_fused,
set_jit_fusion_options,
warmup_jit_bias_gelu_all_dtypes,
)
from .utils import (
divide,
get_default_init_method,
cast_if_needed,
)
from .distributed import (
set_tensor_model_parallel_attributes,
get_distributed_world_size,
allreduce,
initialize_affine_weight_gpu,
reduce_scatter_along_first_dim,
gather_along_first_dim,
gather_along_last_dim,
)
from .cpp_extensions import (
fp8_gemm,
gemm,
fp8_cast_transpose_fused,
fp8_cast_transpose_bgrad_fused,
fp8_gelu,
fp8_cast_transpose_bgrad_dgelu_fused,
layernorm_fwd_fp8,
cast_to_fp8,
cast_from_fp8,
)
from .constants import GemmParallelModes, dist_group_type, TE_DType
_2X_ACC_FPROP = False
_2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True
_cublas_workspace = None
def get_cublas_workspace_size_bytes() -> None:
"""Return 32 MiB if using hopper, 4 MiB for all other architectures."""
if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9:
return 33_554_432
return 4_194_304
def get_workspace() -> torch.Tensor:
"""Returns workspace for cublas."""
global _cublas_workspace
if _cublas_workspace is None:
_cublas_workspace = torch.empty(
get_cublas_workspace_size_bytes(), dtype=torch.int8, device="cuda"
)
return _cublas_workspace
class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Base TE module."""
def __init__(self) -> None:
super().__init__()
assert torch.cuda.is_available(), "TransformerEngine needs CUDA."
self.fp8 = False
self.fp8_meta = {}
self.fp8_meta["fp8_group"] = None
self.fp8_meta["recipe"] = get_default_fp8_recipe()
self.fp8_meta_tensors_initialized = False
self.tp_group = None
self.tp_group_initialized = False
self.tp_size = 1
self.sequence_parallel = False
self.fp8_weight_shapes = []
def set_meta_tensor(self, fwd: bool) -> None:
"""Init scales and amaxes for fwd | bwd."""
fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd"
num_fp8_tensors = (
self.fp8_meta["num_gemms"] * 2 if fwd else self.fp8_meta["num_gemms"]
)
self.fp8_meta[fp8_meta_tensor_key] = tex.FP8TensorMeta()
self.fp8_meta[fp8_meta_tensor_key].scale = torch.ones(
num_fp8_tensors, dtype=torch.float32, device="cuda"
)
self.fp8_meta[fp8_meta_tensor_key].scale_inv = torch.ones(
num_fp8_tensors, dtype=torch.float32, device="cuda"
)
self.fp8_meta[fp8_meta_tensor_key].amax_history = torch.zeros(
self.fp8_meta["recipe"].amax_history_len,
num_fp8_tensors,
dtype=torch.float32,
device="cuda",
)
def init_fp8_meta_tensors(self) -> None:
"""Init scales and amaxes."""
# Checkpoint loaded
if self.fp8_meta_tensors_initialized:
return
self.set_meta_tensor(True)
self.set_meta_tensor(False)
def get_extra_state(self) -> Union[List[Any], None]:
"""Save before checkpointing."""
if self.fp8:
state = []
state.append(self.fp8_meta["scaling_fwd"].scale)
state.append(self.fp8_meta["scaling_fwd"].amax_history)
state.append(self.fp8_meta["scaling_bwd"].scale)
state.append(self.fp8_meta["scaling_bwd"].amax_history)
state.append(get_global_fp8_buffer())
state.append(self.fp8_meta["update_amax_and_scale_fwd"])
state.append(self.fp8_meta["global_fp8_buffer_pos_fwd"])
state.append(self.fp8_meta["global_fp8_buffer_pos_bwd"])
state.append(self.fp8_meta["autocast_id_fwd"])
state.append(self.fp8_meta["autocast_id_bwd"])
return state
return None
def set_extra_state(self, state: Union[List[Any], None]) -> None:
"""Load previous state."""
if state is None:
return
# Retrieve checkpointed items.
scale_fwd = state[0]
amax_history_fwd = state[1]
scale_bwd = state[2]
amax_history_bwd = state[3]
self.fp8_meta["recipe"].amax_history_len = amax_history_fwd.shape[0]
self.fp8_meta["num_gemms"] = (
amax_history_fwd.shape[1] // 2
) # Two FWD tensors per GEMM
# Initialize before loading
self.init_fp8_meta_tensors()
self.fp8_meta["scaling_fwd"].scale.copy_(scale_fwd)
self.fp8_meta["scaling_fwd"].amax_history.copy_(amax_history_fwd)
self.fp8_meta["scaling_bwd"].scale.copy_(scale_bwd)
self.fp8_meta["scaling_bwd"].amax_history.copy_(amax_history_bwd)
self.fp8_meta_tensors_initialized = True
# Restore global FP8 buffer state.
set_global_fp8_buffer(state[4])
self.fp8_meta["update_amax_and_scale_fwd"] = state[5]
self.fp8_meta["global_fp8_buffer_pos_fwd"] = state[6]
self.fp8_meta["global_fp8_buffer_pos_bwd"] = state[7]
self.fp8_meta["autocast_id_fwd"] = state[8]
self.fp8_meta["autocast_id_bwd"] = state[9]
def set_activation_dtype(self, inp: torch.Tensor) -> None:
"""Get activation data type for AMP."""
# Native AMP (`torch.autocast`) gets highest priority
if torch.is_autocast_enabled():
self.activation_dtype = torch.get_autocast_gpu_dtype()
return
# All checks after this have already been performed once, thus skip
# We assume that user doesn't change input types across iterations
if hasattr(self, "activation_dtype"):
return
assert all(
(
(inp.dtype == param.dtype) if param is not None else True
for param in self.parameters()
)
), (
"Data type for activations and weights must "
"match when outside of autocasted region"
)
assert all(
(
(inp.dtype == buf.dtype) if buf is not None else True
for buf in self.buffers()
)
), (
"Data type for activations and buffers must "
"match when outside of autocasted region"
)
self.activation_dtype = inp.dtype
def set_fp8_weights(self) -> None:
"""Initializes FP8 weights for the module as class attributes. These
are not parameters or buffers since we do not want functions such as
`.to(dtype)` or `.to(device)` to effect them. These also do not need
to be checkpointed. During `init` phase of the module, the attribute
`fp8_weight_shapes` must be populated with the tensor shapes for FP8
weights. This function will iterate over those shapes and initialize
respective attributed named `weight1_fp8`, `weight2_fp8`, ...
"""
for i, shape in enumerate(self.fp8_weight_shapes, start=1):
weight_cast_attr = f"weight{i}_fp8"
weight_transpose_attr = f"weight{i}_t_fp8"
if self.fp8:
if not hasattr(self, weight_cast_attr):
setattr(
self,
weight_cast_attr,
torch.empty(
shape,
device=torch.cuda.current_device(),
dtype=torch.int8,
),
)
if not hasattr(self, weight_transpose_attr):
setattr(
self,
weight_transpose_attr,
torch.empty(
shape[1],
shape[0],
device=torch.cuda.current_device(),
dtype=torch.int8,
),
)
else:
setattr(self, weight_cast_attr, torch.Tensor())
setattr(self, weight_transpose_attr, torch.Tensor())
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
"""Set TP group."""
self.tp_group = tp_group
self.tp_group_initialized = True
def fp8_init(self, num_gemms: int = 1) -> None:
"""Initialize fp8 related metadata and tensors during fprop."""
# If fp8 isn't enabled, turn off and return.
if not is_fp8_enabled():
self.fp8 = False
return
# FP8 is already enabled and recipe is the same, don't do anything.
if self.fp8 and get_fp8_recipe() == self.fp8_meta["recipe"]:
return
# Set FP8, recipe, and other FP8 metadata
self.fp8 = True
self.fp8_meta["recipe"] = get_fp8_recipe()
self.fp8_meta["num_gemms"] = num_gemms
self.fp8_meta["fp8_group"] = get_fp8_group()
# Set FP8_MAX per tensor according to recipe
self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd
self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd
# Allocate scales and amaxes
self.init_fp8_meta_tensors()
def pre_forward(self, inp: torch.Tensor, num_gemms: int = 1) -> None:
"""Checks and prep for FWD."""
assert inp.is_cuda, "TransformerEngine needs CUDA."
if self.tp_size > 1:
assert self.tp_group_initialized, "TP group not initialized."
self.set_activation_dtype(inp)
self.fp8_init(num_gemms=num_gemms)
self.set_fp8_weights()
if self.fp8_meta.get("update_amax_and_scale_fwd", False):
# Previous iteration was grad_enabled
copy_amax_from_global_buffer(self.fp8_meta, forward=True)
amax_and_scale_update(self.fp8_meta, True)
set_amax_buffer_key_deletion(self.fp8_meta, forward=True)
if self.fp8 and torch.is_grad_enabled() and self.training:
self.fp8_meta["first_module"] = is_first_fp8_module()
if self.fp8_meta["first_module"]:
self.fp8_meta["autocast_id_fwd"] = new_fp8_context_id()
set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
else:
self.fp8_meta["autocast_id_fwd"] = get_fp8_context_id()
add_amax_to_global_buffer(self.fp8_meta, forward=True)
self.fp8_meta["update_amax_and_scale_fwd"] = True
else:
self.fp8_meta["update_amax_and_scale_fwd"] = False
def post_forward(self) -> None:
"""This is needed because there isn't a way for a module to know
if it's the last FP8 module in the forward autocast. It is useful
to setup the forward aggregated amax reduction for every module
just in case. The autocast exit will pick up the most recent.
"""
if self.fp8 and torch.is_grad_enabled() and self.training:
set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
reduce_func = partial(
global_amax_reduction,
self.fp8_meta,
self.sequence_parallel,
self.tp_group,
forward=True,
)
setup_amax_forward_global_reduce_func(reduce_func)
@staticmethod
def pre_backward(fp8: bool, fp8_meta: Dict[str, Any]) -> None:
"""Checks and prep for BWD."""
if not fp8:
return
# From previous iteration
copy_amax_from_global_buffer(fp8_meta, forward=False)
amax_and_scale_update(fp8_meta, False)
set_amax_buffer_key_deletion(fp8_meta, forward=False)
# Get new backward key.
if "autocast_id_bwd" not in fp8_meta:
fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd"]
else:
fp8_meta["autocast_id_bwd"] += 1
add_amax_to_global_buffer(fp8_meta, forward=False)
@staticmethod
def post_backward(
fp8: bool,
fp8_meta: Dict[str, Any],
reduce_amax_across_tp_group: bool,
tp_group: Union[dist_group_type, None],
) -> None:
"""Checks and prep for BWD."""
if not fp8:
return
if fp8_meta["first_module"]:
global_amax_reduction(
fp8_meta, reduce_amax_across_tp_group, tp_group, forward=False
)
delete_key_from_amax_buffer(forward=False)
def set_nccl_overlap_warning_if_tp(self) -> None:
"""When using TP, the NCCL communication needs to be scheduled
before the GEMM for there to be a guaranteed overlap. From the
host side in TE, the comm calls are always launched first, but
to ensure that the GEMM isn't scheduled first, the environment
variable `CUDA_DEVICE_MAX_CONNECTIONS` needs to be set to 1 to
force a single channel.
"""
if self.tp_size == 1:
return
num_cuda_work_queues = int(os.getenv("CUDA_DEVICE_MAX_CONNECTIONS", "0"))
if num_cuda_work_queues != 1:
warnings.warn(
"To guarantee overlapping TP and SP collectives with the backward"
"GEMMs, set environment variable CUDA_DEVICE_MAX_CONNECTIONS = 1"
)
@staticmethod
def grad_output_preprocess(
ctx, grad_output: torch.Tensor, row_parallel_mode: bool
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Utility function for backward.
Returns tuple in order (all optional/None based on training precion/recipe):
R1: gathered `grad_output` in higher precision.
R2: gathered `grad_output` in FP8.
R3: R2 transposed.
R4: bias gradient on R1.
"""
grad_output = grad_output.contiguous()
grad_output_mat = grad_output.view((-1, grad_output.shape[-1]))
gather_grad_output = row_parallel_mode and ctx.sequence_parallel
# No-FP8 case: bgrad is fused with wgrad for this case.
if not ctx.fp8:
if gather_grad_output:
grad_output_mat, _ = gather_along_first_dim(
grad_output_mat, ctx.tp_group
)
return grad_output_mat, None, None, None
fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False
)
# FP8 case with non-FP8 wgrad
if (
gather_grad_output
and ctx.fp8_meta["recipe"].override_linear_precision.wgrad
):
grad_output_mat, _ = gather_along_first_dim(grad_output_mat, ctx.tp_group)
# FP8 case with gather: unfused bgrad, cast, transpose for efficient gather
elif gather_grad_output:
if ctx.use_bias:
grad_bias = grad_output_mat.sum(dim=0)
else:
grad_bias = None
grad_output_c = cast_to_fp8(
grad_output_mat,
ctx.fp8_meta["scaling_bwd"],
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
)
grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group)
grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
return grad_output_mat, grad_output_c, grad_output_t, grad_bias
# FP8 case without gather: cast, transpose, bgrad fused
if ctx.use_bias:
grad_bias, grad_output_c, grad_output_t = fp8_cast_transpose_bgrad_fused(
grad_output_mat,
ctx.fp8_meta["scaling_bwd"],
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
)
else:
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
grad_output_c, grad_output_t = fp8_cast_transpose_fused(
grad_output_mat,
ctx.fp8_meta["scaling_bwd"],
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
)
else:
grad_output_c = cast_to_fp8(
grad_output_mat,
ctx.fp8_meta["scaling_bwd"],
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
)
grad_output_t = None
grad_bias = None
return grad_output_mat, grad_output_c, grad_output_t, grad_bias
@abstractmethod
def forward(self):
"""Needs override."""
class _LayerNormLinear(torch.autograd.Function):
"""LayerNormLinear semi-top level module
Calls custom cuda extensions.
"""
@staticmethod
def forward(
ctx,
inp: torch.Tensor,
ln_weight: torch.Tensor,
ln_bias: torch.Tensor,
weight: torch.Tensor,
weight_fp8: torch.Tensor,
weight_t_fp8: torch.Tensor,
bias: torch.Tensor,
use_bias: bool,
eps: float,
is_first_microbatch: Union[bool, None],
fp8: bool,
fp8_meta: Dict[str, Any],
fuse_wgrad_accumulation: bool,
tp_group: Union[dist_group_type, None],
sequence_parallel: bool,
activation_dtype: torch.dtype,
parallel_mode: Union[str, None],
return_layernorm_output: bool,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
assert inp.shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features))
update_fp8_weights = is_first_microbatch is None or is_first_microbatch
# Cast for native AMP
inputmat = cast_if_needed(inputmat, activation_dtype)
ln_weight = cast_if_needed(ln_weight, activation_dtype)
ln_bias = cast_if_needed(ln_bias, activation_dtype)
# If residual connection is after LN, we need `ln_out`
# tensor in higher precision, this comes at the cost
# of an extra fp8 cast.
if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if not return_layernorm_output:
ln_out, mu, rsigma = layernorm_fwd_fp8(
inputmat,
ln_weight,
ln_bias,
eps,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
)
else:
ln_out_return, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps
)
ln_out = cast_to_fp8(
ln_out_return,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
)
else:
ln_out, mu, rsigma = tex.layernorm_fwd(inputmat, ln_weight, ln_bias, eps)
ln_out_return = ln_out
# Column Parallel Linear
if parallel_mode == "column" and sequence_parallel:
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
else:
ln_out_total = ln_out
if fp8:
bias_dtype = (
torch.bfloat16
if activation_dtype == torch.float32
else activation_dtype
)
bias = cast_if_needed(bias, bias_dtype) if use_bias else bias
if update_fp8_weights:
fp8_cast_transpose_fused(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
cast_out=weight_fp8,
transpose_out=weight_t_fp8,
)
out = fp8_gemm(
weight_fp8,
fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_WEIGHT],
fp8_dtype_forward,
ln_out_total,
fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_INPUT],
fp8_dtype_forward,
activation_dtype,
get_workspace(),
bias=bias,
use_bias=use_bias,
use_split_accumulator=_2X_ACC_FPROP,
)
else:
# Cast for native AMP
weight = cast_if_needed(weight, activation_dtype)
bias = cast_if_needed(bias, activation_dtype) if use_bias else bias
out, _, _ = gemm(
weight,
ln_out_total,
activation_dtype,
get_workspace(),
bias=bias,
use_bias=use_bias,
)
ctx.save_for_backward(
inputmat,
ln_weight,
mu,
rsigma,
weight,
weight_t_fp8,
ln_out,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
)
ctx.activation_dtype = activation_dtype
ctx.fp8 = fp8
ctx.fp8_meta = fp8_meta
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = use_bias
ctx.sequence_parallel = sequence_parallel
ctx.inp_shape = inp.shape
ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group
ctx.return_layernorm_output = return_layernorm_output
# Row Parallel Linear
if parallel_mode == "row" and sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif parallel_mode == "row":
out, _ = allreduce(out, tp_group)
# [*, in_features] -> [*, out_features] except first dimension changes for SP
out = out.view(-1, *inp.shape[1:-1], out.shape[-1])
if return_layernorm_output:
return out, ln_out_return.view_as(inp)
return out
@staticmethod
def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]:
TransformerEngineBaseModule.pre_backward(ctx.fp8, ctx.fp8_meta)
(
inputmat,
ln_weight,
mu,
rsigma,
weight,
weight_t_fp8,
ln_out,
fwd_scale_inverses,
) = ctx.saved_tensors
(
grad_output,
grad_output_c,
grad_output_t,
grad_bias,
) = TransformerEngineBaseModule.grad_output_preprocess(
ctx, grad_outputs[0], ctx.parallel_mode == "row"
)
# Column Parallel Linear
# Overlap input AG with dgrad
if ctx.parallel_mode == "column" and ctx.sequence_parallel:
ln_out_total, handle = gather_along_first_dim(
ln_out, ctx.tp_group, async_op=True
)
else:
ln_out_total = ln_out
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (
ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
if ctx.fp8:
fp8_dtype_forward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=True
)
fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False
)
# DGRAD
dgrad = fp8_gemm(
weight_t_fp8,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_WEIGHT],
fp8_dtype_forward,
grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1],
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
)
else:
# DGRAD
dgrad, _, _ = gemm(
weight,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NN",
grad=True,
)
# Overlap dgrad-RS/AR with wgrad
if ctx.parallel_mode == "column" and ctx.sequence_parallel:
handle.wait()
dgrad, handle = reduce_scatter_along_first_dim(
dgrad, ctx.tp_group, async_op=True
)
elif ctx.parallel_mode == "column":
dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)
if ctx.fp8:
# WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward)
wgrad = fp8_gemm(
ln_out_total_t,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_INPUT],
fp8_dtype_forward,
grad_output_t,
ctx.fp8_meta["scaling_bwd"].scale_inv[
tex.FP8BwdTensors.GRAD_OUTPUT1
],
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=_2X_ACC_WGRAD,
)
else:
ln_out_total_c = cast_from_fp8(
ln_out_total,
ctx.fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
TE_DType[ctx.activation_dtype],
)
wgrad, _, _ = gemm(
ln_out_total_c,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NT",
grad=True,
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
else:
# WGRAD
wgrad, grad_bias, _ = gemm(
ln_out_total,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NT",
grad=True,
use_bias=ctx.use_bias,
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
# Column Parallel Linear
if ctx.parallel_mode == "column" and handle is not None:
handle.wait()
# LayerNorm gradient
d_ln_out = dgrad.view(inputmat.shape)
# Residual gradient
if ctx.return_layernorm_output:
d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out)
dxmat, dgamma, dbeta = tex.layernorm_bwd(
d_ln_out, inputmat, mu, rsigma, ln_weight
)
if not ctx.use_bias:
grad_bias = None
TransformerEngineBaseModule.post_backward(
ctx.fp8, ctx.fp8_meta, ctx.sequence_parallel, ctx.tp_group
)
return (
dxmat.view(ctx.inp_shape),
dgamma,
dbeta,
wgrad,
None,
None,
grad_bias,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
class LayerNormLinear(TransformerEngineBaseModule):
"""
Applies layer normalization followed by linear transformation to the incoming data.
Parameters
----------
in_features : int
size of each input sample.
out_features : int
size of each output sample.
eps : float, default = 1e-5
a value added to the denominator of layer normalization for numerical stability.
bias : bool, default = `True`
if set to `False`, the layer will not learn an additive bias.
init_method : Callable, default = `None`
used for initializing weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
return_layernorm_output : bool, default = `False`
if set to `True`, output of layernorm is returned from the forward
together with the output of the linear transformation.
Example use case: residual connection for transformer module is
taken post layernorm.
Parallelism parameters
----------------------
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
tp_size : int, default = 1
used as TP (tensor parallel) world size when TP groups are not formed during
initialization. In this case, users must call the
`set_tensor_parallel_group(tp_group)` method on the initialized module before the
forward pass to supply the tensor parallel group needed for tensor and sequence
parallel collectives.
parallel_mode : {None, 'Column', 'Row'}, default = `None`
used to decide whether this Linear layer is Column Parallel Linear or Row
Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
When set to `None`, no communication is performed.
skip_weight_param_allocation: bool, default = `False`
if set to `True`, weight parameter is not allocated and must be
passed as a keyword argument `weight` during the forward pass.
Optimization parameters
-----------------------
fuse_wgrad_accumulation : bool, default = 'False'
if set to `True`, enables fusing of creation and accumulation of
the weight gradient.
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the
output of the linear transformation :math:`y = xA^T`. This is useful when
the bias addition can be fused to subsequent operations.
params_dtype : torch.dtype, default = `torch.float32`
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
"""
def __init__(
self,
in_features: int,
out_features: int,
eps: float = 1e-5,
sequence_parallel: bool = False,
fuse_wgrad_accumulation: bool = False,
tp_group: Optional[dist_group_type] = None,
tp_size: int = 1,
get_rng_state_tracker: Optional[Callable] = None,
init_method: Optional[Callable] = None,
bias: bool = True,
return_bias: bool = False,
params_dtype: torch.dtype = torch.float32,
parallel_mode: Optional[str] = None,
return_layernorm_output: bool = False,
skip_weight_param_allocation: bool = False,
) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
self.use_bias = bias
self.return_bias = return_bias
self.return_layernorm_output = return_layernorm_output
self.skip_weight_param_allocation = skip_weight_param_allocation
if tp_group is None:
self.tp_size = tp_size
if tp_size == 1:
self.set_tensor_parallel_group(tp_group)
else:
self.tp_size = get_distributed_world_size(tp_group)
self.set_tensor_parallel_group(tp_group)
self.set_nccl_overlap_warning_if_tp()
self.parallel_mode = parallel_mode
assert (
self.parallel_mode in GemmParallelModes
), f"parallel_mode {parallel_mode} not supported"
if self.parallel_mode == "column":
self.out_features = divide(self.out_features, self.tp_size)
elif self.parallel_mode == "row":
self.in_features = divide(self.in_features, self.tp_size)
if init_method is None:
init_method = get_default_init_method()
self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
self.eps = eps
self.layer_norm_weight = Parameter(
torch.empty(
in_features,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
self.layer_norm_bias = Parameter(
torch.empty(
in_features,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel)
setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel)
self.reset_layer_norm_parameters()
if not skip_weight_param_allocation:
self.weight = Parameter(
torch.empty(
self.out_features,
self.in_features,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
initialize_affine_weight_gpu(
self.weight,
init_method,
get_rng_state_tracker,
partition_dim=1 if self.parallel_mode == "row" else 0,
stride=1,
)
if self.use_bias or self.return_bias:
self.bias = Parameter(
torch.empty(
self.out_features,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
if self.parallel_mode == "column":
set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
else:
self.register_buffer("bias", torch.Tensor(), persistent=False)
with torch.no_grad():
self.bias.zero_()
self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))
# For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM
if self.parallel_mode == "row" and self.use_bias:
self.gemm_bias_unfused_add = True
self.use_bias = False
else:
self.gemm_bias_unfused_add = False
def reset_layer_norm_parameters(self) -> None:
"""Init LN params"""
init.ones_(self.layer_norm_weight)
init.zeros_(self.layer_norm_bias)
def forward(
self,
inp: torch.Tensor,
weight: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
is_first_microbatch: Optional[bool] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
"""
Apply layer normalization to the input followed by a linear transformation.
Parameters
----------
inp : torch.Tensor
Input tensor.
weight : torch.Tensor, default = None
An optional weight tensor for the module. This argument is compulsory if module
is initialized with `skip_weight_param_allocation=True`
bias : torch.Tensor, default = None
An optional bias tensor for the module. This argument is compulsory if module
is initialized with `skip_weight_param_allocation=True` and one of `use_bias`
or `return_bias`
is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split
into microbatches. Between the microbatches of the same minibatch
the model weights are not updated. Setting this parameter indicates
whether the current microbatch is the first in a minibatch or not.
When set, this parameter enables additional optimizations:
* during FP8 training, it allows caching of the FP8 versions of
the weights
* it also allows skipping gradient accumulation during the
first microbatch (since it is the first gradient being
produced)
"""
self.pre_forward(inp)
bias_tensor = bias if bias is not None else self.bias
out = _LayerNormLinear.apply(
inp,
self.layer_norm_weight,
self.layer_norm_bias,
weight if weight is not None else self.weight,
self.weight1_fp8,
self.weight1_t_fp8,
bias_tensor,
self.use_bias,
self.eps,
is_first_microbatch,
self.fp8,
self.fp8_meta,
self.fuse_wgrad_accumulation,
self.tp_group,
self.sequence_parallel,
self.activation_dtype,
self.parallel_mode,
self.return_layernorm_output,
)
self.post_forward()
if self.return_layernorm_output:
out, ln_out = out
if self.gemm_bias_unfused_add:
out = out + cast_if_needed(bias_tensor, self.activation_dtype)
if self.return_bias:
if self.return_layernorm_output:
return out, cast_if_needed(bias_tensor, self.activation_dtype), ln_out
return out, cast_if_needed(bias_tensor, self.activation_dtype)
if self.return_layernorm_output:
return out, ln_out
return out
class _Linear(torch.autograd.Function):
"""Linear semi-top level module
Calls custom cuda extensions.
"""
@staticmethod
def forward(
ctx,
weight: torch.Tensor,
weight_fp8: torch.Tensor,
weight_t_fp8: torch.Tensor,
inp: torch.Tensor,
bias: torch.Tensor,
use_bias: bool,
is_first_microbatch: Union[bool, None],
fp8: bool,
fp8_meta: Dict[str, Any],
fuse_wgrad_accumulation: bool,
tp_group: Union[dist_group_type, None],
sequence_parallel: bool,
activation_dtype: torch.dtype,
parallel_mode: Union[str, None],
) -> torch.Tensor:
# Make sure input dimensions are compatible
in_features = weight.shape[-1]
assert inp.shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features))
update_fp8_weights = is_first_microbatch is None or is_first_microbatch
# Cast for native AMP
inputmat = cast_if_needed(inputmat, activation_dtype)
inputmat_no_fp8 = inputmat
if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if not fp8_meta["recipe"].override_linear_precision.wgrad:
inputmat, inputmat_t = fp8_cast_transpose_fused(
inputmat,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
)
else:
inputmat = cast_to_fp8(
inputmat,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
)
# Column Parallel Linear
if parallel_mode == "column" and sequence_parallel:
inputmat_total, _ = gather_along_first_dim(inputmat, tp_group)
else:
inputmat_total = inputmat
if fp8:
bias_dtype = (
torch.bfloat16
if activation_dtype == torch.float32
else activation_dtype
)
bias = cast_if_needed(bias, bias_dtype) if use_bias else bias
if update_fp8_weights:
fp8_cast_transpose_fused(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
cast_out=weight_fp8,
transpose_out=weight_t_fp8,
)
out = fp8_gemm(
weight_fp8,
fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_WEIGHT],
fp8_dtype_forward,
inputmat,
fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_INPUT],
fp8_dtype_forward,
activation_dtype,
get_workspace(),
bias=bias,
use_bias=use_bias,
use_split_accumulator=_2X_ACC_FPROP,
)
else:
# Cast for native AMP
weight = cast_if_needed(weight, activation_dtype)
bias = cast_if_needed(bias, activation_dtype) if use_bias else bias
out, _, _ = gemm(
weight,
inputmat_total,
activation_dtype,
get_workspace(),
bias=bias,
use_bias=use_bias,
)
ctx.save_for_backward(
inputmat_no_fp8
if not fp8 or fp8_meta["recipe"].override_linear_precision.wgrad
else None,
inputmat_t
if fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad
else None,
weight,
weight_t_fp8,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
)
ctx.activation_dtype = activation_dtype
ctx.fp8 = fp8
ctx.fp8_meta = fp8_meta
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = use_bias
ctx.sequence_parallel = sequence_parallel
ctx.inp_shape = inp.shape
ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group
# Row Parallel Linear
if parallel_mode == "row" and sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif parallel_mode == "row":
out, _ = allreduce(out, tp_group)
# [*, in_features] -> [*, out_features] except first dimension changes for SP
return out.view(-1, *inp.shape[1:-1], out.shape[-1])
@staticmethod
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
TransformerEngineBaseModule.pre_backward(ctx.fp8, ctx.fp8_meta)
(
inputmat,
inputmat_t,
weight,
weight_t_fp8,
fwd_scale_inverses,
) = ctx.saved_tensors
(
grad_output,
grad_output_c,
grad_output_t,
grad_bias,
) = TransformerEngineBaseModule.grad_output_preprocess(
ctx, grad_output, ctx.parallel_mode == "row"
)
# Column Parallel Linear
# Overlap input AG with dgrad
if ctx.parallel_mode == "column" and ctx.sequence_parallel:
if ctx.fp8 and not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
inputmat_t_total, handle = gather_along_last_dim(
inputmat_t, ctx.tp_group, async_op=True
)
else:
inputmat_total, handle = gather_along_first_dim(
inputmat, ctx.tp_group, async_op=True
)
else:
inputmat_t_total = inputmat_t
inputmat_total = inputmat
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (
ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
if ctx.fp8:
fp8_dtype_forward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=True
)
fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False
)
# DGRAD
dgrad = fp8_gemm(
weight_t_fp8,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_WEIGHT],
fp8_dtype_forward,
grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1],
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
)
else:
# DGRAD
dgrad, _, _ = gemm(
weight,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NN",
grad=True,
)
# Overlap dgrad-RS/AR with wgrad
if ctx.parallel_mode == "column" and ctx.sequence_parallel:
handle.wait()
dgrad, handle = reduce_scatter_along_first_dim(
dgrad, ctx.tp_group, async_op=True
)
elif ctx.parallel_mode == "column":
dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)
if ctx.fp8:
# WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
wgrad = fp8_gemm(
inputmat_t_total,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_INPUT],
fp8_dtype_forward,
grad_output_t,
ctx.fp8_meta["scaling_bwd"].scale_inv[
tex.FP8BwdTensors.GRAD_OUTPUT1
],
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=_2X_ACC_WGRAD,
)
else:
wgrad, _, _ = gemm(
inputmat_total,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NT",
grad=True,
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
else:
# WGRAD
wgrad, grad_bias, _ = gemm(
inputmat_total,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NT",
grad=True,
use_bias=ctx.use_bias,
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
# Column Parallel Linear
if ctx.parallel_mode == "column" and handle is not None:
handle.wait()
if not ctx.use_bias:
grad_bias = None
TransformerEngineBaseModule.post_backward(
ctx.fp8, ctx.fp8_meta, ctx.sequence_parallel, ctx.tp_group
)
return (
wgrad,
None,
None,
dgrad.view(ctx.inp_shape),
grad_bias,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
class Linear(TransformerEngineBaseModule):
"""
Applies a linear transformation to the incoming data :math:`y = xA^T + b`
On NVIDIA GPUs it is a drop-in replacement for `torch.nn.Linear`.
Parameters
----------
in_features : int
size of each input sample.
out_features : int
size of each output sample.
bias : bool, default = `True`
if set to `False`, the layer will not learn an additive bias.
init_method : Callable, default = `None`
used for initializing weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
Parallelism parameters
----------------------
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
tp_size : int, default = 1
used as TP (tensor parallel) world size when TP groups are not formed during
initialization. In this case, users must call the
`set_tensor_parallel_group(tp_group)` method on the initialized module before the
forward pass to supply the tensor parallel group needed for tensor and sequence
parallel collectives.
parallel_mode : {None, 'Column', 'Row'}, default = `None`
used to decide whether this Linear layer is Column Parallel Linear or Row
Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
When set to `None`, no communication is performed.
skip_weight_param_allocation: bool, default = `False`
if set to `True`, weight parameter is not allocated and must be
passed as a keyword argument `weight` during the forward pass.
Optimization parameters
-----------------------
fuse_wgrad_accumulation : bool, default = 'False'
if set to `True`, enables fusing of creation and accumulation of
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the
output of the linear transformation :math:`y = xA^T`. This is useful when
the bias addition can be fused to subsequent operations.
params_dtype : torch.dtype, default = `torch.float32`
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
"""
def __init__(
self,
in_features: int,
out_features: int,
sequence_parallel: bool = False,
fuse_wgrad_accumulation: bool = False,
tp_group: Optional[dist_group_type] = None,
tp_size: int = 1,
get_rng_state_tracker: Optional[Callable] = None,
init_method: Optional[Callable] = None,
bias: bool = True,
return_bias: bool = False,
params_dtype: torch.dtype = torch.float32,
parallel_mode: Optional[str] = None,
skip_weight_param_allocation: bool = False,
) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
self.use_bias = bias
self.return_bias = return_bias
self.skip_weight_param_allocation = skip_weight_param_allocation
if tp_group is None:
self.tp_size = tp_size
if tp_size == 1:
self.set_tensor_parallel_group(tp_group)
else:
self.tp_size = get_distributed_world_size(tp_group)
self.set_tensor_parallel_group(tp_group)
self.set_nccl_overlap_warning_if_tp()
self.parallel_mode = parallel_mode
assert (
self.parallel_mode in GemmParallelModes
), f"parallel_mode {parallel_mode} not supported"
if self.parallel_mode == "column":
self.out_features = divide(self.out_features, self.tp_size)
elif self.parallel_mode == "row":
self.in_features = divide(self.in_features, self.tp_size)
if init_method is None:
init_method = get_default_init_method()
self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
if not skip_weight_param_allocation:
self.weight = Parameter(
torch.empty(
self.out_features,
self.in_features,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
initialize_affine_weight_gpu(
self.weight,
init_method,
get_rng_state_tracker,
partition_dim=1 if self.parallel_mode == "row" else 0,
stride=1,
)
if self.use_bias or self.return_bias:
self.bias = Parameter(
torch.empty(
self.out_features,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
if self.parallel_mode == "column":
set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
else:
self.register_buffer("bias", torch.Tensor(), persistent=False)
with torch.no_grad():
self.bias.zero_()
self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))
# For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM
if self.parallel_mode == "row" and self.use_bias:
self.gemm_bias_unfused_add = True
self.use_bias = False
else:
self.gemm_bias_unfused_add = False
def forward(
self,
inp: torch.Tensor,
weight: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
is_first_microbatch: Optional[bool] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
"""
Apply the linear transformation to the input.
Parameters
----------
inp : torch.Tensor
Input tensor.
weight : torch.Tensor, default = None
An optional weight tensor for the module. This argument is compulsory if module
is initialized with `skip_weight_param_allocation=True`
bias : torch.Tensor, default = None
An optional bias tensor for the module. This argument is compulsory if module
is initialized with `skip_weight_param_allocation=True` and one of `use_bias`
or `return_bias`
is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split
into microbatches. Between the microbatches of the same minibatch
the model weights are not updated. Setting this parameter indicates
whether the current microbatch is the first in a minibatch or not.
When set, this parameter enables additional optimizations:
* during FP8 training, it allows caching of the FP8 versions of
the weights
* it also allows skipping gradient accumulation during the
first microbatch (since it is the first gradient being
produced)
"""
self.pre_forward(inp)
bias_tensor = bias if bias is not None else self.bias
out = _Linear.apply(
weight if weight is not None else self.weight,
self.weight1_fp8,
self.weight1_t_fp8,
inp,
bias_tensor,
self.use_bias,
is_first_microbatch,
self.fp8,
self.fp8_meta,
self.fuse_wgrad_accumulation,
self.tp_group,
self.sequence_parallel,
self.activation_dtype,
self.parallel_mode,
)
self.post_forward()
if self.gemm_bias_unfused_add:
out = out + cast_if_needed(bias_tensor, self.activation_dtype)
if self.return_bias:
return out, cast_if_needed(bias_tensor, self.activation_dtype)
return out
class _LayerNormMLP(torch.autograd.Function):
"""LayerNormMLP semi-top level module
Calls custom cuda extensions.
"""
@staticmethod
def forward(
ctx,
inp: torch.Tensor,
ln_weight: torch.Tensor,
ln_bias: torch.Tensor,
fc1_weight: torch.Tensor,
fc1_weight_fp8: torch.Tensor,
fc1_weight_t_fp8: torch.Tensor,
fc1_bias: torch.Tensor,
fc2_weight: torch.Tensor,
fc2_weight_fp8: torch.Tensor,
fc2_weight_t_fp8: torch.Tensor,
fc2_bias: torch.Tensor,
use_bias: bool,
eps: float,
is_first_microbatch: Union[bool, None],
fp8: bool,
fp8_meta: Dict[str, Any],
fuse_wgrad_accumulation: bool,
tp_group: Union[dist_group_type, None],
sequence_parallel: bool,
activation_dtype: torch.dtype,
return_layernorm_output: bool,
bias_gelu_nvfusion: bool,
set_parallel_mode: bool,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
assert inp.shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features))
update_fp8_weights = is_first_microbatch is None or is_first_microbatch
# Cast for native AMP
inputmat = cast_if_needed(inputmat, activation_dtype)
ln_weight = cast_if_needed(ln_weight, activation_dtype)
ln_bias = cast_if_needed(ln_bias, activation_dtype)
# If residual connection is after LN, we need `ln_out`
# tensor in higher precision, this comes at the cost
# of an extra fp8 cast.
if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if not return_layernorm_output:
ln_out, mu, rsigma = layernorm_fwd_fp8(
inputmat,
ln_weight,
ln_bias,
eps,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
)
else:
ln_out_return, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps
)
ln_out = cast_to_fp8(
ln_out_return,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
)
else:
ln_out, mu, rsigma = tex.layernorm_fwd(inputmat, ln_weight, ln_bias, eps)
ln_out_return = ln_out
# Column Parallel Linear
if set_parallel_mode and sequence_parallel:
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
else:
ln_out_total = ln_out
if fp8:
bias_dtype = (
torch.bfloat16
if activation_dtype == torch.float32
else activation_dtype
)
fc1_bias = cast_if_needed(fc1_bias, bias_dtype)
fc2_bias = cast_if_needed(fc2_bias, bias_dtype) if use_bias else fc2_bias
if update_fp8_weights:
fp8_cast_transpose_fused(
fc1_weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
cast_out=fc1_weight_fp8,
transpose_out=fc1_weight_t_fp8,
)
fp8_cast_transpose_fused(
fc2_weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM2_WEIGHT,
fp8_dtype_forward,
cast_out=fc2_weight_fp8,
transpose_out=fc2_weight_t_fp8,
)
fc1_out = fp8_gemm(
fc1_weight_fp8,
fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_WEIGHT],
fp8_dtype_forward,
ln_out_total,
fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_INPUT],
fp8_dtype_forward,
activation_dtype,
get_workspace(),
bias=fc1_bias,
use_bias=True,
use_split_accumulator=_2X_ACC_FPROP,
)
gelu_out = fp8_gelu(
fc1_out,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM2_INPUT,
fp8_dtype_forward,
)
fc2_out = fp8_gemm(
fc2_weight_fp8,
fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM2_WEIGHT],
fp8_dtype_forward,
gelu_out,
fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM2_INPUT],
fp8_dtype_forward,
activation_dtype,
get_workspace(),
bias=fc2_bias,
use_bias=use_bias,
use_split_accumulator=_2X_ACC_FPROP,
)
else:
# Cast for native AMP
fc1_weight = cast_if_needed(fc1_weight, activation_dtype)
fc2_weight = cast_if_needed(fc2_weight, activation_dtype)
fc1_bias = cast_if_needed(fc1_bias, activation_dtype)
fc2_bias = (
cast_if_needed(fc2_bias, activation_dtype) if use_bias else fc2_bias
)
fc1_outputs = gemm(
fc1_weight,
ln_out_total,
activation_dtype,
get_workspace(),
bias=fc1_bias,
use_bias=not bias_gelu_nvfusion,
gelu=not bias_gelu_nvfusion,
)
if bias_gelu_nvfusion:
fc1_out, _, _ = fc1_outputs
gelu_out = bias_gelu_fused(fc1_out, fc1_bias)
else:
gelu_out, _, fc1_out = fc1_outputs
fc2_out, _, _ = gemm(
fc2_weight,
gelu_out,
activation_dtype,
get_workspace(),
bias=fc2_bias,
use_bias=use_bias,
)
ctx.save_for_backward(
inputmat,
ln_weight,
mu,
rsigma,
ln_out,
fc1_out,
gelu_out,
fc1_weight,
fc1_weight_t_fp8,
fc2_weight,
fc2_weight_t_fp8,
fc1_bias,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
)
ctx.activation_dtype = activation_dtype
ctx.fp8 = fp8
ctx.fp8_meta = fp8_meta
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = use_bias
ctx.sequence_parallel = sequence_parallel
ctx.inp_shape = inp.shape
ctx.tp_group = tp_group
ctx.bias_gelu_nvfusion = bias_gelu_nvfusion
ctx.return_layernorm_output = return_layernorm_output
ctx.set_parallel_mode = set_parallel_mode
# Row Parallel Linear
if set_parallel_mode and sequence_parallel:
fc2_out, _ = reduce_scatter_along_first_dim(fc2_out, tp_group)
elif set_parallel_mode:
fc2_out, _ = allreduce(fc2_out, tp_group)
# [*, in_features] -> [*, out_features] except first dimension changes for SP
fc2_out = fc2_out.view(-1, *inp.shape[1:-1], fc2_out.shape[-1])
if return_layernorm_output:
return fc2_out, ln_out_return.view_as(inp)
return fc2_out
@staticmethod
def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]:
TransformerEngineBaseModule.pre_backward(ctx.fp8, ctx.fp8_meta)
(
inputmat,
ln_weight,
mu,
rsigma,
ln_out,
fc1_out,
gelu_out,
fc1_weight,
fc1_weight_t_fp8,
fc2_weight,
fc2_weight_t_fp8,
fc1_bias,
fwd_scale_inverses,
) = ctx.saved_tensors
(
grad_output,
grad_output_c,
grad_output_t,
fc2_bias_grad,
) = TransformerEngineBaseModule.grad_output_preprocess(
ctx, grad_outputs[0], True
)
# Column Parallel Linear
# Overlap input AG with dgrad
if ctx.set_parallel_mode and ctx.sequence_parallel:
ln_out_total, handle = gather_along_first_dim(
ln_out, ctx.tp_group, async_op=True
)
else:
ln_out_total = ln_out
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (
ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
if ctx.fp8:
fp8_dtype_forward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=True
)
fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False
)
# FC2 DGRAD
fc2_dgrad = fp8_gemm(
fc2_weight_t_fp8,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM2_WEIGHT],
fp8_dtype_forward,
grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1],
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
)
# FC2 WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
gelu_out_t = tex.fp8_transpose(gelu_out, fp8_dtype_forward)
fc2_wgrad = fp8_gemm(
gelu_out_t,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM2_INPUT],
fp8_dtype_forward,
grad_output_t,
ctx.fp8_meta["scaling_bwd"].scale_inv[
tex.FP8BwdTensors.GRAD_OUTPUT1
],
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=_2X_ACC_WGRAD,
)
fc1_bias_grad, dgelu, dgelu_t = fp8_cast_transpose_bgrad_dgelu_fused(
fc2_dgrad,
fc1_out,
ctx.fp8_meta["scaling_bwd"],
tex.FP8BwdTensors.GRAD_OUTPUT2,
fp8_dtype_backward,
)
else:
gelu_out_c = cast_from_fp8(
gelu_out,
ctx.fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM2_INPUT,
fp8_dtype_forward,
TE_DType[ctx.activation_dtype],
)
fc2_wgrad, _, _ = gemm(
gelu_out_c,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NT",
grad=True,
use_bias=ctx.use_bias,
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
fc1_bias_grad, dgelu_no_fp8 = bgrad_dgelu_fused(
fc2_dgrad, fc1_out, fc1_bias
)
dgelu = cast_to_fp8(
dgelu_no_fp8,
ctx.fp8_meta["scaling_bwd"],
tex.FP8BwdTensors.GRAD_OUTPUT2,
fp8_dtype_backward,
)
dgelu_t = None
# FC1 DGRAD
fc1_dgrad = fp8_gemm(
fc1_weight_t_fp8,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_WEIGHT],
fp8_dtype_forward,
dgelu,
ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT2],
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
)
else:
# FC2 DGRAD
fc2_dgrad, _, _ = gemm(
fc2_weight,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NN",
gelu=not ctx.bias_gelu_nvfusion,
grad=True,
gelu_input=fc1_out,
)
# FC2 WGRAD
fc2_wgrad, fc2_bias_grad, _ = gemm(
gelu_out,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NT",
grad=True,
use_bias=ctx.use_bias,
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
if ctx.bias_gelu_nvfusion:
fc1_bias_grad, dgelu = bgrad_dgelu_fused(fc2_dgrad, fc1_out, fc1_bias)
else:
dgelu = fc2_dgrad
# FC1 DGRAD
fc1_dgrad, _, _ = gemm(
fc1_weight,
dgelu,
ctx.activation_dtype,
get_workspace(),
layout="NN",
grad=True,
)
# Overlap dgrad-RS/AR with wgrad
if ctx.set_parallel_mode and ctx.sequence_parallel:
handle.wait()
fc1_dgrad, handle = reduce_scatter_along_first_dim(
fc1_dgrad, ctx.tp_group, async_op=True
)
elif ctx.set_parallel_mode:
fc1_dgrad, handle = allreduce(fc1_dgrad, ctx.tp_group, async_op=True)
if ctx.fp8:
# FC1 WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward)
fc1_wgrad = fp8_gemm(
ln_out_total_t,
fwd_scale_inverses[tex.FP8FwdTensors.GEMM1_INPUT],
fp8_dtype_forward,
dgelu_t,
ctx.fp8_meta["scaling_bwd"].scale_inv[
tex.FP8BwdTensors.GRAD_OUTPUT2
],
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=_2X_ACC_WGRAD,
)
else:
ln_out_total_c = cast_from_fp8(
ln_out_total,
ctx.fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
TE_DType[ctx.activation_dtype],
)
fc1_wgrad, _, _ = gemm(
ln_out_total_c,
dgelu_no_fp8,
ctx.activation_dtype,
get_workspace(),
layout="NT",
grad=True,
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
else:
# FC1 WGRAD
fc1_wgrad_outputs = gemm(
ln_out_total,
dgelu,
ctx.activation_dtype,
get_workspace(),
layout="NT",
grad=True,
use_bias=not ctx.bias_gelu_nvfusion,
accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
if ctx.bias_gelu_nvfusion:
fc1_wgrad, _, _ = fc1_wgrad_outputs
else:
fc1_wgrad, fc1_bias_grad, _ = fc1_wgrad_outputs
# Column Parallel Linear
if ctx.set_parallel_mode and handle is not None:
handle.wait()
# LayerNorm gradient
d_ln_out = fc1_dgrad.view(inputmat.shape)
# Residual gradient
if ctx.return_layernorm_output:
d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out)
dxmat, dgamma, dbeta = tex.layernorm_bwd(
d_ln_out, inputmat, mu, rsigma, ln_weight
)
if not ctx.use_bias:
fc2_bias_grad = None
TransformerEngineBaseModule.post_backward(
ctx.fp8, ctx.fp8_meta, ctx.sequence_parallel, ctx.tp_group
)
return (
dxmat.view(ctx.inp_shape),
dgamma,
dbeta,
fc1_wgrad,
None,
None,
fc1_bias_grad,
fc2_wgrad,
None,
None,
fc2_bias_grad,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
class LayerNormMLP(TransformerEngineBaseModule):
"""
Applies layer normalization on the input followed by the MLP module, consisting of
2 successive linear transformations, separated by the GeLU activation.
Parameters
----------
hidden_size : int
size of each input sample.
ffn_hidden_size : int
intermediate size to which input samples are projected.
eps : float, default = 1e-5
a value added to the denominator of layer normalization for numerical stability.
bias : bool, default = `True`
if set to `False`, the FC2 layer will not learn an additive bias.
init_method : Callable, default = `None`
used for initializing FC1 weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
output_layer_init_method : Callable, default = `None`
used for initializing FC2 weights in the following way:
`output_layer_init_method(weight)`. When set to `None`, defaults to
`torch.nn.init.normal_(mean=0.0, std=0.023)`.
return_layernorm_output : bool, default = `False`
if set to `True`, output of layernorm is returned from the forward
together with the output of the linear transformation.
Example use case: residual connection for transformer module
is taken post layernorm.
Parallelism parameters
----------------------
set_parallel_mode : bool, default = `False`
if set to `True`, FC1 is used as Column Parallel and FC2 is used as Row
Parallel as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
tp_size : int, default = 1
used as TP (tensor parallel) world size when TP groups are not formed during
initialization. In this case, users must call the
`set_tensor_parallel_group(tp_group)` method on the initialized module before the
forward pass to supply the tensor parallel group needed for tensor and sequence
parallel collectives.
Optimization parameters
-----------------------
fuse_wgrad_accumulation : bool, default = 'False'
if set to `True`, enables fusing of creation and accumulation of
the weight gradient.
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the
output of the linear transformation :math:`y = xA^T`. This is useful when
the bias addition can be fused to subsequent operations.
params_dtype : torch.dtype, default = `torch.float32`
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
seq_length: int
sequence length of input samples. Needed for JIT Warmup, a technique where jit fused
functions are warmed up before training to ensure same kernels are used for forward
propogation and activation recompute phase.
micro_batch_size: int
batch size per training step. Needed for JIT Warmup, a technique where jit
fused functions are warmed up before training to ensure same kernels are
used for forward propogation and activation recompute phase.
"""
def __init__(
self,
hidden_size: int,
ffn_hidden_size: int,
eps: float = 1e-5,
sequence_parallel: bool = False,
return_bias: bool = False,
get_rng_state_tracker: Optional[Callable] = None,
tp_group: Optional[dist_group_type] = None,
tp_size: int = 1,
init_method: Optional[Callable] = None,
bias: bool = True,
output_layer_init_method: Optional[Callable] = None,
fuse_wgrad_accumulation: bool = False,
params_dtype: torch.dtype = torch.float32,
return_layernorm_output: bool = False,
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
set_parallel_mode: bool = False,
) -> None:
super().__init__()
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
self.use_bias = bias
self.return_bias = return_bias
self.return_layernorm_output = return_layernorm_output
self.bias_gelu_nvfusion = bool(int(os.getenv("NVTE_BIAS_GELU_NVFUSION", "1")))
self.set_parallel_mode = set_parallel_mode
if tp_group is None:
self.tp_size = tp_size
if tp_size == 1:
self.set_tensor_parallel_group(tp_group)
else:
self.tp_size = get_distributed_world_size(tp_group)
self.set_tensor_parallel_group(tp_group)
self.set_nccl_overlap_warning_if_tp()
if init_method is None:
init_method = get_default_init_method()
if output_layer_init_method is None:
output_layer_init_method = get_default_init_method()
self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
self.size_per_partition = divide(ffn_hidden_size, self.tp_size)
# LN init
self.eps = eps
self.layer_norm_weight = Parameter(
torch.empty(
hidden_size,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
self.layer_norm_bias = Parameter(
torch.empty(
hidden_size,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel)
setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel)
self.reset_layer_norm_parameters()
# FC1 init
self.fc1_weight = Parameter(
torch.empty(
self.size_per_partition,
hidden_size,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
self.fp8_weight_shapes.append(self.fc1_weight.shape)
initialize_affine_weight_gpu(
self.fc1_weight,
init_method,
get_rng_state_tracker,
partition_dim=0,
stride=1,
)
self.fc1_bias = Parameter(
torch.empty(
self.size_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
set_tensor_model_parallel_attributes(self.fc1_bias, True, 0, 1)
with torch.no_grad():
self.fc1_bias.zero_()
# FC2 init
self.fc2_weight = Parameter(
torch.empty(
hidden_size,
self.size_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
self.fp8_weight_shapes.append(self.fc2_weight.shape)
initialize_affine_weight_gpu(
self.fc2_weight,
output_layer_init_method,
get_rng_state_tracker,
partition_dim=1,
stride=1,
)
if self.use_bias or self.return_bias:
self.fc2_bias = Parameter(
torch.empty(
hidden_size, device=torch.cuda.current_device(), dtype=params_dtype
)
)
else:
self.register_buffer("fc2_bias", torch.Tensor(), persistent=False)
# For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM
if self.set_parallel_mode and self.use_bias:
self.gemm_bias_unfused_add = True
self.use_bias = False
else:
self.gemm_bias_unfused_add = False
with torch.no_grad():
self.fc2_bias.zero_()
if self.bias_gelu_nvfusion:
set_jit_fusion_options()
if seq_length and micro_batch_size:
warmup_jit_bias_gelu_all_dtypes(
self.size_per_partition, seq_length, micro_batch_size
)
def reset_layer_norm_parameters(self) -> None:
"""Init LN params"""
init.ones_(self.layer_norm_weight)
init.zeros_(self.layer_norm_bias)
def forward(
self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
"""
Apply layer normalization to the input followed by a feedforward network (MLP Block).
Parameters
----------
inp : torch.Tensor
Input tensor.
is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split
into microbatches. Between the microbatches of the same minibatch
the model weights are not updated. Setting this parameter indicates
whether the current microbatch is the first in a minibatch or not.
When set, this parameter enables additional optimizations:
* during FP8 training, it allows caching of the FP8 versions of
the weights
* it also allows skipping gradient accumulation during the
first microbatch (since it is the first gradient being
produced)
"""
self.pre_forward(inp, num_gemms=2)
out = _LayerNormMLP.apply(
inp,
self.layer_norm_weight,
self.layer_norm_bias,
self.fc1_weight,
self.weight1_fp8,
self.weight1_t_fp8,
self.fc1_bias,
self.fc2_weight,
self.weight2_fp8,
self.weight2_t_fp8,
self.fc2_bias,
False, # use_bias set to False for RPL
self.eps,
is_first_microbatch,
self.fp8,
self.fp8_meta,
self.fuse_wgrad_accumulation,
self.tp_group,
self.sequence_parallel,
self.activation_dtype,
self.return_layernorm_output,
self.bias_gelu_nvfusion,
self.set_parallel_mode,
)
self.post_forward()
if self.return_layernorm_output:
out, ln_out = out
if self.gemm_bias_unfused_add:
out = out + cast_if_needed(self.fc2_bias, self.activation_dtype)
if self.return_bias:
if self.return_layernorm_output:
return out, cast_if_needed(self.fc2_bias, self.activation_dtype), ln_out
return out, cast_if_needed(self.fc2_bias, self.activation_dtype)
if self.return_layernorm_output:
return out, ln_out
return out
class _LayerNorm(torch.autograd.Function):
"""functional LayerNorm"""
@staticmethod
def forward(
ctx,
inp: torch.Tensor,
ln_weight: torch.Tensor,
ln_bias: torch.Tensor,
eps: float,
) -> torch.Tensor:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
assert inp.is_cuda, "TransformerEngine needs CUDA."
assert inp.shape[-1] == in_features, "LayerNorm not possible"
inputmat = inp.view((-1, in_features))
ln_out, mu, rsigma = tex.layernorm_fwd(inputmat, ln_weight, ln_bias, eps)
ctx.save_for_backward(inputmat, ln_weight, mu, rsigma)
ctx.inp_shape = inp.shape
return ln_out.view_as(inp)
@staticmethod
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
inputmat, ln_weight, mu, rsigma = ctx.saved_tensors
grad_output = grad_output.contiguous()
d_ln_out = grad_output.view(inputmat.shape)
dxmat, dgamma, dbeta = tex.layernorm_bwd(
d_ln_out, inputmat, mu, rsigma, ln_weight
)
return dxmat.view(ctx.inp_shape), dgamma, dbeta, None
class LayerNorm(torch.nn.Module):
r"""
Applies Layer Normalization over a mini-batch of inputs as described in
the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
size :attr:`hidden_size`
Parameters
----------
hidden_size : int
size of each input sample.
eps : float, default = 1e-5
a value added to the denominator of layer normalization for numerical stability.
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
params_dtype : torch.dtype, default = `torch.float32`
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
"""
def __init__(
self,
hidden_size: int,
eps: float = 1e-5,
sequence_parallel: bool = False,
params_dtype: torch.dtype = torch.float32,
) -> None:
super().__init__()
self.eps = eps
self.layer_norm_weight = Parameter(
torch.empty(
hidden_size,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
self.layer_norm_bias = Parameter(
torch.empty(
hidden_size,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
setattr(self.layer_norm_weight, "sequence_parallel", sequence_parallel)
setattr(self.layer_norm_bias, "sequence_parallel", sequence_parallel)
self.reset_layer_norm_parameters()
def reset_layer_norm_parameters(self) -> None:
"""Init LN params"""
init.ones_(self.layer_norm_weight)
init.zeros_(self.layer_norm_bias)
def forward(self, inp: torch.Tensor) -> torch.Tensor:
"""LayerNorm FWD"""
return _LayerNorm.apply(
inp, self.layer_norm_weight, self.layer_norm_bias, self.eps
)
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Utilities for debugging numerical issues with FP8"""
from typing import Tuple
import torch
from transformer_engine.common import recipe
_NUMERICS_DEBUG = False
def debug(enabled: bool = True) -> None:
"""Set FP8 debug mode"""
global _NUMERICS_DEBUG
_NUMERICS_DEBUG = enabled
def fp8_tensor_statistics(
tensor: torch.Tensor, fp8_format: str = "E4M3"
) -> Tuple[int, ...]:
"""Print FP8 tensor stats"""
fp8_format = fp8_format.upper()
assert fp8_format in (
"E4M3",
"E5M2",
), "fp8_format must be 'E4M3' or 'E5M2' for amax"
fmt = recipe.Format[fp8_format]
FP8_MAX = fmt.value.max_fwd
num_overflows = (tensor == FP8_MAX).sum().item()
num_underflows = (tensor == 0).sum().item()
return (num_underflows, num_overflows)
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fused scaled masked softmax functions"""
import os
from typing import Callable, Tuple, Union
import torch
from torch import nn
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs: torch.Tensor, scale: float) -> torch.Tensor:
"""ScaledUpperTriangMaskedSoftmax fwd"""
import scaled_upper_triang_masked_softmax_cuda
scale_t = torch.tensor([scale])
softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(
inputs, scale_t[0]
)
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(
ctx, output_grads: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
"""ScaledUpperTriangMaskedSoftmax bwd"""
import scaled_upper_triang_masked_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_upper_triang_masked_softmax_cuda.backward(
output_grads, softmax_results, scale_t[0]
)
return input_grads, None
class ScaledMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply the mask.
3. Perform softmax.
"""
@staticmethod
def forward(
ctx, inputs: torch.Tensor, mask: torch.Tensor, scale: float
) -> torch.Tensor:
"""ScaledMaskedSoftmax fwd"""
import scaled_masked_softmax_cuda
scale_t = torch.tensor([scale])
softmax_results = scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(
ctx, output_grads: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
"""ScaledMaskedSoftmax bwd"""
import scaled_masked_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_masked_softmax_cuda.backward(
output_grads, softmax_results, scale_t[0]
)
return input_grads, None, None
class ScaledSoftmax(torch.autograd.Function):
"""
Fused operation which performs following two operations in sequence
1. Scale the tensor.
2. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs: torch.Tensor, scale: float) -> torch.Tensor:
"""ScaledSoftmax fwd"""
import scaled_softmax_cuda
scale_t = torch.tensor([scale])
softmax_results = scaled_softmax_cuda.forward(inputs, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(
ctx, output_grads: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
"""ScaledSoftmax bwd"""
import scaled_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_softmax_cuda.backward(
output_grads, softmax_results, scale_t[0]
)
return input_grads, None, None
class FusedScaleMaskSoftmax(nn.Module):
"""
fused operation: scaling + mask + softmax
Arguments:
attn_mask_type: attention mask type (pad or causal)
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
"""
def __init__(
self,
attn_mask_type: str,
mask_func: Callable,
softmax_in_fp32: bool,
scale: float,
) -> None:
super().__init__()
self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = bool(
int(os.getenv("NVTE_MASKED_SOFTMAX_FUSION", "1"))
)
self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale
assert (
self.scale is None or softmax_in_fp32
), "softmax should be in fp32 when scaled"
def forward(self, inp: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""FusedScaleMaskSoftmax fprop"""
# [b, np, sq, sk]
assert inp.dim() == 4
self.input_in_fp16 = inp.dtype == torch.float16
self.input_in_bf16 = inp.dtype == torch.bfloat16
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
if self.is_kernel_available(*inp.size()):
return self.forward_fused_softmax(inp, mask)
return self.forward_torch_softmax(inp, mask)
def is_kernel_available(self, b: int, np: int, sq: int, sk: int) -> bool:
"""Check FusedScaleMaskSoftmax kernel availability based on size"""
attn_batches = b * np
if (
self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_float16 # input must be fp16
and 16 < sk <= 4096 # sk must be 16 ~ 2048
and sq % 4 == 0 # sq must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if 0 <= sk <= 4096:
batch_per_block = self.get_batch_per_block(sq, sk, b, np)
if self.attn_mask_type == "causal":
if attn_batches % batch_per_block == 0:
return True
else:
if sq % batch_per_block == 0:
return True
return False
def forward_fused_softmax(
self, inp: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:
"""Fused masked softmax kernel"""
b, np, sq, sk = inp.size()
scale = self.scale if self.scale is not None else 1.0
if self.attn_mask_type == "causal":
assert sq == sk, "causal mask is only for self attention"
# input is 3D tensor (attn_batches, sq, sk)
inp = inp.view(-1, sq, sk)
probs = ScaledUpperTriangMaskedSoftmax.apply(inp, scale)
return probs.view(b, np, sq, sk)
# input is 4D tensor (b, np, sq, sk)
if mask is not None:
return ScaledMaskedSoftmax.apply(inp, mask, scale)
return ScaledSoftmax.apply(inp, scale)
def forward_torch_softmax(
self, inp: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:
"""Framework softmax"""
if self.input_in_float16 and self.softmax_in_fp32:
inp = inp.float()
if self.scale is not None:
inp = inp * self.scale
mask_output = self.mask_func(inp, mask) if mask is not None else inp
probs = torch.nn.Softmax(dim=-1)(mask_output)
if self.input_in_float16 and self.softmax_in_fp32:
if self.input_in_fp16:
probs = probs.half()
else:
probs = probs.bfloat16()
return probs
@staticmethod
def get_batch_per_block(sq: int, sk: int, b: int, np: int) -> int:
"""Softmax utility"""
import scaled_masked_softmax_cuda
return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np)
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Transformer."""
import os
import math
from contextlib import nullcontext
from typing import Any, Callable, Optional, Tuple, Union
import torch
from torch.nn.parameter import Parameter
from transformer_engine.pytorch import LayerNormLinear, Linear, LayerNormMLP, LayerNorm
from transformer_engine.pytorch.jit import (
set_jit_fusion_options,
warmup_jit_bias_dropout_add_all_dtypes,
get_bias_dropout_add,
bias_dropout_add_fused_train,
bias_dropout_add_fused_inference,
)
from transformer_engine.pytorch.utils import (
divide,
attention_mask_func,
split_tensor_along_last_dim,
cast_if_needed,
get_default_init_method,
)
from transformer_engine.pytorch.constants import (
AttnMaskTypes,
AttnTypes,
LayerTypes,
dist_group_type,
)
from transformer_engine.pytorch.softmax import FusedScaleMaskSoftmax
from transformer_engine.pytorch.distributed import (
get_distributed_world_size,
checkpoint,
initialize_affine_weight_gpu,
set_tensor_model_parallel_attributes,
)
class DropPath(torch.nn.Module):
"""Drop paths (Stochastic Depth) per sample
(when applied in main path of residual blocks).
"""
def __init__(self, drop_prob: float = 0.0) -> None:
super().__init__()
self.drop_prob = drop_prob
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
"""DropPath FWD"""
if self.drop_prob == 0.0 or not self.training:
return hidden_state
keep_prob = 1 - self.drop_prob
# work with diff dim tensors, not just 2D ConvNets
shape = (hidden_state.shape[0],) + (1,) * (hidden_state.ndim - 1)
random_tensor = keep_prob + torch.rand(
shape, dtype=hidden_state.dtype, device=hidden_state.device
)
random_tensor.floor_() # binarize
output = hidden_state.div(keep_prob) * random_tensor
return output
class CoreAttention(torch.nn.Module):
"""Parallel attention w/o QKV and Proj Gemms
BMM1 -> softmax + dropout -> BMM2
"""
def __init__(
self,
num_attention_heads: int,
kv_channels: int,
attention_dropout: float,
layer_number: Optional[int] = None,
apply_query_key_layer_scaling: bool = True,
attention_softmax_in_fp32: bool = False,
attn_mask_type: str = "causal",
tp_size: int = 1,
get_rng_state_tracker: Optional[Callable] = None,
sequence_parallel: bool = False,
) -> None:
super().__init__()
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
if layer_number is None:
self.apply_query_key_layer_scaling = False
else:
self.layer_number = max(1, layer_number)
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
self.attn_mask_type = attn_mask_type
projection_size = kv_channels * num_attention_heads
assert (
attn_mask_type in AttnMaskTypes
), f"attn_mask_type {attn_mask_type} not supported"
# Per attention head and per partition values.
self.hidden_size_per_partition = divide(projection_size, tp_size)
self.hidden_size_per_attention_head = divide(
projection_size, num_attention_heads
)
self.sequence_parallel = sequence_parallel
if self.sequence_parallel or get_rng_state_tracker is None:
self.attention_dropout_ctx = nullcontext
else:
self.attention_dropout_ctx = get_rng_state_tracker().fork
coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
if self.apply_query_key_layer_scaling:
coeff = self.layer_number
self.norm_factor *= coeff
self.scale_mask_softmax = FusedScaleMaskSoftmax(
self.attn_mask_type,
attention_mask_func,
self.attention_softmax_in_fp32,
coeff,
)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(attention_dropout)
def forward(
self,
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
"""core attention fprop"""
# [b, np, sq, sk]
output_size = (
query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0),
)
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(
output_size[2], output_size[0] * output_size[1], -1
)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
# preallocting result tensor: [b * np, sq, sk]
matmul_result = torch.empty(
output_size[0] * output_size[1],
output_size[2],
output_size[3],
dtype=query_layer.dtype,
device=torch.cuda.current_device(),
)
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_result,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0,
alpha=(1.0 / self.norm_factor),
)
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# attention scores and attention mask [b, np, sq, sk]
attention_probs = self.scale_mask_softmax(attention_scores, attention_mask)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with self.attention_dropout_ctx():
attention_probs = self.attention_dropout(attention_probs)
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
output_size = (
value_layer.size(1),
value_layer.size(2),
query_layer.size(0),
value_layer.size(3),
)
# change view [sk, b * np, hn]
value_layer = value_layer.view(
value_layer.size(0), output_size[0] * output_size[1], -1
)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(
output_size[0] * output_size[1], output_size[2], -1
)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + (
self.hidden_size_per_partition,
)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer
class MultiHeadAttention(torch.nn.Module):
"""Parallel attention w/o QKV and Proj Gemms
BMM1 -> softmax + dropout -> BMM2
"""
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
kv_channels: int,
attention_dropout: float,
layernorm_epsilon: float,
init_method: Callable,
output_layer_init_method: Callable,
layer_number: Optional[int] = None,
apply_query_key_layer_scaling: bool = True,
attention_softmax_in_fp32: bool = False,
attn_mask_type: str = "causal",
tp_group: Optional[dist_group_type] = None,
tp_size: int = 1,
fuse_wgrad_accumulation: bool = False,
get_rng_state_tracker: Optional[Callable] = None,
sequence_parallel: bool = False,
params_dtype: torch.dtype = torch.float32,
return_layernorm_output: bool = False,
input_layernorm: bool = False,
attention_type: str = "self",
set_parallel_mode: bool = False,
fuse_qkv_params: bool = False,
) -> None:
super().__init__()
self.layer_number = (layer_number,)
self.input_layernorm = input_layernorm
self.attention_type = attention_type
self.get_rng_state_tracker = get_rng_state_tracker
self.tp_group = tp_group
self.return_layernorm_output = return_layernorm_output
self.params_dtype = params_dtype
self.init_method = init_method
self.fuse_qkv_params = fuse_qkv_params
assert (
attention_type in AttnTypes
), f"attention_type {attention_type} not supported"
tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
self.tp_size = tp_size
self.sequence_parallel = (tp_size > 1) and sequence_parallel
projection_size = kv_channels * num_attention_heads
self.hidden_size_per_attention_head = divide(
projection_size, num_attention_heads
)
self.num_attention_heads_per_partition = divide(num_attention_heads, tp_size)
common_gemm_kwargs = {
"fuse_wgrad_accumulation": fuse_wgrad_accumulation,
"tp_group": tp_group,
"tp_size": tp_size,
"get_rng_state_tracker": get_rng_state_tracker,
"sequence_parallel": sequence_parallel,
"params_dtype": params_dtype,
}
qkv_parallel_mode = "column" if set_parallel_mode else None
if not fuse_qkv_params:
self.set_qkv_params(
hidden_size,
3 * hidden_size,
parallel_mode=qkv_parallel_mode,
bias=True,
)
if self.attention_type == "self":
if self.input_layernorm:
self.layernorm_qkv = LayerNormLinear(
hidden_size,
3 * hidden_size,
eps=layernorm_epsilon,
init_method=init_method,
bias=True,
return_bias=False,
parallel_mode=qkv_parallel_mode,
return_layernorm_output=return_layernorm_output,
skip_weight_param_allocation=not fuse_qkv_params,
**common_gemm_kwargs,
)
else:
self.qkv = Linear(
hidden_size,
3 * hidden_size,
init_method=init_method,
bias=True,
return_bias=False,
parallel_mode=qkv_parallel_mode,
skip_weight_param_allocation=not fuse_qkv_params,
**common_gemm_kwargs,
)
else:
if self.input_layernorm:
self.layernorm_query = LayerNormLinear(
hidden_size,
hidden_size,
eps=layernorm_epsilon,
init_method=init_method,
bias=True,
return_bias=False,
parallel_mode=qkv_parallel_mode,
return_layernorm_output=return_layernorm_output,
skip_weight_param_allocation=not fuse_qkv_params,
**common_gemm_kwargs,
)
else:
self.query = Linear(
hidden_size,
hidden_size,
init_method=init_method,
bias=True,
return_bias=False,
parallel_mode=qkv_parallel_mode,
skip_weight_param_allocation=not fuse_qkv_params,
**common_gemm_kwargs,
)
self.key_value = Linear(
hidden_size,
2 * hidden_size,
init_method=init_method,
bias=True,
return_bias=False,
parallel_mode=qkv_parallel_mode,
skip_weight_param_allocation=not fuse_qkv_params,
**common_gemm_kwargs,
)
# Core Self attention.
self.core_attention = CoreAttention(
num_attention_heads,
kv_channels,
attention_dropout,
layer_number=layer_number,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
attention_softmax_in_fp32=attention_softmax_in_fp32,
tp_size=tp_size,
get_rng_state_tracker=get_rng_state_tracker,
attn_mask_type=attn_mask_type,
sequence_parallel=sequence_parallel,
)
# Linear
self.proj = Linear(
hidden_size,
hidden_size,
init_method=output_layer_init_method,
bias=False,
return_bias=True,
parallel_mode="row" if set_parallel_mode else None,
**common_gemm_kwargs,
)
def set_qkv_params(
self,
in_features: torch.Tensor,
out_features: torch.Tensor,
parallel_mode: Optional[bool] = None,
bias: bool = False,
) -> None:
"""Initialize separate Parameters for query, key, and value tensors."""
if parallel_mode == "column":
out_features = divide(out_features, self.tp_size)
elif parallel_mode == "row":
in_features = divide(in_features, self.tp_size)
assert (
out_features % 3 == 0
), f"3 way QKV split with dimension {out_features} not possible."
weight_tensor = torch.empty(
out_features,
in_features,
device=torch.cuda.current_device(),
dtype=self.params_dtype,
)
initialize_affine_weight_gpu(
weight_tensor,
self.init_method,
self.get_rng_state_tracker,
partition_dim=1 if parallel_mode == "row" else 0,
stride=1,
)
qkv_first_dim = out_features // 3
self.query = Parameter(weight_tensor[0:qkv_first_dim, :])
self.key = Parameter(weight_tensor[qkv_first_dim : 2 * qkv_first_dim, :])
self.value = Parameter(weight_tensor[2 * qkv_first_dim : 3 * qkv_first_dim, :])
set_tensor_model_parallel_attributes(
tensor=self.query,
is_parallel=True,
dim=1 if parallel_mode == "row" else 0,
stride=1,
)
set_tensor_model_parallel_attributes(
tensor=self.key,
is_parallel=True,
dim=1 if parallel_mode == "row" else 0,
stride=1,
)
set_tensor_model_parallel_attributes(
tensor=self.value,
is_parallel=True,
dim=1 if parallel_mode == "row" else 0,
stride=1,
)
if bias:
bias_tensor = torch.empty(
out_features,
device=torch.cuda.current_device(),
dtype=self.params_dtype,
)
self.query_bias = Parameter(bias_tensor[0:qkv_first_dim])
self.key_bias = Parameter(bias_tensor[qkv_first_dim : 2 * qkv_first_dim])
self.value_bias = Parameter(
bias_tensor[2 * qkv_first_dim : 3 * qkv_first_dim]
)
if parallel_mode == "column":
set_tensor_model_parallel_attributes(self.query_bias, True, 0, 1)
set_tensor_model_parallel_attributes(self.key_bias, True, 0, 1)
set_tensor_model_parallel_attributes(self.value_bias, True, 0, 1)
else:
self.register_buffer("query_bias", torch.Tensor(), persistent=False)
self.register_buffer("key_bias", torch.Tensor(), persistent=False)
self.register_buffer("value_bias", torch.Tensor(), persistent=False)
with torch.no_grad():
self.query_bias.zero_()
self.key_bias.zero_()
self.value_bias.zero_()
def _checkpointed_core_attention_forward(
self,
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
"""Forward method with activation checkpointing."""
def custom_forward(*inputs):
query_layer = inputs[0]
key_layer = inputs[1]
value_layer = inputs[2]
attention_mask = inputs[3]
output_ = self.core_attention(
query_layer, key_layer, value_layer, attention_mask
)
return output_
hidden_states = checkpoint(
custom_forward,
False,
self.get_rng_state_tracker,
self.tp_group,
query_layer,
key_layer,
value_layer,
attention_mask,
)
return hidden_states
def _allocate_memory(
self, inference_max_sequence_len: int, batch_size: int
) -> torch.Tensor:
return torch.empty(
inference_max_sequence_len,
batch_size,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
dtype=self.params_dtype,
device=torch.cuda.current_device(),
)
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
"""Set TP group"""
self.tp_group = tp_group
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
encoder_output: Optional[torch.Tensor] = None,
is_first_microbatch: Optional[bool] = None,
checkpoint_core_attention: Optional[bool] = None,
inference_params: Optional[Any] = None,
) -> Tuple[Union[torch.Tensor, None], ...]:
"""MultiHeadAttention FWD"""
# hidden_states: [sq, b, h]
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
if inference_params and self.layer_number is not None:
if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_len
inf_max_batch_size = inference_params.max_batch_size
inference_key_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size
)
inference_value_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size
)
inference_params.key_value_memory_dict[self.layer_number] = (
inference_key_memory,
inference_value_memory,
)
else:
(
inference_key_memory,
inference_value_memory,
) = inference_params.key_value_memory_dict[self.layer_number]
# =====================
# Query, Key, and Value
# =====================
if self.attention_type == "self":
qkv_weight = (
torch.cat((self.query, self.key, self.value))
if not self.fuse_qkv_params
else None
)
qkv_bias = (
torch.cat((self.query_bias, self.key_bias, self.value_bias))
if not self.fuse_qkv_params
else None
)
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
if self.input_layernorm:
layernorm_qkv_outputs = self.layernorm_qkv(
hidden_states,
weight=qkv_weight,
bias=qkv_bias,
is_first_microbatch=is_first_microbatch,
)
if self.return_layernorm_output:
mixed_x_layer, layernorm_output = layernorm_qkv_outputs
else:
mixed_x_layer = layernorm_qkv_outputs
else:
mixed_x_layer = self.qkv(
hidden_states,
weight=qkv_weight,
bias=qkv_bias,
is_first_microbatch=is_first_microbatch,
)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
query_layer, key_layer, value_layer = split_tensor_along_last_dim(
mixed_x_layer, 3
)
else:
kv_weight = (
torch.cat((self.key, self.value)) if not self.fuse_qkv_params else None
)
kv_bias = (
torch.cat((self.key_bias, self.value_bias))
if not self.fuse_qkv_params
else None
)
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer = self.key_value(
encoder_output,
weight=kv_weight,
bias=kv_bias,
is_first_microbatch=is_first_microbatch,
)
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
2 * self.hidden_size_per_attention_head,
)
mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(key_layer, value_layer) = split_tensor_along_last_dim(mixed_kv_layer, 2)
# Attention head [sq, b, h] --> [sq, b, hp]
if self.input_layernorm:
layernorm_query_outputs = self.layernorm_query(
hidden_states,
weight=self.query,
bias=self.query_bias,
is_first_microbatch=is_first_microbatch,
)
if self.return_layernorm_output:
query_layer, layernorm_output = layernorm_query_outputs
else:
query_layer = layernorm_query_outputs
else:
query_layer = self.query(
hidden_states,
weight=self.query,
bias=self.query_bias,
is_first_microbatch=is_first_microbatch,
)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape = query_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
query_layer = query_layer.view(*new_tensor_shape)
# ==================================
# Adjust key and value for inference
# ==================================
if inference_params and self.layer_number is not None:
batch_start = inference_params.batch_size_offset
batch_end = batch_start + key_layer.size(1)
assert batch_end <= inference_key_memory.size(1)
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + key_layer.size(0)
assert sequence_end <= inference_key_memory.size(0)
# Copy key and values.
inference_key_memory[
sequence_start:sequence_end, batch_start:batch_end, ...
] = key_layer
inference_value_memory[
sequence_start:sequence_end, batch_start:batch_end, ...
] = value_layer
key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
value_layer = inference_value_memory[
:sequence_end, batch_start:batch_end, ...
]
# ==================================
# core attention computation
# ==================================
if checkpoint_core_attention:
context_layer = self._checkpointed_core_attention_forward(
query_layer, key_layer, value_layer, attention_mask
)
else:
context_layer = self.core_attention(
query_layer, key_layer, value_layer, attention_mask
)
# =================
# Output. [sq, b, h]
# =================
attention_output, attention_bias = self.proj(
context_layer, is_first_microbatch=is_first_microbatch
)
if self.input_layernorm and self.return_layernorm_output:
return attention_output, attention_bias, layernorm_output
return attention_output, attention_bias
class TransformerLayer(torch.nn.Module):
"""
TransformerLayer is made up of an attention block and a feedforward network (MLP).
This standard layer is based on the paper "Attention Is All You Need".
Parameters
----------
hidden_size : int
size of each input sample.
ffn_hidden_size : int
intermediate size to which input samples are projected.
num_attention_heads : int
number of attention heads in the transformer layer.
layernorm_epsilon : float, default = 1e-5
a value added to the denominator of layer normalization
for numerical stability.
hidden_dropout: float, default = 0.1
dropout probability for the dropout op after FC2 layer.
attention_dropout: float, default = 0.1
dropout probability for the dropout op during multi-head attention.
init_method : Callable, default = `None`
used for initializing weights of QKV and FC1 weights in the following way:
`init_method(weight)`. When set to `None`, defaults to
`torch.nn.init.normal_(mean=0.0, std=0.023)`.
output_layer_init_method : Callable, default = `None`
used for initializing weights of PROJ and FC2 in the following way:
`output_layer_init_method(weight)`. When set to `None`, defaults to
`torch.nn.init.normal_(mean=0.0, std=0.023)`.
apply_residual_connection_post_layernorm : bool, default = `False`
if set to `True`, residual connections are taken
from the output of layer norm (default is taken
from input of layer norm)
layer_number: int, default = `None`
layer number of the current `TransformerLayer` when multiple such modules are
concatenated to form a transformer block.
apply_query_key_layer_scaling: bool, default = `True`
apply query-key layer scaling during BMM1
by a factor of `layer_number`
output_layernorm: bool, default = `False`
if set to `True`, layer normalization is applied on the output side,
after the final dropout-add. default behavior is to apply layer
normalization on the input side, before the QKV transformation.
attention_softmax_in_fp32: bool, default = `False`
if set to `True`, softmax is executed in
torch.float32 dtype (single precision)
layer_type: {'encoder', 'decoder'}, default = `encoder`
if set to `decoder`, an additional cross-attn block is added after self-attn.
This can be used for structures like `T5` Transformer in conjunction with the
`encoder` option.
kv_channels: int, default = `None`
number of key-value channels. defaults to
:attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
self_attn_mask_type: {'causal', 'padding'}, default = `causal`
type of attention mask passed into softmax operation.
Parallelism parameters
----------------------
set_parallel_mode : bool, default = `False`
if set to `True`, QKV and FC1 layers are used as Column Parallel
whereas PROJ and FC2 is used as Row Parallel as described
`here <https://arxiv.org/pdf/1909.08053.pdf>`_.
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
tp_size : int, default = 1
used as TP (tensor parallel) world size when TP groups are not formed during
initialization. In this case, users must call the
`set_tensor_parallel_group(tp_group)` method on the initialized module before the
forward pass to supply the tensor parallel group needed for tensor and sequence
parallel collectives.
Optimization parameters
-----------------------
fuse_wgrad_accumulation : bool, default = 'False'
if set to `True`, enables fusing of creation and accumulation of
the weight gradient.
params_dtype : torch.dtype, default = `torch.float32`
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
seq_length: int
sequence length of input samples. Needed for JIT Warmup, a technique where jit
fused functions are warmed up before training to ensure same kernels are used for
forward propogation and activation recompute phase.
micro_batch_size: int
batch size per training step. Needed for JIT Warmup, a technique where jit
fused functions are warmed up before training to ensure same kernels are
used for forward propogation and activation recompute phase.
drop_path_rate: float, default = 0.0
when > 0.0, applies stochastic depth per sample in
the main path of the residual block.
fuse_qkv_params: bool, default = 'False'
if set to `True`, `TransformerLayer` module exposes a single fused
parameter for query-key-value. This enables optimizations such as QKV
fusion without concatentations/splits and also enables the argument
`fuse_wgrad_accumulation`.
"""
def __init__(
self,
hidden_size: int,
ffn_hidden_size: int,
num_attention_heads: int,
layernorm_epsilon: float = 1e-5,
hidden_dropout: float = 0.1,
attention_dropout: float = 0.1,
init_method: Optional[Callable] = None,
output_layer_init_method: Optional[Callable] = None,
layer_number: Optional[int] = None,
kv_channels: Optional[int] = None,
self_attn_mask_type: str = "causal",
tp_group: Optional[dist_group_type] = None,
tp_size: int = 1,
params_dtype: torch.dtype = torch.float32,
get_rng_state_tracker: Optional[Callable] = None,
fuse_wgrad_accumulation: bool = False,
apply_query_key_layer_scaling: bool = True,
attention_softmax_in_fp32: bool = False,
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
sequence_parallel: bool = False,
apply_residual_connection_post_layernorm: bool = False,
output_layernorm: bool = False,
layer_type: str = "encoder",
drop_path_rate: float = 0.0,
set_parallel_mode: bool = False,
fuse_qkv_params: bool = False,
) -> None:
super().__init__()
bias_dropout_fusion = bool(int(os.getenv("NVTE_BIAS_DROPOUT_FUSION", "1")))
self.layer_number = layer_number
self.output_layernorm = output_layernorm
self.layer_type = layer_type
self.apply_residual_connection_post_layernorm = (
apply_residual_connection_post_layernorm
)
assert (
self_attn_mask_type in AttnMaskTypes
), f"self_attn_mask_type {self_attn_mask_type} not supported"
assert layer_type in LayerTypes, f"layer_type {layer_type} not supported"
if not fuse_qkv_params:
assert (
not fuse_wgrad_accumulation
), "Gradient accumulation fusion requires single QKV parameter."
self.kv_channels = (
kv_channels if kv_channels else (hidden_size // num_attention_heads)
)
if init_method is None:
init_method = get_default_init_method()
if output_layer_init_method is None:
output_layer_init_method = get_default_init_method()
tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
self.sequence_parallel = (tp_size > 1) and sequence_parallel
self.get_rng_state_tracker = get_rng_state_tracker
attention_args = (
hidden_size,
num_attention_heads,
self.kv_channels,
attention_dropout,
layernorm_epsilon,
init_method,
output_layer_init_method,
)
common_attention_kwargs = {
"layer_number": layer_number,
"apply_query_key_layer_scaling": apply_query_key_layer_scaling,
"attention_softmax_in_fp32": attention_softmax_in_fp32,
"tp_group": tp_group,
"tp_size": tp_size,
"fuse_wgrad_accumulation": fuse_wgrad_accumulation,
"get_rng_state_tracker": get_rng_state_tracker,
"sequence_parallel": self.sequence_parallel,
"params_dtype": params_dtype,
"return_layernorm_output": apply_residual_connection_post_layernorm,
"set_parallel_mode": set_parallel_mode,
"fuse_qkv_params": fuse_qkv_params,
}
self.self_attention = MultiHeadAttention(
*attention_args,
**common_attention_kwargs,
attn_mask_type=self_attn_mask_type,
input_layernorm=not output_layernorm,
attention_type="self",
)
if layer_type == "decoder":
self.inter_attention = MultiHeadAttention(
*attention_args,
**common_attention_kwargs,
attn_mask_type="padding",
input_layernorm=True,
attention_type="cross",
)
# LayerNorm -> gelu(Linear + Bias) -> Linear
# parallel_mode not supported for LayerNormMLP,
# FC1 is CPL and FC2 is RPL
self.layernorm_mlp = LayerNormMLP(
hidden_size,
ffn_hidden_size,
eps=layernorm_epsilon,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
tp_group=tp_group,
tp_size=tp_size,
get_rng_state_tracker=get_rng_state_tracker,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
bias=False,
return_bias=True,
sequence_parallel=self.sequence_parallel,
params_dtype=params_dtype,
return_layernorm_output=apply_residual_connection_post_layernorm,
seq_length=seq_length,
micro_batch_size=micro_batch_size,
set_parallel_mode=set_parallel_mode,
)
self.hidden_dropout = hidden_dropout
self.bias_dropout_fusion = bias_dropout_fusion
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None
# Set bias+dropout+add fusion grad_enable execution handler.
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
self.bias_dropout_add_exec_handler = (
nullcontext if use_nvfuser else torch.enable_grad
)
if self.bias_dropout_fusion:
set_jit_fusion_options()
if seq_length and micro_batch_size:
if self.sequence_parallel:
seq_length = seq_length // tp_size
warmup_jit_bias_dropout_add_all_dtypes(
hidden_size, seq_length, micro_batch_size
)
if self.output_layernorm:
self.layernorm = LayerNorm(
hidden_size,
eps=layernorm_epsilon,
sequence_parallel=self.sequence_parallel,
params_dtype=params_dtype,
)
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
"""Set TP group"""
# Deep iterate but skip self to avoid infinite recursion.
for index, child in enumerate(self.modules()):
if index == 0:
continue
if hasattr(child, "set_tensor_parallel_group"):
child.set_tensor_parallel_group(tp_group)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
encoder_output: Optional[torch.Tensor] = None,
enc_dec_attn_mask: Optional[torch.Tensor] = None,
is_first_microbatch: Optional[bool] = None,
checkpoint_core_attention: Optional[bool] = False,
inference_params: Optional[Any] = None,
) -> torch.Tensor:
"""
Transformer Layer: attention block and a feedforward network (MLP)
Parameters
----------
hidden_states : torch.Tensor
Input tensor.
attention_mask : torch.Tensor
Boolean tensor used to mask out self-attention softmax input.
encoder_output : torch.Tensor
Output of the encoder block to be fed into the decoder block if using
`layer_type="decoder"`.
enc_dec_attn_mask : torch.Tensor
Boolean tensor used to mask out inter-attention softmax input if using
`layer_type="decoder"`.
is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split
into microbatches. Between the microbatches of the same minibatch
the model weights are not updated. Setting this parameter indicates
whether the current microbatch is the first in a minibatch or not.
When set, this parameter enables additional optimizations:
* during FP8 training, it allows caching of the FP8 versions of
the weights
* it also allows skipping gradient accumulation during the
first microbatch (since it is the first gradient being
produced)
checkpoint_core_attention: bool, default = `True`
If true, forward activations for core attention are recomputed
during the backward pass in order to save memory that would
otherwise be occupied to store the forward activations until
backprop.
"""
# For AMP
if torch.is_autocast_enabled():
hidden_states = cast_if_needed(
hidden_states, torch.get_autocast_gpu_dtype()
)
# Self attention.
self_attention_outputs = self.self_attention(
hidden_states,
attention_mask,
inference_params=inference_params,
is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention,
)
if self.apply_residual_connection_post_layernorm and not self.output_layernorm:
attention_output, attention_bias, residual = self_attention_outputs
else:
attention_output, attention_bias = self_attention_outputs
residual = hidden_states
# Set BDA func.
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)
# Bias dropoout add.
if self.drop_path is None:
with self.bias_dropout_add_exec_handler():
bda_output = bias_dropout_add_func(
attention_output, attention_bias, residual, self.hidden_dropout
)
else:
out = torch.nn.functional.dropout(
attention_output + attention_bias,
p=self.hidden_dropout,
training=self.training,
)
bda_output = residual + self.drop_path(out)
# Cross attention.
if self.layer_type == "decoder":
inter_attention_outputs = self.inter_attention(
bda_output,
enc_dec_attn_mask,
encoder_output=encoder_output,
is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention,
)
if self.apply_residual_connection_post_layernorm:
attention_output, attention_bias, residual = inter_attention_outputs
else:
attention_output, attention_bias = inter_attention_outputs
residual = bda_output
with self.bias_dropout_add_exec_handler():
bda_output = bias_dropout_add_func(
attention_output, attention_bias, residual, self.hidden_dropout
)
# MLP.
mlp_outputs = self.layernorm_mlp(
bda_output, is_first_microbatch=is_first_microbatch
)
if self.apply_residual_connection_post_layernorm:
mlp_output, mlp_bias, residual = mlp_outputs
else:
mlp_output, mlp_bias = mlp_outputs
residual = bda_output
# Bias dropoout add.
if self.drop_path is None:
with self.bias_dropout_add_exec_handler():
output = bias_dropout_add_func(
mlp_output, mlp_bias, residual, self.hidden_dropout
)
else:
out = torch.nn.functional.dropout(
mlp_output + mlp_bias, p=self.hidden_dropout, training=self.training
)
output = residual + self.drop_path(out)
# For BERT like architectures.
if self.output_layernorm:
output = self.layernorm(output)
# output: [b, s, h]
return output
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Utility functions for Transformer Engine modules"""
import math
from typing import Any, Callable, Optional, Tuple
import torch
def attention_mask_func(
attention_scores: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
"""Get attention mask"""
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
def get_default_init_method() -> Callable:
"""Weight initialization method if not provided by user"""
return init_method_normal(0.023)
def init_method_normal(sigma: float) -> Callable:
"""Init method based on N(0, sigma)."""
def init_(tensor: torch.Tensor) -> Callable:
return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
return init_
def scaled_init_method_normal(sigma: float, num_layers: int) -> Callable:
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
std = sigma / math.sqrt(2.0 * num_layers)
def init_(tensor: torch.Tensor) -> Callable:
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
def all_close(a: torch.Tensor, b: torch.Tensor) -> bool:
"""torch.allclose with cpu to not run into OOMs"""
return torch.allclose(a.cpu(), b.cpu())
def print_rank_0(*args: Any) -> None:
"""print on rank 0"""
if torch.cuda.current_device() == 0:
print(*args)
def compare_tensors(a: torch.Tensor, b: torch.Tensor) -> None:
"""util function to show some tensor stats"""
if a.shape != b.shape:
print_rank_0("Tensors have different shape")
return
print_rank_0(a)
print_rank_0(b)
max_err = torch.max(torch.abs(a - b))
max_a = torch.max(a)
max_b = torch.max(b)
print_rank_0(f"max err={max_err}, max a={max_a}, max_b={max_b}")
def ensure_divisibility(numerator: int, denominator: int) -> None:
"""Ensure that numerator is divisible by the denominator."""
assert (
numerator % denominator == 0
), f"{numerator} is not divisible by {denominator}"
def divide(numerator: int, denominator: int) -> int:
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
def split_tensor_along_last_dim(
tensor: torch.Tensor, num_partitions: int, contiguous_split_chunks: bool = False
) -> Tuple[torch.Tensor, ...]:
"""Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
# Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
# Note: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list
def validate_ctx_manager(ctx: Callable) -> None:
"""Checks if passed in object can be used as a context manager."""
try:
with ctx():
pass
except Exception as e:
raise ValueError("Object must be a valid ctx manager") from e
def validate_rng_states_func(get_rng_tracker: Callable) -> None:
"""Checks if passed in param function has everything
required for tensor/model and sequence parallel.
"""
assert callable(get_rng_tracker), "get_rng_tracker is not a valid function"
rng_tracker = None
try:
rng_tracker = get_rng_tracker()
except Exception as e:
raise RuntimeError("Cannot call get_rng_tracker function") from e
assert hasattr(rng_tracker, "get_states") and callable(
rng_tracker.get_states
), "rng_tracker object does not have valid method get_states"
assert hasattr(rng_tracker, "set_states") and callable(
rng_tracker.set_states
), "rng_tracker object does not have valid method set_states"
assert hasattr(rng_tracker, "fork") and callable(
rng_tracker.fork
), "rng_tracker object does not have valid method fork"
validate_ctx_manager(rng_tracker.fork)
def assert_viewless_tensor(
tensor: torch.Tensor, extra_msg: Optional[str] = None
) -> torch.Tensor:
"""Assert that a tensor is not a view (i.e., its '._base' field is
not set)."""
if isinstance(tensor, list):
return [assert_viewless_tensor(t) for t in tensor]
if not isinstance(tensor, torch.Tensor):
return tensor
assert tensor._base is None, (
f"Ensure tensor._base is None before setting tensor.data or storing "
f"tensor to memory buffer. Otherwise, a memory leak will occur (and "
f"likely accumulate over iterations). {extra_msg}"
)
return tensor
def safely_set_viewless_tensor_data(
tensor: torch.Tensor, new_data_tensor: torch.Tensor
) -> None:
"""Safely set tensor's '.data' field.
Check first that the tensor is viewless (i.e., '._base' not set). If not,
raise an exception.
"""
extra_msg = (
f"FYI, tensor._base has shape "
f"{'--' if tensor._base is None else tensor._base.shape},"
f"and new_data_tensor has shape {new_data_tensor.shape}."
)
assert_viewless_tensor(tensor, extra_msg=extra_msg)
tensor.data = new_data_tensor
def cast_if_needed(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
"""Cast tensor to dtype"""
return tensor if tensor is None or tensor.dtype == dtype else tensor.to(dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment