Unverified Commit c6a4a4e0 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

PyTorch refactor (#201)



* Initial refactor
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* refactor attention out of transformer.py
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix ONNX export
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* linting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent f7608d89
......@@ -34,7 +34,7 @@ import transformer_engine.pytorch as te
from transformer_engine.common import recipe
import transformer_engine_extensions as tex
from transformer_engine.pytorch.cpp_extensions import gemm, fp8_gemm, fp8_gelu, cast_to_fp8, cast_from_fp8
from transformer_engine.pytorch.module import get_workspace
from transformer_engine.pytorch.module.base import get_workspace
import transformer_engine.pytorch.cpp_extensions as texcpp
import transformer_engine.pytorch.softmax as softmax_defs
from transformer_engine.pytorch.utils import get_default_init_method
......@@ -882,7 +882,7 @@ def test_export_core_attention(
if attn_mask_type is None:
attn_mask_type = 'causal'
inp = (query_layer, key_layer, value_layer)
model = te.transformer.DotProductAttention(
model = te.attention.DotProductAttention(
num_attention_heads=num_attention_heads,
kv_channels=kv_channels,
attention_dropout=0.5,
......@@ -972,7 +972,7 @@ def test_export_multihead_attention(
input_ln_str = "_input-ln" if input_layernorm else ""
fname = f"te.multihead_attention{fp8_str}{attn_mask_str}{attn_type_str}{input_ln_str}{fuse_qkv_str}{dtype_str}.onnx"
model = te.transformer.MultiHeadAttention(
model = te.attention.MultiHeadAttention(
*attention_args,
attn_mask_type=attn_mask_type,
params_dtype=precision,
......
......@@ -17,8 +17,8 @@ import torch
import transformer_engine.pytorch as te
import transformer_engine_extensions as tex
from transformer_engine.pytorch.cpp_extensions import fp8_gemm, cast_to_fp8, cast_from_fp8
from transformer_engine.pytorch.module import get_workspace
from transformer_engine.pytorch.module import TransformerEngineBaseModule
from transformer_engine.pytorch.module.base import get_workspace
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
def init_meta(size: int=1):
......
......@@ -7,7 +7,7 @@ from .module import LayerNormLinear
from .module import Linear
from .module import LayerNormMLP
from .module import LayerNorm
from .transformer import DotProductAttention
from .attention import DotProductAttention
from .transformer import TransformerLayer
from .fp8 import fp8_autocast
from .export import onnx_export
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Attention."""
import os
import math
from importlib.metadata import version
from contextlib import nullcontext
from typing import Any, Callable, Optional, Tuple, Union
from pkg_resources import packaging
import torch
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
import transformer_engine_extensions as tex
from transformer_engine.pytorch.module import LayerNormLinear, Linear
from transformer_engine.pytorch.utils import (
divide,
attention_mask_func,
split_tensor_along_dim,
get_device_compute_capability,
)
from transformer_engine.pytorch.constants import (
AttnMaskTypes,
AttnTypes,
dist_group_type,
)
from transformer_engine.pytorch.softmax import FusedScaleMaskSoftmax
from transformer_engine.pytorch.distributed import (
get_distributed_world_size,
checkpoint,
)
from transformer_engine.pytorch.export import is_in_onnx_export_mode
_flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_version_required = packaging.version.Version("1.0.2")
__all__ = ["DotProductAttention"]
class _SplitLastDim(torch.autograd.Function):
""""""
@staticmethod
def forward(ctx,
mixed_x_layer: torch.Tensor,
num_parts: int
) -> Tuple[torch.Tensor, ...]:
return split_tensor_along_dim(mixed_x_layer, -1, num_parts)
@staticmethod
def backward(ctx,
*grad_outputs):
assert len(grad_outputs) > 0, "No gradients received for backprop!"
noop_ok = True
strides = grad_outputs[0].stride()
data_ptr = grad_outputs[0].storage().data_ptr()
shape = grad_outputs[0].shape
last_dim_size = grad_outputs[0].shape[-1]
for i, tensor in enumerate(grad_outputs):
if (tensor.stride() != strides or
tensor.shape != shape or
tensor.storage().data_ptr() != data_ptr or
tensor.storage_offset() != i * last_dim_size):
noop_ok = False
break
if noop_ok:
ret = torch.Tensor().to(grad_outputs[0].dtype)
ret = torch.Tensor().to(device=grad_outputs[0].device,
dtype=grad_outputs[0].dtype)
new_shape = list(shape)
new_shape[-1] = new_shape[-1] * len(grad_outputs)
ret.set_(grad_outputs[0].storage(),
grad_outputs[0].storage_offset(),
new_shape,
grad_outputs[0].stride()
)
return ret, None
return torch.cat(grad_outputs, dim = -1), None
class UnfusedDotProductAttention(torch.nn.Module):
"""Parallel attention w/o QKV and Proj Gemms
BMM1 -> softmax + dropout -> BMM2
"""
def __init__(
self,
norm_factor: float,
attention_dropout: float = 0.0,
attention_dropout_ctx: Optional[Callable] = nullcontext,
attn_mask_type: str = "causal",
layer_number: Optional[int] = None,
) -> None:
super().__init__()
assert (
attn_mask_type in AttnMaskTypes
), f"attn_mask_type {attn_mask_type} not supported"
self.norm_factor = norm_factor
self.attention_dropout_ctx = attention_dropout_ctx
self.layer_number = layer_number
self.scale_mask_softmax = FusedScaleMaskSoftmax(
attn_mask_type,
attention_mask_func,
)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(attention_dropout)
def forward(
self,
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""core attention fprop"""
batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]
apply_qk_layer_scaling = self.layer_number is not None and key_layer.dtype == torch.float16
# [b, np, sq, sk]
output_size = (
query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0),
)
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.reshape(
output_size[2], output_size[0] * output_size[1], -1
)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1)
# preallocting result tensor: [b * np, sq, sk]
matmul_result = torch.empty(
output_size[0] * output_size[1],
output_size[2],
output_size[3],
dtype=query_layer.dtype,
device=torch.cuda.current_device(),
)
scale = self.norm_factor
if apply_qk_layer_scaling:
scale *= self.layer_number
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_result,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0,
alpha=(1.0 / scale),
)
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# attention scores and attention mask [b, np, sq, sk]
softmax_scale = self.layer_number if apply_qk_layer_scaling else None
attention_probs = self.scale_mask_softmax(attention_scores, attention_mask, softmax_scale)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with self.attention_dropout_ctx():
attention_probs = self.attention_dropout(attention_probs)
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
output_size = (
value_layer.size(1),
value_layer.size(2),
query_layer.size(0),
value_layer.size(3),
)
# change view [sk, b * np, hn]
value_layer = value_layer.reshape(
value_layer.size(0), output_size[0] * output_size[1], -1
)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(
output_size[0] * output_size[1], output_size[2], -1
)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
context_layer = context_layer.view(seqlen, batch_size, -1)
return context_layer
class _PrepareQKVForFA(torch.autograd.Function):
"""This class converts QKV from interleaved (s, b, ...) layout
to separate contiguous q, k, v tensors in (b, s, ...) layout."""
@staticmethod
def forward(ctx,
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor
) -> torch.Tensor:
# All inputs received are non-contiguous tensors.
# The `query_layer` tensor is used to access the
# full memory region of the QKV tensor.
qkv = tex.fa_prepare_fwd(query_layer)
q, k, v = split_tensor_along_dim(qkv, 0, 3)
query_layer = torch.squeeze(q, 0)
key_layer = torch.squeeze(k, 0)
value_layer = torch.squeeze(v, 0)
return query_layer, key_layer, value_layer
@staticmethod
def backward(ctx,
dq: torch.Tensor,
dk: torch.Tensor,
dv: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
dqkv = tex.fa_prepare_bwd(dq, dk, dv)
dq, dk, dv = split_tensor_along_dim(dqkv, -1, 3)
return dq, dk, dv
def _check_if_interleaved(q, k, v):
data_ptr = q.storage().data_ptr()
check_ptrs = all(x.storage().data_ptr() == data_ptr for x in [q, k, v])
if not check_ptrs:
return False
stride = q.stride()
check_strides = all(stride == x.stride() for x in [q, k, v])
if not check_strides:
return False
shape = q.shape
check_shapes = all(shape == x.shape for x in [q, k, v])
if not check_shapes:
return False
last_dim_size = shape[-1]
check_offsets = all(i * last_dim_size == x.storage_offset()
for i, x in enumerate([q, k, v]))
return check_offsets
class FlashAttention(torch.nn.Module):
"""Dot product attention implementation by using the flash-attn package.
"""
def __init__(
self,
norm_factor: float,
attention_dropout: float = 0.0,
attention_dropout_ctx: Optional[Callable] = nullcontext,
attn_mask_type: str = "causal",
) -> None:
super().__init__()
assert (
_flash_attn_version >= _flash_attn_version_required
), f"FlashAttention minimum version {_flash_attn_version_required} is required."
assert (
attn_mask_type == "causal"
), 'FlashAttention currently only supports causal attention mask.'
self.attn_causal_mask = attn_mask_type == "causal"
self.norm_factor = norm_factor
self.attention_dropout_ctx = attention_dropout_ctx
self.attention_dropout = attention_dropout
self.deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
def forward(
self,
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""flash-attn fprop"""
assert (
(query_layer.dtype in [torch.float16, torch.bfloat16])
and (key_layer.dtype in [torch.float16, torch.bfloat16])
and (value_layer.dtype in [torch.float16, torch.bfloat16])
), 'FlashAttention currently only supports FP16 and BF16.'
assert (
query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
), 'FlashAttention currently only supports CUDA tensors.'
assert (
attention_mask is None
), 'FlashAttention currently does not support external attention mask.'
# For now just 128, will make it more general in the future
if (query_layer.shape[-1] == 128 and
query_layer.shape[0] * query_layer.shape[1] >= 512 and
_check_if_interleaved(query_layer, key_layer, value_layer)):
query_layer, key_layer, value_layer = _PrepareQKVForFA.apply(query_layer,
key_layer,
value_layer)
else:
query_layer, key_layer, value_layer = [x.transpose(0,1).contiguous()
for x in (query_layer, key_layer, value_layer)]
batch_size, seqlen = query_layer.shape[0], query_layer.shape[1]
# [b, sq, np, hn]
query_layer, key_layer, value_layer = [
x.view(x.shape[0] * x.shape[1], *x.shape[2:])
for x in [query_layer, key_layer, value_layer]
]
max_seqlen = seqlen
cu_seqlens = torch.arange(
0,
(batch_size + 1) * seqlen,
step=seqlen,
dtype=torch.int32,
device=query_layer.device)
with self.attention_dropout_ctx():
output = flash_attn_unpadded_func(
query_layer, key_layer, value_layer, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen,
self.attention_dropout if self.training else 0.0,
softmax_scale=1.0/self.norm_factor, causal=self.attn_causal_mask,
deterministic=self.deterministic,
)
# [(b sq), np, hn] -> [sq, b, (np hn)]
return output.view(batch_size, seqlen, -1).transpose(0, 1).contiguous()
class DotProductAttention(torch.nn.Module):
"""Allows the model to jointly attend to information from different
representation subspaces as described in the paper:
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
.. note::
Argument :attr:`attention_mask` will be ignored in the `forward` call when
:attr:`attn_mask_type` is set to `"causal"`.
.. warning::
For the default attention mechanism, this module executes a non-deterministic version of
`flash-attn <https://github.com/ksivaman/flash-attention>`_ whenever possible in order to
achieve optimal performance. To observe deterministic behavior, set the environment
variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order to disable
`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`.
Parameters
----------
num_attention_heads : int
number of attention heads in the transformer layer.
kv_channels : int
number of key-value channels.
attention_dropout: float, default = 0.0
dropout probability for the dropout op during multi-head attention.
attn_mask_type: {'causal', 'padding'}, default = `causal`
type of attention mask passed into softmax operation.
layer_number: int, default = `None`
layer number of the current `DotProductAttention` when multiple such modules
are concatenated, for instance in consecutive transformer blocks.
Parallelism parameters
----------------------
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
tp_size : int, default = 1
tensor parallel world size.
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
"""
def __init__(
self,
num_attention_heads: int,
kv_channels: int,
attention_dropout: float = 0.0,
attn_mask_type: str = "causal",
sequence_parallel: bool = False,
tp_size: int = 1,
get_rng_state_tracker: Optional[Callable] = None,
tp_group: Optional[dist_group_type] = None,
layer_number: Optional[int] = None,
) -> None:
super().__init__()
tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
self.tp_group = tp_group
self.get_rng_state_tracker = get_rng_state_tracker
projection_size = kv_channels * num_attention_heads
self.hidden_size_per_partition = divide(projection_size, tp_size)
self.hidden_size_per_attention_head = divide(
projection_size, num_attention_heads
)
if sequence_parallel or get_rng_state_tracker is None:
attention_dropout_ctx = nullcontext
else:
attention_dropout_ctx = get_rng_state_tracker().fork
norm_factor = math.sqrt(self.hidden_size_per_attention_head)
self.device_compute_capability = get_device_compute_capability()
self.use_flash_attention = (
int(os.getenv("NVTE_FLASH_ATTN", "1"))
and attn_mask_type == "causal"
and self.device_compute_capability >= 8.0
)
attn_kwargs = {
"attention_dropout": attention_dropout,
"attention_dropout_ctx": attention_dropout_ctx,
"attn_mask_type": attn_mask_type,
}
if self.use_flash_attention:
self.flash_attention = FlashAttention(norm_factor, **attn_kwargs)
# Instantiating both types since use of flash-attn
# might be ruled out due to forward inputs.
self.unfused_attention = UnfusedDotProductAttention(
norm_factor, **attn_kwargs, layer_number=layer_number)
def _checkpointed_attention_forward(
self,
attention_func: Callable,
*forward_args: Tuple[torch.Tensor, ...],
) -> torch.Tensor:
"""Forward method with activation checkpointing."""
def custom_forward(*inputs):
return attention_func(*inputs)
hidden_states = checkpoint(
custom_forward,
False,
self.get_rng_state_tracker,
self.tp_group,
*forward_args,
)
return hidden_states
def forward(
self,
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
checkpoint_core_attention: bool = False,
) -> torch.Tensor:
"""
Dot Product Attention Layer.
.. note::
Argument :attr:`attention_mask` will be ignored when :attr:`attn_mask_type`
is set to `"causal"`.
.. note::
Input tensors :attr:`query_layer`, :attr:`key_layer`, and :attr:`value_layer`
must each be of shape (:attr:`sequence_length`, :attr:`batch_size`,
:attr:`num_attention_heads`, :attr:`kv_channels`). Output of shape
(:attr:`sequence_length`, :attr:`batch_size`, :attr:`num_attention_heads`
* :attr:`kv_channels`) is returned.
Parameters
----------
query_layer : torch.Tensor
Query tensor.
key_layer : torch.Tensor
Key tensor.
value_layer : torch.Tensor
Value tensor.
attention_mask : Optional[torch.Tensor], default = `None`
Boolean tensor used to mask out softmax input when not using flash-attn.
checkpoint_core_attention : bool, default = `False`
If true, forward activations for attention are recomputed
during the backward pass in order to save memory that would
otherwise be occupied to store the forward activations until
backprop.
"""
use_flash_attention = self.use_flash_attention
if (query_layer.dtype not in [torch.bfloat16, torch.float16]
or key_layer.dtype not in [torch.bfloat16, torch.float16]
or value_layer.dtype not in [torch.bfloat16, torch.float16]
or (self.device_compute_capability == 8.6 and key_layer.shape[-1] > 64)
):
use_flash_attention = False
if is_in_onnx_export_mode():
use_flash_attention = False
if use_flash_attention:
if checkpoint_core_attention:
return self._checkpointed_attention_forward(self.flash_attention,
query_layer,
key_layer,
value_layer)
return self.flash_attention(query_layer, key_layer, value_layer)
if checkpoint_core_attention:
return self._checkpointed_attention_forward(
self.unfused_attention,
query_layer,
key_layer,
value_layer,
attention_mask,
)
return self.unfused_attention(query_layer, key_layer, value_layer, attention_mask)
class MultiHeadAttention(torch.nn.Module):
"""Parallel attention w/o QKV and Proj Gemms
BMM1 -> softmax + dropout -> BMM2
"""
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
kv_channels: int,
attention_dropout: float,
layernorm_epsilon: float,
init_method: Callable,
output_layer_init_method: Callable,
layer_number: Optional[int] = None,
attn_mask_type: str = "causal",
tp_group: Optional[dist_group_type] = None,
tp_size: int = 1,
fuse_wgrad_accumulation: bool = False,
get_rng_state_tracker: Optional[Callable] = None,
sequence_parallel: bool = False,
params_dtype: torch.dtype = torch.float32,
return_layernorm_output: bool = False,
input_layernorm: bool = False,
attention_type: str = "self",
set_parallel_mode: bool = False,
fuse_qkv_params: bool = False,
zero_centered_gamma: bool = False,
qkv_weight_interleaved: bool = True,
ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False,
ub_split_rs: bool = False,
ub_split_ag: bool = False,
bias: bool = True,
) -> None:
super().__init__()
self.layer_number = (layer_number,)
self.input_layernorm = input_layernorm
self.attention_type = attention_type
self.get_rng_state_tracker = get_rng_state_tracker
self.tp_group = tp_group
self.return_layernorm_output = return_layernorm_output
self.params_dtype = params_dtype
self.init_method = init_method
self.attn_mask_type = attn_mask_type
if not fuse_qkv_params:
qkv_weight_interleaved = False
self.qkv_weight_interleaved = qkv_weight_interleaved
assert (
attention_type in AttnTypes
), f"attention_type {attention_type} not supported"
tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
self.tp_size = tp_size
self.sequence_parallel = (tp_size > 1) and sequence_parallel
self.hidden_size_per_attention_head = kv_channels
self.num_attention_heads_per_partition = divide(num_attention_heads, tp_size)
common_gemm_kwargs = {
"fuse_wgrad_accumulation": fuse_wgrad_accumulation,
"tp_group": tp_group,
"tp_size": tp_size,
"get_rng_state_tracker": get_rng_state_tracker,
"sequence_parallel": sequence_parallel,
"params_dtype": params_dtype,
}
qkv_parallel_mode = "column" if set_parallel_mode else None
if self.attention_type == "self":
if self.input_layernorm:
self.layernorm_qkv = LayerNormLinear(
hidden_size,
3 * hidden_size,
eps=layernorm_epsilon,
init_method=init_method,
bias=bias,
return_bias=False,
parallel_mode=qkv_parallel_mode,
return_layernorm_output=return_layernorm_output,
parameters_split=("query_", "key_", "value_") if not fuse_qkv_params else None,
zero_centered_gamma=zero_centered_gamma,
ub_bulk_wgrad=ub_bulk_wgrad,
ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_ag=ub_split_ag,
**common_gemm_kwargs,
)
else:
self.qkv = Linear(
hidden_size,
3 * hidden_size,
init_method=init_method,
bias=bias,
return_bias=False,
parallel_mode=qkv_parallel_mode,
parameters_split=("query_", "key_", "value_") if not fuse_qkv_params else None,
**common_gemm_kwargs,
)
else:
if self.input_layernorm:
self.layernorm_query = LayerNormLinear(
hidden_size,
hidden_size,
eps=layernorm_epsilon,
init_method=init_method,
bias=bias,
return_bias=False,
parallel_mode=qkv_parallel_mode,
return_layernorm_output=return_layernorm_output,
zero_centered_gamma=zero_centered_gamma,
ub_bulk_wgrad=ub_bulk_wgrad,
ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_ag=ub_split_ag,
**common_gemm_kwargs,
)
else:
self.query_layer = Linear(
hidden_size,
hidden_size,
init_method=init_method,
bias=bias,
return_bias=False,
parallel_mode=qkv_parallel_mode,
**common_gemm_kwargs,
)
self.key_value = Linear(
hidden_size,
2 * hidden_size,
init_method=init_method,
bias=bias,
return_bias=False,
parallel_mode=qkv_parallel_mode,
parameters_split=("key_", "value_") if not fuse_qkv_params else None,
**common_gemm_kwargs,
)
# Attention.
self.core_attention = DotProductAttention(
num_attention_heads,
kv_channels,
attention_dropout,
tp_size=tp_size,
get_rng_state_tracker=get_rng_state_tracker,
attn_mask_type=attn_mask_type,
sequence_parallel=sequence_parallel,
tp_group=tp_group,
layer_number=layer_number,
)
# Linear
self.proj = Linear(
hidden_size,
hidden_size,
init_method=output_layer_init_method,
bias=bias,
return_bias=True,
parallel_mode="row" if set_parallel_mode else None,
ub_split_rs=ub_split_rs,
ub_split_ag=ub_split_ag,
**common_gemm_kwargs,
)
def _allocate_memory(
self, inference_max_sequence_len: int, batch_size: int
) -> torch.Tensor:
return torch.empty(
inference_max_sequence_len,
batch_size,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
dtype=self.params_dtype,
device=torch.cuda.current_device(),
)
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
"""Set TP group"""
self.tp_group = tp_group
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_output: Optional[torch.Tensor] = None,
is_first_microbatch: Optional[bool] = None,
checkpoint_core_attention: bool = False,
inference_params: Optional[Any] = None,
) -> Tuple[Union[torch.Tensor, None], ...]:
"""MultiHeadAttention FWD"""
# hidden_states: [sq, b, h]
if self.attn_mask_type != "causal" and attention_mask is not None:
assert (
attention_mask.dtype == torch.bool
), "Attention mask must be a boolean tensor"
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
if inference_params and self.layer_number is not None:
if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_len
inf_max_batch_size = inference_params.max_batch_size
inference_key_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size
)
inference_value_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size
)
inference_params.key_value_memory_dict[self.layer_number] = (
inference_key_memory,
inference_value_memory,
)
else:
(
inference_key_memory,
inference_value_memory,
) = inference_params.key_value_memory_dict[self.layer_number]
# =====================
# Query, Key, and Value
# =====================
if self.attention_type == "self":
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
if self.input_layernorm:
layernorm_qkv_outputs = self.layernorm_qkv(
hidden_states,
is_first_microbatch=is_first_microbatch,
)
if self.return_layernorm_output:
mixed_x_layer, layernorm_output = layernorm_qkv_outputs
else:
mixed_x_layer = layernorm_qkv_outputs
else:
mixed_x_layer = self.qkv(
hidden_states,
is_first_microbatch=is_first_microbatch,
)
if self.qkv_weight_interleaved:
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
# split along last dimension
split_dim = -1
else:
# [sq, b, (np * 3 * hn)] --> [sq, b, 3 * np, hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + (
3 * self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
# split along second last dimension
split_dim = -2
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# mixed_x_layer --> 3 [sq, b, np, hn]
if split_dim == -1 and not is_in_onnx_export_mode():
query_layer, key_layer, value_layer = _SplitLastDim.apply(mixed_x_layer, 3)
else:
query_layer, key_layer, value_layer = split_tensor_along_dim(
mixed_x_layer, split_dim, 3
)
else:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer = self.key_value(
encoder_output,
is_first_microbatch=is_first_microbatch,
)
if self.qkv_weight_interleaved:
# [sq, b, (np * 2 * hn)] --> [sq, b, np, 2 * hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
2 * self.hidden_size_per_attention_head,
)
# split along last dimension
split_dim = -1
else:
# [sq, b, (np * 2 * hn)] --> [sq, b, 2 * np, hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + (
2 * self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
# split along second last dimension
split_dim = -2
mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
# mixed_kv_layer --> 2 [sk, b, np, hn]
if split_dim == -1 and not is_in_onnx_export_mode():
key_layer, value_layer = _SplitLastDim.apply(mixed_kv_layer, 2)
else:
key_layer, value_layer = split_tensor_along_dim(mixed_kv_layer, split_dim, 2)
# Attention head [sq, b, h] --> [sq, b, hp]
if self.input_layernorm:
layernorm_query_outputs = self.layernorm_query(
hidden_states,
is_first_microbatch=is_first_microbatch,
)
if self.return_layernorm_output:
query_layer, layernorm_output = layernorm_query_outputs
else:
query_layer = layernorm_query_outputs
else:
query_layer = self.query_layer(
hidden_states,
is_first_microbatch=is_first_microbatch,
)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape = query_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
query_layer = query_layer.view(*new_tensor_shape)
# ==================================
# Adjust key and value for inference
# ==================================
if inference_params and self.layer_number is not None:
batch_start = inference_params.batch_size_offset
batch_end = batch_start + key_layer.size(1)
assert batch_end <= inference_key_memory.size(1)
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + key_layer.size(0)
assert sequence_end <= inference_key_memory.size(0)
# Copy key and values.
inference_key_memory[
sequence_start:sequence_end, batch_start:batch_end, ...
] = key_layer
inference_value_memory[
sequence_start:sequence_end, batch_start:batch_end, ...
] = value_layer
key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
value_layer = inference_value_memory[
:sequence_end, batch_start:batch_end, ...
]
# ==================================
# core attention computation
# ==================================
context_layer = self.core_attention(
query_layer,
key_layer,
value_layer,
attention_mask,
checkpoint_core_attention=checkpoint_core_attention,
)
# =================
# Output. [sq, b, h]
# =================
attention_output, attention_bias = self.proj(
context_layer, is_first_microbatch=is_first_microbatch
)
if self.input_layernorm and self.return_layernorm_output:
return attention_output, attention_bias, layernorm_output
return attention_output, attention_bias
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Module level PyTorch APIs"""
from .layernorm_linear import LayerNormLinear
from .linear import Linear
from .layernorm_mlp import LayerNormMLP
from .layernorm import LayerNorm
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Base modules and utilities for TransformerEngine PyTorch API"""
import os
import pickle
import warnings
from abc import ABC, abstractmethod
from typing import Union, Optional, Tuple, Dict, Any, List
from functools import partial
from contextlib import contextmanager
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import transformer_engine_extensions as tex
from ..fp8 import (
is_fp8_enabled,
is_fp8_calibration,
get_fp8_recipe,
get_fp8_group,
get_default_fp8_recipe,
get_fp8_te_dtype,
is_first_fp8_module,
new_fp8_context_id,
get_fp8_context_id,
set_fp8_context_id,
add_amax_to_global_buffer,
copy_amax_from_global_buffer,
global_amax_reduction,
setup_amax_forward_global_reduce_func,
amax_and_scale_update,
get_global_fp8_buffer,
set_global_fp8_buffer,
set_amax_buffer_key_deletion,
delete_key_from_amax_buffer,
copy_forward_fp8_meta_tensors_for_recompute,
get_old_fp8_meta_tensors_for_recompute,
restore_fp8_meta_tensors,
get_amax_reduce_handle_fwd,
)
from ..distributed import (
gather_along_first_dim,
is_fp8_activation_recompute_enabled,
in_fp8_activation_recompute_phase,
)
from ..cpp_extensions import (
fp8_cast_transpose_fused,
fp8_cast_transpose_bgrad_fused,
cast_to_fp8,
)
from ..constants import dist_group_type
_2X_ACC_FPROP = False
_2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True
_cublas_workspace = None
_ub_communicators = None
_NUM_MAX_UB_STREAMS = 3
_amax_reduce_handle_bwd = None
def get_cublas_workspace_size_bytes() -> None:
"""Return 32 MiB if using hopper, 4 MiB for all other architectures."""
if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9:
return 33_554_432
return 4_194_304
def get_workspace() -> torch.Tensor:
"""Returns workspace for cublas."""
global _cublas_workspace
if _cublas_workspace is None:
_cublas_workspace = torch.empty(
get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda"
)
return _cublas_workspace
@contextmanager
def _prepare_backward(
fp8: bool,
fp8_meta: Dict[str, Any],
tp_group: dist_group_type,
tp_size: int,
name: str = ""
) -> None:
"""Checks and prep for BWD."""
if fp8:
global _amax_reduce_handle_bwd
if _amax_reduce_handle_bwd is not None:
_amax_reduce_handle_bwd.wait()
_amax_reduce_handle_bwd = None
# Update amax and scale; Skip all setup for global amax reduction
if not fp8_meta["recipe"].reduce_amax:
amax_and_scale_update(fp8_meta, False)
else:
# From previous iteration
copy_amax_from_global_buffer(fp8_meta, forward=False)
amax_and_scale_update(fp8_meta, False)
set_amax_buffer_key_deletion(fp8_meta, forward=False)
# Get new backward key.
fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd_stack"].pop(0)
add_amax_to_global_buffer(fp8_meta, forward=False)
with torch.cuda.nvtx.range(name + " backward"):
yield
if fp8 and fp8_meta["recipe"].reduce_amax:
if fp8_meta["first_module"]:
_amax_reduce_handle_bwd = global_amax_reduction(
fp8_meta,
tp_group,
tp_size,
forward=False
)
delete_key_from_amax_buffer(forward=False)
def initialize_ub(
shape: list,
tp_size: int,
use_fp8: bool = False,
ub_cfgs: Optional[dict] = None
) -> None:
"""Initialize communicators for TP comm overlap using userbuffers."""
global _ub_communicators
assert _ub_communicators is None, "UB communicators are already initialized."
_ub_communicators = {}
rank_id = torch.distributed.get_rank()
# Increase the workspace by the number of maximum concurrent streams
global _cublas_workspace
_cublas_workspace = get_workspace().repeat(_NUM_MAX_UB_STREAMS)
# Default buffer precision: AllGather buffers use fp8 when using fp8 recipe
fp8_buf = [
"qkv_fprop", "qkv_dgrad", "proj_dgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad"
]
# Default overlap methods for layers
methods = {
"ring_exchange":["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"],
"pipeline":["proj_fprop", "fc2_fprop"],
"bulk":["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
}
def get_method(name):
for method, names in methods.items():
if name in names:
return method
raise KeyError(f"Given layer name {name} does not exist.")
def add_ub(
name: str,
method: str,
num_sm: int = 16,
cga_size: int = 2,
set_sm_margin: int = 0,
num_splits: int = 4,
aggregate: int = 0,
) -> None:
dtype = torch.uint8 if (use_fp8 and name in fp8_buf) else torch.bfloat16
sample_buffer = torch.empty(shape, dtype=dtype, device='cuda')
if method == 'ring_exchange':
ub_obj = tex.UbufP2PCommOverlap(
sample_buffer, # Sample userbuffer
rank_id, # Rank id
tp_size, # TP size
aggregate, # Aggregate 2X GEMM chunks
_NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams
)
else:
ub_obj = tex.UbufCommOverlap(
sample_buffer, # Sample userbuffer
rank_id, # Rank id
tp_size, # TP size
num_sm, # Number of communication SMs
cga_size, # CGA cluster size
num_splits, # Number of communication splits
set_sm_margin, # Set SM margin
_NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams
)
_ub_communicators[name] = ub_obj
for name in (methods["ring_exchange"]+methods["pipeline"]+methods["bulk"]):
if ub_cfgs is not None and name in ub_cfgs:
ub_cfg = ub_cfgs[name]
method = ub_cfg["method"] if "method" in ub_cfg else get_method(name)
num_sm = ub_cfg["num_sm"] if "num_sm" in ub_cfg else 16
cga_size = ub_cfg["cga_size"] if "cga_size" in ub_cfg else 2
num_splits = ub_cfg["num_splits"] if "num_splits" in ub_cfg else 0
set_sm_margin = ub_cfg["set_sm_margin"] if "set_sm_margin" in ub_cfg else 0
aggregate = ub_cfg["aggregate"] if "aggregate" in ub_cfg else 0
add_ub(
name,
method,
num_sm,
cga_size,
set_sm_margin,
num_splits,
aggregate
)
else:
method = get_method(name)
if method == "pipeline":
add_ub(name, method)
else:
add_ub(name, method, num_splits=0)
def get_ub(name: str):
"""Get userbuffer communicator corresponding to give key."""
global _ub_communicators
assert _ub_communicators is not None, "UB manager is not initialized."
assert name in _ub_communicators, f"UB for {name} is not registered."
return _ub_communicators[name]
class _NoopCat(torch.autograd.Function):
"""This class is a no-op replacement for `torch.cat`."""
@staticmethod
def forward(ctx,
full_param_buffer: torch.Tensor,
*params_split: Tuple[torch.Tensor, ...],
) -> torch.Tensor:
assert not full_param_buffer.requires_grad, "Buffers should not require gradient"
assert (
full_param_buffer.shape[0] % len(params_split) == 0
), "Dimensions not compatible for concatenation"
param_temp = full_param_buffer.new()
param_temp.set_(full_param_buffer.storage(),
full_param_buffer.storage_offset(),
full_param_buffer.size(),
full_param_buffer.stride())
param_temp.requires_grad = True
ctx.save_for_backward(full_param_buffer, *params_split)
return param_temp
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
full_param_buffer, *params_split = ctx.saved_tensors
split_size = full_param_buffer.shape[0] // len(params_split)
grads = []
for i, _ in enumerate(params_split):
grads.append(grad_output[i * split_size : (i+1) * split_size])
return None, *grads
class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Base TE module."""
def __init__(self) -> None:
super().__init__()
assert torch.cuda.is_available(), "TransformerEngine needs CUDA."
self.fp8_initialized = False
self.fp8 = False
self.fp8_calibration = False
self.fp8_meta = {}
self.fp8_meta["fp8_group"] = None
self.fp8_meta["recipe"] = get_default_fp8_recipe()
self.fp8_meta_tensors_initialized = False
self.tp_group = None
self.tp_size = 1
self.sequence_parallel = False
self.fp8_weight_shapes = []
self.fp8_meta["autocast_id_fwd_stack"] = []
self.fp8_meta["async_amax_reduction"] = bool(
int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "0"))
)
def set_meta_tensor(self, fwd: bool) -> None:
"""Init scales and amaxes for fwd | bwd."""
fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd"
if self.fp8_meta_tensors_initialized:
# Handle changed amax history size.
curr_len = self.fp8_meta[fp8_meta_tensor_key].amax_history.shape[0]
need_len = self.fp8_meta["recipe"].amax_history_len
if need_len < curr_len:
self.fp8_meta[fp8_meta_tensor_key].amax_history = (
self.fp8_meta[fp8_meta_tensor_key]
.amax_history[: self.fp8_meta["recipe"].amax_history_len].clone()
)
elif need_len > curr_len:
extra_rows = need_len - curr_len
self.fp8_meta[fp8_meta_tensor_key].amax_history = F.pad(
self.fp8_meta[fp8_meta_tensor_key].amax_history, pad=(0, 0, 0, extra_rows)
)
return
# Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
# 2 (grad_output and grad_input) for bwd
num_fp8_tensors = (
self.fp8_meta["num_gemms"] * 3 if fwd else self.fp8_meta["num_gemms"] * 2
)
self.fp8_meta[fp8_meta_tensor_key] = tex.FP8TensorMeta()
self.fp8_meta[fp8_meta_tensor_key].scale = torch.ones(
num_fp8_tensors, dtype=torch.float32, device="cuda"
)
self.fp8_meta[fp8_meta_tensor_key].scale_inv = torch.ones(
num_fp8_tensors, dtype=torch.float32, device="cuda"
)
self.fp8_meta[fp8_meta_tensor_key].amax_history = torch.zeros(
self.fp8_meta["recipe"].amax_history_len,
num_fp8_tensors,
dtype=torch.float32,
device="cuda",
)
# Needed for calculation of scale inverses to
# preserve scale_inv when caching FP8 weights
if fwd:
# [True, False, True]: -> [input, weight, output]
self.fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] = torch.BoolTensor(
[True, False, True] * self.fp8_meta["num_gemms"]
).cuda()
else:
# [True, True]: -> [grad_output, grad_input]
self.fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] = torch.BoolTensor(
[True, True] * self.fp8_meta["num_gemms"]
).cuda()
def init_fp8_meta_tensors(self) -> None:
"""Init scales and amaxes."""
self.set_meta_tensor(True)
self.set_meta_tensor(False)
self.fp8_meta_tensors_initialized = True
def get_extra_state(self) -> torch.Tensor:
"""Save before checkpointing."""
state = None
if self.fp8 or self.fp8_calibration:
state = {}
state["scale_fwd"] = self.fp8_meta["scaling_fwd"].scale
state["scale_inv_fwd"] = self.fp8_meta["scaling_fwd"].scale_inv
state["amax_history_fwd"] = self.fp8_meta["scaling_fwd"].amax_history
state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale
state["scale_inv_bwd"] = self.fp8_meta["scaling_bwd"].scale_inv
state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history
state["global_fp8_buffer"] = get_global_fp8_buffer()
# Store other pickelable values.
extra = {}
for k, v in self.fp8_meta.items():
if isinstance(v, (bool, int, float, str)):
extra[k] = v
state["extra_fp8_variables"] = extra
state_serialized = pickle.dumps(state)
state_tensor = torch.tensor(np.frombuffer(state_serialized, dtype=np.uint8))
return state_tensor
def set_extra_state(self, state: torch.Tensor) -> None:
"""Load previous state."""
if state is None:
return
# Maintain backward compatibility with v0.2.0 and older.
if isinstance(state, list):
warnings.warn(
"This checkpoint format is deprecated and will be"
"removed in a future release of Transformer Engine"
)
# Retrieve checkpointed items.
scale_fwd = state[0]
amax_history_fwd = state[1]
scale_bwd = state[2]
amax_history_bwd = state[3]
self.fp8_meta["recipe"].amax_history_len = amax_history_fwd.shape[0]
self.fp8_meta["num_gemms"] = (
amax_history_fwd.shape[1] // 2
) # Two FWD tensors per GEMM
# Initialize before loading
self.init_fp8_meta_tensors()
self.fp8_meta["scaling_fwd"].scale.copy_(scale_fwd)
self.fp8_meta["scaling_fwd"].amax_history.copy_(amax_history_fwd)
self.fp8_meta["scaling_bwd"].scale.copy_(scale_bwd)
self.fp8_meta["scaling_bwd"].amax_history.copy_(amax_history_bwd)
# Restore global FP8 buffer state.
set_global_fp8_buffer(state[4])
self.fp8_meta["update_amax_and_scale_fwd"] = state[5]
self.fp8_meta["global_fp8_buffer_pos_fwd"] = state[6]
self.fp8_meta["global_fp8_buffer_pos_bwd"] = state[7]
self.fp8_meta["autocast_id_fwd"] = state[8]
self.fp8_meta["autocast_id_bwd"] = state[9]
return
if isinstance(state, torch.Tensor):
state = pickle.loads(state.detach().cpu().numpy().tobytes())
if state is None:
return
# Restore global FP8 buffer states.
set_global_fp8_buffer(state["global_fp8_buffer"])
# Load extra items.
self.fp8_meta.update(state["extra_fp8_variables"])
self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0]
if "global_fp8_buffer_pos_fwd_recompute" in self.fp8_meta:
del self.fp8_meta["global_fp8_buffer_pos_fwd_recompute"]
# Initialize before loading.
self.init_fp8_meta_tensors()
self.fp8_meta["scaling_fwd"].scale.copy_(state["scale_fwd"])
self.fp8_meta["scaling_fwd"].amax_history.copy_(state["amax_history_fwd"])
self.fp8_meta["scaling_bwd"].scale.copy_(state["scale_bwd"])
self.fp8_meta["scaling_bwd"].amax_history.copy_(state["amax_history_bwd"])
# Backwards compatibility: compute scale inv if it wasn't saved in the extra state.
if "scale_inv_fwd" not in state or "scale_inv_bwd" not in state:
assert (
"scale_inv_fwd" not in state and "scale_inv_bwd" not in state
), "Invalid state, began saving scale_inv_fwd and scale_inv_bwd at the same time"
self.fp8_meta["scaling_fwd"].scale_inv.copy_(1.0/state["scale_fwd"])
self.fp8_meta["scaling_bwd"].scale_inv.copy_(1.0/state["scale_bwd"])
else:
self.fp8_meta["scaling_fwd"].scale_inv.copy_(state["scale_inv_fwd"])
self.fp8_meta["scaling_bwd"].scale_inv.copy_(state["scale_inv_bwd"])
def set_activation_dtype(self, inp: torch.Tensor) -> None:
"""Get activation data type for AMP."""
# Native AMP (`torch.autocast`) gets highest priority
if torch.is_autocast_enabled():
self.activation_dtype = torch.get_autocast_gpu_dtype()
return
# All checks after this have already been performed once, thus skip
# We assume that user doesn't change input types across iterations
if hasattr(self, "activation_dtype"):
return
assert all(
(
(inp.dtype == param.dtype) if param is not None else True
for param in self.parameters()
)
), (
"Data type for activations and weights must "
"match when outside of autocasted region"
)
assert all(
(
(inp.dtype == buf.dtype) if buf is not None else True
for buf in self.buffers()
)
), (
"Data type for activations and buffers must "
"match when outside of autocasted region"
)
self.activation_dtype = inp.dtype
def set_fp8_weights(self) -> None:
"""Initializes FP8 weights for the module as class attributes. These
are not parameters or buffers since we do not want functions such as
`.to(dtype)` or `.to(device)` to effect them. These also do not need
to be checkpointed. During `init` phase of the module, the attribute
`fp8_weight_shapes` must be populated with the tensor shapes for FP8
weights. This function will iterate over those shapes and initialize
respective attributed named `weight1_fp8`, `weight2_fp8`, ...
"""
if not self.fp8:
return
for i, shape in enumerate(self.fp8_weight_shapes, start=1):
weight_cast_attr = f"weight{i}_fp8"
weight_transpose_attr = f"weight{i}_t_fp8"
if (
hasattr(self, weight_cast_attr)
and getattr(self, weight_cast_attr).shape == shape
):
return
setattr(
self,
weight_cast_attr,
torch.empty(
shape,
device=torch.cuda.current_device(),
dtype=torch.uint8,
),
)
setattr(
self,
weight_transpose_attr,
torch.empty(
shape[1],
shape[0],
device=torch.cuda.current_device(),
dtype=torch.uint8,
),
)
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
"""Set TP group."""
self.tp_group = tp_group
self.tp_group_initialized = True
# This routine is shared across FP8 and FP8_calibration paths so should not actually
# assume FP8 execution.
def fp8_init(self, num_gemms: int = 1) -> None:
"""Initialize fp8 related metadata and tensors during fprop."""
self.fp8 = is_fp8_enabled()
self.fp8_calibration = is_fp8_calibration()
if self.fp8 or self.fp8_calibration:
# FP8 init has already been run and recipe is the same, don't do anything.
if self.fp8_initialized and get_fp8_recipe() == self.fp8_meta["recipe"]:
return
# Set FP8, recipe, and other FP8 metadata
self.fp8_meta["recipe"] = get_fp8_recipe()
self.fp8_meta["num_gemms"] = num_gemms
self.fp8_meta["fp8_group"] = get_fp8_group()
# Set FP8_MAX per tensor according to recipe
self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd
self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd
# Allocate scales and amaxes
self.init_fp8_meta_tensors()
self.fp8_initialized = True
else:
# If fp8 isn't enabled, turn off and return.
self.fp8_initialized = False
return
@contextmanager
def prepare_forward(
self,
inp: torch.Tensor,
is_first_microbatch: Union[bool, None],
num_gemms: int = 1,
) -> None:
"""Checks and prep for FWD.
The context manager is needed because there isn't a way for a module to know
if it's the last FP8 module in the forward autocast. It is useful
to setup the forward aggregated amax reduction for every module
just in case. The autocast exit will pick up the most recent one.
"""
# Activation recomputation is used and this is the second forward phase.
if self.fp8 and in_fp8_activation_recompute_phase():
get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
else:
assert inp.is_cuda, "TransformerEngine needs CUDA."
if self.tp_size > 1:
assert self.tp_group_initialized, "TP group not initialized."
self.set_activation_dtype(inp)
self.fp8_init(num_gemms=num_gemms)
self.set_fp8_weights()
update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch
if self.fp8 and self.sequence_parallel:
assert self.fp8_meta["recipe"].reduce_amax, \
"Amax reduction across tensor parallel group is " \
"necessary when using sequence parallelism with FP8."
# Previous iteration was grad_enabled
if self.fp8_meta.get("update_amax_and_scale_fwd", False):
if self.fp8_meta["recipe"].reduce_amax:
copy_amax_from_global_buffer(self.fp8_meta, forward=True)
amax_and_scale_update(
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
)
set_amax_buffer_key_deletion(self.fp8_meta, forward=True)
else:
amax_and_scale_update(
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
)
if self.fp8 and self.training:
# Setup for amax reduction
if self.fp8_meta["recipe"].reduce_amax:
self.fp8_meta["first_module"] = is_first_fp8_module()
if self.fp8_meta["first_module"]:
# Wait for the prior AMAX reduction to finish
amax_reduce_handle_fwd = get_amax_reduce_handle_fwd()
if amax_reduce_handle_fwd is not None:
amax_reduce_handle_fwd.wait()
self.fp8_meta["autocast_id_fwd"] = new_fp8_context_id()
set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
else:
self.fp8_meta["autocast_id_fwd"] = get_fp8_context_id()
self.fp8_meta["autocast_id_fwd_stack"].append(
self.fp8_meta["autocast_id_fwd"]
)
add_amax_to_global_buffer(self.fp8_meta, forward=True)
self.fp8_meta["update_amax_and_scale_fwd"] = True
else:
self.fp8_meta["update_amax_and_scale_fwd"] = False
# Activation recomputation is used and this is the first forward phase.
if (
self.fp8
and self.training
and is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase()
):
copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"):
yield inp.contiguous()
if self.fp8 and in_fp8_activation_recompute_phase():
restore_fp8_meta_tensors(self.fp8_meta)
return
if self.fp8 and self.training and self.fp8_meta["recipe"].reduce_amax:
set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
reduce_func = partial(
global_amax_reduction,
self.fp8_meta,
self.tp_group,
self.tp_size,
forward=True
)
setup_amax_forward_global_reduce_func(reduce_func)
def set_nccl_overlap_warning_if_tp(self) -> None:
"""When using TP, the NCCL communication needs to be scheduled
before the GEMM for there to be a guaranteed overlap. From the
host side in TE, the comm calls are always launched first, but
to ensure that the GEMM isn't scheduled first, the environment
variable `CUDA_DEVICE_MAX_CONNECTIONS` needs to be set to 1 to
force a single channel.
"""
if self.tp_size == 1:
return
num_cuda_work_queues = int(os.getenv("CUDA_DEVICE_MAX_CONNECTIONS", "0"))
if num_cuda_work_queues != 1:
warnings.warn(
"To guarantee overlapping TP and SP collectives with the backward"
"GEMMs, set environment variable CUDA_DEVICE_MAX_CONNECTIONS = 1"
)
@staticmethod
def grad_output_preprocess(
ctx, grad_output: torch.Tensor, row_parallel_mode: bool
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Utility function for backward.
Returns tuple in order (all optional/None based on training precion/recipe):
R1: gathered `grad_output` in higher precision.
R2: gathered `grad_output` in FP8.
R3: R2 transposed.
R4: bias gradient on R1.
"""
grad_output = grad_output.contiguous()
grad_output_mat = grad_output.view((-1, grad_output.shape[-1]))
gather_grad_output = row_parallel_mode and ctx.sequence_parallel
# No-FP8 case: bgrad is fused with wgrad for this case.
if not ctx.fp8:
if gather_grad_output:
if not ctx.ub_split_ag:
grad_output_mat, _ = gather_along_first_dim(
grad_output_mat, ctx.tp_group
)
else:
ctx.ub_obj_gradout.copy_input_to_ubuf(grad_output, True)
grad_output_mat = ctx.ub_obj_gradout.get_ubuf_output(1)
return grad_output_mat, None, None, None
fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False
)
# FP8 case with non-FP8 wgrad
if (
gather_grad_output
and ctx.fp8_meta["recipe"].override_linear_precision.wgrad
):
assert (
not ctx.ub_split_ag
), "override_linear_precision.wgrad not supported with ub_split_ag"
grad_output_mat, _ = gather_along_first_dim(grad_output_mat, ctx.tp_group)
# FP8 case with gather: unfused bgrad, cast, transpose for efficient gather
elif gather_grad_output:
if ctx.use_bias:
grad_bias = grad_output_mat.sum(dim=0)
else:
grad_bias = None
if ctx.ub_split_ag:
grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(0)
else:
grad_output_c = torch.empty_like(grad_output_mat, dtype=torch.uint8)
cast_to_fp8(
grad_output_mat,
ctx.fp8_meta["scaling_bwd"],
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
out=grad_output_c,
)
if not ctx.ub_split_ag:
grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group)
grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
else:
grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(1)
grad_output_t = None
return grad_output_mat, grad_output_c, grad_output_t, grad_bias
# FP8 case without gather: cast, transpose, bgrad fused
if ctx.use_bias:
grad_bias, grad_output_c, grad_output_t = fp8_cast_transpose_bgrad_fused(
grad_output_mat,
ctx.fp8_meta["scaling_bwd"],
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
)
else:
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
grad_output_c, grad_output_t = fp8_cast_transpose_fused(
grad_output_mat,
ctx.fp8_meta["scaling_bwd"],
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
)
else:
grad_output_t = None
grad_output_c = cast_to_fp8(
grad_output_mat,
ctx.fp8_meta["scaling_bwd"],
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
)
grad_bias = None
return grad_output_mat, grad_output_c, grad_output_t, grad_bias
def noop_cat(self, buffer_name: str, pnames: List[str]) -> torch.Tensor:
"""No-op replacement of `torch.cat`. The buffer and split parameters must occupy
the same memory region. If this is not the case, then the split parameters
are concatenated and the buffer is overwritten. The parameters' memory is then
re-assigned to point to the buffer to avoid subsequent concatenations.
"""
assert hasattr(self, buffer_name), f"No buffer named {buffer_name}"
full_param_buffer = getattr(self, buffer_name)
split_size = full_param_buffer.shape[0] // len(pnames)
params = [getattr(self, name) for name in pnames]
for i, p in enumerate(params):
if p.data.data_ptr() != full_param_buffer[i*split_size : (i+1)*split_size].data_ptr():
with torch.no_grad():
setattr(self, buffer_name, torch.cat(params))
for j, pname in enumerate(pnames):
full_param_buffer = getattr(self, buffer_name)
setattr(self, pname,
Parameter(full_param_buffer[j*split_size : (j+1)*split_size]))
break
return _NoopCat.apply(getattr(self, buffer_name), *[getattr(self, name) for name in pnames])
@abstractmethod
def forward(self):
"""Needs override."""
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""LayerNorm API"""
import os
from typing import Union, Tuple, Any, Mapping
import torch
from torch.nn.parameter import Parameter
from torch.nn import init
import transformer_engine_extensions as tex
__all__ = ["LayerNorm"]
class _LayerNorm(torch.autograd.Function):
"""functional LayerNorm"""
@staticmethod
def forward(
ctx,
inp: torch.Tensor,
ln_weight: torch.Tensor,
ln_bias: torch.Tensor,
eps: float,
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
) -> torch.Tensor:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
assert inp.is_cuda, "TransformerEngine needs CUDA."
assert inp.shape[-1] == in_features, "LayerNorm not possible"
inputmat = inp.view((-1, in_features))
ln_out, mu, rsigma = tex.layernorm_fwd(inputmat, ln_weight,
ln_bias, eps, fwd_ln_sm_margin,
zero_centered_gamma)
ctx.save_for_backward(inputmat, ln_weight, mu, rsigma)
ctx.inp_shape = inp.shape
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma
return ln_out.view_as(inp)
@staticmethod
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
inputmat, ln_weight, mu, rsigma = ctx.saved_tensors
grad_output = grad_output.contiguous()
d_ln_out = grad_output.view(inputmat.shape)
dxmat, dgamma, dbeta = tex.layernorm_bwd(
d_ln_out, inputmat, mu, rsigma, ln_weight,
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
)
return dxmat.view(ctx.inp_shape), dgamma, dbeta, None, None, None, None
class LayerNorm(torch.nn.Module):
r"""
Applies Layer Normalization over a mini-batch of inputs as described in
the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
size :attr:`hidden_size`
Parameters
----------
hidden_size : int
size of each input sample.
eps : float, default = 1e-5
a value added to the denominator of layer normalization for numerical stability.
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
params_dtype : torch.dtype, default = `torch.float32`
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
(1 + \gamma) + \beta
"""
def __init__(
self,
hidden_size: int,
eps: float = 1e-5,
sequence_parallel: bool = False,
params_dtype: torch.dtype = torch.float32,
zero_centered_gamma: bool = False,
) -> None:
super().__init__()
self.eps = eps
self.zero_centered_gamma = zero_centered_gamma
self.weight = Parameter(
torch.empty(
hidden_size,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
self.bias = Parameter(
torch.empty(
hidden_size,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
setattr(self.weight, "sequence_parallel", sequence_parallel)
setattr(self.bias, "sequence_parallel", sequence_parallel)
self.reset_layer_norm_parameters()
# These many SMs are subtracted from the total SM count when calling forward
# and backward LayerNorm C APIs. These envvars can be used to prevent the LN
# kernels from using all SMs in the device. This is useful for cases such as
# communication overlap with LN.
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
def load_state_dict(
self,
state_dict: Mapping[str, Any],
strict: bool = True,
) -> None:
"""Override PyTorch loader to maintain backward compatibility
with previous version of LayerNorm parameter names.
"""
if "layer_norm_weight" in state_dict:
state_dict["weight"] = state_dict["layer_norm_weight"]
del state_dict["layer_norm_weight"]
if "layer_norm_bias" in state_dict:
state_dict["bias"] = state_dict["layer_norm_bias"]
del state_dict["layer_norm_bias"]
super().load_state_dict(state_dict, strict)
def reset_layer_norm_parameters(self) -> None:
"""Init LN params"""
if not self.zero_centered_gamma:
init.ones_(self.weight)
else:
init.zeros_(self.weight)
init.zeros_(self.bias)
def forward(self, inp: torch.Tensor) -> torch.Tensor:
"""LayerNorm FWD"""
# Maintain backward compatibility.
if hasattr(self, "layer_norm_weight"):
setattr(self, "weight", self.layer_norm_weight)
if hasattr(self, "layer_norm_bias"):
setattr(self, "bias", self.layer_norm_bias)
return _LayerNorm.apply(
inp,
self.weight,
self.bias,
self.eps,
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma
)
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""LayerNormLinear API"""
import os
from typing import Union, Optional, Callable, Tuple, Dict, Any
import torch
from torch.nn.parameter import Parameter
from torch.nn import init
import transformer_engine_extensions as tex
from .base import (
get_workspace,
_prepare_backward,
get_ub,
TransformerEngineBaseModule,
_2X_ACC_FPROP,
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ..fp8 import get_fp8_te_dtype
from ..utils import (
divide,
get_default_init_method,
cast_if_needed,
check_dim_for_fp8_forward_exec,
)
from ..distributed import (
set_tensor_model_parallel_attributes,
get_distributed_world_size,
allreduce,
initialize_affine_weight_gpu,
reduce_scatter_along_first_dim,
gather_along_first_dim,
)
from ..cpp_extensions import (
fp8_gemm,
gemm,
fp8_cast_transpose_fused,
layernorm_fwd_fp8,
layernorm_fwd_fp8_inf,
layernorm_fwd_inf,
cast_to_fp8,
cast_from_fp8,
)
from ..constants import GemmParallelModes, dist_group_type, TE_DType
__all__ = ["LayerNormLinear"]
class _LayerNormLinear(torch.autograd.Function):
"""LayerNormLinear semi-top level module
Calls custom cuda extensions.
"""
@staticmethod
def forward(
ctx,
inp: torch.Tensor,
ln_weight: torch.Tensor,
ln_bias: torch.Tensor,
weight: torch.Tensor,
weight_fp8: Union[torch.Tensor, None],
weight_t_fp8: Union[torch.Tensor, None],
bias: torch.Tensor,
use_bias: bool,
eps: float,
is_first_microbatch: Union[bool, None],
fp8: bool,
fp8_calibration: bool,
fp8_meta: Dict[str, Any],
fuse_wgrad_accumulation: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
sequence_parallel: bool,
tensor_parallel: bool,
activation_dtype: torch.dtype,
parallel_mode: Union[str, None],
return_layernorm_output: bool,
is_grad_enabled: bool,
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
ub_bulk_wgrad: bool,
ub_bulk_dgrad: bool,
ub_split_ag: bool,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
assert inp.shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features))
assert (
not fp8 or check_dim_for_fp8_forward_exec(inputmat, weight)
), "Input and weight dimensions are not compatible for FP8 execution."
update_fp8_weights = is_first_microbatch is None or is_first_microbatch
# Cast for native AMP
inputmat = cast_if_needed(inputmat, activation_dtype)
ln_weight = cast_if_needed(ln_weight, activation_dtype)
ln_bias = cast_if_needed(ln_bias, activation_dtype)
# If residual connection is after LN, we need `ln_out`
# tensor in higher precision, this comes at the cost
# of an extra fp8 cast.
if ub_split_ag:
tp_world_size = get_distributed_world_size(tp_group)
if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output:
ub_split_ag = False
if ub_split_ag:
dim_size = list(inputmat.size())
dim_size[0] = dim_size[0] * tp_world_size
ub_obj_lnout = get_ub("qkv_fprop")
ln_out = ub_obj_lnout.get_ubuf_output(0)
if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if not return_layernorm_output:
if is_grad_enabled:
if not ub_split_ag:
ln_out = torch.empty_like(inputmat, dtype=torch.uint8)
_, mu, rsigma = layernorm_fwd_fp8(
inputmat,
ln_weight,
ln_bias,
eps,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
fwd_ln_sm_margin,
zero_centered_gamma,
ln_out = ln_out
)
else:
mu = rsigma = None
ln_out = layernorm_fwd_fp8_inf(
inputmat,
ln_weight,
ln_bias,
eps,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
zero_centered_gamma,
)
else:
if is_grad_enabled:
ln_out_return, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
)
else:
ln_out_return, mu, rsigma = layernorm_fwd_inf(
inputmat, ln_weight, ln_bias, eps, zero_centered_gamma
), None, None
ln_out = cast_to_fp8(
ln_out_return,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
)
else:
if is_grad_enabled:
if ub_split_ag:
_, mu, rsigma = tex.layernorm_fwd_noalloc(
inputmat, ln_weight, ln_bias, ln_out, eps,
fwd_ln_sm_margin, zero_centered_gamma
)
else:
ln_out, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
)
else:
ln_out, mu, rsigma = layernorm_fwd_inf(
inputmat, ln_weight, ln_bias, eps, zero_centered_gamma
), None, None
ln_out_return = ln_out
# Column Parallel Linear
if ub_split_ag:
ln_out_total = ub_obj_lnout.get_ubuf_output(1)
ln_out = torch.empty_like(ln_out)
elif parallel_mode == "column" and sequence_parallel:
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
else:
ln_out_total = ln_out
if fp8:
bias_dtype = (
torch.bfloat16
if activation_dtype == torch.float32
else activation_dtype
)
bias = cast_if_needed(bias, bias_dtype) if use_bias else bias
if update_fp8_weights:
if is_grad_enabled:
fp8_cast_transpose_fused(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
cast_out=weight_fp8,
transpose_out=weight_t_fp8,
)
else:
weight_t_fp8 = None
weight_fp8 = cast_to_fp8(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward)
out = fp8_gemm(
weight_fp8,
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
ln_out_total,
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
activation_dtype,
get_workspace(),
bias=bias,
use_bias=use_bias,
use_split_accumulator=_2X_ACC_FPROP,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None,
ub=ub_obj_lnout if ub_split_ag else None,
extra_output_tensor=ln_out if ub_split_ag else None,
)
else:
# Cast for native AMP
weight = cast_if_needed(weight, activation_dtype)
bias = cast_if_needed(bias, activation_dtype) if use_bias else bias
if fp8_calibration:
# amax of input
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = \
torch.amax(ln_out_total).float()
# amax of weight
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \
torch.amax(weight).float()
out, _, _ = gemm(
weight,
ln_out_total,
activation_dtype,
get_workspace(),
bias=bias,
use_bias=use_bias,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None,
ub=ub_obj_lnout if ub_split_ag else None,
extra_output_tensor=ln_out if ub_split_ag else None,
)
if is_grad_enabled:
ctx.save_for_backward(
inputmat,
ln_weight,
mu,
rsigma,
weight,
weight_t_fp8,
ln_out,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
)
ctx.activation_dtype = activation_dtype
ctx.fp8 = fp8
ctx.fp8_meta = fp8_meta
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = use_bias
ctx.sequence_parallel = sequence_parallel
ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp.shape
ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group
ctx.tp_size = tp_size
ctx.return_layernorm_output = return_layernorm_output
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma
ctx.ub_bulk_wgrad = ub_bulk_wgrad
ctx.ub_bulk_dgrad = ub_bulk_dgrad
ctx.requires_dgrad = inp.requires_grad
# Row Parallel Linear
if parallel_mode == "row" and sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif parallel_mode == "row" and tensor_parallel:
out, _ = allreduce(out, tp_group)
# [*, in_features] -> [*, out_features] except first dimension changes for SP
out = out.view(-1, *inp.shape[1:-1], out.shape[-1])
if return_layernorm_output:
return out, ln_out_return.view_as(inp)
return out
@staticmethod
def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]:
with _prepare_backward(
ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_LayerNormLinear"
):
(
inputmat,
ln_weight,
mu,
rsigma,
weight,
weight_t_fp8,
ln_out,
fwd_scale_inverses,
) = ctx.saved_tensors
if ctx.ub_bulk_dgrad:
tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1:
ctx.ub_bulk_dgrad = False
if ctx.ub_bulk_dgrad:
dim_size = list(ln_out.size())
dim_size[0] = dim_size[0] * tp_world_size
ub_obj_lnout = get_ub("qkv_dgrad")
ub_obj_lnout.copy_input_to_ubuf(ln_out, 1)
(
grad_output,
grad_output_c,
grad_output_t,
grad_bias,
) = TransformerEngineBaseModule.grad_output_preprocess(
ctx, grad_outputs[0], ctx.parallel_mode == "row"
)
if ctx.ub_bulk_wgrad:
tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1:
ctx.ub_bulk_wgrad = False
# Column Parallel Linear
# Overlap input AG with dgrad
if (not ctx.ub_bulk_dgrad) and ctx.parallel_mode == "column" and ctx.sequence_parallel:
ln_out_total, handle = gather_along_first_dim(
ln_out, ctx.tp_group, async_op=True
)
else:
ln_out_total = ln_out
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (
ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
dgrad_size = list(grad_output.size())
dgrad_size[1] = weight.size(1)
if ctx.ub_bulk_wgrad: # allocate dgrad output
ub_obj_dgrad = get_ub("qkv_wgrad")
dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output
else:
dgrad = torch.empty (dgrad_size, dtype=ctx.activation_dtype, device=weight.device)
if ctx.fp8:
fp8_dtype_forward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=True
)
fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False
)
# DGRAD: Evaluated unconditionally to feed into Linear backward
_ = fp8_gemm(
weight_t_fp8,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
out=dgrad,
use_split_accumulator=_2X_ACC_DGRAD,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None,
ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None
)
else:
# DGRAD: Evaluated unconditionally to feed into Linear backward
_, _, _ = gemm(
weight,
grad_output,
ctx.activation_dtype,
get_workspace(),
out=dgrad,
layout="NN",
grad=True,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None,
ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None
)
if ctx.ub_bulk_dgrad:
ln_out_total = ub_obj_lnout.get_ubuf_output(1)
# Overlap dgrad-RS/AR with wgrad
if ctx.parallel_mode == "column" and ctx.sequence_parallel:
if not ctx.ub_bulk_dgrad:
handle.wait()
if not ctx.ub_bulk_wgrad:
dgrad, handle = reduce_scatter_along_first_dim(
dgrad, ctx.tp_group, async_op=True
)
elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)
if weight.requires_grad:
if ctx.fp8:
# WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward)
wgrad = fp8_gemm(
ln_out_total_t,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
grad_output_t,
ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
accumulate=accumulate_wgrad_into_param_main_grad,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=_2X_ACC_WGRAD,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS
if ctx.ub_bulk_wgrad else None,
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None
)
else:
ln_out_total_c = cast_from_fp8(
ln_out_total,
ctx.fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
TE_DType[ctx.activation_dtype],
)
wgrad, _, _ = gemm(
ln_out_total_c,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NT",
grad=True,
accumulate=accumulate_wgrad_into_param_main_grad,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS
if ctx.ub_bulk_wgrad else None,
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None
)
else:
# WGRAD
wgrad, grad_bias, _ = gemm(
ln_out_total,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NT",
grad=True,
use_bias=ctx.use_bias,
accumulate=accumulate_wgrad_into_param_main_grad,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None,
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None
)
if ctx.ub_bulk_wgrad:
dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output
# Column Parallel Linear
elif ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None:
handle.wait()
# LayerNorm gradient
d_ln_out = dgrad.view(inputmat.shape)
# Residual gradient
if ctx.return_layernorm_output:
d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out)
dxmat, dgamma, dbeta = tex.layernorm_bwd(
d_ln_out, inputmat, mu, rsigma, ln_weight,
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
)
if not ctx.use_bias:
grad_bias = None
return (
dxmat.view(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma,
dbeta,
wgrad if weight.requires_grad else None,
None,
None,
grad_bias,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
class LayerNormLinear(TransformerEngineBaseModule):
r"""
Applies layer normalization followed by linear transformation to the incoming data.
Parameters
----------
in_features : int
size of each input sample.
out_features : int
size of each output sample.
eps : float, default = 1e-5
a value added to the denominator of layer normalization for numerical stability.
bias : bool, default = `True`
if set to `False`, the layer will not learn an additive bias.
init_method : Callable, default = `None`
used for initializing weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
return_layernorm_output : bool, default = `False`
if set to `True`, output of layernorm is returned from the forward
together with the output of the linear transformation.
Example use case: residual connection for transformer module is
taken post layernorm.
parameters_split : Tuple[str, ...], default = None
if a tuple of strings is provided, the weight and bias parameters of the
module are exposed as `N` separate `torch.nn.parameter.Parameter`s each,
split along the first dimension, where `N` is the length of the argument
and the strings contained are the names of the split parameters.
zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
(1 + \gamma) + \beta
Parallelism parameters
----------------------
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
tp_size : int, default = 1
used as TP (tensor parallel) world size when TP groups are not formed during
initialization. In this case, users must call the
`set_tensor_parallel_group(tp_group)` method on the initialized module before the
forward pass to supply the tensor parallel group needed for tensor and sequence
parallel collectives.
parallel_mode : {None, 'Column', 'Row'}, default = `None`
used to decide whether this Linear layer is Column Parallel Linear or Row
Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
When set to `None`, no communication is performed.
skip_weight_param_allocation: bool, default = `False`
if set to `True`, weight parameter is not allocated and must be
passed as a keyword argument `weight` during the forward pass.
Optimization parameters
-----------------------
fuse_wgrad_accumulation : bool, default = 'False'
if set to `True`, enables fusing of creation and accumulation of
the weight gradient.
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the
output of the linear transformation :math:`y = xA^T`. This is useful when
the bias addition can be fused to subsequent operations.
params_dtype : torch.dtype, default = `torch.float32`
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
"""
def __init__(
self,
in_features: int,
out_features: int,
eps: float = 1e-5,
sequence_parallel: bool = False,
fuse_wgrad_accumulation: bool = False,
tp_group: Optional[dist_group_type] = None,
tp_size: int = 1,
get_rng_state_tracker: Optional[Callable] = None,
init_method: Optional[Callable] = None,
bias: bool = True,
return_bias: bool = False,
params_dtype: torch.dtype = torch.float32,
parallel_mode: Optional[str] = None,
return_layernorm_output: bool = False,
skip_weight_param_allocation: bool = False,
parameters_split: Optional[Tuple[str, ...]] = None,
zero_centered_gamma: bool = False,
ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False,
ub_split_ag: bool = False,
) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
self.use_bias = bias
self.return_bias = return_bias
self.apply_bias = bias and not return_bias
self.return_layernorm_output = return_layernorm_output
self.parameters_split = parameters_split
self.zero_centered_gamma = zero_centered_gamma
self.ub_bulk_wgrad = ub_bulk_wgrad
self.ub_bulk_dgrad = ub_bulk_dgrad
self.ub_split_ag = ub_split_ag
if ub_bulk_wgrad or ub_bulk_dgrad or ub_split_ag:
assert (
tex.userbuf_comm_available()
), "Userbuffer communication backend not available."
if tp_group is None:
self.tp_size = tp_size
if tp_size == 1:
self.set_tensor_parallel_group(tp_group)
else:
self.tp_size = get_distributed_world_size(tp_group)
self.set_tensor_parallel_group(tp_group)
self.set_nccl_overlap_warning_if_tp()
self.parallel_mode = parallel_mode
assert (
self.parallel_mode in GemmParallelModes
), f"parallel_mode {parallel_mode} not supported"
if self.parallel_mode == "column":
self.out_features = divide(self.out_features, self.tp_size)
elif self.parallel_mode == "row":
self.in_features = divide(self.in_features, self.tp_size)
if init_method is None:
init_method = get_default_init_method()
self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
self.eps = eps
self.layer_norm_weight = Parameter(
torch.empty(
in_features,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
self.layer_norm_bias = Parameter(
torch.empty(
in_features,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel)
setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel)
self.reset_layer_norm_parameters()
if not skip_weight_param_allocation:
self.register_buffer("weight_tensor",
torch.empty(
self.out_features,
self.in_features,
device=torch.cuda.current_device(),
dtype=params_dtype),
persistent=False)
initialize_affine_weight_gpu(
self.weight_tensor,
init_method,
get_rng_state_tracker,
partition_dim=1 if self.parallel_mode == "row" else 0,
stride=1,
)
if self.use_bias:
self.register_buffer("bias_tensor",
torch.empty(
self.out_features,
device=torch.cuda.current_device(),
dtype=params_dtype),
persistent=False)
else:
self.register_buffer("bias_tensor",
torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device()),
persistent=False)
with torch.no_grad():
self.bias_tensor.zero_()
if parameters_split is None:
parameters_split = ("",)
assert (
self.out_features % len(parameters_split) == 0
), f"Weight and bias params cannot be split into {len(parameters_split)} parts"
split_size = self.out_features // len(parameters_split)
self.weight_names = []
self.bias_names = []
for i, pname in enumerate(parameters_split):
wname = pname + "weight"
bname = pname + "bias"
self.register_parameter(
wname, Parameter(self.weight_tensor[i * split_size : (i+1) * split_size])
)
set_tensor_model_parallel_attributes(
tensor=getattr(self, wname),
is_parallel=True,
dim=1 if parallel_mode == "row" else 0,
stride=1,
)
if self.use_bias:
self.register_parameter(
bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size])
)
else:
self.register_buffer(bname,
torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device()),
persistent=False)
if parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1)
self.weight_names.append(wname)
self.bias_names.append(bname)
self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))
# For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM
if self.parallel_mode == "row" and self.apply_bias:
self.gemm_bias_unfused_add = True
else:
self.gemm_bias_unfused_add = False
# These many SMs are subtracted from the total SM count when calling forward
# and backward LayerNorm C APIs. These envvars can be used to prevent the LN
# kernels from using all SMs in the device. This is useful for cases such as
# communication overlap with LN.
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
def reset_layer_norm_parameters(self) -> None:
"""Init LN params"""
if not self.zero_centered_gamma:
init.ones_(self.layer_norm_weight)
else:
init.zeros_(self.layer_norm_weight)
init.zeros_(self.layer_norm_bias)
def forward(
self,
inp: torch.Tensor,
weight: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
is_first_microbatch: Optional[bool] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
"""
Apply layer normalization to the input followed by a linear transformation.
Parameters
----------
inp : torch.Tensor
Input tensor.
weight : torch.Tensor, default = None
An optional weight tensor for the module. This argument is compulsory if module
is initialized with `skip_weight_param_allocation=True`
bias : torch.Tensor, default = None
An optional bias tensor for the module. This argument is compulsory if module
is initialized with `skip_weight_param_allocation=True` and one of `use_bias`
or `return_bias`
is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split
into microbatches. Between the microbatches of the same minibatch
the model weights are not updated. Setting this parameter indicates
whether the current microbatch is the first in a minibatch or not.
When set, this parameter enables additional optimizations:
* during FP8 training, it allows caching of the FP8 versions of
the weights
* it also allows skipping gradient accumulation during the
first microbatch (since it is the first gradient being
produced)
"""
with self.prepare_forward(inp, is_first_microbatch) as inp:
bias_tensor = (
bias if bias is not None
else self.bias if self.parameters_split is None
else self.bias_tensor if not torch.is_grad_enabled()
else self.noop_cat("bias_tensor", self.bias_names)
)
weight_tensor = (
weight if weight is not None
else self.weight if self.parameters_split is None
else self.weight_tensor if not torch.is_grad_enabled()
else self.noop_cat("weight_tensor", self.weight_names)
)
if torch.is_grad_enabled():
fwd_fn = _LayerNormLinear.apply
args = []
else:
fwd_fn = _LayerNormLinear.forward
args = [None]
args += (
inp,
self.layer_norm_weight,
self.layer_norm_bias,
weight_tensor,
self.weight1_fp8 if self.fp8 else None,
self.weight1_t_fp8 if self.fp8 else None,
bias_tensor,
self.apply_bias and not self.gemm_bias_unfused_add,
self.eps,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.fp8_meta,
self.fuse_wgrad_accumulation,
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
self.return_layernorm_output,
torch.is_grad_enabled(),
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.ub_bulk_wgrad,
self.ub_bulk_dgrad,
self.ub_split_ag,
)
out = fwd_fn(*args)
if self.return_layernorm_output:
out, ln_out = out
if self.gemm_bias_unfused_add:
out = out + cast_if_needed(bias_tensor, self.activation_dtype)
if self.return_bias:
if self.return_layernorm_output:
return out, cast_if_needed(bias_tensor, self.activation_dtype), ln_out
return out, cast_if_needed(bias_tensor, self.activation_dtype)
if self.return_layernorm_output:
return out, ln_out
return out
......@@ -2,2327 +2,61 @@
#
# See LICENSE for license information.
"""Top level Transformer Engine PyTorch modules"""
"""LayerNormMLP API"""
import os
import pickle
import warnings
from abc import ABC, abstractmethod
from typing import Union, Optional, Callable, Tuple, Dict, Any, Mapping, List
from functools import partial
from contextlib import contextmanager
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.nn import init
import transformer_engine_extensions as tex
from .fp8 import (
is_fp8_enabled,
is_fp8_calibration,
get_fp8_recipe,
get_fp8_group,
get_default_fp8_recipe,
get_fp8_te_dtype,
is_first_fp8_module,
new_fp8_context_id,
get_fp8_context_id,
set_fp8_context_id,
add_amax_to_global_buffer,
copy_amax_from_global_buffer,
global_amax_reduction,
setup_amax_forward_global_reduce_func,
amax_and_scale_update,
get_global_fp8_buffer,
set_global_fp8_buffer,
set_amax_buffer_key_deletion,
delete_key_from_amax_buffer,
copy_forward_fp8_meta_tensors_for_recompute,
get_old_fp8_meta_tensors_for_recompute,
restore_fp8_meta_tensors,
get_amax_reduce_handle_fwd,
)
from .jit import (
bias_gelu_fused,
bgrad_dgelu_fused,
set_jit_fusion_options,
warmup_jit_bias_gelu_all_dtypes,
)
from .utils import (
divide,
get_default_init_method,
cast_if_needed,
check_dim_for_fp8_forward_exec,
)
from .distributed import (
set_tensor_model_parallel_attributes,
get_distributed_world_size,
allreduce,
initialize_affine_weight_gpu,
reduce_scatter_along_first_dim,
gather_along_first_dim,
gather_along_last_dim,
is_fp8_activation_recompute_enabled,
in_fp8_activation_recompute_phase,
)
from .cpp_extensions import (
fp8_gemm,
gemm,
fp8_cast_transpose_fused,
fp8_cast_transpose_bgrad_fused,
fp8_gelu,
fp8_cast_transpose_bgrad_dgelu_fused,
layernorm_fwd_fp8,
layernorm_fwd_fp8_inf,
layernorm_fwd_inf,
cast_to_fp8,
cast_from_fp8,
)
from .constants import GemmParallelModes, dist_group_type, TE_DType
_2X_ACC_FPROP = False
_2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True
_cublas_workspace = None
_ub_communicators = None
_NUM_MAX_UB_STREAMS = 3
_amax_reduce_handle_bwd = None
def get_cublas_workspace_size_bytes() -> None:
"""Return 32 MiB if using hopper, 4 MiB for all other architectures."""
if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9:
return 33_554_432
return 4_194_304
def get_workspace() -> torch.Tensor:
"""Returns workspace for cublas."""
global _cublas_workspace
if _cublas_workspace is None:
_cublas_workspace = torch.empty(
get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda"
)
return _cublas_workspace
@contextmanager
def _prepare_backward(
fp8: bool,
fp8_meta: Dict[str, Any],
tp_group: dist_group_type,
tp_size: int,
name: str = ""
) -> None:
"""Checks and prep for BWD."""
if fp8:
global _amax_reduce_handle_bwd
if _amax_reduce_handle_bwd is not None:
_amax_reduce_handle_bwd.wait()
_amax_reduce_handle_bwd = None
# Update amax and scale; Skip all setup for global amax reduction
if not fp8_meta["recipe"].reduce_amax:
amax_and_scale_update(fp8_meta, False)
else:
# From previous iteration
copy_amax_from_global_buffer(fp8_meta, forward=False)
amax_and_scale_update(fp8_meta, False)
set_amax_buffer_key_deletion(fp8_meta, forward=False)
# Get new backward key.
fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd_stack"].pop(0)
add_amax_to_global_buffer(fp8_meta, forward=False)
with torch.cuda.nvtx.range(name + " backward"):
yield
if fp8 and fp8_meta["recipe"].reduce_amax:
if fp8_meta["first_module"]:
_amax_reduce_handle_bwd = global_amax_reduction(
fp8_meta,
tp_group,
tp_size,
forward=False
)
delete_key_from_amax_buffer(forward=False)
def initialize_ub(
shape: list,
tp_size: int,
use_fp8: bool = False,
ub_cfgs: Optional[dict] = None
) -> None:
"""Initialize communicators for TP comm overlap using userbuffers."""
global _ub_communicators
assert _ub_communicators is None, "UB communicators are already initialized."
_ub_communicators = {}
rank_id = torch.distributed.get_rank()
# Increase the workspace by the number of maximum concurrent streams
global _cublas_workspace
_cublas_workspace = get_workspace().repeat(_NUM_MAX_UB_STREAMS)
# Default buffer precision: AllGather buffers use fp8 when using fp8 recipe
fp8_buf = [
"qkv_fprop", "qkv_dgrad", "proj_dgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad"
]
# Default overlap methods for layers
methods = {
"ring_exchange":["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"],
"pipeline":["proj_fprop", "fc2_fprop"],
"bulk":["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
}
def get_method(name):
for method, names in methods.items():
if name in names:
return method
raise KeyError(f"Given layer name {name} does not exist.")
def add_ub(
name: str,
method: str,
num_sm: int = 16,
cga_size: int = 2,
set_sm_margin: int = 0,
num_splits: int = 4,
aggregate: int = 0,
) -> None:
dtype = torch.uint8 if (use_fp8 and name in fp8_buf) else torch.bfloat16
sample_buffer = torch.empty(shape, dtype=dtype, device='cuda')
if method == 'ring_exchange':
ub_obj = tex.UbufP2PCommOverlap(
sample_buffer, # Sample userbuffer
rank_id, # Rank id
tp_size, # TP size
aggregate, # Aggregate 2X GEMM chunks
_NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams
)
else:
ub_obj = tex.UbufCommOverlap(
sample_buffer, # Sample userbuffer
rank_id, # Rank id
tp_size, # TP size
num_sm, # Number of communication SMs
cga_size, # CGA cluster size
num_splits, # Number of communication splits
set_sm_margin, # Set SM margin
_NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams
)
_ub_communicators[name] = ub_obj
for name in (methods["ring_exchange"]+methods["pipeline"]+methods["bulk"]):
if ub_cfgs is not None and name in ub_cfgs:
ub_cfg = ub_cfgs[name]
method = ub_cfg["method"] if "method" in ub_cfg else get_method(name)
num_sm = ub_cfg["num_sm"] if "num_sm" in ub_cfg else 16
cga_size = ub_cfg["cga_size"] if "cga_size" in ub_cfg else 2
num_splits = ub_cfg["num_splits"] if "num_splits" in ub_cfg else 0
set_sm_margin = ub_cfg["set_sm_margin"] if "set_sm_margin" in ub_cfg else 0
aggregate = ub_cfg["aggregate"] if "aggregate" in ub_cfg else 0
add_ub(
name,
method,
num_sm,
cga_size,
set_sm_margin,
num_splits,
aggregate
)
else:
method = get_method(name)
if method == "pipeline":
add_ub(name, method)
else:
add_ub(name, method, num_splits=0)
def get_ub(name: str):
"""Get userbuffer communicator corresponding to give key."""
global _ub_communicators
assert _ub_communicators is not None, "UB manager is not initialized."
assert name in _ub_communicators, f"UB for {name} is not registered."
return _ub_communicators[name]
class _NoopCat(torch.autograd.Function):
"""This class is a no-op replacement for `torch.cat`."""
@staticmethod
def forward(ctx,
full_param_buffer: torch.Tensor,
*params_split: Tuple[torch.Tensor, ...],
) -> torch.Tensor:
assert not full_param_buffer.requires_grad, "Buffers should not require gradient"
assert (
full_param_buffer.shape[0] % len(params_split) == 0
), "Dimensions not compatible for concatenation"
param_temp = full_param_buffer.new()
param_temp.set_(full_param_buffer.storage(),
full_param_buffer.storage_offset(),
full_param_buffer.size(),
full_param_buffer.stride())
param_temp.requires_grad = True
ctx.save_for_backward(full_param_buffer, *params_split)
return param_temp
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
full_param_buffer, *params_split = ctx.saved_tensors
split_size = full_param_buffer.shape[0] // len(params_split)
grads = []
for i, _ in enumerate(params_split):
grads.append(grad_output[i * split_size : (i+1) * split_size])
return None, *grads
class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Base TE module."""
def __init__(self) -> None:
super().__init__()
assert torch.cuda.is_available(), "TransformerEngine needs CUDA."
self.fp8_initialized = False
self.fp8 = False
self.fp8_calibration = False
self.fp8_meta = {}
self.fp8_meta["fp8_group"] = None
self.fp8_meta["recipe"] = get_default_fp8_recipe()
self.fp8_meta_tensors_initialized = False
self.tp_group = None
self.tp_size = 1
self.sequence_parallel = False
self.fp8_weight_shapes = []
self.fp8_meta["autocast_id_fwd_stack"] = []
self.fp8_meta["async_amax_reduction"] = bool(
int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "0"))
)
def set_meta_tensor(self, fwd: bool) -> None:
"""Init scales and amaxes for fwd | bwd."""
fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd"
if self.fp8_meta_tensors_initialized:
# Handle changed amax history size.
curr_len = self.fp8_meta[fp8_meta_tensor_key].amax_history.shape[0]
need_len = self.fp8_meta["recipe"].amax_history_len
if need_len < curr_len:
self.fp8_meta[fp8_meta_tensor_key].amax_history = (
self.fp8_meta[fp8_meta_tensor_key]
.amax_history[: self.fp8_meta["recipe"].amax_history_len].clone()
)
elif need_len > curr_len:
extra_rows = need_len - curr_len
self.fp8_meta[fp8_meta_tensor_key].amax_history = F.pad(
self.fp8_meta[fp8_meta_tensor_key].amax_history, pad=(0, 0, 0, extra_rows)
)
return
# Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
# 2 (grad_output and grad_input) for bwd
num_fp8_tensors = (
self.fp8_meta["num_gemms"] * 3 if fwd else self.fp8_meta["num_gemms"] * 2
)
self.fp8_meta[fp8_meta_tensor_key] = tex.FP8TensorMeta()
self.fp8_meta[fp8_meta_tensor_key].scale = torch.ones(
num_fp8_tensors, dtype=torch.float32, device="cuda"
)
self.fp8_meta[fp8_meta_tensor_key].scale_inv = torch.ones(
num_fp8_tensors, dtype=torch.float32, device="cuda"
)
self.fp8_meta[fp8_meta_tensor_key].amax_history = torch.zeros(
self.fp8_meta["recipe"].amax_history_len,
num_fp8_tensors,
dtype=torch.float32,
device="cuda",
)
# Needed for calculation of scale inverses to
# preserve scale_inv when caching FP8 weights
if fwd:
# [True, False, True]: -> [input, weight, output]
self.fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] = torch.BoolTensor(
[True, False, True] * self.fp8_meta["num_gemms"]
).cuda()
else:
# [True, True]: -> [grad_output, grad_input]
self.fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] = torch.BoolTensor(
[True, True] * self.fp8_meta["num_gemms"]
).cuda()
def init_fp8_meta_tensors(self) -> None:
"""Init scales and amaxes."""
self.set_meta_tensor(True)
self.set_meta_tensor(False)
self.fp8_meta_tensors_initialized = True
def get_extra_state(self) -> torch.Tensor:
"""Save before checkpointing."""
state = None
if self.fp8 or self.fp8_calibration:
state = {}
state["scale_fwd"] = self.fp8_meta["scaling_fwd"].scale
state["scale_inv_fwd"] = self.fp8_meta["scaling_fwd"].scale_inv
state["amax_history_fwd"] = self.fp8_meta["scaling_fwd"].amax_history
state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale
state["scale_inv_bwd"] = self.fp8_meta["scaling_bwd"].scale_inv
state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history
state["global_fp8_buffer"] = get_global_fp8_buffer()
# Store other pickelable values.
extra = {}
for k, v in self.fp8_meta.items():
if isinstance(v, (bool, int, float, str)):
extra[k] = v
state["extra_fp8_variables"] = extra
state_serialized = pickle.dumps(state)
state_tensor = torch.tensor(np.frombuffer(state_serialized, dtype=np.uint8))
return state_tensor
def set_extra_state(self, state: torch.Tensor) -> None:
"""Load previous state."""
if state is None:
return
# Maintain backward compatibility with v0.2.0 and older.
if isinstance(state, list):
warnings.warn(
"This checkpoint format is deprecated and will be"
"removed in a future release of Transformer Engine"
)
# Retrieve checkpointed items.
scale_fwd = state[0]
amax_history_fwd = state[1]
scale_bwd = state[2]
amax_history_bwd = state[3]
self.fp8_meta["recipe"].amax_history_len = amax_history_fwd.shape[0]
self.fp8_meta["num_gemms"] = (
amax_history_fwd.shape[1] // 2
) # Two FWD tensors per GEMM
# Initialize before loading
self.init_fp8_meta_tensors()
self.fp8_meta["scaling_fwd"].scale.copy_(scale_fwd)
self.fp8_meta["scaling_fwd"].amax_history.copy_(amax_history_fwd)
self.fp8_meta["scaling_bwd"].scale.copy_(scale_bwd)
self.fp8_meta["scaling_bwd"].amax_history.copy_(amax_history_bwd)
# Restore global FP8 buffer state.
set_global_fp8_buffer(state[4])
self.fp8_meta["update_amax_and_scale_fwd"] = state[5]
self.fp8_meta["global_fp8_buffer_pos_fwd"] = state[6]
self.fp8_meta["global_fp8_buffer_pos_bwd"] = state[7]
self.fp8_meta["autocast_id_fwd"] = state[8]
self.fp8_meta["autocast_id_bwd"] = state[9]
return
if isinstance(state, torch.Tensor):
state = pickle.loads(state.detach().cpu().numpy().tobytes())
if state is None:
return
# Restore global FP8 buffer states.
set_global_fp8_buffer(state["global_fp8_buffer"])
# Load extra items.
self.fp8_meta.update(state["extra_fp8_variables"])
self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0]
if "global_fp8_buffer_pos_fwd_recompute" in self.fp8_meta:
del self.fp8_meta["global_fp8_buffer_pos_fwd_recompute"]
# Initialize before loading.
self.init_fp8_meta_tensors()
self.fp8_meta["scaling_fwd"].scale.copy_(state["scale_fwd"])
self.fp8_meta["scaling_fwd"].amax_history.copy_(state["amax_history_fwd"])
self.fp8_meta["scaling_bwd"].scale.copy_(state["scale_bwd"])
self.fp8_meta["scaling_bwd"].amax_history.copy_(state["amax_history_bwd"])
# Backwards compatibility: compute scale inv if it wasn't saved in the extra state.
if "scale_inv_fwd" not in state or "scale_inv_bwd" not in state:
assert (
"scale_inv_fwd" not in state and "scale_inv_bwd" not in state
), "Invalid state, began saving scale_inv_fwd and scale_inv_bwd at the same time"
self.fp8_meta["scaling_fwd"].scale_inv.copy_(1.0/state["scale_fwd"])
self.fp8_meta["scaling_bwd"].scale_inv.copy_(1.0/state["scale_bwd"])
else:
self.fp8_meta["scaling_fwd"].scale_inv.copy_(state["scale_inv_fwd"])
self.fp8_meta["scaling_bwd"].scale_inv.copy_(state["scale_inv_bwd"])
def set_activation_dtype(self, inp: torch.Tensor) -> None:
"""Get activation data type for AMP."""
# Native AMP (`torch.autocast`) gets highest priority
if torch.is_autocast_enabled():
self.activation_dtype = torch.get_autocast_gpu_dtype()
return
# All checks after this have already been performed once, thus skip
# We assume that user doesn't change input types across iterations
if hasattr(self, "activation_dtype"):
return
assert all(
(
(inp.dtype == param.dtype) if param is not None else True
for param in self.parameters()
)
), (
"Data type for activations and weights must "
"match when outside of autocasted region"
)
assert all(
(
(inp.dtype == buf.dtype) if buf is not None else True
for buf in self.buffers()
)
), (
"Data type for activations and buffers must "
"match when outside of autocasted region"
)
self.activation_dtype = inp.dtype
def set_fp8_weights(self) -> None:
"""Initializes FP8 weights for the module as class attributes. These
are not parameters or buffers since we do not want functions such as
`.to(dtype)` or `.to(device)` to effect them. These also do not need
to be checkpointed. During `init` phase of the module, the attribute
`fp8_weight_shapes` must be populated with the tensor shapes for FP8
weights. This function will iterate over those shapes and initialize
respective attributed named `weight1_fp8`, `weight2_fp8`, ...
"""
if not self.fp8:
return
for i, shape in enumerate(self.fp8_weight_shapes, start=1):
weight_cast_attr = f"weight{i}_fp8"
weight_transpose_attr = f"weight{i}_t_fp8"
if (
hasattr(self, weight_cast_attr)
and getattr(self, weight_cast_attr).shape == shape
):
return
setattr(
self,
weight_cast_attr,
torch.empty(
shape,
device=torch.cuda.current_device(),
dtype=torch.uint8,
),
)
setattr(
self,
weight_transpose_attr,
torch.empty(
shape[1],
shape[0],
device=torch.cuda.current_device(),
dtype=torch.uint8,
),
)
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
"""Set TP group."""
self.tp_group = tp_group
self.tp_group_initialized = True
# This routine is shared across FP8 and FP8_calibration paths so should not actually
# assume FP8 execution.
def fp8_init(self, num_gemms: int = 1) -> None:
"""Initialize fp8 related metadata and tensors during fprop."""
self.fp8 = is_fp8_enabled()
self.fp8_calibration = is_fp8_calibration()
if self.fp8 or self.fp8_calibration:
# FP8 init has already been run and recipe is the same, don't do anything.
if self.fp8_initialized and get_fp8_recipe() == self.fp8_meta["recipe"]:
return
# Set FP8, recipe, and other FP8 metadata
self.fp8_meta["recipe"] = get_fp8_recipe()
self.fp8_meta["num_gemms"] = num_gemms
self.fp8_meta["fp8_group"] = get_fp8_group()
# Set FP8_MAX per tensor according to recipe
self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd
self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd
# Allocate scales and amaxes
self.init_fp8_meta_tensors()
self.fp8_initialized = True
else:
# If fp8 isn't enabled, turn off and return.
self.fp8_initialized = False
return
@contextmanager
def prepare_forward(
self,
inp: torch.Tensor,
is_first_microbatch: Union[bool, None],
num_gemms: int = 1,
) -> None:
"""Checks and prep for FWD.
The context manager is needed because there isn't a way for a module to know
if it's the last FP8 module in the forward autocast. It is useful
to setup the forward aggregated amax reduction for every module
just in case. The autocast exit will pick up the most recent one.
"""
# Activation recomputation is used and this is the second forward phase.
if self.fp8 and in_fp8_activation_recompute_phase():
get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
else:
assert inp.is_cuda, "TransformerEngine needs CUDA."
if self.tp_size > 1:
assert self.tp_group_initialized, "TP group not initialized."
self.set_activation_dtype(inp)
self.fp8_init(num_gemms=num_gemms)
self.set_fp8_weights()
update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch
if self.fp8 and self.sequence_parallel:
assert self.fp8_meta["recipe"].reduce_amax, \
"Amax reduction across tensor parallel group is " \
"necessary when using sequence parallelism with FP8."
# Previous iteration was grad_enabled
if self.fp8_meta.get("update_amax_and_scale_fwd", False):
if self.fp8_meta["recipe"].reduce_amax:
copy_amax_from_global_buffer(self.fp8_meta, forward=True)
amax_and_scale_update(
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
)
set_amax_buffer_key_deletion(self.fp8_meta, forward=True)
else:
amax_and_scale_update(
self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
)
if self.fp8 and self.training:
# Setup for amax reduction
if self.fp8_meta["recipe"].reduce_amax:
self.fp8_meta["first_module"] = is_first_fp8_module()
if self.fp8_meta["first_module"]:
# Wait for the prior AMAX reduction to finish
amax_reduce_handle_fwd = get_amax_reduce_handle_fwd()
if amax_reduce_handle_fwd is not None:
amax_reduce_handle_fwd.wait()
self.fp8_meta["autocast_id_fwd"] = new_fp8_context_id()
set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
else:
self.fp8_meta["autocast_id_fwd"] = get_fp8_context_id()
self.fp8_meta["autocast_id_fwd_stack"].append(
self.fp8_meta["autocast_id_fwd"]
)
add_amax_to_global_buffer(self.fp8_meta, forward=True)
self.fp8_meta["update_amax_and_scale_fwd"] = True
else:
self.fp8_meta["update_amax_and_scale_fwd"] = False
# Activation recomputation is used and this is the first forward phase.
if (
self.fp8
and self.training
and is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase()
):
copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"):
yield inp.contiguous()
if self.fp8 and in_fp8_activation_recompute_phase():
restore_fp8_meta_tensors(self.fp8_meta)
return
if self.fp8 and self.training and self.fp8_meta["recipe"].reduce_amax:
set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
reduce_func = partial(
global_amax_reduction,
self.fp8_meta,
self.tp_group,
self.tp_size,
forward=True
)
setup_amax_forward_global_reduce_func(reduce_func)
def set_nccl_overlap_warning_if_tp(self) -> None:
"""When using TP, the NCCL communication needs to be scheduled
before the GEMM for there to be a guaranteed overlap. From the
host side in TE, the comm calls are always launched first, but
to ensure that the GEMM isn't scheduled first, the environment
variable `CUDA_DEVICE_MAX_CONNECTIONS` needs to be set to 1 to
force a single channel.
"""
if self.tp_size == 1:
return
num_cuda_work_queues = int(os.getenv("CUDA_DEVICE_MAX_CONNECTIONS", "0"))
if num_cuda_work_queues != 1:
warnings.warn(
"To guarantee overlapping TP and SP collectives with the backward"
"GEMMs, set environment variable CUDA_DEVICE_MAX_CONNECTIONS = 1"
)
@staticmethod
def grad_output_preprocess(
ctx, grad_output: torch.Tensor, row_parallel_mode: bool
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Utility function for backward.
Returns tuple in order (all optional/None based on training precion/recipe):
R1: gathered `grad_output` in higher precision.
R2: gathered `grad_output` in FP8.
R3: R2 transposed.
R4: bias gradient on R1.
"""
grad_output = grad_output.contiguous()
grad_output_mat = grad_output.view((-1, grad_output.shape[-1]))
gather_grad_output = row_parallel_mode and ctx.sequence_parallel
# No-FP8 case: bgrad is fused with wgrad for this case.
if not ctx.fp8:
if gather_grad_output:
if not ctx.ub_split_ag:
grad_output_mat, _ = gather_along_first_dim(
grad_output_mat, ctx.tp_group
)
else:
ctx.ub_obj_gradout.copy_input_to_ubuf(grad_output, True)
grad_output_mat = ctx.ub_obj_gradout.get_ubuf_output(1)
return grad_output_mat, None, None, None
fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False
)
# FP8 case with non-FP8 wgrad
if (
gather_grad_output
and ctx.fp8_meta["recipe"].override_linear_precision.wgrad
):
assert (
not ctx.ub_split_ag
), "override_linear_precision.wgrad not supported with ub_split_ag"
grad_output_mat, _ = gather_along_first_dim(grad_output_mat, ctx.tp_group)
# FP8 case with gather: unfused bgrad, cast, transpose for efficient gather
elif gather_grad_output:
if ctx.use_bias:
grad_bias = grad_output_mat.sum(dim=0)
else:
grad_bias = None
if ctx.ub_split_ag:
grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(0)
else:
grad_output_c = torch.empty_like(grad_output_mat, dtype=torch.uint8)
cast_to_fp8(
grad_output_mat,
ctx.fp8_meta["scaling_bwd"],
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
out=grad_output_c,
)
if not ctx.ub_split_ag:
grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group)
grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
else:
grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(1)
grad_output_t = None
return grad_output_mat, grad_output_c, grad_output_t, grad_bias
# FP8 case without gather: cast, transpose, bgrad fused
if ctx.use_bias:
grad_bias, grad_output_c, grad_output_t = fp8_cast_transpose_bgrad_fused(
grad_output_mat,
ctx.fp8_meta["scaling_bwd"],
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
)
else:
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
grad_output_c, grad_output_t = fp8_cast_transpose_fused(
grad_output_mat,
ctx.fp8_meta["scaling_bwd"],
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
)
else:
grad_output_t = None
grad_output_c = cast_to_fp8(
grad_output_mat,
ctx.fp8_meta["scaling_bwd"],
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
)
grad_bias = None
return grad_output_mat, grad_output_c, grad_output_t, grad_bias
def noop_cat(self, buffer_name: str, pnames: List[str]) -> torch.Tensor:
"""No-op replacement of `torch.cat`. The buffer and split parameters must occupy
the same memory region. If this is not the case, then the split parameters
are concatenated and the buffer is overwritten. The parameters' memory is then
re-assigned to point to the buffer to avoid subsequent concatenations.
"""
assert hasattr(self, buffer_name), f"No buffer named {buffer_name}"
full_param_buffer = getattr(self, buffer_name)
split_size = full_param_buffer.shape[0] // len(pnames)
params = [getattr(self, name) for name in pnames]
for i, p in enumerate(params):
if p.data.data_ptr() != full_param_buffer[i*split_size : (i+1)*split_size].data_ptr():
with torch.no_grad():
setattr(self, buffer_name, torch.cat(params))
for j, pname in enumerate(pnames):
full_param_buffer = getattr(self, buffer_name)
setattr(self, pname,
Parameter(full_param_buffer[j*split_size : (j+1)*split_size]))
break
return _NoopCat.apply(getattr(self, buffer_name), *[getattr(self, name) for name in pnames])
@abstractmethod
def forward(self):
"""Needs override."""
class _LayerNormLinear(torch.autograd.Function):
"""LayerNormLinear semi-top level module
Calls custom cuda extensions.
"""
@staticmethod
def forward(
ctx,
inp: torch.Tensor,
ln_weight: torch.Tensor,
ln_bias: torch.Tensor,
weight: torch.Tensor,
weight_fp8: Union[torch.Tensor, None],
weight_t_fp8: Union[torch.Tensor, None],
bias: torch.Tensor,
use_bias: bool,
eps: float,
is_first_microbatch: Union[bool, None],
fp8: bool,
fp8_calibration: bool,
fp8_meta: Dict[str, Any],
fuse_wgrad_accumulation: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
sequence_parallel: bool,
tensor_parallel: bool,
activation_dtype: torch.dtype,
parallel_mode: Union[str, None],
return_layernorm_output: bool,
is_grad_enabled: bool,
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
ub_bulk_wgrad: bool,
ub_bulk_dgrad: bool,
ub_split_ag: bool,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
assert inp.shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features))
assert (
not fp8 or check_dim_for_fp8_forward_exec(inputmat, weight)
), "Input and weight dimensions are not compatible for FP8 execution."
update_fp8_weights = is_first_microbatch is None or is_first_microbatch
# Cast for native AMP
inputmat = cast_if_needed(inputmat, activation_dtype)
ln_weight = cast_if_needed(ln_weight, activation_dtype)
ln_bias = cast_if_needed(ln_bias, activation_dtype)
# If residual connection is after LN, we need `ln_out`
# tensor in higher precision, this comes at the cost
# of an extra fp8 cast.
if ub_split_ag:
tp_world_size = get_distributed_world_size(tp_group)
if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output:
ub_split_ag = False
if ub_split_ag:
dim_size = list(inputmat.size())
dim_size[0] = dim_size[0] * tp_world_size
ub_obj_lnout = get_ub("qkv_fprop")
ln_out = ub_obj_lnout.get_ubuf_output(0)
if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if not return_layernorm_output:
if is_grad_enabled:
if not ub_split_ag:
ln_out = torch.empty_like(inputmat, dtype=torch.uint8)
_, mu, rsigma = layernorm_fwd_fp8(
inputmat,
ln_weight,
ln_bias,
eps,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
fwd_ln_sm_margin,
zero_centered_gamma,
ln_out = ln_out
)
else:
mu = rsigma = None
ln_out = layernorm_fwd_fp8_inf(
inputmat,
ln_weight,
ln_bias,
eps,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
zero_centered_gamma,
)
else:
if is_grad_enabled:
ln_out_return, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
)
else:
ln_out_return, mu, rsigma = layernorm_fwd_inf(
inputmat, ln_weight, ln_bias, eps, zero_centered_gamma
), None, None
ln_out = cast_to_fp8(
ln_out_return,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
)
else:
if is_grad_enabled:
if ub_split_ag:
_, mu, rsigma = tex.layernorm_fwd_noalloc(
inputmat, ln_weight, ln_bias, ln_out, eps,
fwd_ln_sm_margin, zero_centered_gamma
)
else:
ln_out, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
)
else:
ln_out, mu, rsigma = layernorm_fwd_inf(
inputmat, ln_weight, ln_bias, eps, zero_centered_gamma
), None, None
ln_out_return = ln_out
# Column Parallel Linear
if ub_split_ag:
ln_out_total = ub_obj_lnout.get_ubuf_output(1)
ln_out = torch.empty_like(ln_out)
elif parallel_mode == "column" and sequence_parallel:
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
else:
ln_out_total = ln_out
if fp8:
bias_dtype = (
torch.bfloat16
if activation_dtype == torch.float32
else activation_dtype
)
bias = cast_if_needed(bias, bias_dtype) if use_bias else bias
if update_fp8_weights:
if is_grad_enabled:
fp8_cast_transpose_fused(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
cast_out=weight_fp8,
transpose_out=weight_t_fp8,
)
else:
weight_t_fp8 = None
weight_fp8 = cast_to_fp8(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward)
out = fp8_gemm(
weight_fp8,
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
ln_out_total,
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
activation_dtype,
get_workspace(),
bias=bias,
use_bias=use_bias,
use_split_accumulator=_2X_ACC_FPROP,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None,
ub=ub_obj_lnout if ub_split_ag else None,
extra_output_tensor=ln_out if ub_split_ag else None,
)
else:
# Cast for native AMP
weight = cast_if_needed(weight, activation_dtype)
bias = cast_if_needed(bias, activation_dtype) if use_bias else bias
if fp8_calibration:
# amax of input
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = \
torch.amax(ln_out_total).float()
# amax of weight
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \
torch.amax(weight).float()
out, _, _ = gemm(
weight,
ln_out_total,
activation_dtype,
get_workspace(),
bias=bias,
use_bias=use_bias,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None,
ub=ub_obj_lnout if ub_split_ag else None,
extra_output_tensor=ln_out if ub_split_ag else None,
)
if is_grad_enabled:
ctx.save_for_backward(
inputmat,
ln_weight,
mu,
rsigma,
weight,
weight_t_fp8,
ln_out,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
)
ctx.activation_dtype = activation_dtype
ctx.fp8 = fp8
ctx.fp8_meta = fp8_meta
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = use_bias
ctx.sequence_parallel = sequence_parallel
ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp.shape
ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group
ctx.tp_size = tp_size
ctx.return_layernorm_output = return_layernorm_output
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma
ctx.ub_bulk_wgrad = ub_bulk_wgrad
ctx.ub_bulk_dgrad = ub_bulk_dgrad
ctx.requires_dgrad = inp.requires_grad
# Row Parallel Linear
if parallel_mode == "row" and sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif parallel_mode == "row" and tensor_parallel:
out, _ = allreduce(out, tp_group)
# [*, in_features] -> [*, out_features] except first dimension changes for SP
out = out.view(-1, *inp.shape[1:-1], out.shape[-1])
if return_layernorm_output:
return out, ln_out_return.view_as(inp)
return out
@staticmethod
def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]:
with _prepare_backward(
ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_LayerNormLinear"
):
(
inputmat,
ln_weight,
mu,
rsigma,
weight,
weight_t_fp8,
ln_out,
fwd_scale_inverses,
) = ctx.saved_tensors
if ctx.ub_bulk_dgrad:
tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1:
ctx.ub_bulk_dgrad = False
if ctx.ub_bulk_dgrad:
dim_size = list(ln_out.size())
dim_size[0] = dim_size[0] * tp_world_size
ub_obj_lnout = get_ub("qkv_dgrad")
ub_obj_lnout.copy_input_to_ubuf(ln_out, 1)
(
grad_output,
grad_output_c,
grad_output_t,
grad_bias,
) = TransformerEngineBaseModule.grad_output_preprocess(
ctx, grad_outputs[0], ctx.parallel_mode == "row"
)
if ctx.ub_bulk_wgrad:
tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1:
ctx.ub_bulk_wgrad = False
# Column Parallel Linear
# Overlap input AG with dgrad
if (not ctx.ub_bulk_dgrad) and ctx.parallel_mode == "column" and ctx.sequence_parallel:
ln_out_total, handle = gather_along_first_dim(
ln_out, ctx.tp_group, async_op=True
)
else:
ln_out_total = ln_out
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (
ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
dgrad_size = list(grad_output.size())
dgrad_size[1] = weight.size(1)
if ctx.ub_bulk_wgrad: # allocate dgrad output
ub_obj_dgrad = get_ub("qkv_wgrad")
dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output
else:
dgrad = torch.empty (dgrad_size, dtype=ctx.activation_dtype, device=weight.device)
if ctx.fp8:
fp8_dtype_forward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=True
)
fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False
)
# DGRAD: Evaluated unconditionally to feed into Linear backward
_ = fp8_gemm(
weight_t_fp8,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
out=dgrad,
use_split_accumulator=_2X_ACC_DGRAD,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None,
ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None
)
else:
# DGRAD: Evaluated unconditionally to feed into Linear backward
_, _, _ = gemm(
weight,
grad_output,
ctx.activation_dtype,
get_workspace(),
out=dgrad,
layout="NN",
grad=True,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None,
ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None
)
if ctx.ub_bulk_dgrad:
ln_out_total = ub_obj_lnout.get_ubuf_output(1)
# Overlap dgrad-RS/AR with wgrad
if ctx.parallel_mode == "column" and ctx.sequence_parallel:
if not ctx.ub_bulk_dgrad:
handle.wait()
if not ctx.ub_bulk_wgrad:
dgrad, handle = reduce_scatter_along_first_dim(
dgrad, ctx.tp_group, async_op=True
)
elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)
if weight.requires_grad:
if ctx.fp8:
# WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward)
wgrad = fp8_gemm(
ln_out_total_t,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
grad_output_t,
ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
accumulate=accumulate_wgrad_into_param_main_grad,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=_2X_ACC_WGRAD,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS
if ctx.ub_bulk_wgrad else None,
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None
)
else:
ln_out_total_c = cast_from_fp8(
ln_out_total,
ctx.fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
TE_DType[ctx.activation_dtype],
)
wgrad, _, _ = gemm(
ln_out_total_c,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NT",
grad=True,
accumulate=accumulate_wgrad_into_param_main_grad,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS
if ctx.ub_bulk_wgrad else None,
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None
)
else:
# WGRAD
wgrad, grad_bias, _ = gemm(
ln_out_total,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NT",
grad=True,
use_bias=ctx.use_bias,
accumulate=accumulate_wgrad_into_param_main_grad,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None,
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None
)
if ctx.ub_bulk_wgrad:
dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output
# Column Parallel Linear
elif ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None:
handle.wait()
# LayerNorm gradient
d_ln_out = dgrad.view(inputmat.shape)
# Residual gradient
if ctx.return_layernorm_output:
d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out)
dxmat, dgamma, dbeta = tex.layernorm_bwd(
d_ln_out, inputmat, mu, rsigma, ln_weight,
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
)
if not ctx.use_bias:
grad_bias = None
return (
dxmat.view(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma,
dbeta,
wgrad if weight.requires_grad else None,
None,
None,
grad_bias,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
class LayerNormLinear(TransformerEngineBaseModule):
r"""
Applies layer normalization followed by linear transformation to the incoming data.
Parameters
----------
in_features : int
size of each input sample.
out_features : int
size of each output sample.
eps : float, default = 1e-5
a value added to the denominator of layer normalization for numerical stability.
bias : bool, default = `True`
if set to `False`, the layer will not learn an additive bias.
init_method : Callable, default = `None`
used for initializing weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
return_layernorm_output : bool, default = `False`
if set to `True`, output of layernorm is returned from the forward
together with the output of the linear transformation.
Example use case: residual connection for transformer module is
taken post layernorm.
parameters_split : Tuple[str, ...], default = None
if a tuple of strings is provided, the weight and bias parameters of the
module are exposed as `N` separate `torch.nn.parameter.Parameter`s each,
split along the first dimension, where `N` is the length of the argument
and the strings contained are the names of the split parameters.
zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
(1 + \gamma) + \beta
Parallelism parameters
----------------------
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
tp_size : int, default = 1
used as TP (tensor parallel) world size when TP groups are not formed during
initialization. In this case, users must call the
`set_tensor_parallel_group(tp_group)` method on the initialized module before the
forward pass to supply the tensor parallel group needed for tensor and sequence
parallel collectives.
parallel_mode : {None, 'Column', 'Row'}, default = `None`
used to decide whether this Linear layer is Column Parallel Linear or Row
Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
When set to `None`, no communication is performed.
skip_weight_param_allocation: bool, default = `False`
if set to `True`, weight parameter is not allocated and must be
passed as a keyword argument `weight` during the forward pass.
Optimization parameters
-----------------------
fuse_wgrad_accumulation : bool, default = 'False'
if set to `True`, enables fusing of creation and accumulation of
the weight gradient.
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the
output of the linear transformation :math:`y = xA^T`. This is useful when
the bias addition can be fused to subsequent operations.
params_dtype : torch.dtype, default = `torch.float32`
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
"""
def __init__(
self,
in_features: int,
out_features: int,
eps: float = 1e-5,
sequence_parallel: bool = False,
fuse_wgrad_accumulation: bool = False,
tp_group: Optional[dist_group_type] = None,
tp_size: int = 1,
get_rng_state_tracker: Optional[Callable] = None,
init_method: Optional[Callable] = None,
bias: bool = True,
return_bias: bool = False,
params_dtype: torch.dtype = torch.float32,
parallel_mode: Optional[str] = None,
return_layernorm_output: bool = False,
skip_weight_param_allocation: bool = False,
parameters_split: Optional[Tuple[str, ...]] = None,
zero_centered_gamma: bool = False,
ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False,
ub_split_ag: bool = False,
) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
self.use_bias = bias
self.return_bias = return_bias
self.apply_bias = bias and not return_bias
self.return_layernorm_output = return_layernorm_output
self.parameters_split = parameters_split
self.zero_centered_gamma = zero_centered_gamma
self.ub_bulk_wgrad = ub_bulk_wgrad
self.ub_bulk_dgrad = ub_bulk_dgrad
self.ub_split_ag = ub_split_ag
if ub_bulk_wgrad or ub_bulk_dgrad or ub_split_ag:
assert (
tex.userbuf_comm_available()
), "Userbuffer communication backend not available."
if tp_group is None:
self.tp_size = tp_size
if tp_size == 1:
self.set_tensor_parallel_group(tp_group)
else:
self.tp_size = get_distributed_world_size(tp_group)
self.set_tensor_parallel_group(tp_group)
self.set_nccl_overlap_warning_if_tp()
self.parallel_mode = parallel_mode
assert (
self.parallel_mode in GemmParallelModes
), f"parallel_mode {parallel_mode} not supported"
if self.parallel_mode == "column":
self.out_features = divide(self.out_features, self.tp_size)
elif self.parallel_mode == "row":
self.in_features = divide(self.in_features, self.tp_size)
if init_method is None:
init_method = get_default_init_method()
self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
self.eps = eps
self.layer_norm_weight = Parameter(
torch.empty(
in_features,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
self.layer_norm_bias = Parameter(
torch.empty(
in_features,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel)
setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel)
self.reset_layer_norm_parameters()
if not skip_weight_param_allocation:
self.register_buffer("weight_tensor",
torch.empty(
self.out_features,
self.in_features,
device=torch.cuda.current_device(),
dtype=params_dtype),
persistent=False)
initialize_affine_weight_gpu(
self.weight_tensor,
init_method,
get_rng_state_tracker,
partition_dim=1 if self.parallel_mode == "row" else 0,
stride=1,
)
if self.use_bias:
self.register_buffer("bias_tensor",
torch.empty(
self.out_features,
device=torch.cuda.current_device(),
dtype=params_dtype),
persistent=False)
else:
self.register_buffer("bias_tensor",
torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device()),
persistent=False)
with torch.no_grad():
self.bias_tensor.zero_()
if parameters_split is None:
parameters_split = ("",)
assert (
self.out_features % len(parameters_split) == 0
), f"Weight and bias params cannot be split into {len(parameters_split)} parts"
split_size = self.out_features // len(parameters_split)
self.weight_names = []
self.bias_names = []
for i, pname in enumerate(parameters_split):
wname = pname + "weight"
bname = pname + "bias"
self.register_parameter(
wname, Parameter(self.weight_tensor[i * split_size : (i+1) * split_size])
)
set_tensor_model_parallel_attributes(
tensor=getattr(self, wname),
is_parallel=True,
dim=1 if parallel_mode == "row" else 0,
stride=1,
)
if self.use_bias:
self.register_parameter(
bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size])
)
else:
self.register_buffer(bname,
torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device()),
persistent=False)
if parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1)
self.weight_names.append(wname)
self.bias_names.append(bname)
self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))
# For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM
if self.parallel_mode == "row" and self.apply_bias:
self.gemm_bias_unfused_add = True
else:
self.gemm_bias_unfused_add = False
# These many SMs are subtracted from the total SM count when calling forward
# and backward LayerNorm C APIs. These envvars can be used to prevent the LN
# kernels from using all SMs in the device. This is useful for cases such as
# communication overlap with LN.
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
def reset_layer_norm_parameters(self) -> None:
"""Init LN params"""
if not self.zero_centered_gamma:
init.ones_(self.layer_norm_weight)
else:
init.zeros_(self.layer_norm_weight)
init.zeros_(self.layer_norm_bias)
def forward(
self,
inp: torch.Tensor,
weight: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
is_first_microbatch: Optional[bool] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
"""
Apply layer normalization to the input followed by a linear transformation.
Parameters
----------
inp : torch.Tensor
Input tensor.
weight : torch.Tensor, default = None
An optional weight tensor for the module. This argument is compulsory if module
is initialized with `skip_weight_param_allocation=True`
bias : torch.Tensor, default = None
An optional bias tensor for the module. This argument is compulsory if module
is initialized with `skip_weight_param_allocation=True` and one of `use_bias`
or `return_bias`
is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split
into microbatches. Between the microbatches of the same minibatch
the model weights are not updated. Setting this parameter indicates
whether the current microbatch is the first in a minibatch or not.
When set, this parameter enables additional optimizations:
* during FP8 training, it allows caching of the FP8 versions of
the weights
* it also allows skipping gradient accumulation during the
first microbatch (since it is the first gradient being
produced)
"""
with self.prepare_forward(inp, is_first_microbatch) as inp:
bias_tensor = (
bias if bias is not None
else self.bias if self.parameters_split is None
else self.bias_tensor if not torch.is_grad_enabled()
else self.noop_cat("bias_tensor", self.bias_names)
)
weight_tensor = (
weight if weight is not None
else self.weight if self.parameters_split is None
else self.weight_tensor if not torch.is_grad_enabled()
else self.noop_cat("weight_tensor", self.weight_names)
)
if torch.is_grad_enabled():
fwd_fn = _LayerNormLinear.apply
args = []
else:
fwd_fn = _LayerNormLinear.forward
args = [None]
args += (
inp,
self.layer_norm_weight,
self.layer_norm_bias,
weight_tensor,
self.weight1_fp8 if self.fp8 else None,
self.weight1_t_fp8 if self.fp8 else None,
bias_tensor,
self.apply_bias and not self.gemm_bias_unfused_add,
self.eps,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.fp8_meta,
self.fuse_wgrad_accumulation,
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
self.return_layernorm_output,
torch.is_grad_enabled(),
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.ub_bulk_wgrad,
self.ub_bulk_dgrad,
self.ub_split_ag,
)
out = fwd_fn(*args)
if self.return_layernorm_output:
out, ln_out = out
if self.gemm_bias_unfused_add:
out = out + cast_if_needed(bias_tensor, self.activation_dtype)
if self.return_bias:
if self.return_layernorm_output:
return out, cast_if_needed(bias_tensor, self.activation_dtype), ln_out
return out, cast_if_needed(bias_tensor, self.activation_dtype)
if self.return_layernorm_output:
return out, ln_out
return out
class _Linear(torch.autograd.Function):
"""Linear semi-top level module
Calls custom cuda extensions.
"""
@staticmethod
def forward(
ctx,
weight: torch.Tensor,
weight_fp8: Union[torch.Tensor, None],
weight_t_fp8: Union[torch.Tensor, None],
inp: torch.Tensor,
bias: torch.Tensor,
use_bias: bool,
is_first_microbatch: Union[bool, None],
fp8: bool,
fp8_calibration: bool,
fp8_meta: Dict[str, Any],
fuse_wgrad_accumulation: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
sequence_parallel: bool,
tensor_parallel: bool,
activation_dtype: torch.dtype,
parallel_mode: Union[str, None],
is_grad_enabled: bool,
ub_split_rs: bool,
ub_split_ag: bool,
) -> torch.Tensor:
# Make sure input dimensions are compatible
in_features = weight.shape[-1]
assert inp.shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features))
assert (
not fp8 or check_dim_for_fp8_forward_exec(inputmat, weight)
), "Input and weight dimensions are not compatible for FP8 execution."
update_fp8_weights = is_first_microbatch is None or is_first_microbatch
if ub_split_rs:
tp_world_size = get_distributed_world_size(tp_group)
if tp_world_size == 1:
ub_split_rs = False
# Cast for native AMP
inputmat = cast_if_needed(inputmat, activation_dtype)
inputmat_no_fp8 = inputmat
if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if not fp8_meta["recipe"].override_linear_precision.wgrad:
if is_grad_enabled:
inputmat, inputmat_t = fp8_cast_transpose_fused(
inputmat,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
)
else:
inputmat = cast_to_fp8(
inputmat,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
)
else:
inputmat, inputmat_t = cast_to_fp8(
inputmat,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
), None
# Column Parallel Linear
if parallel_mode == "column" and sequence_parallel:
inputmat_total, _ = gather_along_first_dim(inputmat, tp_group)
else:
inputmat_total = inputmat
if fp8:
bias_dtype = (
torch.bfloat16
if activation_dtype == torch.float32
else activation_dtype
)
bias = cast_if_needed(bias, bias_dtype) if use_bias else bias
if update_fp8_weights:
if is_grad_enabled:
fp8_cast_transpose_fused(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
cast_out=weight_fp8,
transpose_out=weight_t_fp8,
)
else:
weight_t_fp8 = None
weight_fp8 = cast_to_fp8(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
)
if ub_split_rs:
ub_obj_projout = get_ub("proj_fprop")
out = ub_obj_projout.get_ubuf_output(1)
dim_size = list(inputmat_total.size())
dim_size[0] = dim_size[0] // tp_world_size
dim_size[1] = weight.size(0)
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
else:
dim_size = list(inputmat_total.size())
dim_size[1] = weight.size(0)
out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
_ = fp8_gemm(
weight_fp8,
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
inputmat_total,
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
activation_dtype,
get_workspace(),
bias=bias,
use_bias=use_bias,
use_split_accumulator=_2X_ACC_FPROP,
out=out,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else None,
ub=ub_obj_projout if ub_split_rs else None,
extra_output_tensor=rs_out if ub_split_rs else None,
)
else:
# Cast for native AMP
weight = cast_if_needed(weight, activation_dtype)
bias = cast_if_needed(bias, activation_dtype) if use_bias else bias
if fp8_calibration:
# amax of input
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = \
torch.amax(inputmat_total).float()
# amax of weight
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \
torch.amax(weight).float()
if ub_split_rs:
ub_obj_projout = get_ub("proj_fprop")
out = ub_obj_projout.get_ubuf_output(1)
dim_size = list(inputmat_total.size())
dim_size[0] = dim_size[0] // tp_world_size
dim_size[1] = weight.size(0)
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
else:
dim_size = list(inputmat_total.size())
dim_size[1] = weight.size(0)
out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
_, _, _ = gemm(
weight,
inputmat_total,
activation_dtype,
get_workspace(),
bias=bias,
use_bias=use_bias,
out=out,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else None,
ub=ub_obj_projout if ub_split_rs else None,
extra_output_tensor=rs_out if ub_split_rs else None,
)
if is_grad_enabled:
fp8_wgrad = fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad
ctx.save_for_backward(
inputmat_no_fp8 if weight.requires_grad and not fp8_wgrad else None,
inputmat_t if weight.requires_grad and fp8_wgrad else None,
weight,
weight_t_fp8 if fp8 else None,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
)
ctx.activation_dtype = activation_dtype
ctx.fp8 = fp8
ctx.fp8_meta = fp8_meta
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = use_bias
ctx.sequence_parallel = sequence_parallel
ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp.shape
ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group
ctx.ub_split_ag = ub_split_ag
ctx.tp_size = tp_size
ctx.requires_dgrad = inp.requires_grad
# Row Parallel Linear
if ub_split_rs:
out = rs_out
elif parallel_mode == "row" and sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif parallel_mode == "row" and tensor_parallel:
out, _ = allreduce(out, tp_group)
# [*, in_features] -> [*, out_features] except first dimension changes for SP
return out.view(-1, *inp.shape[1:-1], out.shape[-1])
@staticmethod
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
with _prepare_backward(
ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_Linear"
):
(
inputmat,
inputmat_t,
weight,
weight_t_fp8,
fwd_scale_inverses,
) = ctx.saved_tensors
if ctx.ub_split_ag:
tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1:
ctx.ub_split_ag = False
if ctx.ub_split_ag:
dim_size = list(grad_output.size())
dim_size[0] = dim_size[0] * tp_world_size
ctx.ub_obj_gradout = get_ub("proj_dgrad")
(
grad_output,
grad_output_c,
grad_output_t,
grad_bias,
) = TransformerEngineBaseModule.grad_output_preprocess(
ctx, grad_output, ctx.parallel_mode == "row"
)
# Column Parallel Linear
# Overlap input AG with dgrad
if ctx.parallel_mode == "column" and ctx.sequence_parallel:
if ctx.fp8 and not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
inputmat_t_total, handle = gather_along_last_dim(
inputmat_t, ctx.tp_group, async_op=ctx.requires_dgrad
)
else:
inputmat_total, handle = gather_along_first_dim(
inputmat, ctx.tp_group, async_op=ctx.requires_dgrad
)
else:
inputmat_t_total = inputmat_t
inputmat_total = inputmat
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (
ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
if ctx.fp8:
fp8_dtype_forward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=True
)
fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False
)
if ctx.requires_dgrad:
if ctx.fp8:
dgrad = fp8_gemm(
weight_t_fp8,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None,
ub=ctx.ub_obj_gradout if ctx.ub_split_ag else None,
)
else:
dgrad, _, _ = gemm(
weight,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NN",
grad=True,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None,
ub=ctx.ub_obj_gradout if ctx.ub_split_ag else None,
)
# Overlap dgrad-RS/AR with wgrad
if ctx.parallel_mode == "column" and ctx.sequence_parallel:
handle.wait()
dgrad, handle = reduce_scatter_along_first_dim(
dgrad, ctx.tp_group, async_op=True
)
elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)
if weight.requires_grad:
if ctx.fp8:
# WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
if ctx.ub_split_ag:
grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
wgrad = fp8_gemm(
inputmat_t_total,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
grad_output_t,
ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
accumulate=accumulate_wgrad_into_param_main_grad,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=_2X_ACC_WGRAD,
)
else:
wgrad, _, _ = gemm(
inputmat_total,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NT",
grad=True,
accumulate=accumulate_wgrad_into_param_main_grad,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
else:
# WGRAD
wgrad, grad_bias, _ = gemm(
inputmat_total,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NT",
grad=True,
use_bias=ctx.use_bias,
accumulate=accumulate_wgrad_into_param_main_grad,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
# Column Parallel Linear
if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None:
handle.wait()
if not ctx.use_bias:
grad_bias = None
return (
wgrad if weight.requires_grad else None,
None,
None,
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
grad_bias,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
class Linear(TransformerEngineBaseModule):
"""
Applies a linear transformation to the incoming data :math:`y = xA^T + b`
On NVIDIA GPUs it is a drop-in replacement for `torch.nn.Linear`.
Parameters
----------
in_features : int
size of each input sample.
out_features : int
size of each output sample.
bias : bool, default = `True`
if set to `False`, the layer will not learn an additive bias.
init_method : Callable, default = `None`
used for initializing weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
parameters_split : Tuple[str, ...], default = None
if a tuple of strings is provided, the weight and bias parameters of the
module are exposed as `N` separate `torch.nn.parameter.Parameter`s each,
split along the first dimension, where `N` is the length of the argument
and the strings contained are the names of the split parameters.
Parallelism parameters
----------------------
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
tp_size : int, default = 1
used as TP (tensor parallel) world size when TP groups are not formed during
initialization. In this case, users must call the
`set_tensor_parallel_group(tp_group)` method on the initialized module before the
forward pass to supply the tensor parallel group needed for tensor and sequence
parallel collectives.
parallel_mode : {None, 'Column', 'Row'}, default = `None`
used to decide whether this Linear layer is Column Parallel Linear or Row
Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
When set to `None`, no communication is performed.
skip_weight_param_allocation: bool, default = `False`
if set to `True`, weight parameter is not allocated and must be
passed as a keyword argument `weight` during the forward pass.
Optimization parameters
-----------------------
fuse_wgrad_accumulation : bool, default = 'False'
if set to `True`, enables fusing of creation and accumulation of
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the
output of the linear transformation :math:`y = xA^T`. This is useful when
the bias addition can be fused to subsequent operations.
params_dtype : torch.dtype, default = `torch.float32`
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
"""
def __init__(
self,
in_features: int,
out_features: int,
sequence_parallel: bool = False,
fuse_wgrad_accumulation: bool = False,
tp_group: Optional[dist_group_type] = None,
tp_size: int = 1,
get_rng_state_tracker: Optional[Callable] = None,
init_method: Optional[Callable] = None,
bias: bool = True,
return_bias: bool = False,
params_dtype: torch.dtype = torch.float32,
parallel_mode: Optional[str] = None,
skip_weight_param_allocation: bool = False,
parameters_split: Optional[Tuple[str, ...]] = None,
ub_split_rs: bool = False,
ub_split_ag: bool = False,
) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
self.use_bias = bias
self.return_bias = return_bias
self.apply_bias = bias and not return_bias
self.parameters_split = parameters_split
self.ub_split_rs = ub_split_rs
self.ub_split_ag = ub_split_ag
if ub_split_rs or ub_split_ag:
assert (
tex.userbuf_comm_available()
), "Userbuffer communication backend not available."
from typing import Union, Optional, Callable, Tuple, Dict, Any
if tp_group is None:
self.tp_size = tp_size
if tp_size == 1:
self.set_tensor_parallel_group(tp_group)
else:
self.tp_size = get_distributed_world_size(tp_group)
self.set_tensor_parallel_group(tp_group)
self.set_nccl_overlap_warning_if_tp()
self.parallel_mode = parallel_mode
assert (
self.parallel_mode in GemmParallelModes
), f"parallel_mode {parallel_mode} not supported"
if self.parallel_mode == "column":
self.out_features = divide(self.out_features, self.tp_size)
elif self.parallel_mode == "row":
self.in_features = divide(self.in_features, self.tp_size)
if init_method is None:
init_method = get_default_init_method()
self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
if not skip_weight_param_allocation:
self.register_buffer("weight_tensor",
torch.empty(
self.out_features,
self.in_features,
device=torch.cuda.current_device(),
dtype=params_dtype),
persistent=False)
initialize_affine_weight_gpu(
self.weight_tensor,
init_method,
get_rng_state_tracker,
partition_dim=1 if self.parallel_mode == "row" else 0,
stride=1,
)
if self.use_bias:
self.register_buffer("bias_tensor",
torch.empty(
self.out_features,
device=torch.cuda.current_device(),
dtype=params_dtype),
persistent=False)
else:
self.register_buffer("bias_tensor",
torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device()),
persistent=False)
with torch.no_grad():
self.bias_tensor.zero_()
if parameters_split is None:
parameters_split = ("",)
assert (
self.out_features % len(parameters_split) == 0
), f"Weight and bias params cannot be split into {len(parameters_split)} parts"
split_size = self.out_features // len(parameters_split)
self.weight_names = []
self.bias_names = []
for i, pname in enumerate(parameters_split):
wname = pname + "weight"
bname = pname + "bias"
self.register_parameter(
wname, Parameter(self.weight_tensor[i * split_size : (i+1) * split_size])
)
set_tensor_model_parallel_attributes(
tensor=getattr(self, wname),
is_parallel=True,
dim=1 if parallel_mode == "row" else 0,
stride=1,
)
if self.use_bias:
self.register_parameter(
bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size])
)
else:
self.register_buffer(bname,
torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device()),
persistent=False)
if parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1)
self.weight_names.append(wname)
self.bias_names.append(bname)
self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))
# For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM
if self.parallel_mode == "row" and self.apply_bias:
self.gemm_bias_unfused_add = True
else:
self.gemm_bias_unfused_add = False
def forward(
self,
inp: torch.Tensor,
weight: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
is_first_microbatch: Optional[bool] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
"""
Apply the linear transformation to the input.
Parameters
----------
inp : torch.Tensor
Input tensor.
weight : torch.Tensor, default = None
An optional weight tensor for the module. This argument is compulsory if module
is initialized with `skip_weight_param_allocation=True`
bias : torch.Tensor, default = None
An optional bias tensor for the module. This argument is compulsory if module
is initialized with `skip_weight_param_allocation=True` and one of `use_bias`
or `return_bias`
is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split
into microbatches. Between the microbatches of the same minibatch
the model weights are not updated. Setting this parameter indicates
whether the current microbatch is the first in a minibatch or not.
When set, this parameter enables additional optimizations:
* during FP8 training, it allows caching of the FP8 versions of
the weights
* it also allows skipping gradient accumulation during the
first microbatch (since it is the first gradient being
produced)
"""
with self.prepare_forward(inp, is_first_microbatch) as inp:
bias_tensor = (
bias if bias is not None
else self.bias if self.parameters_split is None
else self.bias_tensor if not torch.is_grad_enabled()
else self.noop_cat("bias_tensor", self.bias_names)
)
weight_tensor = (
weight if weight is not None
else self.weight if self.parameters_split is None
else self.weight_tensor if not torch.is_grad_enabled()
else self.noop_cat("weight_tensor", self.weight_names)
)
import torch
from torch.nn.parameter import Parameter
from torch.nn import init
if torch.is_grad_enabled():
linear_fn = _Linear.apply
args = []
else:
linear_fn = _Linear.forward
args = [None]
args += (
weight_tensor,
self.weight1_fp8 if self.fp8 else None,
self.weight1_t_fp8 if self.fp8 else None,
inp,
bias_tensor,
self.apply_bias and not self.gemm_bias_unfused_add,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.fp8_meta,
self.fuse_wgrad_accumulation,
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
torch.is_grad_enabled(),
self.ub_split_rs,
self.ub_split_ag,
)
out = linear_fn(*args)
import transformer_engine_extensions as tex
if self.gemm_bias_unfused_add:
out = out + cast_if_needed(bias_tensor, self.activation_dtype)
from .base import (
get_workspace,
_prepare_backward,
get_ub,
TransformerEngineBaseModule,
_2X_ACC_FPROP,
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ..fp8 import get_fp8_te_dtype
from ..jit import (
bias_gelu_fused,
bgrad_dgelu_fused,
set_jit_fusion_options,
warmup_jit_bias_gelu_all_dtypes,
)
from ..utils import (
divide,
get_default_init_method,
cast_if_needed,
check_dim_for_fp8_forward_exec,
)
from ..distributed import (
set_tensor_model_parallel_attributes,
get_distributed_world_size,
allreduce,
initialize_affine_weight_gpu,
reduce_scatter_along_first_dim,
gather_along_first_dim,
)
from ..cpp_extensions import (
fp8_gemm,
gemm,
fp8_cast_transpose_fused,
fp8_gelu,
fp8_cast_transpose_bgrad_dgelu_fused,
layernorm_fwd_fp8,
layernorm_fwd_fp8_inf,
layernorm_fwd_inf,
cast_to_fp8,
cast_from_fp8,
)
from ..constants import dist_group_type, TE_DType
if self.return_bias:
return out, cast_if_needed(bias_tensor, self.activation_dtype)
return out
__all__ = ["LayerNormMLP"]
class _LayerNormMLP(torch.autograd.Function):
......@@ -3412,159 +1146,3 @@ class LayerNormMLP(TransformerEngineBaseModule):
if self.return_layernorm_output:
return out, ln_out
return out
class _LayerNorm(torch.autograd.Function):
"""functional LayerNorm"""
@staticmethod
def forward(
ctx,
inp: torch.Tensor,
ln_weight: torch.Tensor,
ln_bias: torch.Tensor,
eps: float,
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
) -> torch.Tensor:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
assert inp.is_cuda, "TransformerEngine needs CUDA."
assert inp.shape[-1] == in_features, "LayerNorm not possible"
inputmat = inp.view((-1, in_features))
ln_out, mu, rsigma = tex.layernorm_fwd(inputmat, ln_weight,
ln_bias, eps, fwd_ln_sm_margin,
zero_centered_gamma)
ctx.save_for_backward(inputmat, ln_weight, mu, rsigma)
ctx.inp_shape = inp.shape
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma
return ln_out.view_as(inp)
@staticmethod
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
inputmat, ln_weight, mu, rsigma = ctx.saved_tensors
grad_output = grad_output.contiguous()
d_ln_out = grad_output.view(inputmat.shape)
dxmat, dgamma, dbeta = tex.layernorm_bwd(
d_ln_out, inputmat, mu, rsigma, ln_weight,
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
)
return dxmat.view(ctx.inp_shape), dgamma, dbeta, None, None, None, None
class LayerNorm(torch.nn.Module):
r"""
Applies Layer Normalization over a mini-batch of inputs as described in
the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
size :attr:`hidden_size`
Parameters
----------
hidden_size : int
size of each input sample.
eps : float, default = 1e-5
a value added to the denominator of layer normalization for numerical stability.
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
params_dtype : torch.dtype, default = `torch.float32`
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
(1 + \gamma) + \beta
"""
def __init__(
self,
hidden_size: int,
eps: float = 1e-5,
sequence_parallel: bool = False,
params_dtype: torch.dtype = torch.float32,
zero_centered_gamma: bool = False,
) -> None:
super().__init__()
self.eps = eps
self.zero_centered_gamma = zero_centered_gamma
self.weight = Parameter(
torch.empty(
hidden_size,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
self.bias = Parameter(
torch.empty(
hidden_size,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
setattr(self.weight, "sequence_parallel", sequence_parallel)
setattr(self.bias, "sequence_parallel", sequence_parallel)
self.reset_layer_norm_parameters()
# These many SMs are subtracted from the total SM count when calling forward
# and backward LayerNorm C APIs. These envvars can be used to prevent the LN
# kernels from using all SMs in the device. This is useful for cases such as
# communication overlap with LN.
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
def load_state_dict(
self,
state_dict: Mapping[str, Any],
strict: bool = True,
) -> None:
"""Override PyTorch loader to maintain backward compatibility
with previous version of LayerNorm parameter names.
"""
if "layer_norm_weight" in state_dict:
state_dict["weight"] = state_dict["layer_norm_weight"]
del state_dict["layer_norm_weight"]
if "layer_norm_bias" in state_dict:
state_dict["bias"] = state_dict["layer_norm_bias"]
del state_dict["layer_norm_bias"]
super().load_state_dict(state_dict, strict)
def reset_layer_norm_parameters(self) -> None:
"""Init LN params"""
if not self.zero_centered_gamma:
init.ones_(self.weight)
else:
init.zeros_(self.weight)
init.zeros_(self.bias)
def forward(self, inp: torch.Tensor) -> torch.Tensor:
"""LayerNorm FWD"""
# Maintain backward compatibility.
if hasattr(self, "layer_norm_weight"):
setattr(self, "weight", self.layer_norm_weight)
if hasattr(self, "layer_norm_bias"):
setattr(self, "bias", self.layer_norm_bias)
return _LayerNorm.apply(
inp,
self.weight,
self.bias,
self.eps,
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma
)
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Linear API"""
from typing import Union, Optional, Callable, Tuple, Dict, Any
import torch
from torch.nn.parameter import Parameter
import transformer_engine_extensions as tex
from .base import (
get_workspace,
_prepare_backward,
get_ub,
TransformerEngineBaseModule,
_2X_ACC_FPROP,
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ..fp8 import get_fp8_te_dtype
from ..utils import (
divide,
get_default_init_method,
cast_if_needed,
check_dim_for_fp8_forward_exec,
)
from ..distributed import (
set_tensor_model_parallel_attributes,
get_distributed_world_size,
allreduce,
initialize_affine_weight_gpu,
reduce_scatter_along_first_dim,
gather_along_first_dim,
gather_along_last_dim,
)
from ..cpp_extensions import (
fp8_gemm,
gemm,
fp8_cast_transpose_fused,
cast_to_fp8,
)
from ..constants import GemmParallelModes, dist_group_type
__all__ = ["Linear"]
class _Linear(torch.autograd.Function):
"""Linear semi-top level module
Calls custom cuda extensions.
"""
@staticmethod
def forward(
ctx,
weight: torch.Tensor,
weight_fp8: Union[torch.Tensor, None],
weight_t_fp8: Union[torch.Tensor, None],
inp: torch.Tensor,
bias: torch.Tensor,
use_bias: bool,
is_first_microbatch: Union[bool, None],
fp8: bool,
fp8_calibration: bool,
fp8_meta: Dict[str, Any],
fuse_wgrad_accumulation: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
sequence_parallel: bool,
tensor_parallel: bool,
activation_dtype: torch.dtype,
parallel_mode: Union[str, None],
is_grad_enabled: bool,
ub_split_rs: bool,
ub_split_ag: bool,
) -> torch.Tensor:
# Make sure input dimensions are compatible
in_features = weight.shape[-1]
assert inp.shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features))
assert (
not fp8 or check_dim_for_fp8_forward_exec(inputmat, weight)
), "Input and weight dimensions are not compatible for FP8 execution."
update_fp8_weights = is_first_microbatch is None or is_first_microbatch
if ub_split_rs:
tp_world_size = get_distributed_world_size(tp_group)
if tp_world_size == 1:
ub_split_rs = False
# Cast for native AMP
inputmat = cast_if_needed(inputmat, activation_dtype)
inputmat_no_fp8 = inputmat
if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if not fp8_meta["recipe"].override_linear_precision.wgrad:
if is_grad_enabled:
inputmat, inputmat_t = fp8_cast_transpose_fused(
inputmat,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
)
else:
inputmat = cast_to_fp8(
inputmat,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
)
else:
inputmat, inputmat_t = cast_to_fp8(
inputmat,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
), None
# Column Parallel Linear
if parallel_mode == "column" and sequence_parallel:
inputmat_total, _ = gather_along_first_dim(inputmat, tp_group)
else:
inputmat_total = inputmat
if fp8:
bias_dtype = (
torch.bfloat16
if activation_dtype == torch.float32
else activation_dtype
)
bias = cast_if_needed(bias, bias_dtype) if use_bias else bias
if update_fp8_weights:
if is_grad_enabled:
fp8_cast_transpose_fused(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
cast_out=weight_fp8,
transpose_out=weight_t_fp8,
)
else:
weight_t_fp8 = None
weight_fp8 = cast_to_fp8(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
)
if ub_split_rs:
ub_obj_projout = get_ub("proj_fprop")
out = ub_obj_projout.get_ubuf_output(1)
dim_size = list(inputmat_total.size())
dim_size[0] = dim_size[0] // tp_world_size
dim_size[1] = weight.size(0)
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
else:
dim_size = list(inputmat_total.size())
dim_size[1] = weight.size(0)
out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
_ = fp8_gemm(
weight_fp8,
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
inputmat_total,
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
activation_dtype,
get_workspace(),
bias=bias,
use_bias=use_bias,
use_split_accumulator=_2X_ACC_FPROP,
out=out,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else None,
ub=ub_obj_projout if ub_split_rs else None,
extra_output_tensor=rs_out if ub_split_rs else None,
)
else:
# Cast for native AMP
weight = cast_if_needed(weight, activation_dtype)
bias = cast_if_needed(bias, activation_dtype) if use_bias else bias
if fp8_calibration:
# amax of input
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = \
torch.amax(inputmat_total).float()
# amax of weight
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \
torch.amax(weight).float()
if ub_split_rs:
ub_obj_projout = get_ub("proj_fprop")
out = ub_obj_projout.get_ubuf_output(1)
dim_size = list(inputmat_total.size())
dim_size[0] = dim_size[0] // tp_world_size
dim_size[1] = weight.size(0)
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
else:
dim_size = list(inputmat_total.size())
dim_size[1] = weight.size(0)
out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
_, _, _ = gemm(
weight,
inputmat_total,
activation_dtype,
get_workspace(),
bias=bias,
use_bias=use_bias,
out=out,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else None,
ub=ub_obj_projout if ub_split_rs else None,
extra_output_tensor=rs_out if ub_split_rs else None,
)
if is_grad_enabled:
fp8_wgrad = fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad
ctx.save_for_backward(
inputmat_no_fp8 if weight.requires_grad and not fp8_wgrad else None,
inputmat_t if weight.requires_grad and fp8_wgrad else None,
weight,
weight_t_fp8 if fp8 else None,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
)
ctx.activation_dtype = activation_dtype
ctx.fp8 = fp8
ctx.fp8_meta = fp8_meta
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = use_bias
ctx.sequence_parallel = sequence_parallel
ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp.shape
ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group
ctx.ub_split_ag = ub_split_ag
ctx.tp_size = tp_size
ctx.requires_dgrad = inp.requires_grad
# Row Parallel Linear
if ub_split_rs:
out = rs_out
elif parallel_mode == "row" and sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif parallel_mode == "row" and tensor_parallel:
out, _ = allreduce(out, tp_group)
# [*, in_features] -> [*, out_features] except first dimension changes for SP
return out.view(-1, *inp.shape[1:-1], out.shape[-1])
@staticmethod
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
with _prepare_backward(
ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_Linear"
):
(
inputmat,
inputmat_t,
weight,
weight_t_fp8,
fwd_scale_inverses,
) = ctx.saved_tensors
if ctx.ub_split_ag:
tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1:
ctx.ub_split_ag = False
if ctx.ub_split_ag:
dim_size = list(grad_output.size())
dim_size[0] = dim_size[0] * tp_world_size
ctx.ub_obj_gradout = get_ub("proj_dgrad")
(
grad_output,
grad_output_c,
grad_output_t,
grad_bias,
) = TransformerEngineBaseModule.grad_output_preprocess(
ctx, grad_output, ctx.parallel_mode == "row"
)
# Column Parallel Linear
# Overlap input AG with dgrad
if ctx.parallel_mode == "column" and ctx.sequence_parallel:
if ctx.fp8 and not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
inputmat_t_total, handle = gather_along_last_dim(
inputmat_t, ctx.tp_group, async_op=ctx.requires_dgrad
)
else:
inputmat_total, handle = gather_along_first_dim(
inputmat, ctx.tp_group, async_op=ctx.requires_dgrad
)
else:
inputmat_t_total = inputmat_t
inputmat_total = inputmat
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (
ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
if ctx.fp8:
fp8_dtype_forward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=True
)
fp8_dtype_backward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=False
)
if ctx.requires_dgrad:
if ctx.fp8:
dgrad = fp8_gemm(
weight_t_fp8,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
grad_output_c,
ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None,
ub=ctx.ub_obj_gradout if ctx.ub_split_ag else None,
)
else:
dgrad, _, _ = gemm(
weight,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NN",
grad=True,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None,
ub=ctx.ub_obj_gradout if ctx.ub_split_ag else None,
)
# Overlap dgrad-RS/AR with wgrad
if ctx.parallel_mode == "column" and ctx.sequence_parallel:
handle.wait()
dgrad, handle = reduce_scatter_along_first_dim(
dgrad, ctx.tp_group, async_op=True
)
elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)
if weight.requires_grad:
if ctx.fp8:
# WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
if ctx.ub_split_ag:
grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
wgrad = fp8_gemm(
inputmat_t_total,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
grad_output_t,
ctx.fp8_meta["scaling_bwd"].scale_inv,
tex.FP8BwdTensors.GRAD_OUTPUT1,
fp8_dtype_backward,
ctx.activation_dtype,
get_workspace(),
accumulate=accumulate_wgrad_into_param_main_grad,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=_2X_ACC_WGRAD,
)
else:
wgrad, _, _ = gemm(
inputmat_total,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NT",
grad=True,
accumulate=accumulate_wgrad_into_param_main_grad,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
else:
# WGRAD
wgrad, grad_bias, _ = gemm(
inputmat_total,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NT",
grad=True,
use_bias=ctx.use_bias,
accumulate=accumulate_wgrad_into_param_main_grad,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
# Column Parallel Linear
if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None:
handle.wait()
if not ctx.use_bias:
grad_bias = None
return (
wgrad if weight.requires_grad else None,
None,
None,
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
grad_bias,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
class Linear(TransformerEngineBaseModule):
"""
Applies a linear transformation to the incoming data :math:`y = xA^T + b`
On NVIDIA GPUs it is a drop-in replacement for `torch.nn.Linear`.
Parameters
----------
in_features : int
size of each input sample.
out_features : int
size of each output sample.
bias : bool, default = `True`
if set to `False`, the layer will not learn an additive bias.
init_method : Callable, default = `None`
used for initializing weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
parameters_split : Tuple[str, ...], default = None
if a tuple of strings is provided, the weight and bias parameters of the
module are exposed as `N` separate `torch.nn.parameter.Parameter`s each,
split along the first dimension, where `N` is the length of the argument
and the strings contained are the names of the split parameters.
Parallelism parameters
----------------------
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
tp_size : int, default = 1
used as TP (tensor parallel) world size when TP groups are not formed during
initialization. In this case, users must call the
`set_tensor_parallel_group(tp_group)` method on the initialized module before the
forward pass to supply the tensor parallel group needed for tensor and sequence
parallel collectives.
parallel_mode : {None, 'Column', 'Row'}, default = `None`
used to decide whether this Linear layer is Column Parallel Linear or Row
Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
When set to `None`, no communication is performed.
skip_weight_param_allocation: bool, default = `False`
if set to `True`, weight parameter is not allocated and must be
passed as a keyword argument `weight` during the forward pass.
Optimization parameters
-----------------------
fuse_wgrad_accumulation : bool, default = 'False'
if set to `True`, enables fusing of creation and accumulation of
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the
output of the linear transformation :math:`y = xA^T`. This is useful when
the bias addition can be fused to subsequent operations.
params_dtype : torch.dtype, default = `torch.float32`
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
"""
def __init__(
self,
in_features: int,
out_features: int,
sequence_parallel: bool = False,
fuse_wgrad_accumulation: bool = False,
tp_group: Optional[dist_group_type] = None,
tp_size: int = 1,
get_rng_state_tracker: Optional[Callable] = None,
init_method: Optional[Callable] = None,
bias: bool = True,
return_bias: bool = False,
params_dtype: torch.dtype = torch.float32,
parallel_mode: Optional[str] = None,
skip_weight_param_allocation: bool = False,
parameters_split: Optional[Tuple[str, ...]] = None,
ub_split_rs: bool = False,
ub_split_ag: bool = False,
) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
self.use_bias = bias
self.return_bias = return_bias
self.apply_bias = bias and not return_bias
self.parameters_split = parameters_split
self.ub_split_rs = ub_split_rs
self.ub_split_ag = ub_split_ag
if ub_split_rs or ub_split_ag:
assert (
tex.userbuf_comm_available()
), "Userbuffer communication backend not available."
if tp_group is None:
self.tp_size = tp_size
if tp_size == 1:
self.set_tensor_parallel_group(tp_group)
else:
self.tp_size = get_distributed_world_size(tp_group)
self.set_tensor_parallel_group(tp_group)
self.set_nccl_overlap_warning_if_tp()
self.parallel_mode = parallel_mode
assert (
self.parallel_mode in GemmParallelModes
), f"parallel_mode {parallel_mode} not supported"
if self.parallel_mode == "column":
self.out_features = divide(self.out_features, self.tp_size)
elif self.parallel_mode == "row":
self.in_features = divide(self.in_features, self.tp_size)
if init_method is None:
init_method = get_default_init_method()
self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
if not skip_weight_param_allocation:
self.register_buffer("weight_tensor",
torch.empty(
self.out_features,
self.in_features,
device=torch.cuda.current_device(),
dtype=params_dtype),
persistent=False)
initialize_affine_weight_gpu(
self.weight_tensor,
init_method,
get_rng_state_tracker,
partition_dim=1 if self.parallel_mode == "row" else 0,
stride=1,
)
if self.use_bias:
self.register_buffer("bias_tensor",
torch.empty(
self.out_features,
device=torch.cuda.current_device(),
dtype=params_dtype),
persistent=False)
else:
self.register_buffer("bias_tensor",
torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device()),
persistent=False)
with torch.no_grad():
self.bias_tensor.zero_()
if parameters_split is None:
parameters_split = ("",)
assert (
self.out_features % len(parameters_split) == 0
), f"Weight and bias params cannot be split into {len(parameters_split)} parts"
split_size = self.out_features // len(parameters_split)
self.weight_names = []
self.bias_names = []
for i, pname in enumerate(parameters_split):
wname = pname + "weight"
bname = pname + "bias"
self.register_parameter(
wname, Parameter(self.weight_tensor[i * split_size : (i+1) * split_size])
)
set_tensor_model_parallel_attributes(
tensor=getattr(self, wname),
is_parallel=True,
dim=1 if parallel_mode == "row" else 0,
stride=1,
)
if self.use_bias:
self.register_parameter(
bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size])
)
else:
self.register_buffer(bname,
torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device()),
persistent=False)
if parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1)
self.weight_names.append(wname)
self.bias_names.append(bname)
self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))
# For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM
if self.parallel_mode == "row" and self.apply_bias:
self.gemm_bias_unfused_add = True
else:
self.gemm_bias_unfused_add = False
def forward(
self,
inp: torch.Tensor,
weight: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
is_first_microbatch: Optional[bool] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
"""
Apply the linear transformation to the input.
Parameters
----------
inp : torch.Tensor
Input tensor.
weight : torch.Tensor, default = None
An optional weight tensor for the module. This argument is compulsory if module
is initialized with `skip_weight_param_allocation=True`
bias : torch.Tensor, default = None
An optional bias tensor for the module. This argument is compulsory if module
is initialized with `skip_weight_param_allocation=True` and one of `use_bias`
or `return_bias`
is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split
into microbatches. Between the microbatches of the same minibatch
the model weights are not updated. Setting this parameter indicates
whether the current microbatch is the first in a minibatch or not.
When set, this parameter enables additional optimizations:
* during FP8 training, it allows caching of the FP8 versions of
the weights
* it also allows skipping gradient accumulation during the
first microbatch (since it is the first gradient being
produced)
"""
with self.prepare_forward(inp, is_first_microbatch) as inp:
bias_tensor = (
bias if bias is not None
else self.bias if self.parameters_split is None
else self.bias_tensor if not torch.is_grad_enabled()
else self.noop_cat("bias_tensor", self.bias_names)
)
weight_tensor = (
weight if weight is not None
else self.weight if self.parameters_split is None
else self.weight_tensor if not torch.is_grad_enabled()
else self.noop_cat("weight_tensor", self.weight_names)
)
if torch.is_grad_enabled():
linear_fn = _Linear.apply
args = []
else:
linear_fn = _Linear.forward
args = [None]
args += (
weight_tensor,
self.weight1_fp8 if self.fp8 else None,
self.weight1_t_fp8 if self.fp8 else None,
inp,
bias_tensor,
self.apply_bias and not self.gemm_bias_unfused_add,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.fp8_meta,
self.fuse_wgrad_accumulation,
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
torch.is_grad_enabled(),
self.ub_split_rs,
self.ub_split_ag,
)
out = linear_fn(*args)
if self.gemm_bias_unfused_add:
out = out + cast_if_needed(bias_tensor, self.activation_dtype)
if self.return_bias:
return out, cast_if_needed(bias_tensor, self.activation_dtype)
return out
......@@ -4,19 +4,15 @@
"""Transformer."""
import os
import math
import warnings
from importlib.metadata import version
from contextlib import nullcontext
from typing import Any, Callable, Optional, Tuple, Union
from pkg_resources import packaging
from typing import Any, Callable, Optional, Union
import torch
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
import transformer_engine_extensions as tex
from transformer_engine.pytorch.module import LayerNormLinear, Linear, LayerNormMLP, LayerNorm
from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm
from transformer_engine.pytorch.attention import MultiHeadAttention
from transformer_engine.pytorch.jit import (
set_jit_fusion_options,
warmup_jit_bias_dropout_add_all_dtypes,
......@@ -25,32 +21,21 @@ from transformer_engine.pytorch.jit import (
bias_dropout_add_fused_inference,
)
from transformer_engine.pytorch.utils import (
divide,
attention_mask_func,
split_tensor_along_dim,
cast_if_needed,
get_default_init_method,
get_device_compute_capability,
)
from transformer_engine.pytorch.constants import (
AttnMaskTypes,
AttnTypes,
LayerTypes,
dist_group_type,
)
from transformer_engine.pytorch.softmax import FusedScaleMaskSoftmax
from transformer_engine.pytorch.distributed import (
get_distributed_world_size,
checkpoint,
)
from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.distributed import get_distributed_world_size
_flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_version_required = packaging.version.Version("1.0.2")
warnings.filterwarnings("module", category=DeprecationWarning, module="transformer")
__all__ = ["DotProductAttention", "TransformerLayer"]
__all__ = ["TransformerLayer"]
class DropPath(torch.nn.Module):
......@@ -77,868 +62,7 @@ class DropPath(torch.nn.Module):
output = hidden_state.div(keep_prob) * random_tensor
return output
class _SplitLastDim(torch.autograd.Function):
""""""
@staticmethod
def forward(ctx,
mixed_x_layer: torch.Tensor,
num_parts: int
) -> Tuple[torch.Tensor, ...]:
return split_tensor_along_dim(mixed_x_layer, -1, num_parts)
@staticmethod
def backward(ctx,
*grad_outputs):
assert len(grad_outputs) > 0, "No gradients received for backprop!"
noop_ok = True
strides = grad_outputs[0].stride()
data_ptr = grad_outputs[0].storage().data_ptr()
shape = grad_outputs[0].shape
last_dim_size = grad_outputs[0].shape[-1]
for i, tensor in enumerate(grad_outputs):
if (tensor.stride() != strides or
tensor.shape != shape or
tensor.storage().data_ptr() != data_ptr or
tensor.storage_offset() != i * last_dim_size):
noop_ok = False
break
if noop_ok:
ret = torch.Tensor().to(grad_outputs[0].dtype)
ret = torch.Tensor().to(device=grad_outputs[0].device,
dtype=grad_outputs[0].dtype)
new_shape = list(shape)
new_shape[-1] = new_shape[-1] * len(grad_outputs)
ret.set_(grad_outputs[0].storage(),
grad_outputs[0].storage_offset(),
new_shape,
grad_outputs[0].stride()
)
return ret, None
return torch.cat(grad_outputs, dim = -1), None
class UnfusedDotProductAttention(torch.nn.Module):
"""Parallel attention w/o QKV and Proj Gemms
BMM1 -> softmax + dropout -> BMM2
"""
def __init__(
self,
norm_factor: float,
attention_dropout: float = 0.0,
attention_dropout_ctx: Optional[Callable] = nullcontext,
attn_mask_type: str = "causal",
layer_number: Optional[int] = None,
) -> None:
super().__init__()
assert (
attn_mask_type in AttnMaskTypes
), f"attn_mask_type {attn_mask_type} not supported"
self.norm_factor = norm_factor
self.attention_dropout_ctx = attention_dropout_ctx
self.layer_number = layer_number
self.scale_mask_softmax = FusedScaleMaskSoftmax(
attn_mask_type,
attention_mask_func,
)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(attention_dropout)
def forward(
self,
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""core attention fprop"""
batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]
apply_qk_layer_scaling = self.layer_number is not None and key_layer.dtype == torch.float16
# [b, np, sq, sk]
output_size = (
query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0),
)
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.reshape(
output_size[2], output_size[0] * output_size[1], -1
)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1)
# preallocting result tensor: [b * np, sq, sk]
matmul_result = torch.empty(
output_size[0] * output_size[1],
output_size[2],
output_size[3],
dtype=query_layer.dtype,
device=torch.cuda.current_device(),
)
scale = self.norm_factor
if apply_qk_layer_scaling:
scale *= self.layer_number
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_result,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0,
alpha=(1.0 / scale),
)
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# attention scores and attention mask [b, np, sq, sk]
softmax_scale = self.layer_number if apply_qk_layer_scaling else None
attention_probs = self.scale_mask_softmax(attention_scores, attention_mask, softmax_scale)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with self.attention_dropout_ctx():
attention_probs = self.attention_dropout(attention_probs)
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
output_size = (
value_layer.size(1),
value_layer.size(2),
query_layer.size(0),
value_layer.size(3),
)
# change view [sk, b * np, hn]
value_layer = value_layer.reshape(
value_layer.size(0), output_size[0] * output_size[1], -1
)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(
output_size[0] * output_size[1], output_size[2], -1
)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
context_layer = context_layer.view(seqlen, batch_size, -1)
return context_layer
class _PrepareQKVForFA(torch.autograd.Function):
"""This class converts QKV from interleaved (s, b, ...) layout
to separate contiguous q, k, v tensors in (b, s, ...) layout."""
@staticmethod
def forward(ctx,
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor
) -> torch.Tensor:
# All inputs received are non-contiguous tensors.
# The `query_layer` tensor is used to access the
# full memory region of the QKV tensor.
qkv = tex.fa_prepare_fwd(query_layer)
q, k, v = split_tensor_along_dim(qkv, 0, 3)
query_layer = torch.squeeze(q, 0)
key_layer = torch.squeeze(k, 0)
value_layer = torch.squeeze(v, 0)
return query_layer, key_layer, value_layer
@staticmethod
def backward(ctx,
dq: torch.Tensor,
dk: torch.Tensor,
dv: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
dqkv = tex.fa_prepare_bwd(dq, dk, dv)
dq, dk, dv = split_tensor_along_dim(dqkv, -1, 3)
return dq, dk, dv
def _check_if_interleaved(q, k, v):
data_ptr = q.storage().data_ptr()
check_ptrs = all(x.storage().data_ptr() == data_ptr for x in [q, k, v])
if not check_ptrs:
return False
stride = q.stride()
check_strides = all(stride == x.stride() for x in [q, k, v])
if not check_strides:
return False
shape = q.shape
check_shapes = all(shape == x.shape for x in [q, k, v])
if not check_shapes:
return False
last_dim_size = shape[-1]
check_offsets = all(i * last_dim_size == x.storage_offset()
for i, x in enumerate([q, k, v]))
return check_offsets
class FlashAttention(torch.nn.Module):
"""Dot product attention implementation by using the flash-attn package.
"""
def __init__(
self,
norm_factor: float,
attention_dropout: float = 0.0,
attention_dropout_ctx: Optional[Callable] = nullcontext,
attn_mask_type: str = "causal",
) -> None:
super().__init__()
assert (
_flash_attn_version >= _flash_attn_version_required
), f"FlashAttention minimum version {_flash_attn_version_required} is required."
assert (
attn_mask_type == "causal"
), 'FlashAttention currently only supports causal attention mask.'
self.attn_causal_mask = attn_mask_type == "causal"
self.norm_factor = norm_factor
self.attention_dropout_ctx = attention_dropout_ctx
self.attention_dropout = attention_dropout
self.deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
def forward(
self,
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""flash-attn fprop"""
assert (
(query_layer.dtype in [torch.float16, torch.bfloat16])
and (key_layer.dtype in [torch.float16, torch.bfloat16])
and (value_layer.dtype in [torch.float16, torch.bfloat16])
), 'FlashAttention currently only supports FP16 and BF16.'
assert (
query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
), 'FlashAttention currently only supports CUDA tensors.'
assert (
attention_mask is None
), 'FlashAttention currently does not support external attention mask.'
# For now just 128, will make it more general in the future
if (query_layer.shape[-1] == 128 and
query_layer.shape[0] * query_layer.shape[1] >= 512 and
_check_if_interleaved(query_layer, key_layer, value_layer)):
query_layer, key_layer, value_layer = _PrepareQKVForFA.apply(query_layer,
key_layer,
value_layer)
else:
query_layer, key_layer, value_layer = [x.transpose(0,1).contiguous()
for x in (query_layer, key_layer, value_layer)]
batch_size, seqlen = query_layer.shape[0], query_layer.shape[1]
# [b, sq, np, hn]
query_layer, key_layer, value_layer = [
x.view(x.shape[0] * x.shape[1], *x.shape[2:])
for x in [query_layer, key_layer, value_layer]
]
max_seqlen = seqlen
cu_seqlens = torch.arange(
0,
(batch_size + 1) * seqlen,
step=seqlen,
dtype=torch.int32,
device=query_layer.device)
with self.attention_dropout_ctx():
output = flash_attn_unpadded_func(
query_layer, key_layer, value_layer, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen,
self.attention_dropout if self.training else 0.0,
softmax_scale=1.0/self.norm_factor, causal=self.attn_causal_mask,
deterministic=self.deterministic,
)
# [(b sq), np, hn] -> [sq, b, (np hn)]
return output.view(batch_size, seqlen, -1).transpose(0, 1).contiguous()
class DotProductAttention(torch.nn.Module):
"""Allows the model to jointly attend to information from different
representation subspaces as described in the paper:
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
.. note::
Argument :attr:`attention_mask` will be ignored in the `forward` call when
:attr:`attn_mask_type` is set to `"causal"`.
.. warning::
For the default attention mechanism, this module executes a non-deterministic version of
`flash-attn <https://github.com/ksivaman/flash-attention>`_ whenever possible in order to
achieve optimal performance. To observe deterministic behavior, set the environment
variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order to disable
`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`.
Parameters
----------
num_attention_heads : int
number of attention heads in the transformer layer.
kv_channels : int
number of key-value channels.
attention_dropout: float, default = 0.0
dropout probability for the dropout op during multi-head attention.
attn_mask_type: {'causal', 'padding'}, default = `causal`
type of attention mask passed into softmax operation.
layer_number: int, default = `None`
layer number of the current `DotProductAttention` when multiple such modules
are concatenated, for instance in consecutive transformer blocks.
Parallelism parameters
----------------------
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
tp_size : int, default = 1
tensor parallel world size.
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
"""
def __init__(
self,
num_attention_heads: int,
kv_channels: int,
attention_dropout: float = 0.0,
attn_mask_type: str = "causal",
sequence_parallel: bool = False,
tp_size: int = 1,
get_rng_state_tracker: Optional[Callable] = None,
tp_group: Optional[dist_group_type] = None,
layer_number: Optional[int] = None,
) -> None:
super().__init__()
tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
self.tp_group = tp_group
self.get_rng_state_tracker = get_rng_state_tracker
projection_size = kv_channels * num_attention_heads
self.hidden_size_per_partition = divide(projection_size, tp_size)
self.hidden_size_per_attention_head = divide(
projection_size, num_attention_heads
)
if sequence_parallel or get_rng_state_tracker is None:
attention_dropout_ctx = nullcontext
else:
attention_dropout_ctx = get_rng_state_tracker().fork
norm_factor = math.sqrt(self.hidden_size_per_attention_head)
self.device_compute_capability = get_device_compute_capability()
self.use_flash_attention = (
int(os.getenv("NVTE_FLASH_ATTN", "1"))
and attn_mask_type == "causal"
and self.device_compute_capability >= 8.0
)
attn_kwargs = {
"attention_dropout": attention_dropout,
"attention_dropout_ctx": attention_dropout_ctx,
"attn_mask_type": attn_mask_type,
}
if self.use_flash_attention:
self.flash_attention = FlashAttention(norm_factor, **attn_kwargs)
# Instantiating both types since use of flash-attn
# might be ruled out due to forward inputs.
self.unfused_attention = UnfusedDotProductAttention(
norm_factor, **attn_kwargs, layer_number=layer_number)
def _checkpointed_attention_forward(
self,
attention_func: Callable,
*forward_args: Tuple[torch.Tensor, ...],
) -> torch.Tensor:
"""Forward method with activation checkpointing."""
def custom_forward(*inputs):
return attention_func(*inputs)
hidden_states = checkpoint(
custom_forward,
False,
self.get_rng_state_tracker,
self.tp_group,
*forward_args,
)
return hidden_states
def forward(
self,
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
checkpoint_core_attention: bool = False,
) -> torch.Tensor:
"""
Dot Product Attention Layer.
.. note::
Argument :attr:`attention_mask` will be ignored when :attr:`attn_mask_type`
is set to `"causal"`.
.. note::
Input tensors :attr:`query_layer`, :attr:`key_layer`, and :attr:`value_layer`
must each be of shape (:attr:`sequence_length`, :attr:`batch_size`,
:attr:`num_attention_heads`, :attr:`kv_channels`). Output of shape
(:attr:`sequence_length`, :attr:`batch_size`, :attr:`num_attention_heads`
* :attr:`kv_channels`) is returned.
Parameters
----------
query_layer : torch.Tensor
Query tensor.
key_layer : torch.Tensor
Key tensor.
value_layer : torch.Tensor
Value tensor.
attention_mask : Optional[torch.Tensor], default = `None`
Boolean tensor used to mask out softmax input when not using flash-attn.
checkpoint_core_attention : bool, default = `False`
If true, forward activations for attention are recomputed
during the backward pass in order to save memory that would
otherwise be occupied to store the forward activations until
backprop.
"""
use_flash_attention = self.use_flash_attention
if (query_layer.dtype not in [torch.bfloat16, torch.float16]
or key_layer.dtype not in [torch.bfloat16, torch.float16]
or value_layer.dtype not in [torch.bfloat16, torch.float16]
or (self.device_compute_capability == 8.6 and key_layer.shape[-1] > 64)
):
use_flash_attention = False
if is_in_onnx_export_mode():
use_flash_attention = False
if use_flash_attention:
if checkpoint_core_attention:
return self._checkpointed_attention_forward(self.flash_attention,
query_layer,
key_layer,
value_layer)
return self.flash_attention(query_layer, key_layer, value_layer)
if checkpoint_core_attention:
return self._checkpointed_attention_forward(
self.unfused_attention,
query_layer,
key_layer,
value_layer,
attention_mask,
)
return self.unfused_attention(query_layer, key_layer, value_layer, attention_mask)
class MultiHeadAttention(torch.nn.Module):
"""Parallel attention w/o QKV and Proj Gemms
BMM1 -> softmax + dropout -> BMM2
"""
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
kv_channels: int,
attention_dropout: float,
layernorm_epsilon: float,
init_method: Callable,
output_layer_init_method: Callable,
layer_number: Optional[int] = None,
attn_mask_type: str = "causal",
tp_group: Optional[dist_group_type] = None,
tp_size: int = 1,
fuse_wgrad_accumulation: bool = False,
get_rng_state_tracker: Optional[Callable] = None,
sequence_parallel: bool = False,
params_dtype: torch.dtype = torch.float32,
return_layernorm_output: bool = False,
input_layernorm: bool = False,
attention_type: str = "self",
set_parallel_mode: bool = False,
fuse_qkv_params: bool = False,
zero_centered_gamma: bool = False,
qkv_weight_interleaved: bool = True,
ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False,
ub_split_rs: bool = False,
ub_split_ag: bool = False,
bias: bool = True,
) -> None:
super().__init__()
self.layer_number = (layer_number,)
self.input_layernorm = input_layernorm
self.attention_type = attention_type
self.get_rng_state_tracker = get_rng_state_tracker
self.tp_group = tp_group
self.return_layernorm_output = return_layernorm_output
self.params_dtype = params_dtype
self.init_method = init_method
self.attn_mask_type = attn_mask_type
if not fuse_qkv_params:
qkv_weight_interleaved = False
self.qkv_weight_interleaved = qkv_weight_interleaved
assert (
attention_type in AttnTypes
), f"attention_type {attention_type} not supported"
tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
self.tp_size = tp_size
self.sequence_parallel = (tp_size > 1) and sequence_parallel
self.hidden_size_per_attention_head = kv_channels
self.num_attention_heads_per_partition = divide(num_attention_heads, tp_size)
common_gemm_kwargs = {
"fuse_wgrad_accumulation": fuse_wgrad_accumulation,
"tp_group": tp_group,
"tp_size": tp_size,
"get_rng_state_tracker": get_rng_state_tracker,
"sequence_parallel": sequence_parallel,
"params_dtype": params_dtype,
}
qkv_parallel_mode = "column" if set_parallel_mode else None
if self.attention_type == "self":
if self.input_layernorm:
self.layernorm_qkv = LayerNormLinear(
hidden_size,
3 * hidden_size,
eps=layernorm_epsilon,
init_method=init_method,
bias=bias,
return_bias=False,
parallel_mode=qkv_parallel_mode,
return_layernorm_output=return_layernorm_output,
parameters_split=("query_", "key_", "value_") if not fuse_qkv_params else None,
zero_centered_gamma=zero_centered_gamma,
ub_bulk_wgrad=ub_bulk_wgrad,
ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_ag=ub_split_ag,
**common_gemm_kwargs,
)
else:
self.qkv = Linear(
hidden_size,
3 * hidden_size,
init_method=init_method,
bias=bias,
return_bias=False,
parallel_mode=qkv_parallel_mode,
parameters_split=("query_", "key_", "value_") if not fuse_qkv_params else None,
**common_gemm_kwargs,
)
else:
if self.input_layernorm:
self.layernorm_query = LayerNormLinear(
hidden_size,
hidden_size,
eps=layernorm_epsilon,
init_method=init_method,
bias=bias,
return_bias=False,
parallel_mode=qkv_parallel_mode,
return_layernorm_output=return_layernorm_output,
zero_centered_gamma=zero_centered_gamma,
ub_bulk_wgrad=ub_bulk_wgrad,
ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_ag=ub_split_ag,
**common_gemm_kwargs,
)
else:
self.query_layer = Linear(
hidden_size,
hidden_size,
init_method=init_method,
bias=bias,
return_bias=False,
parallel_mode=qkv_parallel_mode,
**common_gemm_kwargs,
)
self.key_value = Linear(
hidden_size,
2 * hidden_size,
init_method=init_method,
bias=bias,
return_bias=False,
parallel_mode=qkv_parallel_mode,
parameters_split=("key_", "value_") if not fuse_qkv_params else None,
**common_gemm_kwargs,
)
# Attention.
self.core_attention = DotProductAttention(
num_attention_heads,
kv_channels,
attention_dropout,
tp_size=tp_size,
get_rng_state_tracker=get_rng_state_tracker,
attn_mask_type=attn_mask_type,
sequence_parallel=sequence_parallel,
tp_group=tp_group,
layer_number=layer_number,
)
# Linear
self.proj = Linear(
hidden_size,
hidden_size,
init_method=output_layer_init_method,
bias=bias,
return_bias=True,
parallel_mode="row" if set_parallel_mode else None,
ub_split_rs=ub_split_rs,
ub_split_ag=ub_split_ag,
**common_gemm_kwargs,
)
def _allocate_memory(
self, inference_max_sequence_len: int, batch_size: int
) -> torch.Tensor:
return torch.empty(
inference_max_sequence_len,
batch_size,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
dtype=self.params_dtype,
device=torch.cuda.current_device(),
)
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
"""Set TP group"""
self.tp_group = tp_group
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_output: Optional[torch.Tensor] = None,
is_first_microbatch: Optional[bool] = None,
checkpoint_core_attention: bool = False,
inference_params: Optional[Any] = None,
) -> Tuple[Union[torch.Tensor, None], ...]:
"""MultiHeadAttention FWD"""
# hidden_states: [sq, b, h]
if self.attn_mask_type != "causal" and attention_mask is not None:
assert (
attention_mask.dtype == torch.bool
), "Attention mask must be a boolean tensor"
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
if inference_params and self.layer_number is not None:
if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_len
inf_max_batch_size = inference_params.max_batch_size
inference_key_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size
)
inference_value_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size
)
inference_params.key_value_memory_dict[self.layer_number] = (
inference_key_memory,
inference_value_memory,
)
else:
(
inference_key_memory,
inference_value_memory,
) = inference_params.key_value_memory_dict[self.layer_number]
# =====================
# Query, Key, and Value
# =====================
if self.attention_type == "self":
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
if self.input_layernorm:
layernorm_qkv_outputs = self.layernorm_qkv(
hidden_states,
is_first_microbatch=is_first_microbatch,
)
if self.return_layernorm_output:
mixed_x_layer, layernorm_output = layernorm_qkv_outputs
else:
mixed_x_layer = layernorm_qkv_outputs
else:
mixed_x_layer = self.qkv(
hidden_states,
is_first_microbatch=is_first_microbatch,
)
if self.qkv_weight_interleaved:
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
# split along last dimension
split_dim = -1
else:
# [sq, b, (np * 3 * hn)] --> [sq, b, 3 * np, hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + (
3 * self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
# split along second last dimension
split_dim = -2
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# mixed_x_layer --> 3 [sq, b, np, hn]
if split_dim == -1 and not is_in_onnx_export_mode():
query_layer, key_layer, value_layer = _SplitLastDim.apply(mixed_x_layer, 3)
else:
query_layer, key_layer, value_layer = split_tensor_along_dim(
mixed_x_layer, split_dim, 3
)
else:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer = self.key_value(
encoder_output,
is_first_microbatch=is_first_microbatch,
)
if self.qkv_weight_interleaved:
# [sq, b, (np * 2 * hn)] --> [sq, b, np, 2 * hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
2 * self.hidden_size_per_attention_head,
)
# split along last dimension
split_dim = -1
else:
# [sq, b, (np * 2 * hn)] --> [sq, b, 2 * np, hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + (
2 * self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
# split along second last dimension
split_dim = -2
mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
# mixed_kv_layer --> 2 [sk, b, np, hn]
if split_dim == -1 and not is_in_onnx_export_mode():
key_layer, value_layer = _SplitLastDim.apply(mixed_kv_layer, 2)
else:
key_layer, value_layer = split_tensor_along_dim(mixed_kv_layer, split_dim, 2)
# Attention head [sq, b, h] --> [sq, b, hp]
if self.input_layernorm:
layernorm_query_outputs = self.layernorm_query(
hidden_states,
is_first_microbatch=is_first_microbatch,
)
if self.return_layernorm_output:
query_layer, layernorm_output = layernorm_query_outputs
else:
query_layer = layernorm_query_outputs
else:
query_layer = self.query_layer(
hidden_states,
is_first_microbatch=is_first_microbatch,
)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape = query_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
query_layer = query_layer.view(*new_tensor_shape)
# ==================================
# Adjust key and value for inference
# ==================================
if inference_params and self.layer_number is not None:
batch_start = inference_params.batch_size_offset
batch_end = batch_start + key_layer.size(1)
assert batch_end <= inference_key_memory.size(1)
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + key_layer.size(0)
assert sequence_end <= inference_key_memory.size(0)
# Copy key and values.
inference_key_memory[
sequence_start:sequence_end, batch_start:batch_end, ...
] = key_layer
inference_value_memory[
sequence_start:sequence_end, batch_start:batch_end, ...
] = value_layer
key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
value_layer = inference_value_memory[
:sequence_end, batch_start:batch_end, ...
]
# ==================================
# core attention computation
# ==================================
context_layer = self.core_attention(
query_layer,
key_layer,
value_layer,
attention_mask,
checkpoint_core_attention=checkpoint_core_attention,
)
# =================
# Output. [sq, b, h]
# =================
attention_output, attention_bias = self.proj(
context_layer, is_first_microbatch=is_first_microbatch
)
if self.input_layernorm and self.return_layernorm_output:
return attention_output, attention_bias, layernorm_output
return attention_output, attention_bias
class TransformerLayer(torch.nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment