Commit 5b6ef054 authored by yuguo's avatar yuguo
Browse files
parents 76060570 a7eeb28b
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
from transformer_engine.jax.flax import extend_logical_axis_rules
from transformer_engine.jax.sharding import global_shard_guard, MeshResource
LOGICAL_RULES = [
[(("a1", None), ("a2", "ma2")), False],
[(("a1", None), ("a2", "ma2"), ("a3", ("ma31", "ma32"))), True],
[(("a1", None), ("a2", "ma2"), ("a3", "ma31"), ("a3", "ma32")), False],
[(("a1", None), ("a2", "ma2"), ("batch", "batch_1200234")), True],
[(("a1", None), ("a2", "ma2"), ("a2", "ma1"), ("batch", "model"), ("batch", "data")), True],
]
MeshS = [
MeshResource(),
MeshResource("data", None),
MeshResource(None, "model"),
MeshResource("data", "model"),
]
class TestShardingSideAPI:
@pytest.mark.parametrize("base_rules,need_assert", LOGICAL_RULES)
@pytest.mark.parametrize("sr", MeshS)
def test_extend_logical_axis_rules(self, base_rules, need_assert, sr):
with global_shard_guard(sr):
try:
target_te_rules = extend_logical_axis_rules(tuple())
extended_rules = extend_logical_axis_rules(base_rules)
assert extended_rules == (*base_rules, *target_te_rules)
assert not need_assert
except AssertionError as ae:
assert need_assert, f"{ae.args}"
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tests for the softmax primitives"""
from contextlib import nullcontext
from dataclasses import dataclass
from functools import wraps
import jax
import jax.numpy as jnp
import pytest
from jax import lax
from jax import nn
from jax import value_and_grad, jit
from jax.typing import DTypeLike
from utils import assert_allclose
from transformer_engine.jax.cpp_extensions import is_softmax_kernel_available
from transformer_engine.jax.softmax import SoftmaxType, softmax
def catch_unsupported(method):
"""
The unsupported case should raise error instead of running it incorrectly.
This helper function is to check if the unsupported case raises the assertion error.
"""
@wraps(method)
def wrapper(self, *args, **kwargs):
if not self._is_support():
assertion_checker = pytest.raises(AssertionError)
else:
assertion_checker = nullcontext()
with assertion_checker:
return method(self, *args, **kwargs)
return wrapper
@dataclass
class SoftmaxRunner:
"""
Softmax runner
"""
batch_size: int
max_seqlen_q: int
max_seqlen_kv: int
num_heads: int
scale_factor: float
softmax_type: SoftmaxType
dtype: DTypeLike
@staticmethod
def reference_softmax(logits, mask, scale_factor, **_):
"""
Jax softmax as the reference
"""
if mask is not None:
logits += lax.select(
mask > 0,
jnp.full(mask.shape, -1e10).astype(logits.dtype),
jnp.full(mask.shape, 0.0).astype(logits.dtype),
)
return nn.softmax(logits * scale_factor)
def _is_support(self):
return is_softmax_kernel_available(
self.softmax_type,
self.batch_size,
self.num_heads,
self.max_seqlen_q,
self.max_seqlen_kv,
self.dtype,
)
def _setup_inputs(self):
key = jax.random.PRNGKey(0)
logits_key, mask_key = jax.random.split(key, 2)
logits_shape = (self.batch_size, self.num_heads, self.max_seqlen_q, self.max_seqlen_kv)
mask_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
self.logits = jax.random.uniform(logits_key, logits_shape, self.dtype, -1.0)
match self.softmax_type:
case SoftmaxType.SCALED:
self.mask = None
case SoftmaxType.SCALED_MASKED:
self.mask = jax.random.bernoulli(mask_key, shape=mask_shape).astype(jnp.uint8)
case SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
self.mask = (1.0 - jnp.tril(jnp.ones_like(self.logits))).astype(jnp.uint8)
case _:
raise ValueError(f"Unknown {self.softmax_type=}")
@catch_unsupported
def test_forward(self):
"""
Test transformer_engine.jax.softmax.softmax fwd rule
"""
self._setup_inputs()
primitive_out = softmax(self.logits, self.mask, self.scale_factor, self.softmax_type)
reference_out = __class__.reference_softmax(self.logits, self.mask, self.scale_factor)
assert_allclose(primitive_out, reference_out, dtype=self.dtype)
@catch_unsupported
def test_backward(self):
"""
Test transformer_engine.jax.softmax.softmax bwd rule
"""
self._setup_inputs()
def grad_func(func, *args, **kwargs):
fwd_out = func(*args, **kwargs)
return jnp.mean(fwd_out, dtype=jnp.float32).astype(self.dtype)
args = [self.logits, self.mask]
kwargs = {
"scale_factor": self.scale_factor,
"softmax_type": self.softmax_type,
}
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
jitted_primitive = jit(
value_and_grad(lambda logits, *args: grad_func(softmax, logits, *args, **kwargs), (0,))
)
jitted_reference = jit(
value_and_grad(
lambda logits, *args: grad_func(
__class__.reference_softmax, logits, *args, **kwargs
),
(0,),
)
)
primitive_out, (primitive_grad_logits,) = jitted_primitive(*args)
reference_out, (reference_grad_logits,) = jitted_reference(*args)
assert_allclose(primitive_out, reference_out, dtype=self.dtype)
assert_allclose(primitive_grad_logits, reference_grad_logits, dtype=self.dtype)
@pytest.mark.parametrize(
"b, s_q, s_kv, h",
[
pytest.param(8, 16, 16, 16, id="8-16-16-16"),
pytest.param(8, 512, 512, 16, id="8-512-512-16"),
pytest.param(2, 8, 16384, 8, id="2-8-16384-8"),
],
)
@pytest.mark.parametrize("scale_factor", [0.125])
@pytest.mark.parametrize(
"softmax_type",
[
pytest.param(SoftmaxType.SCALED, id="SCALED"),
pytest.param(SoftmaxType.SCALED_MASKED, id="SCALED_MASKED"),
pytest.param(SoftmaxType.SCALED_UPPER_TRIANG_MASKED, id="SCALED_UPPER_TRIANG_MASKED"),
],
)
@pytest.mark.parametrize(
"dtype",
[
pytest.param(jnp.bfloat16, id="BF16"),
pytest.param(jnp.float16, id="FP16"),
],
)
class TestSoftmax:
"""
Test transformer_engine.jax.softmax.softmax
"""
@staticmethod
def test_forward(b, s_q, s_kv, h, scale_factor, softmax_type, dtype):
"""
Test forward with parameterized configs
"""
runner = SoftmaxRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype)
runner.test_forward()
@staticmethod
def test_backward(b, s_q, s_kv, h, scale_factor, softmax_type, dtype):
"""
Test forward with parameterized configs
"""
runner = SoftmaxRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype)
runner.test_backward()
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Utility for the TE layer tests"""
import functools
import math
import operator
from typing import Any, Callable, Dict, Tuple, Sequence, Union, Iterable, Optional
import os
import jax
import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from flax.linen.attention import combine_masks
from jax import lax, vmap
from jax import nn as jax_nn
from jax import random as jax_random
from transformer_engine.jax.attention import (
AttnMaskType,
canonicalize_attn_mask_type,
make_swa_mask,
)
from transformer_engine.jax.fp8 import DType as TEDType
PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = Any
PrecisionLike = Union[
None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
Initializer = Callable[[PRNGKey, Shape, DType], Array]
# Enables verbose printing of tensor numerics for debug.
NVTE_DEBUG_NUMERICS = bool(int(os.getenv("NVTE_DEBUG_NUMERICS", 0)))
def is_devices_enough(required):
"""
Check if the available GPUs is enough
"""
return len(jax.devices()) >= required
def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]:
# Generate broadcast dims for drop_path.
drop_path_shape = list(range(0, len(shape)))
drop_path_shape.pop(batch_dim)
return drop_path_shape
def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
# A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
return tuple(ax if ax >= 0 else ndim + ax for ax in axes)
def _canonicalize_tuple(x):
if isinstance(x, Iterable):
return tuple(x)
return (x,)
def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
"""Convert a string to an activation function."""
if fn_or_string == "linear":
return lambda x: x
if isinstance(fn_or_string, str):
return getattr(nn, fn_or_string)
if callable(fn_or_string):
return fn_or_string
raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")
def combine_biases(*masks: Optional[Array]):
"""Combine attention biases.
Args:
*masks: set of attention bias arguments to combine, some can be None.
Returns:
Combined mask, reduced by summation, returns None if no masks given.
"""
masks = [m for m in masks if m is not None]
if not masks:
return None
assert all(
map(lambda x: x.ndim == masks[0].ndim, masks)
), f"masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}"
mask, *other_masks = masks
for other_mask in other_masks:
mask = mask + other_mask
return mask
class DotProductAttention(nn.Module):
transpose_batch_sequence: bool = True
scale_attn_logits: bool = True
dropout_rate: float = 0.0
dtype: DType = jnp.float32
float32_logits: bool = False
"""Computes dot-product attention given query, key, and value.
This is the core function for applying attention based on
https://arxiv.org/abs/1706.03762. It calculates the attention weights given
query and key and combines the values using the attention weights.
Args:
dropout_rate: dropout rate
dtype: the data type used to allocate the initial parameters (default: float32).
float32_logits: bool, if True then compute logits in float32 to avoid
numerical issues with bfloat16.
"""
@nn.compact
def __call__(
self,
query: Array,
key: Array,
value: Array,
bias: Optional[Array] = None,
deterministic: bool = False,
):
"""
Args:
query: queries for calculating attention with shape of `[batch, q_length,
num_heads, qk_depth_per_head]`.
key: keys for calculating attention with shape of `[batch, kv_length,
num_gqa_groups, qk_depth_per_head]`.
value: values to be used in attention with shape of `[batch, kv_length,
num_gqa_groups, v_depth_per_head]`.
bias: bias for the attention weights. This should be broadcastable to the
shape `[batch, num_heads, q_length, kv_length]` This can be used for
incorporating causal masks, padding masks, proximity bias, etc.
dropout_rng: JAX PRNGKey: to be used for dropout
deterministic: bool, deterministic or not (to apply dropout)
Returns:
Output of shape `[batch, length, num_heads, v_depth_per_head]`.
"""
assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank."
batch_dim = 1 if self.transpose_batch_sequence else 0
assert (
query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim]
), "q, k, v batch dims must match."
sequence_dim = 0 if self.transpose_batch_sequence else 1
assert key.shape[sequence_dim] == value.shape[sequence_dim], "k, v lengths must match."
assert key.shape[-2] == value.shape[-2], "k, v num_heads must match."
assert query.shape[-1] == key.shape[-1], "q, k head_dim must match."
if self.scale_attn_logits:
head_dim = query.shape[-1]
depth_scaling = jnp.sqrt(head_dim).astype(self.dtype)
query = query / depth_scaling
# Casting logits and softmax computation for float32 for model stability.
if self.float32_logits:
query = query.astype(jnp.float32)
key = key.astype(jnp.float32)
# `attn_weights`: [batch, num_heads, groups, q_length, kv_length]
h_q, h_kv = query.shape[-2], key.shape[-2]
assert (h_q % h_kv == 0) and (h_q >= h_kv)
group_size = h_q // h_kv
grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1]))
if self.transpose_batch_sequence:
attn_weights = jnp.einsum("qbhgd,kbhd->bhgqk", grouped_query, key)
else:
attn_weights = jnp.einsum("bqhgd,bkhd->bhgqk", grouped_query, key)
# reshape back to normal DPA shape for bias/softmax/dropout
b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape
attn_weights_without_groups_shape = (b, h * g, q, k)
attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)
# Apply attention bias: masking, dropout, proximity bias, etc.
if bias is not None:
attn_weights = attn_weights + bias.astype(attn_weights.dtype)
# Normalize the attention weights across `kv_length` dimension.
attn_weights = jax_nn.softmax(attn_weights).astype(self.dtype)
# Apply attention dropout.
if not deterministic and self.dropout_rate > 0.0:
keep_prob = 1.0 - self.dropout_rate
# T5 broadcasts along the "length" dim, but unclear which one that
# corresponds to in positional dimensions here, assuming query dim.
dropout_shape = list(attn_weights.shape)
dropout_rng = self.make_rng("dropout")
keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=self.dtype)
attn_weights = attn_weights * multiplier
attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)
attn_weights = attn_weights.astype(value.dtype)
# Take the linear combination of `value`.
if self.transpose_batch_sequence:
return jnp.einsum("bhgqk,kbhd->qbhgd", attn_weights, value).reshape(query.shape)
return jnp.einsum("bhgqk,bkhd->bqhgd", attn_weights, value).reshape(query.shape)
class DenseGeneral(nn.Module):
"""A linear transformation with flexible axes and FP8 support.
Attributes:
features: tuple with numbers of output features.
axis: tuple with axes to apply the transformation on.
dtype: the data type used to allocate the initial parameters (default: float32).
kernel_init: initializer function for the weight matrix.
use_bias: whether to add a bias to the output (default: False).
bias_init: initializer function for the bias vector.
"""
features: Union[Iterable[int], int]
axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32
kernel_init: Initializer = None
kernel_axes: Tuple[str, ...] = ()
use_bias: bool = False
bias_init: Initializer = nn.initializers.zeros
bias_axes: Tuple[str, ...] = ()
def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype
)
super().__post_init__()
@nn.compact
def __call__(self, inputs: Array) -> Array:
"""Applies a linear transformation to the inputs along multiple dimensions.
Args:
inputs: The nd-array to be transformed.
Returns:
The transformed input.
"""
input_dtype = inputs.dtype
features = _canonicalize_tuple(self.features)
axis = _canonicalize_tuple(self.axis)
inputs = jnp.asarray(inputs, self.dtype)
axis = _normalize_axes(axis, inputs.ndim)
kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]), np.prod(features))
kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_param_shape, self.dtype, axes=self.kernel_axes
)
kernel = jnp.asarray(kernel, input_dtype)
kernel = jnp.reshape(kernel, kernel_shape)
if self.use_bias:
bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, self.features, self.dtype, axes=self.bias_axes
)
bias = bias.astype(input_dtype)
else:
bias = None
contract_ind = tuple(range(0, len(axis)))
y = lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ())))
y = y.astype(input_dtype)
if bias is not None:
y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
return y
class MlpBlock(nn.Module):
"""Transformer MLP / feed-forward block.
Attributes:
intermediate_dim: Shared dimension of hidden layers.
activations: Type of activations for each layer. Each element is either
'linear', a string function name in flax.linen, or a function.
kernel_init: Kernel function, passed to the dense layers.
deterministic: Whether the dropout layers should be deterministic.
intermediate_dropout_rate: Dropout rate used after the intermediate layers.
dtype: the data type used to allocate the initial parameters (default: float32).
"""
transpose_batch_sequence: bool
intermediate_dim: int = 2048
activations: Sequence[Union[str, Callable]] = ("relu",)
kernel_init: Initializer = None
intermediate_dropout_rate: float = 0.1
intermediate_dropout_dims: Sequence[int] = ()
use_bias: bool = False
dtype: Any = jnp.float32
fuse_wi: bool = True
def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype
)
super().__post_init__()
@nn.compact
def __call__(self, inputs, deterministic: bool = False):
"""Applies Transformer MlpBlock module."""
# Iterate over specified MLP input activation functions.
# e.g. ('relu',) or ('gelu', 'linear') for gated-gelu.
activations = []
if self.fuse_wi:
dense_name = "wi"
num_activations = len(self.activations)
x = DenseGeneral(
self.intermediate_dim * num_activations,
dtype=self.dtype,
kernel_init=self.kernel_init,
kernel_axes=("embed", "mlp"),
use_bias=self.use_bias,
bias_axes="mlp",
name=dense_name,
)(inputs)
x = jnp.split(x, num_activations, axis=-1)
for idx, act_fn in enumerate(self.activations):
x_i = _convert_to_activation_function(act_fn)(x[idx])
activations.append(x_i)
else:
for idx, act_fn in enumerate(self.activations):
dense_name = "wi" if len(self.activations) == 1 else f"wi_{idx}"
x = DenseGeneral(
self.intermediate_dim,
dtype=self.dtype,
kernel_init=self.kernel_init,
kernel_axes=("embed", "mlp"),
use_bias=self.use_bias,
bias_axes="mlp",
name=dense_name,
)(inputs)
x = _convert_to_activation_function(act_fn)(x)
activations.append(x)
# Take elementwise product of above intermediate activations.
x = functools.reduce(operator.mul, activations)
# Apply dropout and final dense output projection.
x = nn.Dropout(
rate=self.intermediate_dropout_rate, broadcast_dims=self.intermediate_dropout_dims
)(
x, deterministic=deterministic
) # Broadcast along length.
if self.transpose_batch_sequence:
x = nn_partitioning.with_sharding_constraint(x, ("length", "batch", "mlp"))
else:
x = nn_partitioning.with_sharding_constraint(x, ("batch", "length", "mlp"))
output = DenseGeneral(
inputs.shape[-1],
dtype=self.dtype,
kernel_init=self.kernel_init,
kernel_axes=("mlp", "embed"),
use_bias=self.use_bias,
bias_axes="embed",
name="wo",
)(x)
assert (
output.dtype == inputs.dtype
), f"input.dtype={input.dtype}, output.dtype={output.dtype}"
return output
def apply_rotary_pos_emb_alternate(
inputs: jnp.ndarray,
position: jnp.ndarray,
min_timescale: int = 1,
max_timescale: int = 10000,
):
embedding_dim = inputs.shape[-1]
half_embedding_dim = embedding_dim // 2
fraction = 2 * jnp.arange(0, half_embedding_dim) / embedding_dim
timescale = min_timescale * (max_timescale / min_timescale) ** fraction
timescale = jnp.expand_dims(timescale, axis=tuple(range(inputs.ndim - 1)))
position = jnp.expand_dims(position, axis=tuple(range(2, inputs.ndim)))
sinusoid_inp = position / timescale
sin = jnp.sin(sinusoid_inp)
cos = jnp.cos(sinusoid_inp)
first_half, second_half = jnp.split(inputs, 2, axis=-1)
first_part = first_half * cos - second_half * sin
second_part = second_half * cos + first_half * sin
first_part = first_part.astype(inputs.dtype)
second_part = second_part.astype(inputs.dtype)
return jnp.concatenate([first_part, second_part], axis=-1)
def apply_rotary_pos_emb_consecutive(
inputs: jnp.ndarray,
position: jnp.ndarray,
min_timescale: int = 1,
max_timescale: int = 10000,
):
embedding_dim = inputs.shape[-1]
half_embedding_dim = embedding_dim // 2
fraction = 2 * jnp.arange(0, half_embedding_dim) / embedding_dim
inputs_shifted_left = jnp.concatenate([inputs[..., 1:], inputs[..., :1]], axis=-1)
inputs_shifted_right = jnp.concatenate([inputs[..., -1:], inputs[..., :-1]], axis=-1)
inputs_shifted = jax.lax.select(
jnp.tile(
jnp.mod(jnp.arange(embedding_dim, dtype=jnp.int32), 2),
inputs.shape[:-1] + (1,),
),
inputs_shifted_right,
inputs_shifted_left,
)
fraction = jnp.repeat(fraction, 2)
timescale = min_timescale * (max_timescale / min_timescale) ** fraction
position = jnp.expand_dims(position, axis=tuple(range(2, inputs.ndim)))
sinusoid_inp = position / timescale
sin = jnp.sin(sinusoid_inp)
cos = jnp.cos(sinusoid_inp)
sign = jnp.sign(jnp.mod(jnp.arange(embedding_dim, dtype=jnp.int32), 2) - 0.5)
outputs = inputs * cos + inputs_shifted * sin * sign
return outputs
dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))
class MultiHeadAttention(nn.Module):
"""Multi-head dot-product attention.
Attributes:
num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
should be divisible by the number of heads.
num_gqa_groups: number of kv attention heads
head_dim: dimension of each head.
dtype: the data type used to allocate the initial parameters (default: float32).
dropout_rate: dropout rate
kernel_init: initializer for the kernel of the Dense layers.
float32_logits: bool, if True then compute logits in float32 to avoid
numerical issues with bfloat16.
"""
num_heads: int = 8
num_gqa_groups: int | None = None
head_dim: int = 64
transpose_batch_sequence: bool = True
dtype: DType = jnp.float32
dropout_rate: float = 0.0
kernel_init: Initializer = None
float32_logits: bool = False # computes logits in float32 for stability.
scale_attn_logits: bool = False
scaled_query_init: bool = True
enable_rotary_pos_emb: bool = False
rotary_pos_emb_group_method: str = "consecutive"
fuse_qkv: bool = True
use_bias: bool = False
def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "normal", dtype=self.dtype
)
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
super().__post_init__()
@nn.compact
def __call__(
self,
inputs_q: Array,
inputs_kv: Array,
mask: Optional[Array] = None,
bias: Optional[Array] = None,
*,
decode: bool = False,
deterministic: bool = False,
) -> Array:
"""Applies multi-head dot product attention on the input data.
Projects the inputs into multi-headed query, key, and value vectors,
applies dot-product attention and project the results to an output vector.
There are two modes: decoding and non-decoding (e.g., training). The mode is
determined by `decode` argument. For decoding, this method is called twice,
first to initialize the cache and then for an actual decoding process. The
two calls are differentiated by the presence of 'cached_key' in the variable
dict. In the cache initialization stage, the cache variables are initialized
as zeros and will be filled in the subsequent decoding process.
In the cache initialization call, `inputs_q` has a shape [batch, length,
q_features] and `inputs_kv`: [batch, length, kv_features]. During the
incremental decoding stage, query, key and value all have the shape [batch,
1, qkv_features] corresponding to a single step.
Args:
inputs_q: input queries of shape `[batch, q_length, q_features]`.
inputs_kv: key/values of shape `[batch, kv_length, kv_features]`.
mask: attention mask of shape `[batch, num_heads, q_length, kv_length]`.
bias: attention bias of shape `[batch, num_heads, q_length, kv_length]`.
decode: Whether to prepare and use an autoregressive cache.
deterministic: Disables dropout if set to True.
Returns:
output of shape `[batch, length, q_features]`.
"""
q_projection = functools.partial(
DenseGeneral,
axis=-1,
features=self.num_heads * self.head_dim,
kernel_axes=("embed", "joined_kv"),
use_bias=self.use_bias,
bias_axes="joined_kv",
dtype=self.dtype,
)
kv_projection = functools.partial(
DenseGeneral,
axis=-1,
features=self.num_gqa_groups * self.head_dim,
kernel_axes=("embed", "joined_kv"),
use_bias=self.use_bias,
bias_axes="joined_kv",
dtype=self.dtype,
)
# NOTE: T5 does not explicitly rescale the attention logits by
# 1/sqrt(depth_kq)! This is folded into the initializers of the
# linear transformations, which is equivalent under Adafactor
depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
query_init = lambda *args: self.kernel_init(*args) / (
depth_scaling if self.scaled_query_init else 1.0
)
# Project inputs_q to multi-headed q/k/v
# dimensions are then [batch, length, num_heads, head_dim]
def qkv_init(key, shape, dtype):
assert shape[-1] % 3 == 0
q_shape = (shape[0], shape[1] // 3)
k_shape = (shape[0], shape[1] // 3)
v_shape = (shape[0], shape[1] // 3)
q_kernel = query_init(key, q_shape, dtype)
k_kernel = self.kernel_init(key, k_shape, dtype)
v_kernel = self.kernel_init(key, v_shape, dtype)
return jnp.concatenate([q_kernel, k_kernel, v_kernel], axis=-1, dtype=dtype)
is_self_attn = inputs_q is inputs_kv
is_gqa = self.num_heads != self.num_gqa_groups
is_qkvpack = is_self_attn and not is_gqa
if self.fuse_qkv:
if is_qkvpack:
qkv_proj = DenseGeneral(
axis=-1,
features=self.num_heads * self.head_dim * 3,
kernel_axes=("embed", "joined_kv"),
kernel_init=qkv_init,
use_bias=self.use_bias,
bias_axes="joined_kv",
name="qkv",
dtype=self.dtype,
)(inputs_kv)
query, key, value = jnp.split(
qkv_proj,
[self.num_heads * self.head_dim, self.num_heads * self.head_dim * 2],
axis=-1,
)
else:
query = q_projection(kernel_init=query_init, name="query")(inputs_q)
kv_proj = DenseGeneral(
axis=-1,
features=self.num_gqa_groups * self.head_dim * 2,
kernel_axes=("embed", "joined_kv"),
kernel_init=self.kernel_init,
use_bias=self.use_bias,
bias_axes="joined_kv",
name="kv",
dtype=self.dtype,
)(inputs_kv)
key, value = jnp.split(kv_proj, [self.num_gqa_groups * self.head_dim], axis=-1)
else:
query = q_projection(kernel_init=query_init, name="query")(inputs_q)
key = kv_projection(kernel_init=self.kernel_init, name="key")(inputs_kv)
value = kv_projection(kernel_init=self.kernel_init, name="value")(inputs_kv)
if self.enable_rotary_pos_emb:
batch_dim = 1 if self.transpose_batch_sequence else 0
seq_dim = 1 - batch_dim
q_position = jnp.expand_dims(jnp.arange(query.shape[seq_dim]), axis=batch_dim)
k_position = jnp.expand_dims(jnp.arange(query.shape[seq_dim]), axis=batch_dim)
if self.rotary_pos_emb_group_method == "alternate":
apply_rotary_pos_emb = apply_rotary_pos_emb_alternate
else:
apply_rotary_pos_emb = apply_rotary_pos_emb_consecutive
query = query.reshape((*query.shape[:2], self.num_heads, self.head_dim))
key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
query = apply_rotary_pos_emb(query, q_position)
key = apply_rotary_pos_emb(key, k_position)
query = query.reshape((*query.shape[:2], self.num_heads, self.head_dim))
key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
if self.transpose_batch_sequence:
query = nn_partitioning.with_sharding_constraint(
query, ("length", "batch", "heads", "kv")
)
key = nn_partitioning.with_sharding_constraint(key, ("length", "batch", "heads", "kv"))
value = nn_partitioning.with_sharding_constraint(
value, ("length", "batch", "heads", "kv")
)
else:
query = nn_partitioning.with_sharding_constraint(
query, ("batch", "length", "heads", "kv")
)
key = nn_partitioning.with_sharding_constraint(key, ("batch", "length", "heads", "kv"))
value = nn_partitioning.with_sharding_constraint(
value, ("batch", "length", "heads", "kv")
)
if decode:
# Detect if we're initializing by absence of existing cache data.
is_initialized = self.has_variable("cache", "cached_key")
# The key and value have dimension [batch, length, num_heads, head_dim],
# but we cache them as [batch, num_heads, head_dim, length] as a TPU
# fusion optimization. This also enables the "scatter via one-hot
# broadcast" trick, which means we do a one-hot broadcast instead of a
# scatter/gather operations, resulting in a 3-4x speedup in practice.
swap_dims = lambda x: x[:-3] + tuple(x[i] for i in [-2, -1, -3])
cached_key = self.variable(
"cache", "cached_key", jnp.zeros, swap_dims(key.shape), key.dtype
)
cached_value = self.variable(
"cache", "cached_value", jnp.zeros, swap_dims(value.shape), value.dtype
)
cache_index = self.variable(
"cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)
)
if is_initialized:
batch, num_heads, head_dim, length = cached_key.value.shape
# During fast autoregressive decoding, we feed one position at a time,
# and cache the keys and values step by step.
# Sanity shape check of cached key against input query.
expected_shape = (batch, 1, num_heads, head_dim)
if expected_shape != query.shape:
raise ValueError(
"Autoregressive cache shape error, "
f"expected query shape {expected_shape} instead got {query.shape}."
)
# Create a OHE of the current index. NOTE: the index is increased below.
cur_index = cache_index.value
one_hot_indices = jax_nn.one_hot(cur_index, length, dtype=key.dtype)
# In order to update the key, value caches with the current key and
# value, we move the length axis to the back, similar to what we did for
# the cached ones above.
# Note these are currently the key and value of a single position, since
# we feed one position at a time.
one_token_key = jnp.moveaxis(key, -3, -1)
one_token_value = jnp.moveaxis(value, -3, -1)
# Update key, value caches with our new 1d spatial slices.
# We implement an efficient scatter into the cache via one-hot
# broadcast and addition.
key = cached_key.value + one_token_key * one_hot_indices
value = cached_value.value + one_token_value * one_hot_indices
cached_key.value = key
cached_value.value = value
cache_index.value = cache_index.value + 1
# Move the keys and values back to their original shapes.
key = jnp.moveaxis(key, -1, -3)
value = jnp.moveaxis(value, -1, -3)
# Causal mask for cached decoder self-attention: our single query
# position should only attend to those key positions that have already
# been generated and cached, not the remaining zero elements.
mask = combine_masks(
jnp.logical_not(mask),
jnp.broadcast_to(
jnp.arange(length) <= cur_index,
# (1, 1, length) represent (head dim, query length, key length)
# query length is 1 because during decoding we deal with one
# index.
# The same mask is applied to all batch elements and heads.
(batch, 1, 1, length),
),
)
# Grab the correct relative attention bias during decoding. This is
# only required during single step decoding.
if bias is not None:
# The bias is a full attention matrix, but during decoding we only
# have to take a slice of it.
# This is equivalent to bias[..., cur_index:cur_index+1, :].
bias = dynamic_vector_slice_in_dim(
jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2
)
# Convert the boolean attention mask to an attention bias.
if mask is not None:
# attention mask in the form of attention bias
attention_bias = lax.select(
mask > 0,
jnp.full(mask.shape, 0.0).astype(self.dtype),
jnp.full(mask.shape, -1e10).astype(self.dtype),
)
else:
attention_bias = None
# Add provided bias term (e.g. relative position embedding).
if bias is not None:
attention_bias = combine_biases(attention_bias, bias)
# Apply attention.
x = DotProductAttention(
transpose_batch_sequence=self.transpose_batch_sequence,
scale_attn_logits=self.scale_attn_logits,
dropout_rate=self.dropout_rate,
dtype=self.dtype,
float32_logits=self.float32_logits,
)(query, key, value, bias=attention_bias, deterministic=deterministic)
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
if self.transpose_batch_sequence:
x = nn_partitioning.with_sharding_constraint(x, ("length", "batch", "joined_kv"))
else:
x = nn_partitioning.with_sharding_constraint(x, ("batch", "length", "joined_kv"))
# Back to the original inputs dimensions.
out = DenseGeneral(
features=inputs_q.shape[-1], # output dim is set to the input dim.
axis=-1,
kernel_init=self.kernel_init,
kernel_axes=("joined_kv", "embed"),
use_bias=self.use_bias,
bias_axes="embed",
dtype=self.dtype,
name="out",
)(x)
assert (
inputs_q.dtype == inputs_kv.dtype == out.dtype
), f"q.dtype={inputs_q.dtype}, kv.dtype={inputs_kv.dtype}, out.dtype={out.dtype}"
return out
class LayerNorm(nn.Module):
"""T5 Layer normalization operating on the last axis of the input data."""
epsilon: float = 1e-6
dtype: Any = jnp.float32
layernorm_type: str = "layernorm"
zero_centered_gamma: bool = False
scale_init: Initializer = None
bias_init: Initializer = nn.initializers.zeros
def __post_init__(self):
if self.scale_init is None:
if not self.zero_centered_gamma:
self.scale_init = nn.initializers.ones
else:
self.scale_init = nn.initializers.zeros
super().__post_init__()
@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
"""Applies layer normalization on the input."""
input_dtype = x.dtype
features = x.shape[-1]
scale = nn_partitioning.param_with_axes(
"scale", self.scale_init, (features,), self.dtype, axes=("embed",)
)
scale = jnp.asarray(scale, input_dtype)
if self.layernorm_type == "layernorm":
mean = jnp.mean(x, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(x - mean), axis=-1, keepdims=True)
y = (x - mean) * lax.rsqrt(var + self.epsilon)
bias = nn_partitioning.param_with_axes(
"ln_bias", self.bias_init, (features,), self.dtype, axes=("embed",)
)
bias = jnp.asarray(bias, input_dtype)
if not self.zero_centered_gamma:
z = y * scale + bias
else:
z = y * (scale + 1.0) + bias
else:
assert self.layernorm_type == "rmsnorm"
assert not self.zero_centered_gamma
mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
y = x * lax.rsqrt(mean2 + self.epsilon)
z = y * scale
assert z.dtype == x.dtype, f"output_dtype={z.dtype}, input_dtype={x.dtype}"
return z
class RelativePositionBiases(nn.Module):
"""Adds T5-style relative positional embeddings to the attention logits.
Attributes:
num_buckets: Number of buckets to bucket distances between key and query
positions into.
max_distance: Maximum distance before everything is lumped into the last
distance bucket.
num_heads: Number of heads in the attention layer. Each head will get a
different relative position weighting.
dtype: the data type used to allocate the initial parameters (default: float32).
embedding_init: initializer for relative embedding table.
"""
num_buckets: int
max_distance: int
num_heads: int
dtype: Any
embedding_init: Callable[..., Array] = nn.linear.default_embed_init
@staticmethod
def _relative_position_bucket(
relative_position, bidirectional=True, num_buckets=32, max_distance=128
):
"""Translate relative position to a bucket number for relative attention.
The relative position is defined as memory_position - query_position, i.e.
the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are
invalid.
We use smaller buckets for small absolute relative_position and larger
buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative
positions <=-max_distance map to the same bucket. This should allow for
more graceful generalization to longer sequences than the model has been
trained on.
Args:
relative_position: an int32 array
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position, containing int32
values in the range [0, num_buckets)
"""
ret = 0
n = -relative_position
if bidirectional:
num_buckets //= 2
ret += (n < 0).astype(np.int32) * num_buckets
n = np.abs(n)
else:
n = np.maximum(n, 0)
# now n is in the range [0, inf)
max_exact = num_buckets // 2
is_small = n < max_exact
val_if_large = max_exact + (
np.log(n.astype(np.float32) / max_exact + np.finfo(np.float32).eps)
/ np.log(max_distance / max_exact)
* (num_buckets - max_exact)
).astype(np.int32)
val_if_large = np.minimum(val_if_large, num_buckets - 1)
ret += np.where(is_small, n, val_if_large)
return ret
@nn.compact
def __call__(self, qlen, klen, bidirectional=True):
"""Produce relative position embedding attention biases.
Args:
qlen: attention query length.
klen: attention key length.
bidirectional: whether to allow positive memory-query relative position
embeddings.
Returns:
output: `(1, len, q_len, k_len)` attention bias
"""
context_position = np.arange(qlen, dtype=jnp.int32)[:, None]
memory_position = np.arange(klen, dtype=jnp.int32)[None, :]
relative_position = memory_position - context_position # shape (qlen, klen)
rp_bucket = self._relative_position_bucket(
relative_position,
bidirectional=bidirectional,
num_buckets=self.num_buckets,
max_distance=self.max_distance,
)
relative_attention_bias = nn_partitioning.param_with_axes(
"rel_embedding",
self.embedding_init,
(self.num_heads, self.num_buckets),
jnp.float32,
axes=("heads", "relpos_buckets"),
)
relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)
# Instead of using a slow gather, we create a leading-dimension one-hot
# array from rp_bucket and use it to perform the gather-equivalent via a
# contraction, i.e.:
# (num_head, num_buckets) x (num_buckets one-hot, qlen, klen).
# This is equivalent to relative_attention_bias[:, rp_bucket]
bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0)
rp_bucket_one_hot = jnp.array(rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype)
# --> shape (qlen, klen, num_heads)
values = lax.dot_general(
relative_attention_bias,
rp_bucket_one_hot,
(((1,), (0,)), ((), ())), # rhs, lhs contracting dims
) # no batched dims
# Add a singleton batch dimension.
# --> shape (1, num_heads, qlen, klen)
return values[jnp.newaxis, ...]
def apply_swa_mask(
attn_mask_type: str,
original_mask: Array,
window_size: Tuple[int, int] = (-1, -1),
) -> Array:
"""Apply the sliding window mask to a given mask"""
_attn_mask_type = canonicalize_attn_mask_type(attn_mask_type)
assert _attn_mask_type is not None
batch = original_mask.shape[0]
max_seqlen_q = original_mask.shape[-2]
max_seqlen_kv = original_mask.shape[-1]
pos_q = jnp.broadcast_to(jnp.arange(max_seqlen_q), (batch, max_seqlen_q))
pos_kv = jnp.broadcast_to(jnp.arange(max_seqlen_kv), (batch, max_seqlen_kv))
swa_mask = make_swa_mask(pos_q, pos_kv, window_size, original_mask.dtype)
# In swa_mask and original_mask 0 is masked out
new_mask = jnp.where(original_mask == 1, swa_mask, original_mask)
return new_mask
class EncoderLayer(nn.Module):
"""Transformer encoder layer."""
enable_relative_embedding: bool = True
relative_embedding: nn.Module = None
num_attention_heads: int = 8
num_gqa_groups: int | None = None
head_dim: int = 64
hidden_dropout: float = 0.1
hidden_dropout_dims: Sequence[int] = ()
attention_dropout: float = 0.1
intermediate_dropout: float = 0.1
intermediate_dropout_dims: Sequence[int] = ()
transpose_batch_sequence: bool = True
float32_attention_logits: bool = False
scale_attn_logits: bool = False
scaled_query_init: bool = True
mlp_dim: int = 2048
mlp_activations: Sequence[str] = ("relu",)
use_bias: bool = False
dtype: Any = jnp.float32
apply_residual_connection_post_layernorm: bool = False
layernorm_type: str = "layernorm"
layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False
output_layernorm: bool = False
drop_path: float = 0.0
enable_rotary_pos_emb: bool = False
rotary_pos_emb_group_method: str = "consecutive"
fuse_qkv_params: bool = True
fuse_mlp_wi: bool = True
self_attn_bias_type: Any = None
self_attn_mask_type: str = "no_mask"
window_size: Tuple[int, int] = (-1, -1)
def __post_init__(self):
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
super().__post_init__()
@nn.compact
def __call__(self, inputs, encoder_mask=None, deterministic=False):
# Currently cuDNN backend only supports SWA for causal/padding_causal, follow this
encoder_mask = apply_swa_mask(
self.self_attn_mask_type,
encoder_mask,
self.window_size,
)
# Relative position embedding as attention biases.
sequence_dim = 0 if self.transpose_batch_sequence else 1
batch_dim = 1 - sequence_dim
if self.enable_relative_embedding:
if self.relative_embedding is None:
rel_emb = RelativePositionBiases(
num_buckets=32,
max_distance=128,
num_heads=self.num_attention_heads,
dtype=self.dtype,
embedding_init=nn.initializers.variance_scaling(1.0, "fan_avg", "uniform"),
name="relpos_bias",
)
else:
rel_emb = self.relative_embedding
encoder_bias = rel_emb(inputs.shape[sequence_dim], inputs.shape[sequence_dim], True)
else:
encoder_bias = None
# Attention block.
residual = inputs
if not self.output_layernorm:
# Attention block.
x = LayerNorm(
layernorm_type=self.layernorm_type,
epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype,
name="pre_attention_layer_norm",
)(inputs)
if self.apply_residual_connection_post_layernorm:
residual = x
else:
x = inputs
# [batch, length, emb_dim] -> [batch, length, emb_dim]
x = MultiHeadAttention(
num_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
dtype=self.dtype,
head_dim=self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
dropout_rate=self.attention_dropout,
float32_logits=self.float32_attention_logits,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
fuse_qkv=self.fuse_qkv_params,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
use_bias=self.use_bias,
name="attention",
)(x, x, encoder_mask, encoder_bias, deterministic=deterministic)
x = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
x, deterministic=deterministic
)
if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
x = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
x, deterministic=deterministic
)
x = x + residual
# MLP block.
residual = x
y = LayerNorm(
layernorm_type=self.layernorm_type,
epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype,
name="pre_mlp_layer_norm",
)(x)
if self.apply_residual_connection_post_layernorm:
residual = y
# [batch, length, emb_dim] -> [batch, length, emb_dim]
y = MlpBlock(
transpose_batch_sequence=self.transpose_batch_sequence,
intermediate_dim=self.mlp_dim,
activations=self.mlp_activations,
intermediate_dropout_rate=self.intermediate_dropout,
intermediate_dropout_dims=self.intermediate_dropout_dims,
use_bias=self.use_bias,
dtype=self.dtype,
fuse_wi=self.fuse_mlp_wi,
name="mlp",
)(y, deterministic=deterministic)
y = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
y, deterministic=deterministic
)
if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(y.shape, batch_dim)
y = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
y, deterministic=deterministic
)
y = y + residual
if self.output_layernorm:
y = LayerNorm(
layernorm_type=self.layernorm_type,
epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype,
name="output_layernorm",
)(y)
assert y.dtype == inputs.dtype, f"output_dtype={y.dtype}, input_dtype={inputs.dtype}"
return y
class DecoderLayer(nn.Module):
"""Transformer decoder layer that attends to the encoder."""
enable_relative_embedding: bool = True
relative_embedding: nn.Module = None
num_attention_heads: int = 8
num_gqa_groups: int | None = None
head_dim: int = 64
hidden_dropout: float = 0.1
hidden_dropout_dims: Sequence[int] = ()
attention_dropout: float = 0.1
intermediate_dropout: float = 0.1
intermediate_dropout_dims: Sequence[int] = ()
transpose_batch_sequence: bool = True
float32_attention_logits: bool = False
scale_attn_logits: bool = False
scaled_query_init: bool = True
mlp_dim: int = 2048
mlp_activations: Sequence[str] = ("relu",)
use_bias: bool = False
dtype: Any = jnp.float32
apply_residual_connection_post_layernorm: bool = False
output_layernorm: bool = False
layernorm_type: str = "layernorm"
layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False
drop_path: float = 0.0
enable_rotary_pos_emb: bool = False
rotary_pos_emb_group_method: str = "consecutive"
fuse_qkv_params: bool = True
fuse_mlp_wi: bool = True
self_attn_bias_type: Any = None
self_attn_mask_type: str = "no_mask"
window_size: Tuple[int, int] = (-1, -1)
def __post_init__(self):
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
super().__post_init__()
@nn.compact
def __call__(
self,
inputs,
encoded,
decoder_mask=None,
encoder_decoder_mask=None,
deterministic=False,
decode=False,
max_decode_length=None,
):
decoder_mask = apply_swa_mask(
self.self_attn_mask_type,
decoder_mask,
self.window_size,
)
encoder_decoder_mask = apply_swa_mask(
"padding",
encoder_decoder_mask,
self.window_size,
)
# Relative position embedding as attention biases.
sequence_dim = 0 if self.transpose_batch_sequence else 1
batch_dim = 1 - sequence_dim
if self.enable_relative_embedding:
l = max_decode_length if decode and max_decode_length else inputs.shape[sequence_dim]
if self.relative_embedding is None:
rel_emb = RelativePositionBiases(
num_buckets=32,
max_distance=128,
num_heads=self.num_attention_heads,
dtype=self.dtype,
embedding_init=nn.initializers.variance_scaling(1.0, "fan_avg", "uniform"),
name="relpos_bias",
)
else:
rel_emb = self.relative_embedding
decoder_bias = rel_emb(l, l, False)
else:
decoder_bias = None
# inputs: embedded inputs to the decoder with shape [batch, length, emb_dim]
residual = inputs
if not self.output_layernorm:
# Attention block.
x = LayerNorm(
layernorm_type=self.layernorm_type,
epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype,
name="pre_self_attention_layer_norm",
)(inputs)
if self.apply_residual_connection_post_layernorm:
residual = x
else:
x = inputs
# Self-attention block
x = MultiHeadAttention(
num_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
dtype=self.dtype,
head_dim=self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
dropout_rate=self.attention_dropout,
float32_logits=self.float32_attention_logits,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
fuse_qkv=self.fuse_qkv_params,
use_bias=self.use_bias,
name="self_attention",
)(x, x, decoder_mask, decoder_bias, deterministic=deterministic, decode=decode)
x = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
x, deterministic=deterministic
)
if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
x = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
x, deterministic=deterministic
)
x = x + residual
# Encoder-Decoder block.
residual = x
y = LayerNorm(
layernorm_type=self.layernorm_type,
epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype,
name="pre_cross_attention_layer_norm",
)(x)
if self.apply_residual_connection_post_layernorm:
residual = y
y = MultiHeadAttention(
num_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
dtype=self.dtype,
head_dim=self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
dropout_rate=self.attention_dropout,
float32_logits=self.float32_attention_logits,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
fuse_qkv=self.fuse_qkv_params,
use_bias=self.use_bias,
name="encoder_decoder_attention",
)(y, encoded, encoder_decoder_mask, deterministic=deterministic)
y = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
y, deterministic=deterministic
)
y = y + residual
# MLP block.
residual = y
z = LayerNorm(
layernorm_type=self.layernorm_type,
epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype,
name="pre_mlp_layer_norm",
)(y)
if self.apply_residual_connection_post_layernorm:
residual = z
z = MlpBlock(
transpose_batch_sequence=self.transpose_batch_sequence,
intermediate_dim=self.mlp_dim,
activations=self.mlp_activations,
intermediate_dropout_rate=self.intermediate_dropout,
intermediate_dropout_dims=self.intermediate_dropout_dims,
use_bias=self.use_bias,
dtype=self.dtype,
fuse_wi=self.fuse_mlp_wi,
name="mlp",
)(z, deterministic=deterministic)
z = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
z, deterministic=deterministic
)
if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(z.shape, batch_dim)
z = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
z, deterministic=deterministic
)
z = z + residual
if self.output_layernorm:
z = LayerNorm(
layernorm_type=self.layernorm_type,
epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype,
name="output_layernorm",
)(z)
assert z.dtype == inputs.dtype, f"output_dtype={z.dtype}, input_dtype={inputs.dtype}"
return z
def make_causal_mask(batch, seqlen, dtype=jnp.uint8):
"""
Generate causal mask
"""
shape = (batch, seqlen)
idxs = jnp.broadcast_to(jnp.arange(shape[-1], dtype=jnp.int32), shape)
mask = jnp.greater_equal(jnp.expand_dims(idxs, axis=-1), jnp.expand_dims(idxs, axis=-2))
mask = jnp.expand_dims(mask, axis=-3)
mask = 1 - mask
return mask.astype(dtype)
def make_self_mask(batch, seqlen, dtype=jnp.uint8):
"""
Generate attention mask
"""
shape = (batch, seqlen)
mask = jnp.ones((*shape, shape[-1]))
mask = jnp.expand_dims(mask, axis=-3)
mask = 1 - mask
return mask.astype(dtype)
def assert_allclose(
actual: Array,
desired: Array,
rtol: Optional[float] = None,
atol: Optional[float] = None,
dtype: Optional[Union[DType, TEDType, np.dtype, str]] = None,
**kwargs,
) -> None:
"""Check if two tensors are close.
Args:
actual: test tensor.
desired: reference tensor.
dtype: data type or data type name (default: inferred from
`actual`).
rtol: relative tolerance (default: based on `dtype`).
atol: absolute tolerance (default: based on `dtype`).
**kwargs: keyword arguments to pass to np.testing.assert_allclose.
"""
# Infer data type if needed
if dtype is None:
if isinstance(actual, float):
dtype = "float32"
else:
dtype = actual.dtype
# Determine tolerances
tols = {}
if rtol is None or atol is None:
tols = dtype_tols(dtype)
if rtol is not None:
tols["rtol"] = rtol
if atol is not None:
tols["atol"] = atol
# Cast tensors to fp32
if not isinstance(actual, float):
actual = actual.astype(jnp.float32)
if not isinstance(desired, float):
desired = desired.astype(jnp.float32)
# Check if tensors are close
np.testing.assert_allclose(actual, desired, **tols, **kwargs)
def assert_tree_like_allclose(expected, actual, rtol=1e-05, atol=1e-08):
flatten_expected, _ = jax.tree_util.tree_flatten_with_path(expected)
flatten_actual, _ = jax.tree_util.tree_flatten_with_path(actual)
for (expected_path, expected_value), (actual_path, actual_value) in zip(
flatten_expected, flatten_actual
):
assert expected_path == actual_path
key_str = jax.tree_util.keystr(expected_path)
assert_allclose(
expected_value,
actual_value,
rtol=rtol,
atol=atol,
err_msg=f"Value of expected{key_str} and actual{key_str} is not close",
)
def dtype_tols(
dtype: Union[DType, TEDType, np.dtype],
reference_value: float = 1.0,
rtol: Optional[float] = None,
atol: Optional[float] = None,
) -> Dict[str, float]:
"""Expected numerical tolerance for a data type.
Args:
dtype: data type.
reference_value: reference value (default: 1).
rtol: override for relative tolerance estimate
atol: override for absolute tolerance estimate
Returns:
Dictionary with "rtol" and "atol" as keys
"""
# Return immediately if tolerances are fully specified
if rtol is not None and atol is not None:
return {"rtol": rtol, "atol": atol}
# Convert to JAX dtype if needed
if isinstance(dtype, TEDType):
dtype = {
TEDType.kByte: jnp.uint8,
TEDType.kInt32: jnp.int32,
TEDType.kInt64: jnp.int64,
TEDType.kFloat32: jnp.float32,
TEDType.kFloat16: jnp.float16,
TEDType.kBFloat16: jnp.bfloat16,
TEDType.kFloat8E4M3: jnp.float8_e4m3fn,
TEDType.kFloat8E5M2: jnp.float8_e5m2,
}[dtype]
elif isinstance(dtype, np.dtype):
dtype = jnp.dtype(dtype)
# Expect bit-wise accuracy for integer dtypes
if not jnp.issubdtype(dtype, jnp.floating):
if rtol is None:
rtol = 0.0
if atol is None:
atol = 0.0
return {"rtol": rtol, "atol": atol}
# Estimate floating-point error
finfo = jnp.finfo(dtype)
eps_relaxed = math.pow(finfo.eps, 2 / 3)
with jax.default_device(jax.devices("cpu")[0]):
if isinstance(reference_value, (float, int)):
reference_value = jnp.array(reference_value, dtype=dtype)
else:
reference_value = reference_value.astype(dtype)
spacing_high = jnp.nextafter(reference_value, finfo.max) - reference_value
spacing_low = reference_value - jnp.nextafter(reference_value, finfo.min)
ulp = max(spacing_high.item(), spacing_low.item())
if rtol is None:
rtol = eps_relaxed
if atol is None:
atol = max(ulp, eps_relaxed)
return {"rtol": rtol, "atol": atol}
def sync_params_values(dst, src, transformations, sep="/"):
"""
This function will reconstuct a tree with dst's tree_def/shape and src's value.
transformations is a map that records the key mappings between dst and src.
If no dst key found in the transformerations, it will fall back to src key = dst key.
transformations = {
dst key map 0: src key map 0,
dst key map 1: src key map 1,
...
# if dst key = src key, we don't need to add it
}
"""
src_values = {}
for key, value in jax.tree_util.tree_leaves_with_path(src):
normalized_key = sep.join(x.key for x in key)
src_values[normalized_key] = value
flatten_dst, dst_tree_def = jax.tree_util.tree_flatten_with_path(dst)
synced_dst_values = []
for key, value in flatten_dst:
normalized_key = sep.join(x.key for x in key)
if normalized_key in transformations:
corresponding_src_key = transformations[normalized_key]
else:
corresponding_src_key = normalized_key
synced_dst_values.append(src_values[corresponding_src_key])
synced_dst = jax.tree_util.tree_unflatten(dst_tree_def, synced_dst_values)
return jax.tree_util.tree_map(lambda x, y: x.reshape(y.shape), synced_dst, dst)
@functools.partial(jax.jit, static_argnums=[0, 2])
def print_debug_tensor_stats(prefix, tensor, hist=False):
if NVTE_DEBUG_NUMERICS:
args = [
jnp.mean(tensor),
jnp.min(tensor),
jnp.max(tensor),
jnp.cumprod(jnp.array(tensor.shape))[-1] if len(tensor.shape) >= 1 else 1,
jnp.count_nonzero(tensor),
]
fmt = prefix + " mean={}, min={}, max={}, numel={}, nzcnt={}"
if hist:
h = jnp.histogram(tensor.astype(jnp.float32), bins=10)
args += [h[0], h[1]]
fmt = fmt + "\n {}\n {}"
jax.debug.print(fmt, *args)
#!/usr/bin/python3
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import sys
import argparse
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import nn, optim
from torch.distributed import DeviceMesh
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed.device_mesh import init_device_mesh
from contextlib import nullcontext
class SimpleNet(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(SimpleNet, self).__init__()
self.fc1 = te.Linear(input_size, hidden_size)
self.fc2 = te.Linear(hidden_size, output_size)
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
def save_custom_attrs(module):
custom_attrs = {}
for name, param in module.named_parameters():
attrs = vars(param)
custom_attrs[name] = {k: v for k, v in attrs.items()}
return custom_attrs
def restore_custom_attrs(module, custom_attrs):
for name, param in module.named_parameters():
if name in custom_attrs:
for attr_name, attr_value in custom_attrs[name].items():
setattr(param, attr_name, attr_value)
def _parse_args(argv=None, namespace=None):
parser = argparse.ArgumentParser(description="Toy example for debugging fully_shard()")
parser.add_argument("--input-size", type=int, default=2048, help="Input size for the model")
parser.add_argument("--hidden-size", type=int, default=2048, help="Hidden layer size")
parser.add_argument("--output-size", type=int, default=2048, help="Output size for the model")
parser.add_argument("--batch-size", type=int, default=2048, help="Output size for the model")
parser.add_argument(
"--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8."
)
parser.add_argument(
"--iter", type=int, default=10, help="Number of iterations for forward pass"
)
parser.add_argument("--seed", type=int, default=42, help="RNG seed.")
# Adding hsdp_dim as a list argument, comma-separated
parser.add_argument(
"--sharding-dims",
type=int,
nargs="+",
help='FSDP/HSDP sharding dimensions ("replicate", "shard")',
)
args = parser.parse_args(argv, namespace)
if args.sharding_dims:
assert len(args.sharding_dims) <= 2
return args
sub_modules_to_wrap = [te.Linear]
def _train(args):
assert "TORCHELASTIC_RUN_ID" in os.environ
WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
assert LOCAL_SIZE == WORLD_SIZE
# Set device and initialize RNG states
torch.cuda.set_device(WORLD_RANK)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
# Initialize torch.distributed global process group and get DP/TP groups
dist_init_kwargs = {
"backend": "nccl",
"rank": WORLD_RANK,
"world_size": WORLD_SIZE,
}
assert dist.is_nccl_available()
dist.init_process_group(**dist_init_kwargs)
nccl_world = dist.new_group(backend="nccl")
device = torch.device(f"cuda:{LOCAL_RANK}")
# FP8 Configuration
fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
if not args.fp8_init:
# Build model context (FP8 init)
build_model_context = nullcontext
build_model_context_args = {}
from transformer_engine.pytorch import fp8_model_init
build_model_context = fp8_model_init
build_model_context_args["enabled"] = True
# Build the model with the specified context
with build_model_context(**build_model_context_args):
model = SimpleNet(args.input_size, args.hidden_size, args.output_size)
else:
model = SimpleNet(args.input_size, args.hidden_size, args.output_size)
# Move the model to the correct device
model.to(device)
if LOCAL_RANK == 0:
print(f"Rank {LOCAL_RANK}: Applying FSDP fully_shard() to the model...")
# Creating a DeviceMesh for fully_shard
world_size = int(WORLD_SIZE)
device_ids = list(range(world_size))
if LOCAL_RANK == 0:
print(f"sharding-dims:{args.sharding_dims}")
# Setup the sharding mesh for FSDP/HSDP
if args.sharding_dims == None: # FSDP
mesh = DeviceMesh("cuda", device_ids)
elif len(args.sharding_dims) == 1:
assert args.sharding_dims[0] == device_ids[-1] + 1
mesh = DeviceMesh("cuda", device_ids)
elif len(args.sharding_dims) == 2: # HSDP
assert args.sharding_dims[0] * args.sharding_dims[1] == device_ids[-1] + 1
mesh = init_device_mesh(
"cuda",
(args.sharding_dims[0], args.sharding_dims[1]),
mesh_dim_names=("replicate", "shard"),
)
else:
assert False
# Apply FSDP/HSDP
custom_attrs = save_custom_attrs(model)
for sub_module in model.modules():
if any(
isinstance(sub_module, sub_module_to_wrap) for sub_module_to_wrap in sub_modules_to_wrap
):
fully_shard(sub_module, mesh=mesh)
fully_shard(model, mesh=mesh)
restore_custom_attrs(model, custom_attrs)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
for iteration in range(args.iter):
# Zero the parameter gradients
optimizer.zero_grad()
input_data = torch.randn(args.batch_size, args.input_size).to(device)
output = model(input_data)
target = torch.randn(args.batch_size, args.output_size).to(device)
loss = F.mse_loss(output, target)
loss.backward()
optimizer.step()
if LOCAL_RANK == 0:
print(f"Rank {LOCAL_RANK}: Iteration {iteration} completed.")
dist.destroy_process_group()
if LOCAL_RANK == 0:
print(f"Rank {LOCAL_RANK}: Done...")
return 0
if __name__ == "__main__":
sys.exit(_train(_parse_args()))
#!/usr/bin/python3
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import sys
import socket
import warnings
import subprocess
import argparse
import operator
from functools import partial, reduce
import torch
import torch.distributed as dist
from torch.distributed.elastic.multiprocessing.errors import record
import transformer_engine.pytorch as te
import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.module.base import get_cublas_workspace_size_bytes
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
torch_dtypes = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
nvte_comm_types = {
"rs": tex.CommOverlapType.RS,
"ag": tex.CommOverlapType.AG,
}
def _mapped_argtype(opt, typemap):
if str(opt).lower() not in typemap.keys():
raise TypeError(f"Unrecognized option! Please choose from: {typemap.keys()}")
return typemap[str(opt).lower()]
def _parse_args(argv=None, namespace=None):
parser = argparse.ArgumentParser(description="Test comm+GEMM overlap with Userbuffers.")
parser.add_argument("-b", "--batch-size", type=int, default=2, help="Input batch size.")
parser.add_argument("-s", "--seq-length", type=int, default=1024, help="Input sequence length.")
parser.add_argument(
"-n", "--num-heads", type=int, default=16, help="Number of attention heads."
)
parser.add_argument(
"-d", "--head-dim", type=int, default=48, help="Dimension of each attention head."
)
parser.add_argument("--seed", type=int, default=42, help="RNG seed.")
parser.add_argument(
"--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context."
)
parser.add_argument(
"--fp8-output", action="store_true", default=False, help="Get FP8 output from GEMM."
)
parser.add_argument(
"--p2p", action="store_true", default=False, help="Test overlap with P2P comms."
)
parser.add_argument(
"--atomic", action="store_true", default=False, help="Test overlap with atomic GEMM."
)
parser.add_argument(
"--aggregate",
action="store_true",
default=False,
help="Aggregate 2X chunks for P2P split pipelined all-gather.",
)
parser.add_argument(
"--comm-type",
type=partial(_mapped_argtype, typemap=nvte_comm_types),
default=tex.CommOverlapType.AG,
help="Comm type to overlap.",
)
parser.add_argument(
"--bulk-overlap",
action="store_true",
default=False,
help="Enable bulk AG or RS overlap for a tensor that is not involved in the GEMM compute.",
)
parser.add_argument(
"--check-numerics",
action="store_true",
default=False,
help="Test numerical result against torch.matmul(...)",
)
parser.add_argument(
"--warmup-iters",
type=int,
default=0,
help="Run some warmup iterations of the comm+GEMM overlap before " + "the timing runs.",
)
parser.add_argument(
"--timing-iters",
type=int,
default=1,
help="Benchmark the comm+GEMM overlap as an average of many iterations.",
)
parser.add_argument(
"--clock-speed",
type=int,
default=-1,
help="Set device clock speed to a fixed value via `nvidia-smi`.",
)
parser.add_argument(
"--std", type=float, default=0.023, help="Standard deviation for input and weight tensors."
)
parser.add_argument(
"--tcp-init",
action="store_true",
default=False,
help="Initialize torch.distributed with TcpStore.",
)
parser.add_argument(
"--init-method", type=str, default=None, help="Set the torch.distributed init method."
)
parser.add_argument(
"--bind-to-device",
action="store_true",
default=False,
help=(
"Initialize torch.distributed with 'device_id' argument to bind each rank to 1 device."
),
)
parser.add_argument(
"--bootstrap-backend",
type=str.lower,
default="nccl",
choices=["gloo", "mpi", "nccl"],
help=(
"PyTorch distributed backend for host tensor collectives during comm+GEMM overlap "
+ "initialization."
),
)
parser.add_argument(
"--use-cuda-graphs", action="store_true", default=False, help="Use CUDA graphs."
)
parser.add_argument(
"-v", "--verbose", action="store_true", default=False, help="Verbose info messages."
)
opts = parser.parse_args(argv, namespace)
if opts.bulk_overlap:
if opts.p2p:
warnings.warn("Point-2-point comms are not supported with bulk overlap.")
opts.p2p = False
if opts.atomic:
warnings.warn("Atomic GEMM is not supported with bulk overlap.")
opts.atomic = False
if opts.fp8:
warnings.warn("Bulk overlap is supported in FP8 but only tested in BF16.")
opts.fp8 = False
elif opts.comm_type == tex.CommOverlapType.AG:
if opts.atomic:
setattr(opts, "atomic_rs_p2p", opts.p2p)
opts.p2p = True
if opts.atomic:
if not te.fp8.check_fp8_support():
assert not opts.fp8, "Atomic GEMM is only supported in FP8."
opts.fp8 = True
return opts
@record
def _main(opts):
if "OMPI_COMM_WORLD_SIZE" in os.environ:
# Execution with `mpirun -np N`
WORLD_RANK = int(os.getenv("OMPI_COMM_WORLD_RANK", "0"))
WORLD_SIZE = int(os.getenv("OMPI_COMM_WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("OMPI_COMM_WORLD_LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("OMPI_COMM_WORLD_LOCAL_SIZE", "1"))
opts.tcp_init = True
opts.bootstrap_backend = "mpi"
else: # TORCHELASTIC, SLURM, etc...
WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", str(torch.cuda.device_count())))
result = subprocess.run(
"nvidia-smi -q | grep -m1 CliqueId | awk '{printf $3}'",
capture_output=True,
text=True,
shell=True,
)
if result.stdout == "0": # Extra checks for non-MNNVL platforms
assert WORLD_SIZE == LOCAL_SIZE
assert LOCAL_SIZE <= torch.cuda.device_count()
# Fix clock speed
torch.cuda.set_device(LOCAL_RANK)
if opts.clock_speed > 0:
subprocess.run(
["nvidia-smi", "-pm", "ENABLED", "-i", str(LOCAL_RANK)],
env=os.environ,
check=False,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
result = subprocess.run(
["nvidia-smi", "-lgc", str(opts.clock_speed), "-i", str(LOCAL_RANK)],
env=os.environ,
check=False,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
msg = result.stdout.decode("utf-8").splitlines()[0]
print(f"[rank:{LOCAL_RANK}] {msg}\n", end="", flush=True)
# Info printout
def dist_print(msg, src=None, info=False, error=False, section=False, group=None):
group = dist.new_group() if group is None else group
rank = dist.get_rank(group)
stream = sys.stderr if error else sys.stdout
if info or opts.verbose:
if section:
if rank == (0 if src is None else src):
stream.write("\n")
dist.barrier(group)
if src is None or rank == src:
prefix = "[GLOBAL] " if src is not None else f"[rank:{rank}] "
lines = msg.splitlines()
msg = "\n".join(
[prefix + lines[0]] + [(" " * len(prefix)) + line for line in lines[1:]]
)
stream.write(msg + "\n")
dist.barrier(group)
# Initialize torch.distributed global process group and get TP group
dist_init_kwargs = {
"backend": "nccl",
"rank": WORLD_RANK,
"world_size": WORLD_SIZE,
}
if opts.tcp_init:
if opts.init_method is not None:
assert opts.init_method.startswith("tcp://")
init_method = opts.init_method
else:
MASTER_ADDR = os.getenv("MASTER_ADDR", socket.gethostbyname(socket.gethostname()))
MASTER_PORT = os.getenv("MASTER_PORT", "1234")
init_method = f"tcp://{MASTER_ADDR}:{MASTER_PORT}"
dist_init_kwargs["init_method"] = init_method
elif opts.init_method is not None:
assert (
opts.init_method.startswith("env://")
or opts.init_method.startswith("file://")
or opts.init_method.startswith("tcp://")
)
dist_init_kwargs["init_method"] = opts.init_method
if opts.bind_to_device or opts.bootstrap_backend == "nccl":
dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}")
assert dist.is_nccl_available()
dist.init_process_group(**dist_init_kwargs)
tp_group = dist.new_group(backend="nccl")
tp_rank = dist.get_rank(tp_group)
tp_size = dist.get_world_size(tp_group)
dist_print(
f"Initialized default NCCL process group with {tp_size} GPUs",
src=0,
section=True,
info=True,
group=tp_group,
)
# Initialize backend used in bootstrapping Userbuffers
if opts.bootstrap_backend == "gloo":
assert dist.is_gloo_available()
elif opts.bootstrap_backend == "mpi":
assert dist.is_mpi_available()
bootstrap_pg = dist.new_group(backend=opts.bootstrap_backend)
dist_print(
f'Bootstrapping comm+GEMM overlap with backend="{opts.bootstrap_backend}"',
src=0,
section=True,
info=True,
group=bootstrap_pg,
)
if WORLD_RANK == 0:
print("\n", end="", flush=True)
helper = (
tex.CommOverlapHelper()
if tex.ubuf_built_with_mpi()
else tex.CommOverlapHelper(bootstrap_pg)
)
# Initialize userbuffers with (M, N) buffer
# M = sequence * batch
# N = hidden size
hidden_size = opts.num_heads * opts.head_dim
inp_shape = (opts.seq_length, opts.batch_size, hidden_size)
outer_size = reduce(operator.mul, inp_shape[:-1], 1)
buffer_dtype = torch.bfloat16
if opts.fp8 and not opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.AG:
buffer_dtype = torch.uint8
ub_obj = (
tex.CommOverlapP2P(
(outer_size, hidden_size),
buffer_dtype,
helper,
tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE)
opts.comm_type,
set_sm_margin=opts.comm_type == tex.CommOverlapType.RS or opts.atomic,
atomic_gemm=opts.atomic,
aggregate=opts.aggregate,
use_ce=not (opts.atomic and bool(int(os.getenv("NVTE_AG_P2P_MULTI_ATOMIC", "0")))),
)
if opts.p2p
else tex.CommOverlap(
(outer_size, hidden_size),
buffer_dtype,
helper,
tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE)
atomic_gemm=opts.atomic,
)
)
# Numerical check on AG + atomic GEMM requires testing an AG+RS pair
ub_obj2 = None
if opts.atomic and opts.comm_type == tex.CommOverlapType.AG and opts.check_numerics:
ub_obj2 = (
tex.CommOverlapP2P(
(outer_size, hidden_size),
torch.uint8 if opts.fp8_output else torch.bfloat16,
helper,
tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE)
tex.CommOverlapType.RS,
set_sm_margin=True,
atomic_gemm=True,
)
if opts.atomic_rs_p2p
else tex.CommOverlap(
(outer_size, hidden_size),
torch.uint8 if opts.fp8_output else torch.bfloat16,
helper,
tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE)
atomic_gemm=True,
)
)
# Figure out problem sizing:
# M = sequence * batch
# N = hidden size
# K = MLP intermediate size (usually 4x hidden size)
# P = number of devices for sequence/tensor parallelism
# NOTE: TE-GEMM is set up to work with a transposed kernels and non-transposed inputs.
ffn_hidden_size = 4 * hidden_size
if opts.bulk_overlap:
# Bulk overlap weight and input tensors are not relevant so they're globally sized
local_kernel_t_shape = (ffn_hidden_size, hidden_size)
local_inp_shape = (outer_size, hidden_size)
# Bulk overlap comm tensor is distributed for AG overlap only
if opts.comm_type == tex.CommOverlapType.AG:
bulk_inp_shape = (outer_size // tp_size, hidden_size)
else:
bulk_inp_shape = (outer_size, hidden_size)
else:
if opts.comm_type == tex.CommOverlapType.AG:
# (M/P, N) -> overlapped AG -> (M, N) x (K/P, N)^T = (M, K/P)
local_kernel_t_shape = (ffn_hidden_size // tp_size, hidden_size)
local_inp_shape = (outer_size // tp_size, hidden_size)
if ub_obj2 is not None:
local_kernel2_t_shape = (hidden_size, ffn_hidden_size // tp_size)
else:
# (M, K/P) x (N, K/P)^T = (M, N) -> overlapped RS -> (M/P, N)
local_kernel_t_shape = (hidden_size, ffn_hidden_size // tp_size)
local_inp_shape = (outer_size, ffn_hidden_size // tp_size)
# Initialize distributed input tensor and GEMM kernels
torch.manual_seed(opts.seed + tp_rank)
torch.cuda.manual_seed(opts.seed + tp_rank)
inp = torch.nn.init.normal_(
torch.empty(local_inp_shape, dtype=torch.bfloat16, device="cuda"),
mean=0.0,
std=opts.std,
)
kernel_t = torch.nn.init.normal_(
torch.empty(local_kernel_t_shape, dtype=torch.bfloat16, device="cuda"),
mean=0.0,
std=opts.std,
)
if ub_obj2 is not None:
kernel2_t = torch.nn.init.normal_(
torch.empty(local_kernel2_t_shape, dtype=torch.bfloat16, device="cuda"),
mean=0.0,
std=opts.std,
)
# Allocate cuBLAS workspace
workspace_size = 3 * get_cublas_workspace_size_bytes()
workspace = torch.empty(workspace_size, dtype=torch.uint8, device="cuda")
# Gather global tensors and calculate reference result (need these first for Fp8 scales)
if opts.bulk_overlap:
ker_g = torch.transpose(kernel_t, 0, 1)
inp_g = inp
bulk_inp = torch.nn.init.normal_(
torch.empty(bulk_inp_shape, dtype=torch.bfloat16, device="cuda"),
mean=0.0,
std=opts.std,
)
else:
if opts.comm_type == tex.CommOverlapType.AG:
# AG Kernel: (K/P, N) -> gather -> (K, N) -> T -> (N, K)
ker_g = torch.transpose(
te.distributed.gather_along_first_dim(kernel_t, tp_group)[0], 0, 1
).to(dtype=torch.float32)
# AG Input: (M/P, N) -> gather -> (M, N)
inp_g = te.distributed.gather_along_first_dim(inp, tp_group)[0].to(dtype=torch.float32)
if ub_obj2 is not None:
ker2_g = te.distributed.gather_along_first_dim(
torch.transpose(kernel2_t, 0, 1), tp_group
)[0].to(dtype=torch.float32)
else:
# RS Kernel: (N, K/P) -> T -> (K/P, N) -> gather -> (K, N)
ker_g = te.distributed.gather_along_first_dim(
torch.transpose(kernel_t, 0, 1), tp_group
)[0].to(dtype=torch.float32)
# RS Input: (M, K/P) -> T -> (K/P, M) -> gather -> (K, M) -> T -> (M, K)
inp_g = torch.transpose(
te.distributed.gather_along_first_dim(torch.transpose(inp, 0, 1), tp_group)[0], 0, 1
).to(dtype=torch.float32)
if opts.bulk_overlap:
if opts.comm_type == tex.CommOverlapType.AG:
ref_g = te.distributed.gather_along_first_dim(bulk_inp, tp_group)[0]
else:
# First all-gather all the bulk inputs into a list
bulk_inp_list = [torch.zeros_like(bulk_inp) for _ in range(tp_size)]
dist.all_gather(bulk_inp_list, bulk_inp, tp_group)
# Sum the list together for final global result
ref_g = torch.stack(bulk_inp_list).sum(dim=0)
else:
ref_g = torch.matmul(inp_g, ker_g)
if ub_obj2 is not None:
inp2_g = torch.nn.functional.gelu(ref_g) # pylint: disable=not-callable
ref2_g = torch.matmul(inp2_g, ker2_g)
inp_quantizer = None
ker_quantizer = None
out_quantizer = None
bulk_inp_quantizer = None
inp2_quantizer = None
ker2_quantizer = None
out2_quantizer = None
if opts.fp8:
# Structure to maintain amax and scale/scale_inv information for the kernel and input
num_gemms = 6 if ub_obj2 is not None else 3
fp8_dtype = tex.DType.kFloat8E4M3
fp8_scales = torch.ones(num_gemms, dtype=torch.float, device="cuda")
fp8_amaxes = torch.zeros(num_gemms, dtype=torch.float, device="cuda")
# Compute initial amaxes and scales
inp_amax = torch.max(torch.abs(inp_g))
fp8_amaxes[0].copy_(inp_amax)
ker_amax = torch.max(torch.abs(ker_g))
fp8_amaxes[1].copy_(ker_amax)
ref_amax = torch.max(torch.abs(ref_g))
fp8_amaxes[2].copy_(ref_amax)
if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS:
bulk_amax = torch.max(torch.abs(bulk_inp))
fp8_amaxes[5].copy_(bulk_amax)
elif ub_obj2 is not None:
inp2_amax = torch.max(torch.abs(inp2_g))
fp8_amaxes[3].copy_(inp2_amax)
ker2_amax = torch.max(torch.abs(ker2_g))
fp8_amaxes[4].copy_(ker2_amax)
ref2_amax = torch.max(torch.abs(ref2_g))
fp8_amaxes[5].copy_(ref2_amax)
inp_quantizer = Float8Quantizer(fp8_scales[0].clone(), fp8_amaxes[0].clone(), fp8_dtype)
ker_quantizer = Float8Quantizer(fp8_scales[1].clone(), fp8_amaxes[1].clone(), fp8_dtype)
if opts.fp8_output:
out_quantizer = Float8Quantizer(fp8_scales[2].clone(), fp8_amaxes[2].clone(), fp8_dtype)
if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS:
bulk_inp_quantizer = Float8Quantizer(
fp8_scales[5].clone(), fp8_amaxes[5].clone(), fp8_dtype
)
elif ub_obj2 is not None:
inp2_quantizer = Float8Quantizer(
fp8_scales[3].clone(), fp8_amaxes[3].clone(), fp8_dtype
)
ker2_quantizer = Float8Quantizer(
fp8_scales[4].clone(), fp8_amaxes[4].clone(), fp8_dtype
)
if opts.fp8_output:
out2_quantizer = Float8Quantizer(
fp8_scales[5].clone(), fp8_amaxes[5].clone(), fp8_dtype
)
# Cast input to Float8Tensor
inp_fp8 = inp_quantizer(inp)
# Cast kernel to Float8Tensor
kernel_t_fp8 = ker_quantizer(kernel_t)
if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS:
bulk_inp_fp8 = bulk_inp_quantizer(bulk_inp)
elif ub_obj2 is not None:
kernel2_t_fp8 = ker2_quantizer(kernel2_t)
# Make sure the inputs are cast correctly
if opts.check_numerics:
torch.allclose(
inp.to(dtype=torch.float32),
inp_fp8.dequantize(dtype=torch.float32),
rtol=0.125,
atol=0.0675,
)
torch.allclose(
kernel_t.to(dtype=torch.float32),
kernel_t_fp8.dequantize(dtype=torch.float32),
rtol=0.125,
atol=0.0675,
)
if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS:
torch.allclose(
bulk_inp.to(dtype=torch.float32),
bulk_inp_fp8.dequantize(dtype=torch.float32),
rtol=0.125,
atol=0.0675,
)
elif ub_obj2 is not None:
torch.allclose(
kernel2_t.to(dtype=torch.float32),
kernel2_t_fp8.dequantize(dtype=torch.float32),
rtol=0.125,
atol=0.0675,
)
# Set up comm/compute buffers
rs_out = None
rs_out2 = None
if opts.comm_type == tex.CommOverlapType.AG:
if opts.bulk_overlap:
ub_obj.copy_into_buffer(bulk_inp, bulk_inp_quantizer, True)
gemm_inp = inp
else:
ub_obj.copy_into_buffer(inp_fp8 if opts.fp8 else inp, inp_quantizer, True)
gemm_inp = ub_obj.get_buffer(inp_quantizer, False, inp_g.size())
if ub_obj2 is not None:
if opts.fp8 and opts.fp8_output:
ub_obj2.set_buffer_params(out_quantizer)
rs_out2 = torch.empty(
(outer_size // tp_size, hidden_size), dtype=torch.bfloat16, device="cuda"
)
else:
if opts.bulk_overlap:
ub_obj.copy_into_buffer(
bulk_inp_fp8 if opts.fp8 else bulk_inp, bulk_inp_quantizer, False
)
if opts.fp8:
ub_obj.set_buffer_params(bulk_inp_quantizer)
elif opts.fp8 and opts.fp8_output:
ub_obj.set_buffer_params(out_quantizer)
gemm_inp = inp_fp8 if opts.fp8 else inp
rs_out = torch.empty(
(outer_size // tp_size, hidden_size), dtype=torch.bfloat16, device="cuda"
)
# Wrap GEMM ops in condensed functions to make CUDA Graphs easier to use
def _fp8_gemm():
return tex.general_gemm(
kernel_t_fp8,
gemm_inp,
workspace,
out_dtype=torch.float8_e4m3fn if opts.fp8_output else torch.bfloat16,
quantization_params=out_quantizer,
use_split_accumulator=te.module.base._2X_ACC_FPROP,
ub=ub_obj,
ub_type=opts.comm_type,
extra_output=rs_out,
bulk_overlap=opts.bulk_overlap,
)
def _fp8_gemm2(gemm1_out):
gemm2_inp = tex.gelu(
(gemm1_out.dequantize() if opts.fp8_output else gemm1_out),
inp2_quantizer,
)
return tex.general_gemm(
kernel2_t_fp8,
gemm2_inp,
workspace,
out_dtype=torch.float8_e4m3fn if opts.fp8_output else torch.bfloat16,
quantization_params=out2_quantizer,
use_split_accumulator=te.module.base._2X_ACC_FPROP,
ub=ub_obj2,
ub_type=tex.CommOverlapType.AG,
extra_output=rs_out2,
)
def _gemm():
return tex.general_gemm(
kernel_t,
gemm_inp,
workspace,
out_dtype=torch.bfloat16,
use_split_accumulator=te.module.base._2X_ACC_FPROP,
ub=ub_obj,
ub_type=opts.comm_type,
extra_output=rs_out,
bulk_overlap=opts.bulk_overlap,
)
# Trigger GEMM
total_iters = opts.warmup_iters + opts.timing_iters
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(total_iters)]
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(total_iters)]
torch.cuda.synchronize()
if opts.use_cuda_graphs:
# Trace the CUDA graph first
g = torch.cuda.CUDAGraph()
if opts.fp8:
if ub_obj is None:
with torch.cuda.graph(g):
all_outputs = _fp8_gemm()
else:
with torch.cuda.graph(g):
all_outputs = _fp8_gemm()
_ = _fp8_gemm2(all_outputs[0])
else:
with torch.cuda.graph(g):
all_outputs = _gemm()
# Now replay the CUDA graph in a loop
for i in range(total_iters):
start_events[i].record()
g.replay()
end_events[i].record()
else:
for i in range(total_iters):
if opts.fp8:
start_events[i].record()
all_outputs = _fp8_gemm()
end_events[i].record()
if ub_obj2 is not None:
_fp8_gemm2(all_outputs[0])
else:
start_events[i].record()
all_outputs = _gemm()
end_events[i].record()
torch.cuda.synchronize()
gpu_times = [
s.elapsed_time(e)
for s, e in zip(start_events[opts.warmup_iters :], end_events[opts.warmup_iters :])
]
avg_gpu_time = sum(gpu_times) / opts.timing_iters
gemm_name = "".join(
[
"p2p all-gather + " if opts.comm_type == tex.CommOverlapType.AG else "",
"atomic " if opts.atomic else "",
"GEMM",
(
f" + {'p2p ' if opts.p2p else ''}reduce-scatter"
if opts.comm_type == tex.CommOverlapType.RS
else ""
),
]
)
timing_info = (
f"Avg. GPU time for {gemm_name}: {avg_gpu_time} ms "
+ f"({opts.warmup_iters} warmup + {opts.timing_iters} timing runs)"
)
dist_print(timing_info, section=True, info=True, group=tp_group)
# Compare against standard GEMM
numerics_failed = False
if opts.check_numerics:
torch.cuda.synchronize()
dist.barrier(tp_group)
if opts.bulk_overlap:
output_info = ""
if opts.comm_type == tex.CommOverlapType.AG:
# Bulk overlap AG output is already gathered
test_out = ub_obj.get_buffer(bulk_inp_quantizer, False)
else:
# Bulk overlap RS output needs to be gathered
out_local = ub_obj.get_buffer(bulk_inp_quantizer, True)
output_info += f"rs_output: {list(out_local.shape)} | "
test_out = te.distributed.gather_along_first_dim(out_local, tp_group)[0]
ref_out = ref_g
output_info += f"output: {list(test_out.shape)} | reference: {list(ref_out.shape)}"
dist_print(
output_info,
src=0 if opts.comm_type == tex.CommOverlapType.RS else None,
section=True,
)
test_nonzeros = torch.count_nonzero(test_out)
ref_nonzeros = torch.count_nonzero(ref_out)
nonzero_info = (
f"output nonzeros = {test_nonzeros} " + f"| reference count = {ref_nonzeros}"
)
dist_print(nonzero_info, src=0, section=True, group=tp_group)
else:
if opts.comm_type == tex.CommOverlapType.AG:
if ub_obj2 is not None:
# AG+RS Output: (M/P, N) -> gather -> (M, N)
output = rs_out2.to(dtype=torch.float32)
test_out = te.distributed.gather_along_first_dim(output, tp_group)[0]
else:
# AG Output: (M, K/P) -> T -> (K/P, M) -> gather -> (K, M) -> T -> (M, K)
output = all_outputs[0].dequantize() if opts.fp8_output else all_outputs[0]
test_out = torch.transpose(
te.distributed.gather_along_first_dim(
torch.transpose(output, 0, 1), tp_group
)[0],
0,
1,
)
else:
# RS Output: (M/P, N) -> gather -> (M, N)
output = rs_out.to(dtype=torch.float32)
test_out = te.distributed.gather_along_first_dim(output, tp_group)[0]
ref_out = ref2_g if ub_obj2 is not None else ref_g
test_nonzeros = torch.count_nonzero(test_out)
ref_nonzeros = torch.count_nonzero(ref_out)
nonzero_info = (
f"output nonzeros = {test_nonzeros} " + f"| reference count = {ref_nonzeros}"
)
dist_print(nonzero_info, src=0, section=True, group=tp_group)
sizing_info = (
f"input: {list(inp.shape)} " + f"| GEMM1 weights: {list(kernel_t.shape)[::-1]} "
)
if ub_obj2 is not None:
sizing_info += f"| GEMM2 weights: {list(kernel2_t.shape)[::-1]} "
sizing_info += f"| output: {list(output.shape)}\n"
dist_print(sizing_info, section=True, group=tp_group)
sizing_info_g = (
f"input: {list(inp_g.shape)} " + f"| GEMM1 weights: {list(ker_g.shape)} "
)
if ub_obj2 is not None:
sizing_info_g += f"| GEMM2 weights: {list(ker2_g.shape)} "
sizing_info_g += (
f"| output: {list(test_out.shape)} " + f"| reference: {list(ref_out.shape)}\n"
)
dist_print(sizing_info_g, src=0, group=tp_group)
torch.cuda.synchronize()
dist.barrier(tp_group)
diff = torch.abs(test_out - ref_out).flatten()
m = torch.argmax(diff)
abs_err = diff[m].item()
rel_err = abs_err / max(abs(ref_out.flatten()[m].item()), 1e-5)
rtol = 0.125 if opts.fp8 else 0.02
atol = 0.0625 if opts.fp8 else 0.001
if rel_err > rtol and abs_err > atol:
numerics_failed = True
numerics_info = (
"NUMERICAL CHECK FAILED: "
+ f"Outputs not close enough at index {m.item()} "
+ f"with {test_out.flatten()[m].item()} vs {ref_out.flatten()[m].item()} | "
+ f"rel. error = {rel_err} (tol = {rtol}) | "
+ f"abs. error = {abs_err} (tol = {atol})"
)
else:
numerics_info = "NUMERICAL CHECK PASSED: "
if rel_err <= rtol:
numerics_info += f"rel. error = {rel_err} (tol = {rtol})" + (
" | " if abs_err < atol else ""
)
if abs_err <= atol:
numerics_info += f"abs. error = {abs_err} (tol = {atol})"
dist_print(
numerics_info, src=0, section=True, info=True, error=numerics_failed, group=tp_group
)
dist.barrier(tp_group)
if LOCAL_RANK == 0:
print("\n", end="", flush=True)
dist.destroy_process_group()
# Reset clock speeds
if opts.clock_speed > 0:
subprocess.run(
["nvidia-smi", "-pm", "ENABLED", "-i", str(LOCAL_RANK)],
env=os.environ,
check=False,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
result = subprocess.run(
["nvidia-smi", "-rgc", "-i", str(LOCAL_RANK)],
env=os.environ,
check=False,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
return int(numerics_failed)
if __name__ == "__main__":
sys.exit(_main(_parse_args()))
#!/usr/bin/python3
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import sys
import socket
import subprocess
import argparse
import warnings
import pprint
import yaml
import torch
import torch.distributed as dist
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling, Float8CurrentScaling
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
class multi_module_model(torch.nn.Module):
def __init__(self, module, num_layers, *args, **kwargs):
super().__init__()
self.num_layers = num_layers
self.layers = torch.nn.ModuleList([module(*args, **kwargs) for _ in range(num_layers)])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
def _te_layer_argtype(name):
te_layers = [
te.Linear,
te.LayerNormLinear,
te.LayerNormMLP,
te.MultiheadAttention,
te.TransformerLayer,
]
layer_map = dict(zip([layer.__name__.lower() for layer in te_layers], te_layers))
if name.lower() not in layer_map.keys():
raise argparse.ArgumentTypeError(
f"Invalid TE layer name! Please choose from: {layer_map.keys()}"
)
return layer_map[name.lower()]
def _get_layer_args(config, tp_group, tp_size, num_layers, reference=False):
hidden_size = config.num_heads * config.head_dim
ffn_hidden_size = 4 * hidden_size
qkv_size = 3 * hidden_size
if num_layers > 1 and config.layer_type != te.TransformerLayer:
raise ValueError("Stacked layers are only supported for te.TransformerLayer!")
input_shape = [config.seq_length, config.batch_size, hidden_size]
args = [hidden_size]
kwargs = {
"params_dtype": torch.float32 if not config.use_bf16_params else torch.bfloat16,
"device": "cuda",
"tp_group": tp_group,
"tp_size": tp_size,
"sequence_parallel": True,
"ub_overlap_ag": not reference,
"ub_overlap_rs": not reference,
}
if config.layer_type in [te.Linear, te.LayerNormLinear]:
if config.linear_parallel_mode == "row":
input_shape[-1] = ffn_hidden_size // tp_size
args = [ffn_hidden_size, hidden_size]
if config.in_features is not None:
input_shape[-1] = config.in_features // tp_size
args = [config.in_features, hidden_size]
kwargs["ub_name"] = "proj" if config.layer_type == te.Linear else "fc2"
kwargs["ub_name"] = kwargs["ub_name"] if config.ub_name is None else config.ub_name
elif config.linear_parallel_mode == "column":
input_shape[0] = config.seq_length // tp_size
if config.out_features is not None:
args.append(config.out_features)
else:
args.append(qkv_size)
kwargs["ub_name"] = "qkv" if config.ub_name is None else config.ub_name
kwargs["ub_overlap_rs_dgrad"] = config.overlap_rs_dgrad and not reference
kwargs["ub_bulk_dgrad"] = not config.overlap_rs_dgrad and not reference
kwargs["ub_bulk_wgrad"] = not config.overlap_rs_dgrad and not reference
kwargs["parallel_mode"] = config.linear_parallel_mode
else:
input_shape[0] = config.seq_length // tp_size
if config.layer_type in [te.LayerNormMLP, te.TransformerLayer]:
args.append(ffn_hidden_size)
kwargs["seq_length"] = config.seq_length
if config.layer_type in [te.MultiheadAttention, te.TransformerLayer]:
args.append(config.num_heads)
kwargs["attention_dropout"] = 0.0
kwargs["fuse_qkv_params"] = True
if config.layer_type is te.MultiheadAttention:
kwargs["input_layernorm"] = True
else:
kwargs["ub_tp_comm_overlap"] = not reference
kwargs["hidden_dropout"] = 0.0
kwargs["set_parallel_mode"] = True
kwargs["ub_overlap_rs_dgrad"] = config.overlap_rs_dgrad and not reference
kwargs["ub_bulk_dgrad"] = not config.overlap_rs_dgrad and not reference
kwargs["ub_bulk_wgrad"] = not config.overlap_rs_dgrad and not reference
if config.ub_cfg is not None and isinstance(config.ub_cfg, str):
with open(config.ub_cfg, "r") as stream:
config.ub_cfg = yaml.safe_load(stream)
return args, kwargs, input_shape
def _parse_args(argv=None, namespace=None):
parser = argparse.ArgumentParser(
description="Test a Transformer Engine layer with GEMM+comm overlap via Userbuffers."
)
parser.add_argument("-l", "--layer-type", type=_te_layer_argtype, default=te.LayerNormMLP)
parser.add_argument(
"--num-layers", type=int, default=1, help="Number of identical layers to stack."
)
parser.add_argument("-b", "--batch-size", type=int, default=2, help="Input batch size.")
parser.add_argument("-s", "--seq-length", type=int, default=1024, help="Input sequence length.")
parser.add_argument(
"-n", "--num-heads", type=int, default=16, help="Number of attention heads."
)
parser.add_argument(
"-d", "--head-dim", type=int, default=48, help="Dimension of each attention head."
)
parser.add_argument(
"--in-features",
type=int,
default=None,
help="Optional input feature size for weight. Only used for Linear layer.",
)
parser.add_argument(
"--out-features",
type=int,
default=None,
help="Optional output feature size for weight. Only used for LayerNormLinear layer.",
)
parser.add_argument(
"--tp",
type=int,
default=None,
help="Optional tensor_model_parallel_size used to initialize UB.",
)
parser.add_argument(
"--use-bf16-params",
action="store_true",
default=False,
help="Use BF16 params instead of FP32.",
)
parser.add_argument("--seed", type=int, default=42, help="RNG seed.")
parser.add_argument(
"--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context."
)
parser.add_argument(
"--quantization",
type=str.lower,
default="none",
choices=["none", "fp8_delayed_scaling", "fp8_current_scaling"],
help="Quantization recipe",
)
parser.add_argument(
"--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8."
)
parser.add_argument(
"--tcp-init",
action="store_true",
default=False,
help="Initialize torch.distributed with TcpStore.",
)
parser.add_argument(
"--bind-to-device",
action="store_true",
default=False,
help="Initialize torch.distributed with `device_id` to bind each rank to a single device.",
)
parser.add_argument(
"--bootstrap-backend",
type=str.lower,
default="nccl",
choices=["gloo", "mpi", "nccl"],
help="Communications backend for host tensor collectives during Userbuffers bootstrapping.",
)
parser.add_argument(
"--use-cuda-graphs", action="store_true", default=False, help="Use CUDA Graphs."
)
parser.add_argument(
"--ub-cfg", type=str, default=None, help="Optional TP config yaml file input."
)
parser.add_argument("--ub-name", type=str, default=None, help="Optional TP layer name.")
parser.add_argument(
"--skip-verify",
action="store_true",
default=False,
help="Skip numerics check.",
)
parser.add_argument(
"--benchmark",
action="store_true",
default=False,
help="Benchmark comm-gemm overlap perf.",
)
parser.add_argument(
"--benchmark-iter",
type=int,
default=100,
help="Number of iterations for benchmarking perf.",
)
parser.add_argument(
"--linear-parallel-mode",
type=str.lower,
default="row",
choices=["row", "column"],
help="Parallel mode for te.Linear.",
)
parser.add_argument(
"--overlap-rs-dgrad",
action="store_true",
default=False,
help="Replace bulk DGRAD/WGRAD overlaps with DGRAD+RS in the backward pass for AG+GEMM.",
)
parser.add_argument(
"--debug",
action="store_true",
default=False,
help="Print out additional debug information.",
)
args = parser.parse_args(argv, namespace)
if args.use_cuda_graphs and args.layer_type in [te.MultiheadAttention, te.TransformerLayer]:
warnings.warn(f"{args.layer_type.__name__} does not support CUDA Graphs!")
args.use_cuda_graphs = False
return args
def _compare_tensors(name, test, ref, rtol, atol):
# Make sure tensors aren't zero and we don't pass trivially
if test.count_nonzero() == 0:
if ref.count_nonzero() == 0:
warnings.warn(
f"WARNING: {name} is a zero-tensor for both test and reference models!",
category=RuntimeWarning,
)
else:
numerics_info = (
f"NUMERICAL CHECK FAILED: {name} is a zero-tensor but does not match reference!"
)
return 1, numerics_info
diff = torch.abs(test.flatten() - ref.flatten())
m = torch.argmax(diff)
abs_err = diff[m].item()
rel_err = abs_err / max(abs(ref.flatten()[m].item()), 1e-5)
numerics_failed = 0
if rel_err > rtol and abs_err > atol:
numerics_failed = 1
numerics_info = (
"NUMERICAL CHECK FAILED: "
+ f"{name} not close enough at index {m.item()} "
+ f"with {test.flatten()[m].item()} vs {ref.flatten()[m].item()} | "
+ f"rel. error = {rel_err} (tol = {rtol}) | "
+ f"abs. error = {abs_err} (tol = {atol})"
)
else:
numerics_info = f"NUMERICAL CHECK PASSED: {name} | "
if rel_err <= rtol:
numerics_info += f"rel. error = {rel_err} (tol = {rtol})" + (
" | " if abs_err <= atol else "."
)
if abs_err <= atol:
numerics_info += f" abs. error = {abs_err} (tol = {atol})"
return numerics_failed, numerics_info
def _train(opts):
if "OMPI_COMM_WORLD_SIZE" in os.environ:
# Execution with `mpirun -np N`
WORLD_RANK = int(os.getenv("OMPI_COMM_WORLD_RANK", "0"))
WORLD_SIZE = int(os.getenv("OMPI_COMM_WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("OMPI_COMM_WORLD_LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("OMPI_COMM_WORLD_LOCAL_SIZE", "1"))
opts.tcp_init = True
opts.bind_to_device = True
opts.bootstrap_backend = "mpi"
else:
WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", str(torch.cuda.device_count())))
result = subprocess.run(
"nvidia-smi -q | grep -m1 CliqueId | awk '{printf $3}'",
capture_output=True,
text=True,
shell=True,
)
if result.stdout == "0" and opts.tp is None: # Extra checks for non-MNNVL platforms
assert WORLD_SIZE == LOCAL_SIZE
# Initialize torch.distributed tp process group
new_group_kwargs = {
"backend": "nccl",
}
if opts.tp is not None:
LOCAL_SIZE = opts.tp
tp_base_rank = (WORLD_RANK // LOCAL_SIZE) * LOCAL_SIZE
tp_rank_list = list(range(tp_base_rank, tp_base_rank + LOCAL_SIZE))
new_group_kwargs = {
"backend": "nccl",
"ranks": tp_rank_list,
}
else:
opts.tp = WORLD_SIZE
# Tensor dim overrides for tensors that do not require TP communication
if opts.in_features is not None:
assert opts.layer_type is te.Linear and opts.linear_parallel_mode == "row", (
"--in-features is only used to configure row-tensor-parallel Linear layers. Use"
" --num-heads or --head-dim for other cases."
)
if opts.out_features is not None:
assert opts.layer_type is te.LayerNormLinear and opts.linear_parallel_mode == "column", (
"--out-features is only used to configure column-tensor-parallel LayerNormLinear"
" layers. Use --num-heads or --head-dim for other cases."
)
def dist_print(msg, src=None, end="\n", debug=False, error=False):
if debug and not opts.debug:
return
stream = sys.stderr if error else sys.stdout
if WORLD_RANK == (0 if src is None else src):
stream.write(f"[rank{WORLD_RANK}] {msg}{end}\n")
dist.barrier()
# Set device and initialize RNG states
torch.cuda.set_device(LOCAL_RANK)
torch.manual_seed(opts.seed)
torch.cuda.manual_seed(opts.seed)
# Initialize torch.distributed global process group and get DP/TP groups
dist_init_kwargs = {
"backend": "nccl",
"rank": WORLD_RANK,
"world_size": WORLD_SIZE,
}
if opts.tcp_init:
MASTER_ADDR = os.getenv("MASTER_ADDR", socket.gethostbyname(socket.gethostname()))
MASTER_PORT = os.getenv("MASTER_PORT", "1234")
dist_init_kwargs["init_method"] = f"tcp://{MASTER_ADDR}:{MASTER_PORT}"
if opts.bind_to_device or opts.bootstrap_backend == "nccl":
dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}")
assert dist.is_nccl_available()
dist.init_process_group(**dist_init_kwargs)
nccl_world = dist.new_group(**new_group_kwargs)
dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs")
# Initialize the Transformer Engine layer with overlap
args, kwargs, input_shape = _get_layer_args(
opts, nccl_world, opts.tp, num_layers=opts.num_layers
)
# Intialize userbuffers
ub_cfgs = None
if opts.overlap_rs_dgrad:
ub_cfgs = {
"qkv_dgrad": {"method": "ring_exchange"},
"fc1_dgrad": {"method": "ring_exchange"},
}
te.module.base.initialize_ub(
[opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim],
opts.tp,
use_fp8=opts.fp8,
dtype=torch.bfloat16,
bootstrap_backend=opts.bootstrap_backend,
ub_cfgs=ub_cfgs if opts.ub_cfg is None else opts.ub_cfg,
)
with te.fp8_model_init(enabled=opts.fp8_init):
test_model = multi_module_model(opts.layer_type, opts.num_layers, *args, **kwargs)
dist_print("Initialized test model...", debug=True)
if WORLD_RANK == 0:
pprint.pprint(kwargs)
sys.stdout.write("\n")
dist.barrier()
# Initialize the reference model and copy all parameters
ref_args, ref_kwargs, _ = _get_layer_args(
opts, nccl_world, opts.tp, num_layers=opts.num_layers, reference=True
)
with te.fp8_model_init(enabled=opts.fp8_init):
ref_model = multi_module_model(opts.layer_type, opts.num_layers, *ref_args, **ref_kwargs)
dist_print("Initialized reference model...", debug=True)
for test_param, ref_param in zip(test_model.parameters(), ref_model.parameters()):
with torch.no_grad():
ref_param.copy_(test_param)
torch.testing.assert_close(test_param, ref_param, rtol=0.0, atol=0.0)
dist_print("Copied parameters from test model to reference model...", debug=True)
# Fp8 recipe setup
fp8_format = Format.HYBRID
fp8_recipe = None
if opts.quantization == "fp8_delayed_scaling":
fp8_recipe = DelayedScaling(
fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max"
)
elif opts.quantization == "fp8_current_scaling":
fp8_recipe = Float8CurrentScaling(fp8_format=fp8_format)
# Prepare random input tensors
test_x = torch.randn(input_shape, dtype=torch.float32, device="cuda", requires_grad=True)
test_x.retain_grad()
ref_x = torch.empty_like(test_x).requires_grad_(True)
with torch.no_grad():
ref_x.copy_(test_x)
torch.testing.assert_close(test_x, ref_x, rtol=0.0, atol=0.0)
ref_x.retain_grad()
# Execute fwd/bwd and collect tensors to test
def run_fwd_bwd(model, x):
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world):
y = model(x)
if isinstance(y, tuple):
out, *_ = y
else:
out = y
loss = out.sum()
loss.backward()
return out
torch_rng_state = torch.get_rng_state()
cuda_rng_state = torch.cuda.get_rng_state(torch.device(f"cuda:{LOCAL_RANK}"))
if opts.use_cuda_graphs:
test_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(test_graph):
test_out = run_fwd_bwd(test_model, test_x)
test_graph.replay()
if not opts.benchmark:
del test_graph
else:
test_out = run_fwd_bwd(test_model, test_x)
test_grads = [test_out, test_x.grad]
names = ["output", "input.grad"]
for test_name, test_param in test_model.named_parameters():
if test_param.requires_grad and "layer_norm" not in test_name:
test_grads.append(test_param.grad)
names.append(test_name + ".grad")
torch.set_rng_state(torch_rng_state)
torch.cuda.set_rng_state(cuda_rng_state, torch.device(f"cuda:{LOCAL_RANK}"))
if opts.use_cuda_graphs:
ref_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(ref_graph):
ref_out = run_fwd_bwd(ref_model, ref_x)
ref_graph.replay()
del ref_graph
else:
ref_out = run_fwd_bwd(ref_model, ref_x)
ref_grads = [ref_out, ref_x.grad]
for ref_name, ref_param in ref_model.named_parameters():
if ref_param.requires_grad and "layer_norm" not in ref_name:
ref_grads.append(ref_param.grad)
numerics_failed = torch.tensor([0], dtype=torch.uint8, device="cuda")
if not opts.skip_verify:
# Make sure we have the same number of gradients
if len(test_grads) != len(ref_grads):
numerics_failed[0] = 1
numerics_info = (
"NUMERICAL CHECK FAILED: Incorrect number of gradients, "
+ f"expected {len(ref_grads)} but got {len(test_grads)}."
)
dist_print(numerics_info, src=WORLD_RANK, error=True)
dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world)
# Now validate accuracy
if not bool(numerics_failed.item()):
for i, (test_g, ref_g) in enumerate(zip(test_grads, ref_grads)):
rtol = 0.125 if opts.fp8 else 0.025
atol = 0.0625 if opts.fp8 else 0.00125
grad_failed, grad_info = _compare_tensors(names[i], test_g, ref_g, rtol, atol)
dist_print(grad_info, src=WORLD_RANK, error=grad_failed)
numerics_failed[0] = int(grad_failed)
dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world)
if bool(numerics_failed.item()) and not opts.debug:
break
if opts.benchmark:
# Warmup to not profile CPU overhead
for _ in range(100):
if opts.use_cuda_graphs:
test_graph.replay()
else:
test_out = run_fwd_bwd(test_model, test_x)
torch.cuda.cudart().cudaProfilerStart()
for _ in range(opts.benchmark_iter):
if opts.use_cuda_graphs:
test_graph.replay()
else:
test_out = run_fwd_bwd(test_model, test_x)
torch.cuda.cudart().cudaProfilerStop()
if opts.use_cuda_graphs:
del test_graph
te.module.base.destroy_ub()
dist_print("Destroying Userbuffers objects...", debug=True)
dist_print("Destroying all process groups...", debug=True)
dist.destroy_process_group()
if opts.debug and WORLD_RANK == 0:
print("Exiting...\n", end="", flush=True)
return numerics_failed[0].item()
if __name__ == "__main__":
sys.exit(_train(_parse_args()))
#!/usr/bin/python3
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import argparse
import datetime
import os
import sys
from functools import wraps
import transformer_engine.pytorch as te
import torch
from torch import nn
import torch.distributed as dist
import transformer_engine_torch as tex
from transformer_engine.common.recipe import (
MXFP8BlockScaling,
DelayedScaling,
Float8CurrentScaling,
Format,
Recipe,
)
from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer
from run_layer_with_overlap import _compare_tensors
SEQ_LEN, BATCH_SIZE = 16, 16
HIDDEN_SIZE = 64
NR_HEADS = 4
WORLD_RANK, WORLD_SIZE = None, None
NCCL_WORLD = None
LOSS_FN = nn.MSELoss()
QUANTIZATION = None
# Disable TF32
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
# Quantization recipe setup
def quantization_recipe() -> Recipe:
if QUANTIZATION == "fp8":
return DelayedScaling(
fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max"
)
if QUANTIZATION == "mxfp8":
return MXFP8BlockScaling()
if QUANTIZATION == "fp8_cs":
return Float8CurrentScaling()
return te.fp8.get_default_fp8_recipe()
def main(argv=None, namespace=None):
global WORLD_RANK, WORLD_SIZE, NCCL_WORLD, QUANTIZATION
WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
assert WORLD_SIZE == LOCAL_SIZE # this test supports only 1 node
assert LOCAL_SIZE <= torch.cuda.device_count()
dist_init_kwargs = {
"backend": "nccl",
"rank": WORLD_RANK,
"world_size": WORLD_SIZE,
"timeout": datetime.timedelta(seconds=30),
}
dist_init_kwargs["init_method"] = "env://"
dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}")
assert dist.is_nccl_available()
torch.cuda.set_device(LOCAL_RANK)
dist.init_process_group(**dist_init_kwargs)
NCCL_WORLD = dist.new_group(backend="nccl")
WORLD_SIZE = dist.get_world_size()
parser = argparse.ArgumentParser()
parser.add_argument("-l", "--layer-type", type=str)
parser.add_argument("--quantization", type=str, default=None)
args = parser.parse_args(argv, namespace)
# Quantization scheme
QUANTIZATION = args.quantization
if QUANTIZATION in ("fp8", "mxfp8"):
global SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE
SEQ_LEN = 32
BATCH_SIZE = 32
HIDDEN_SIZE = 128
test_dict = [
test_quantizer,
test_linear,
test_layernorm,
test_layernorm_linear,
test_layernorm_mlp,
test_transformer_layer,
]
for test in test_dict:
test()
dist.destroy_process_group()
return 0
def run_distributed_test(test_name=None):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
name = test_name if test_name is not None else func.__name__
dist_print(f"Starting test {name} with args {args} and {kwargs}")
torch.cuda.set_device(WORLD_RANK)
torch.manual_seed(12345)
torch.cuda.manual_seed(12345)
func(*args, **kwargs)
dist.barrier()
dist_print(f"Passed test {name}")
return wrapper
return decorator
def _gather(tensor, dim=0):
"""
Gathers tensors and concats them. Since torch.distributed.nn.functional.all_gather
multiplies gradients by WORLD_SIZE, those gradiedts are rescaled.
"""
class HalfGradient(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
return input # forward pass (identity)
@staticmethod
def backward(ctx, grad_output):
return grad_output / WORLD_SIZE # gradient division by WORLD_SIZE
tensor = HalfGradient.apply(tensor)
gathered = torch.distributed.nn.functional.all_gather(tensor, group=NCCL_WORLD)
return torch.cat(gathered, dim=dim)
def _constant(tensor):
return nn.init.constant_(tensor, 0.5)
def dist_print(msg, src=None, end="\n", error=False):
stream = sys.stderr if error else sys.stdout
if WORLD_RANK == (0 if src is None else src):
stream.write(f"[rank{WORLD_RANK}] {msg}{end}\n")
def _get_tolerances(dtype):
# loose tolerances for fp8_cs because of sequence parallel & amax reduction
# so that each rank has a different scale_inv for computing Y when we have
# row parallel & sequence parallel, because we do the all_gather in backward pass
if QUANTIZATION == "fp8_cs":
return {"rtol": 0.4, "atol": 0.25}
elif QUANTIZATION is not None:
return {"rtol": 0.125, "atol": 0.0625}
if dtype == torch.float16:
return {"rtol": 1e-3, "atol": 1e-5}
if dtype == torch.bfloat16:
return {"rtol": 1.6e-2, "atol": 1e-5}
if dtype == torch.float32:
return {"rtol": 1.3e-6, "atol": 1e-5}
raise ValueError(f"Unsupported dtype ({dtype})")
def _check_outputs(output_single_node, output_distributed):
numerics_failed = torch.tensor([0], dtype=torch.uint8, device="cuda")
output_failed, output_info = _compare_tensors(
"outputs",
output_distributed,
output_single_node,
**_get_tolerances(output_single_node.dtype),
)
if output_failed:
dist_print(output_info, src=WORLD_RANK, error=output_failed)
numerics_failed[0] = int(output_failed)
dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, NCCL_WORLD)
assert not bool(numerics_failed.item())
def _match_param_sizes(dist_param, single_param):
"""
Adjust single_param to match the shape of dist_param
by slicing along dimensions where the shapes differ.
This function is typically used in a distributed setting
where single_param is a larger tensor that needs
to be partitioned among multiple processes.
Args:
dist_param: Tensor representing the distributed output
with the desired shape for the current process.
single_param: Tensor representing the non-distributed output,
possibly larger than dist_param.
Returns:
Tensor: Sliced version of single_param matching
the shape of dist_param for the current process.
"""
# Initialize indices for slicing with full slices for each dimension
indices = [slice(None)] * len(single_param.shape)
# Iterate over each dimension to identify where shapes differ
for i in range(len(dist_param.shape)):
if dist_param.shape[i] != single_param.shape[i]:
# Calculate the start and end indices for slicing based on the world rank
start = WORLD_RANK * dist_param.shape[i]
end = (WORLD_RANK + 1) * dist_param.shape[i]
src_slice = slice(start, end)
# Update the slicing indices for the current dimension
indices[i] = src_slice
# Slice single_param to obtain the output matching dist_param's shape
to_output = single_param[tuple(indices)]
return to_output
def _check_gradients(model_distributed, model_single, main_grad_check=False):
for i, ((name, param_d), param_s) in enumerate(
zip(model_distributed.named_parameters(), model_single.parameters())
):
numerics_failed = torch.tensor([0], dtype=torch.uint8, device="cuda")
grad_failed, grad_info = None, None
if main_grad_check:
param_s_grad = _match_param_sizes(param_d.main_grad, param_s.main_grad)
grad_failed, grad_info = _compare_tensors(
str(i), param_d.main_grad, param_s_grad, **_get_tolerances(param_s_grad.dtype)
)
else:
param_s_grad = _match_param_sizes(param_d.grad, param_s.grad)
grad_failed, grad_info = _compare_tensors(
str(i), param_d.grad, param_s_grad, **_get_tolerances(param_s_grad.dtype)
)
if grad_failed:
dist_print(i, src=WORLD_RANK)
dist_print(name, src=WORLD_RANK)
dist_print(grad_info, src=WORLD_RANK, error=grad_failed)
numerics_failed[0] = int(grad_failed)
dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, NCCL_WORLD)
assert not bool(numerics_failed.item())
def _copy_params(model_distributed, model_single):
for dist_param, single_param in zip(model_distributed.parameters(), model_single.parameters()):
with torch.no_grad():
to_copy = single_param
for dim, _ in enumerate(dist_param.shape):
if dist_param.shape[dim] != single_param.shape[dim]:
src_slice = slice(
WORLD_RANK * dist_param.shape[dim], (WORLD_RANK + 1) * dist_param.shape[dim]
)
indices = [slice(None)] * max(min(dim, len(dist_param.shape) - 1), 0)
indices.append(src_slice)
if dim < len(dist_param.shape) - 1:
indices.append(slice(None))
to_copy = single_param[tuple(indices)]
dist_param.copy_(to_copy)
def _apply_models(
model_single_node, model_distributed, input_single_node, input_distributed, **kwargs
):
_alloc_main_grad(model_single_node, model_distributed) # for fuse_wgrad_accumulation=True
input_single_node.requires_grad_()
input_distributed.requires_grad_()
with te.fp8_autocast(
enabled=QUANTIZATION is not None,
fp8_recipe=quantization_recipe(),
):
output_single_node = model_single_node(input_single_node, **kwargs)
with te.fp8_autocast(
enabled=QUANTIZATION is not None,
fp8_recipe=quantization_recipe(),
fp8_group=NCCL_WORLD,
):
output_distributed = model_distributed(input_distributed, **kwargs)
return output_single_node, output_distributed
def _loss_backward(output_single_node, output_distributed):
target = torch.randn_like(output_single_node)
LOSS_FN(output_single_node, target).backward()
LOSS_FN(output_distributed, target).backward()
def _alloc_main_grad(model_single_node, model_distributed):
for model in [model_single_node, model_distributed]:
for param in model.parameters():
param.main_grad = torch.zeros_like(param, dtype=torch.float32)
###############################################
# Quantizer #
###############################################
def _construct_quantizer(quantizer_class, fp8_dtype, device, tp_group, tp_size):
"""
quantizer is the reference quantizer on a single GPU.
quantizer_dist is the distributed quantizer to be tested on multiple GPUs.
"""
if quantizer_class == Float8CurrentScalingQuantizer:
quantizer_dist = quantizer_class(
fp8_dtype=fp8_dtype,
device=device,
with_amax_reduction=True,
amax_reduction_group=tp_group,
amax_reduction_size=tp_size,
)
quantizer = quantizer_class(
fp8_dtype=fp8_dtype,
device=device,
with_amax_reduction=False,
)
return quantizer, quantizer_dist
else:
raise ValueError(f"Unsupported quantizer class: {quantizer_class}")
def _shard_tensor(x, world_size, axis):
split_size = x.size()[axis] // world_size
split_tensor = torch.split(x, split_size, axis)
out = []
for tensor in split_tensor:
out.append(tensor.detach().clone().requires_grad_(x.requires_grad).cuda())
return out
@run_distributed_test()
def _test_quantizer(input_dtype, fp8_dtype):
"""Test the quantizer under distributed settings.
Args:
input_dtype (torch.dtype): The data type of the input.
fp8_dtype (tex.DType): The data type of the fp8.
"""
M, N = WORLD_SIZE * BATCH_SIZE, HIDDEN_SIZE
# high precision input
x_hp_cpu = torch.randn((M, N), device="cpu").to(input_dtype)
# set one element of the input to a very large value, which doesn't live in rank 0 after the split
# to test the amax reduction on purpose
x_hp_cpu[M - 1, N - 1] = 1e4
# rank 0 takes the full copy and quantize with GPU 0 for verification
if WORLD_RANK == 0:
x_hp_rank0 = x_hp_cpu.clone().detach().requires_grad_(True).to("cuda")
x_hp_local_rank = _shard_tensor(x_hp_cpu, WORLD_SIZE, 0)[WORLD_RANK]
# Create quantizers
quantizer, quantizer_dist = _construct_quantizer(
Float8CurrentScalingQuantizer, fp8_dtype, x_hp_local_rank.device, NCCL_WORLD, WORLD_SIZE
)
# quantize the input
if WORLD_RANK == 0:
x_fp8_single = quantizer(x_hp_rank0)
# multi-GPU quantizer
x_fp8_dist = quantizer_dist(x_hp_local_rank)
# check scale_inv with zero tolerance
if WORLD_RANK == 0:
torch.testing.assert_close(
x_fp8_single._scale_inv, x_fp8_dist._scale_inv, rtol=0.0, atol=0.0
)
def test_quantizer():
"""
Run quantizer tests with various configurations.
Currently only check fp8_cs because it needs to do amax reduction in the quantizer.
"""
# skip this test for other quantization schemes
if QUANTIZATION != "fp8_cs":
return
input_dtypes = [torch.float32, torch.bfloat16]
fp8_dtypes = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]
for input_dtype in input_dtypes:
for fp8_dtype in fp8_dtypes:
_test_quantizer(input_dtype, fp8_dtype)
############################################
# Linear #
############################################
@run_distributed_test()
def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs):
"""Test the linear layer with specified parallel mode and sequence parallelization.
Args:
parallel_mode (str): 'row' or 'column' parallelism.
sequence_parallel (bool): Enable sequence parallelism if True.
kwargs (dict): Additional arguments for the linear layer.
"""
# Set parameter data type
params_dtype = kwargs.get("params_dtype", torch.float32)
# Create models
model_single_node = te.Linear(HIDDEN_SIZE, HIDDEN_SIZE, **kwargs)
model_distributed = te.Linear(
HIDDEN_SIZE,
HIDDEN_SIZE,
tp_size=WORLD_SIZE,
tp_group=NCCL_WORLD,
parallel_mode=parallel_mode,
sequence_parallel=sequence_parallel,
**kwargs,
)
# Synchronize parameters between models
_copy_params(model_distributed, model_single_node)
# Prepare input tensors
input_single_node = torch.randn((BATCH_SIZE, HIDDEN_SIZE)).cuda().to(params_dtype)
if parallel_mode == "row":
# Split input across GPUs for row parallelism
split_size = HIDDEN_SIZE // WORLD_SIZE
input_distributed = input_single_node[
:, WORLD_RANK * split_size : (WORLD_RANK + 1) * split_size
].clone()
elif parallel_mode == "column":
if sequence_parallel:
# Duplicate input for sequence parallelism
input_single_node = (
torch.empty((WORLD_SIZE * BATCH_SIZE, HIDDEN_SIZE)).cuda().to(params_dtype)
)
input_distributed = torch.randn((BATCH_SIZE, HIDDEN_SIZE)).cuda().to(params_dtype)
# when quantization is fp8_cs, we need to trigger corner cases to see if amax reduction is working
if QUANTIZATION == "fp8_cs":
input_distributed = torch.clamp(input_distributed, min=-10, max=10)
if WORLD_RANK == WORLD_SIZE - 1:
input_distributed[BATCH_SIZE - 1, HIDDEN_SIZE - 1] = 11
input_single_node = _gather(input_distributed, dim=0).detach()
else:
input_distributed = input_single_node.clone()
else:
raise ValueError(f"Invalid parallel_mode: {parallel_mode}")
# Apply models
output_single_node, output_distributed = _apply_models(
model_single_node, model_distributed, input_single_node, input_distributed
)
if "return_bias" in kwargs:
output_single_node, bias_s = output_single_node
output_distributed, bias_d = output_distributed
if parallel_mode == "column":
bias_d = _gather(bias_d)
_check_outputs(bias_s, bias_d)
# Gather outputs if necessary
if parallel_mode == "column" or (sequence_parallel and parallel_mode == "row"):
output_distributed = _gather(output_distributed, dim=1 if parallel_mode == "column" else 0)
# Compute loss and backpropagate
_loss_backward(output_single_node, output_distributed)
# Validate outputs and gradients
_check_outputs(output_single_node, output_distributed)
# gradients in other cases need additional synchronization
if (parallel_mode == "column" or not sequence_parallel) and "return_bias" not in kwargs:
_check_gradients(
model_distributed,
model_single_node,
main_grad_check=("fuse_wgrad_accumulation" in kwargs),
)
def test_linear():
"""Run linear layer tests with various configurations."""
kwargs_list = [
{},
{"bias": False},
{"init_method": _constant},
{"fuse_wgrad_accumulation": True},
{"return_bias": True},
{"params_dtype": torch.float16},
]
for kwargs in kwargs_list:
for parallel_mode in ["column", "row"]:
for sequence_parallel in [False, True]:
_test_linear(parallel_mode, sequence_parallel, **kwargs)
############################################
# LayerNorm #
############################################
@run_distributed_test()
def _test_layernorm(kwargs):
"""Test LayerNorm and RMSNorm with given arguments.
Args:
kwargs (dict): Contains 'norm', 'basic_args', and 'distributed_args'.
"""
# Extract parameters
norm = kwargs["norm"]
basic_args = kwargs["basic_args"]
distributed_args = kwargs["distributed_args"]
params_dtype = basic_args.get("params_dtype", torch.float32)
# Create models
model_single_node = norm(HIDDEN_SIZE, **basic_args)
model_distributed = norm(HIDDEN_SIZE, **{**basic_args, **distributed_args})
# Synchronize parameters between models
_copy_params(model_distributed, model_single_node)
# Prepare input tensors
input_single_node = torch.randn((BATCH_SIZE, HIDDEN_SIZE), dtype=params_dtype).cuda()
input_distributed = input_single_node.clone()
# Apply models
output_single_node, output_distributed = _apply_models(
model_single_node, model_distributed, input_single_node, input_distributed
)
# Compute loss and backpropagate
_loss_backward(output_single_node, output_distributed)
# Validate outputs and gradients
_check_outputs(output_single_node, output_distributed)
_check_gradients(model_distributed, model_single_node)
def test_layernorm():
"""Run LayerNorm and RMSNorm tests with various configurations."""
norms = [te.LayerNorm, te.RMSNorm]
# Define basic arguments for the models
basic_args_list = [
{"zero_centered_gamma": True},
{"params_dtype": torch.float16},
]
# Define distributed arguments
distributed_args_list = [
{},
{"sequence_parallel": True},
]
# Generate combinations of norms and arguments
for norm in norms:
for basic_args in basic_args_list:
for distributed_args in distributed_args_list:
kwargs = {
"norm": norm,
"basic_args": basic_args,
"distributed_args": distributed_args,
}
_test_layernorm(kwargs)
############################################
# LayerNormLinear #
############################################
@run_distributed_test()
def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs):
"""Test the linear layer with specified parallel mode and sequence parallelization.
Args:
parallel_mode (str): 'row' or 'column' parallelism.
sequence_parallel (bool): Enable sequence parallelism if True.
kwargs (dict): Additional arguments for the linear layer.
"""
# Set parameter data type
params_dtype = kwargs.get("params_dtype", torch.float32)
# Create models
model_single_node = te.LayerNormLinear(HIDDEN_SIZE, HIDDEN_SIZE, **kwargs)
model_distributed = te.LayerNormLinear(
HIDDEN_SIZE,
HIDDEN_SIZE,
tp_size=WORLD_SIZE,
tp_group=NCCL_WORLD,
parallel_mode=parallel_mode,
sequence_parallel=sequence_parallel,
**kwargs,
)
# Synchronize parameters between models
_copy_params(model_distributed, model_single_node)
# Prepare input tensors
input_single_node = torch.randn((SEQ_LEN, HIDDEN_SIZE)).cuda().to(params_dtype)
if sequence_parallel:
# Duplicate input for sequence parallelism
input_single_node = torch.empty((WORLD_SIZE * SEQ_LEN, HIDDEN_SIZE)).cuda().to(params_dtype)
input_distributed = torch.randn((SEQ_LEN, HIDDEN_SIZE)).cuda().to(params_dtype)
# make the last element of the input a large value to test the amax reduction on purpose
# when quantization is fp8_cs, we need to trigger corner cases to see if amax reduction is working
if QUANTIZATION == "fp8_cs":
input_distributed = torch.clamp(input_distributed, min=-10, max=10)
if WORLD_RANK == WORLD_SIZE - 1:
input_distributed[SEQ_LEN - 1, HIDDEN_SIZE - 1] = 11
input_single_node = _gather(input_distributed).detach()
else:
input_distributed = input_single_node.clone()
# Apply models
output_single_node, output_distributed = _apply_models(
model_single_node, model_distributed, input_single_node, input_distributed
)
if "return_layernorm_output" in kwargs:
output_single_node, norm_s = output_single_node
output_distributed, norm_d = output_distributed
if sequence_parallel:
norm_d = _gather(norm_d)
_check_outputs(norm_s, norm_d)
if "return_bias" in kwargs:
output_single_node, bias_s = output_single_node
output_distributed, bias_d = output_distributed
if parallel_mode == "column":
bias_d = _gather(bias_d)
_check_outputs(bias_s, bias_d)
# Gather outputs if necessary
if parallel_mode == "column" or (sequence_parallel and parallel_mode == "row"):
output_distributed = _gather(output_distributed, dim=1 if parallel_mode == "column" else 0)
# Compute loss and backpropagate
_loss_backward(output_single_node, output_distributed)
# Validate outputs and gradients
_check_outputs(output_single_node, output_distributed)
# gradients in other cases need additional synchronization
if parallel_mode == "column" and not sequence_parallel and "return_bias" not in kwargs:
_check_gradients(
model_distributed,
model_single_node,
main_grad_check=("fuse_wgrad_accumulation" in kwargs),
)
def test_layernorm_linear():
kwargs_list = [
{},
{"bias": False},
{"init_method": _constant},
{"fuse_wgrad_accumulation": True},
{"return_bias": True},
{"params_dtype": torch.float16},
{"zero_centered_gamma": False},
{"return_layernorm_output": True},
]
for kwargs in kwargs_list:
for parallel_mode in ["column"]:
for sequence_parallel in [False, True]:
_test_layernorm_linear(parallel_mode, sequence_parallel, **kwargs)
############################################
# LayerNormMLP #
############################################
@run_distributed_test()
def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwargs):
"""Test the LayerNormMLP with specified parallel mode and sequence parallelization.
Args:
set_parallel_mode (bool): Enable parallel mode.
sequence_parallel (bool): Enable sequence parallelism if True.
kwargs (dict): Additional arguments for the linear layer.
"""
# Set parameter data type
params_dtype = kwargs.get("params_dtype", torch.float32)
FFN_HIDDEN_SIZE = 32 if QUANTIZATION is None else 128
# Create models
model_single_node = te.LayerNormMLP(HIDDEN_SIZE, FFN_HIDDEN_SIZE, **kwargs)
model_distributed = te.LayerNormMLP(
HIDDEN_SIZE,
FFN_HIDDEN_SIZE,
tp_size=WORLD_SIZE,
tp_group=NCCL_WORLD,
set_parallel_mode=set_parallel_mode,
sequence_parallel=sequence_parallel,
**kwargs,
)
# Synchronize parameters between models
_copy_params(model_distributed, model_single_node)
# Prepare input tensors
input_single_node = torch.randn((SEQ_LEN, HIDDEN_SIZE)).cuda().to(params_dtype)
if sequence_parallel:
# Duplicate input for sequence parallelism
input_single_node = torch.empty((WORLD_SIZE * SEQ_LEN, HIDDEN_SIZE)).cuda().to(params_dtype)
input_distributed = torch.randn((SEQ_LEN, HIDDEN_SIZE)).cuda().to(params_dtype)
# make the last element of the input a large value to test the amax reduction on purpose
# when quantization is fp8_cs, we need to trigger corner cases to see if amax reduction is working
if QUANTIZATION == "fp8_cs":
input_distributed = torch.clamp(input_distributed, min=-10, max=10)
if WORLD_RANK == WORLD_SIZE - 1:
input_distributed[SEQ_LEN - 1, HIDDEN_SIZE - 1] = 11
input_single_node = _gather(input_distributed).detach()
else:
input_distributed = input_single_node.clone()
# Apply models
output_single_node, output_distributed = _apply_models(
model_single_node, model_distributed, input_single_node, input_distributed
)
if "return_layernorm_output" in kwargs:
output_single_node, norm_s = output_single_node
output_distributed, norm_d = output_distributed
if sequence_parallel:
norm_d = _gather(norm_d)
_check_outputs(norm_s, norm_d)
if "return_bias" in kwargs:
output_single_node, bias_s = output_single_node
output_distributed, bias_d = output_distributed
_check_outputs(bias_s, bias_d)
if sequence_parallel:
output_distributed = _gather(output_distributed)
# Compute loss and backpropagate
_loss_backward(output_single_node, output_distributed)
# Validate outputs and gradients
_check_outputs(output_single_node, output_distributed)
# gradients in other cases need additional synchronization
if not sequence_parallel and "return_bias" not in kwargs:
_check_gradients(
model_distributed,
model_single_node,
main_grad_check=("fuse_wgrad_accumulation" in kwargs),
)
def test_layernorm_mlp():
kwargs_list = [
{},
{"init_method": _constant},
{"output_layer_init_method": _constant},
{"normalization": "RMSNorm"},
{"zero_centered_gamma": True},
{"bias": False},
{"params_dtype": torch.float16},
{"activation": "relu"},
{"fuse_wgrad_accumulation": True},
{"return_bias": True},
{"return_layernorm_output": True},
]
for kwargs in kwargs_list:
for set_parallel_mode in [True]:
for sequence_parallel in [False, True]:
_test_layernorm_mlp(set_parallel_mode, sequence_parallel, **kwargs)
############################################
# TransformerLayer #
############################################
@run_distributed_test()
def _test_transformer_layer_parallel(sequence_parallel=False, **kwargs):
params_dtype = kwargs.get("params_dtype", torch.float32)
FFN_HIDDEN_SIZE = 32 if QUANTIZATION is None else 128
model_single_node = te.TransformerLayer(
HIDDEN_SIZE, FFN_HIDDEN_SIZE, NR_HEADS, attention_dropout=0, hidden_dropout=0, **kwargs
)
model_distributed = te.TransformerLayer(
HIDDEN_SIZE,
FFN_HIDDEN_SIZE,
NR_HEADS,
tp_size=WORLD_SIZE,
tp_group=NCCL_WORLD,
set_parallel_mode=True,
sequence_parallel=sequence_parallel,
seq_length=WORLD_SIZE * SEQ_LEN if sequence_parallel else None,
attention_dropout=0,
hidden_dropout=0,
**kwargs,
)
_copy_params(model_distributed, model_single_node)
_alloc_main_grad(model_single_node, model_distributed) # for fuse_wgrad_accumulation=True
input_single_node = (
torch.randn((WORLD_SIZE * SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE)).cuda().to(params_dtype)
)
if sequence_parallel:
input_distributed = input_single_node[
WORLD_RANK * SEQ_LEN : (WORLD_RANK + 1) * SEQ_LEN, :, :
]
else:
input_distributed = input_single_node.clone().cuda()
encoder_output = None
if "layer_type" in kwargs:
encoder_output = torch.randn((SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE)).cuda()
output_single_node, output_distributed = _apply_models(
model_single_node,
model_distributed,
input_single_node,
input_distributed,
encoder_output=encoder_output,
)
if sequence_parallel:
output_distributed = _gather(output_distributed)
_loss_backward(output_single_node, output_distributed)
_check_outputs(output_single_node, output_distributed)
# gradients in other cases need additional synchronization
if not sequence_parallel and "return_bias" not in kwargs:
_check_gradients(
model_distributed,
model_single_node,
main_grad_check=("fuse_wgrad_accumulation" in kwargs),
)
def test_transformer_layer():
kwargs_list = [
{},
{"num_gqa_groups": 4},
{"init_method": _constant},
{"output_layer_init_method": _constant},
{"apply_residual_connection_post_layernorm": True},
{"output_layernorm": True},
{"parallel_attention_mlp": True},
# {"layer_type": "decoder"},
{"window_size": (2, 2)},
{"normalization": "RMSNorm"},
{"zero_centered_gamma": True},
{"fuse_qkv_params": True},
{"fuse_qkv_params": True, "fuse_wgrad_accumulation": True},
{"qkv_weight_interleaved": False},
{"bias": False},
{"params_dtype": torch.float16},
{"fuse_qkv_params": True},
{"activation": "relu"},
]
for kwargs in kwargs_list:
for sequence_parallel in [False, True]:
_test_transformer_layer_parallel(sequence_parallel, **kwargs)
if __name__ == "__main__":
sys.exit(main())
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import subprocess
from pathlib import Path
import pytest
import torch
import transformer_engine.pytorch as te
import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
if torch.cuda.device_count() < 2:
pytest.skip("Comm+GEMM overlap requires at least 2 GPUs.")
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
RNG_SEED: int = 42
SEQ_LENGTH: int = 1024
BATCH_SIZE: int = 2
NUM_HEADS: int = 16
HEAD_DIM: int = 48
TE_LAYERS = [
te.Linear,
te.LayerNormLinear,
te.LayerNormMLP,
te.MultiheadAttention,
te.TransformerLayer,
]
MAX_LAYER_NAME_LENGTH = max([len(layer.__name__) for layer in TE_LAYERS])
# to avoid numerical tolerance issues of doing comm gemm overlap, limit the number of GPUs used
MAX_GPUS_TO_USE = 4
TEST_ROOT = Path(__file__).parent.resolve()
NUM_PROCS: int = min(torch.cuda.device_count(), MAX_GPUS_TO_USE)
LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"]
if tex.ubuf_built_with_mpi():
LAUNCH_CMD = ["mpirun", "-np", str(NUM_PROCS), "--oversubscribe", "--quiet", "python3"]
# Fall back on CUDA IPC if the platform does not support CUDA multicast
if not tex.device_supports_multicast():
os.environ["UB_SKIPMC"] = "1"
# Force GPU kernels to launch in the order they're executed by the host CPU
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
# Clear torch.dynamo caches
torch._dynamo.reset()
def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8):
test_path = TEST_ROOT / "run_gemm_with_overlap.py"
test_cmd = LAUNCH_CMD + [
str(test_path),
"--check-numerics",
f"--seed={RNG_SEED}",
f"--seq-length={SEQ_LENGTH}",
f"--batch-size={BATCH_SIZE}",
f"--num-heads={NUM_HEADS}",
f"--head-dim={HEAD_DIM}",
f"--comm-type={comm_type}",
]
if bulk:
test_cmd.append("--bulk-overlap")
else:
if fp8:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
test_cmd.append("--fp8")
if p2p:
test_cmd.append("--p2p")
if atomic:
if torch.cuda.get_device_properties(0).major != 9:
pytest.skip("Atomic GEMM is requires device compute capability 9.x (Hopper).")
test_cmd.append("--atomic")
result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False)
if (
result.returncode != 0
or "NUMERICAL CHECK FAILED" in result.stderr.decode()
or "NUMERICAL CHECK PASSED" not in result.stdout.decode()
):
raise AssertionError(result.stderr.decode())
def _run_layer_with_overlap(
layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization, num_layers=1
):
test_path = TEST_ROOT / "run_layer_with_overlap.py"
test_cmd = LAUNCH_CMD + [
str(test_path),
f"--seed={RNG_SEED}",
f"--seq-length={SEQ_LENGTH}",
f"--batch-size={BATCH_SIZE}",
f"--num-heads={NUM_HEADS}",
f"--head-dim={HEAD_DIM}",
f"--layer-type={layer_type}",
f"--num-layers={num_layers}",
]
if layer_type in [te.Linear.__name__, te.LayerNormLinear.__name__]:
test_cmd.append(f"--linear-parallel-mode={linear_parallel_mode}")
if overlap_rs_dgrad:
test_cmd.append("--overlap-rs-dgrad")
if fp8:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
test_cmd.append("--fp8")
test_cmd.append(f"--quantization={quantization}")
os.environ["PYTORCH_JIT"] = "0"
os.environ["NVTE_TORCH_COMPILE"] = "0"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0"
result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False)
os.unsetenv("PYTORCH_JIT")
os.unsetenv("NVTE_TORCH_COMPILE")
os.unsetenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO")
if (
result.returncode != 0
or "NUMERICAL CHECK FAILED" in result.stderr.decode()
or "NUMERICAL CHECK PASSED" not in result.stdout.decode()
):
raise AssertionError(result.stderr.decode())
@pytest.mark.parametrize(
"fp8",
(False, True),
ids=[" BF16 - RING-EXCHANGE ", " FP8 - RING-EXCHANGE "],
)
def test_split_all_gather_overlaps(fp8):
"""
Test (split GEMM -> all-gather) overlaps with direct calls to te.cpp_extensions.gemm or
te.cpp_extensions.fp8_gemm.
"""
_run_gemm_with_overlap("AG", False, True, False, fp8)
@pytest.mark.parametrize(
"fp8,p2p",
[
(False, False),
(False, True),
(True, False),
(True, True),
],
ids=[
" BF16 - PIPELINE ",
" BF16 - RING-EXCHANGE ",
" FP8 - PIPELINE ",
" FP8 - RING-EXCHANGE ",
],
)
def test_split_reduce_scatter_overlaps(fp8, p2p):
"""
Test (reduce-scatter -> split GEMM) overlaps with direct calls to te.cpp_extensions.gemm or
te.cpp_extensions.fp8_gemm.
"""
_run_gemm_with_overlap("RS", False, p2p, False, fp8)
@pytest.mark.parametrize(
"comm_type, fp8, connections",
[
("AG", False, 1),
("RS", False, 1),
("RS", True, 1),
("AG", False, 8),
("RS", False, 8),
("RS", True, 8),
],
ids=[
"ALL-GATHER - BF16 - 1 connections",
"REDUCE-SCATTER - BF16 - 1 connections",
"REDUCE-SCATTER - FP8 - 1 connections",
"ALL-GATHER - BF16 - 8 connections",
"REDUCE-SCATTER - BF16 - 8 connections",
"REDUCE-SCATTER - FP8 - 8 connections",
],
)
def test_bulk_overlaps(comm_type, fp8, connections):
"""
Test bulk overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm.
"""
if connections == 8:
if torch.cuda.get_device_properties(0).major != 9:
pytest.skip(
"CUDA_DEVICE_MAX_CONNECTIONS=8 test only applies to devices with compute capability"
" 9.0 (HOPPER ARCH)."
)
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8"
_run_gemm_with_overlap(comm_type, True, False, False, fp8)
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
else:
_run_gemm_with_overlap(comm_type, True, False, False, fp8)
@pytest.mark.parametrize(
"fp8",
(False,),
ids=[
" BF16 ",
],
)
@pytest.mark.parametrize(
"layer_type,linear_parallel_mode,overlap_rs_dgrad",
[
(te.Linear.__name__, "row", False),
(te.Linear.__name__, "column", False),
(te.Linear.__name__, "column", True),
(te.LayerNormLinear.__name__, "row", False),
(te.LayerNormLinear.__name__, "column", False),
(te.LayerNormLinear.__name__, "column", True),
]
+ list(
zip(
[layer.__name__ for layer in TE_LAYERS[2:] for _ in range(2)],
[None] * len(TE_LAYERS[2:]) * 2,
[False, True] * len(TE_LAYERS[2:]),
)
),
ids=[
f" {te.Linear.__name__} - ROW-PARALLEL ",
f" {te.Linear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ",
f" {te.Linear.__name__} - COL-PARLALEL - DGRAD+RS ",
f" {te.LayerNormLinear.__name__} - ROW-PARALLEL ",
f" {te.LayerNormLinear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ",
f" {te.LayerNormLinear.__name__} - COL-PARALLEL - DGRAD+RS ",
]
+ [
" " + " - ".join(test_name_parts) + " "
for test_name_parts in zip(
[layer.__name__ for layer in TE_LAYERS[2:] for _ in range(2)],
["BULK DGRAD/WGRAD", "DGRAD+RS"] * len(TE_LAYERS[2:]),
)
],
)
def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8):
"""
Test Transformer Engine layers with comm+GEMM overlap.
"""
_run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, None)
@pytest.mark.parametrize(
"quantization",
["fp8_delayed_scaling", "fp8_current_scaling"],
ids=[" DELAYED SCALING ", " CURRENT SCALING "],
)
@pytest.mark.parametrize(
"fp8",
(True,),
ids=[
" FP8 ",
],
)
@pytest.mark.parametrize(
"layer_type,linear_parallel_mode,overlap_rs_dgrad",
[
(te.Linear.__name__, "row", False),
(te.Linear.__name__, "column", False),
(te.Linear.__name__, "column", True),
(te.LayerNormLinear.__name__, "row", False),
(te.LayerNormLinear.__name__, "column", False),
(te.LayerNormLinear.__name__, "column", True),
]
+ list(
zip(
[layer.__name__ for layer in TE_LAYERS[2:] for _ in range(2)],
[None] * len(TE_LAYERS[2:]) * 2,
[False, True] * len(TE_LAYERS[2:]),
)
),
ids=[
f" {te.Linear.__name__} - ROW-PARALLEL ",
f" {te.Linear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ",
f" {te.Linear.__name__} - COL-PARLALEL - DGRAD+RS ",
f" {te.LayerNormLinear.__name__} - ROW-PARALLEL ",
f" {te.LayerNormLinear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ",
f" {te.LayerNormLinear.__name__} - COL-PARALLEL - DGRAD+RS ",
]
+ [
" " + " - ".join(test_name_parts) + " "
for test_name_parts in zip(
[layer.__name__ for layer in TE_LAYERS[2:] for _ in range(2)],
["BULK DGRAD/WGRAD", "DGRAD+RS"] * len(TE_LAYERS[2:]),
)
],
)
def test_layers_with_overlap_fp8(
layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization
):
"""
Test Transformer Engine layers with comm+GEMM overlap.
"""
_run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization)
@pytest.mark.parametrize(
"fp8",
(False,),
ids=[
" BF16 ",
],
)
@pytest.mark.parametrize(
"num_layers",
(2,),
ids=[
" 2 layers ",
],
)
@pytest.mark.parametrize(
"layer_type,linear_parallel_mode,overlap_rs_dgrad",
list(
zip(
[te.TransformerLayer.__name__ for _ in range(2)],
[None] * 2,
[False, True],
)
),
ids=[
" " + " - ".join(test_name_parts) + " "
for test_name_parts in zip(
[te.TransformerLayer.__name__ for _ in range(2)],
["BULK DGRAD/WGRAD", "DGRAD+RS"],
)
],
)
def test_multi_layer_with_overlap_bf16(
layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, num_layers
):
"""
Test Transformer Engine layers with comm+GEMM overlap.
"""
_run_layer_with_overlap(
layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, None, num_layers
)
@pytest.mark.parametrize(
"quantization",
["fp8_delayed_scaling", "fp8_current_scaling"],
ids=[" DELAYED SCALING ", " CURRENT SCALING "],
)
@pytest.mark.parametrize(
"fp8",
(True,),
ids=[
" FP8 ",
],
)
@pytest.mark.parametrize(
"num_layers",
(2,),
ids=[
" 2 layers ",
],
)
@pytest.mark.parametrize(
"layer_type,linear_parallel_mode,overlap_rs_dgrad",
list(
zip(
[te.TransformerLayer.__name__ for _ in range(2)],
[None] * 2,
[False, True],
)
),
ids=[
" " + " - ".join(test_name_parts) + " "
for test_name_parts in zip(
[te.TransformerLayer.__name__ for _ in range(2)],
["BULK DGRAD/WGRAD", "DGRAD+RS"],
)
],
)
def test_multi_layer_with_overlap_fp8(
layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization, num_layers
):
"""
Test Transformer Engine layers with comm+GEMM overlap.
"""
_run_layer_with_overlap(
layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization, num_layers
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from __future__ import annotations
import argparse
from collections.abc import Iterable
import functools
import itertools
import os
import pathlib
import subprocess
import sys
from typing import Optional
import pytest
import torch
import transformer_engine
import transformer_engine.common.recipe
import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops._common import is_float8_tensor
from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine_torch as tex
# Check what quantization schemes are supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
quantization_list: list[Optional[str]] = [None]
if fp8_available:
quantization_list.append("fp8")
if mxfp8_available:
quantization_list.append("mxfp8")
@functools.cache
def world_group() -> torch.distributed.ProcessGroup:
"""Get NCCL process group, initializing if needed"""
world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(rank)
group = torch.distributed.init_process_group(
"nccl",
init_method="file:///tmp/rdzv",
world_size=world_size,
rank=rank,
)
return group
def reset_rng(seed: int = 1234) -> None:
"""Reset random number generators"""
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@torch.no_grad()
def make_reference_and_test_tensors(
shape: int | Iterable[int],
ref_dtype: torch.dtype = torch.float64,
ref_device: torch.device = "cpu",
test_dtype: torch.dtype = torch.float32,
test_device: torch.device = "cuda",
test_is_fp8: bool = False,
requires_grad: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Construct tensors with the same values
The reference tensor is intended for use in plain PyTorch
operations in high precision. The test tensor is intended for use
in Transformer Engine operations.
"""
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
test = ref.to(device=test_device, dtype=test_dtype)
if test_is_fp8:
quantizer = Float8Quantizer(
scale=torch.ones(1, dtype=torch.float32, device=test_device),
amax=torch.zeros(1, dtype=torch.float32, device=test_device),
fp8_dtype=tex.DType.kFloat8E4M3,
)
test = quantizer(test)
elif test.data_ptr() == ref.data_ptr():
test = test.clone()
ref.copy_(test)
ref.requires_grad_(requires_grad)
test.requires_grad_(requires_grad)
return ref, test
def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]:
"""Estimated numerical error for a datatype
Based on tolerances for torch.testing.assert_close.
"""
# Transformer Engine dtypes
if isinstance(dtype, tex.DType):
if dtype == tex.DType.kFloat8E4M3:
return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625
if dtype == tex.DType.kFloat8E5M2:
return dict(rtol=0.25, atol=0.125) # epsilon = 0.152
dtype = {
tex.DType.kByte: torch.uint8,
tex.DType.kInt32: torch.int32,
tex.DType.kFloat32: torch.float32,
tex.DType.kFloat16: torch.half,
tex.DType.kBFloat16: torch.bfloat16,
}[dtype]
# PyTorch dtypes
if dtype == torch.float16:
return dict(rtol=1e-3, atol=1e-5)
if dtype == torch.bfloat16:
return dict(rtol=1.6e-2, atol=1e-5)
if dtype == torch.float32:
return dict(rtol=1.3e-6, atol=1e-5)
if dtype == torch.float64:
return dict(rtol=1e-7, atol=1e-7)
raise ValueError(f"Unsupported dtype ({dtype})")
def make_recipe(name: Optional[str] = None) -> Optional[Recipe]:
"""Make recipe for quantization scheme"""
if name is None:
return None
if name == "fp8":
return transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "mxfp8":
return transformer_engine.common.recipe.MXFP8BlockScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
raise ValueError(f"Unsupported quantization scheme ({name})")
def _test_all_reduce(
*,
local_size: int = 17,
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
fp8: bool = False,
) -> None:
# Distributed process group
process_group = world_group()
rank = torch.distributed.get_rank(process_group)
world_size = torch.distributed.get_world_size(process_group)
# Tensor dimensions
in_shape = [world_size, local_size]
out_shape = [local_size]
# Random data
reset_rng()
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
)
# Plain PyTorch implementation
y_ref = x_ref.sum(0)
y_ref.backward(dy_ref)
# Convert to distributed tensors
with torch.no_grad():
dx_ref = x_ref.grad[rank]
x_ref = x_ref[rank]
x_test = x_test[rank].clone()
x_test.requires_grad_()
# Implementation with fusible operation
op = te_ops.AllReduce(process_group=process_group)
y_test = op(x_test)
y_test.backward(dy_test)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **dtype_tols(dtype))
torch.testing.assert_close(dx_test, dx_ref, rtol=0, atol=0)
def _test_all_gather(
*,
local_size: int = 13,
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
fp8: bool = False,
) -> None:
# Distributed process group
process_group = world_group()
rank = torch.distributed.get_rank(process_group)
world_size = torch.distributed.get_world_size(process_group)
# Tensor dimensions
in_shape = [world_size, local_size]
out_shape = [world_size, world_size * local_size]
# Random data
reset_rng()
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
)
# Plain PyTorch implementation
y_ref = x_ref.tile((world_size, 1)).reshape(out_shape)
y_ref.backward(dy_ref)
# Convert to distributed tensors
with torch.no_grad():
dx_ref = x_ref.grad[rank]
x_ref = x_ref[rank]
x_test = x_test[rank].clone()
y_ref = y_ref[rank]
dy_ref = dy_ref[rank]
dy_test = dy_test[rank].clone()
x_test.requires_grad_()
# Implementation with fusible operation
op = te_ops.AllGather(process_group=process_group)
y_test = op(x_test)
y_test.backward(dy_test)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, rtol=0, atol=0)
torch.testing.assert_close(dx_test, dx_ref, **dtype_tols(dtype))
def _test_reduce_scatter(
*,
local_size: int = 11,
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
fp8: bool = False,
) -> None:
# Distributed process group
process_group = world_group()
rank = torch.distributed.get_rank(process_group)
world_size = torch.distributed.get_world_size(process_group)
# Tensor dimensions
in_shape = [world_size, world_size * local_size]
out_shape = [world_size, local_size]
# Random data
reset_rng()
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8,
)
# Plain PyTorch implementation
y_ref = x_ref.sum(0).reshape(out_shape)
y_ref.backward(dy_ref)
# Convert to distributed tensors
with torch.no_grad():
dx_ref = x_ref.grad[rank]
x_ref = x_ref[rank]
x_test = x_test[rank].clone()
y_ref = y_ref[rank]
dy_ref = dy_ref[rank]
dy_test = dy_test[rank].clone()
x_test.requires_grad_()
# Implementation with fusible operation
op = te_ops.ReduceScatter(process_group=process_group)
y_test = op(x_test)
y_test.backward(dy_test)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **dtype_tols(dtype))
torch.testing.assert_close(dx_test, dx_ref, rtol=0, atol=0)
def _test_basic_linear(
*,
local_weight_shape: tuple[int, int] = (32, 32),
local_batch_size: int = 32,
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
quantization: Optional[str] = None,
quantized_weight: bool = False,
tensor_parallel_mode: str = "column",
sequence_parallel: bool = False,
) -> None:
quantized_compute = quantization is not None
# Distributed process group
process_group = world_group()
rank = torch.distributed.get_rank(process_group)
world_size = torch.distributed.get_world_size(process_group)
# Tensor dimensions
local_out_features, local_in_features = local_weight_shape
out_features, in_features = local_out_features, local_in_features
batch_size = local_batch_size
if tensor_parallel_mode == "column":
out_features *= world_size
elif tensor_parallel_mode == "row":
in_features *= world_size
if sequence_parallel:
batch_size *= world_size
in_shape = [batch_size, in_features]
out_shape = [batch_size, out_features]
# Random data
reset_rng()
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
)
if isinstance(x_test, QuantizedTensor):
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
test_dtype=dtype,
test_device=device,
test_is_fp8=(quantized_compute or quantized_weight),
)
if isinstance(w_test, QuantizedTensor):
w_test = w_test.dequantize()
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
requires_grad=False,
)
if isinstance(dy_test, QuantizedTensor):
dy_test = dy_test.dequantize()
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref)
y_ref.backward(dy_ref)
# Convert to distributed tensors
with torch.no_grad():
dw_ref = w_ref.grad
dx_ref = x_ref.grad
if tensor_parallel_mode == "column":
local_out_features = out_features // world_size
local_slice = slice(
rank * local_out_features,
(rank + 1) * local_out_features,
)
w_ref = w_ref[local_slice, :]
dw_ref = dw_ref[local_slice, :]
w_test = w_test[local_slice, :]
y_ref = y_ref[..., local_slice]
dy_ref = dy_ref[..., local_slice]
dy_test = dy_test[..., local_slice].clone()
elif tensor_parallel_mode == "row":
local_in_features = in_features // world_size
local_slice = slice(
rank * local_in_features,
(rank + 1) * local_in_features,
)
w_ref = w_ref[:, local_slice]
dw_ref = dw_ref[:, local_slice]
w_test = w_test[:, local_slice]
x_ref = x_ref[..., local_slice]
dx_ref = dx_ref[..., local_slice]
x_test = x_test[..., local_slice].clone()
if sequence_parallel:
local_batch_size = batch_size // world_size
local_slice = slice(
rank * local_batch_size,
(rank + 1) * local_batch_size,
)
if tensor_parallel_mode == "column":
x_ref = x_ref[local_slice, ...]
dx_ref = dx_ref[local_slice, ...]
x_test = x_test[local_slice, ...].clone()
elif tensor_parallel_mode == "row":
y_ref = y_ref[local_slice, ...]
dy_ref = dy_ref[local_slice, ...]
dy_test = dy_test[local_slice, ...].clone()
x_test.requires_grad_()
# Implementation with fusible operation
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
op = te_ops.BasicLinear(
in_features,
out_features,
device=device,
dtype=dtype,
tensor_parallel_mode=tensor_parallel_mode,
tensor_parallel_group=process_group,
sequence_parallel=sequence_parallel,
)
with torch.no_grad():
op.weight.copy_(w_test)
del w_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y_test = op(x_test)
y_test.backward(dy_test)
# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, dx_ref, **tols)
torch.testing.assert_close(dw_test, dw_ref, **tols)
def _test_linear(
*,
bias: bool = True,
local_weight_shape: tuple[int, int] = (32, 32),
local_batch_size: int = 32,
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
quantization: Optional[str] = None,
quantized_weight: bool = False,
tensor_parallel_mode: str = "column",
sequence_parallel: bool = False,
) -> None:
quantized_compute = quantization is not None
# Distributed process group
process_group = world_group()
rank = torch.distributed.get_rank(process_group)
world_size = torch.distributed.get_world_size(process_group)
# Tensor dimensions
local_out_features, local_in_features = local_weight_shape
out_features, in_features = local_out_features, local_in_features
batch_size = local_batch_size
if tensor_parallel_mode == "column":
out_features *= world_size
elif tensor_parallel_mode == "row":
in_features *= world_size
if sequence_parallel:
batch_size *= world_size
in_shape = [batch_size, in_features]
out_shape = [batch_size, out_features]
# Random data
reset_rng()
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
)
if isinstance(x_test, QuantizedTensor):
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
test_dtype=dtype,
test_device=device,
test_is_fp8=(quantized_compute or quantized_weight),
)
if isinstance(w_test, QuantizedTensor):
w_test = w_test.dequantize()
b_ref, b_test = None, None
if bias:
if tensor_parallel_mode == "row":
bias_shape = [world_size, out_features]
else:
bias_shape = [out_features]
b_ref, b_test = make_reference_and_test_tensors(
bias_shape,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=quantized_compute,
requires_grad=False,
)
if isinstance(dy_test, QuantizedTensor):
dy_test = dy_test.dequantize()
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref)
if bias:
if tensor_parallel_mode == "row":
y_ref += b_ref.sum(dim=0)
else:
y_ref += b_ref
y_ref.backward(dy_ref)
# Convert to distributed tensors
with torch.no_grad():
dw_ref = w_ref.grad
db_ref = b_ref.grad if bias else None
dx_ref = x_ref.grad
if tensor_parallel_mode == "column":
local_out_features = out_features // world_size
local_slice = slice(
rank * local_out_features,
(rank + 1) * local_out_features,
)
w_ref = w_ref[local_slice, :]
dw_ref = dw_ref[local_slice, :]
w_test = w_test[local_slice, :]
if bias:
b_ref = b_ref[local_slice]
db_ref = db_ref[local_slice]
b_test = b_test[local_slice]
y_ref = y_ref[..., local_slice]
dy_ref = dy_ref[..., local_slice]
dy_test = dy_test[..., local_slice].clone()
elif tensor_parallel_mode == "row":
local_in_features = in_features // world_size
local_slice = slice(
rank * local_in_features,
(rank + 1) * local_in_features,
)
w_ref = w_ref[:, local_slice]
dw_ref = dw_ref[:, local_slice]
w_test = w_test[:, local_slice]
if bias:
b_ref = b_ref[rank, :]
db_ref = db_ref[rank, :]
b_test = b_test[rank, :]
x_ref = x_ref[..., local_slice]
dx_ref = dx_ref[..., local_slice]
x_test = x_test[..., local_slice].clone()
if sequence_parallel:
local_batch_size = batch_size // world_size
local_slice = slice(
rank * local_batch_size,
(rank + 1) * local_batch_size,
)
if tensor_parallel_mode == "column":
x_ref = x_ref[local_slice, ...]
dx_ref = dx_ref[local_slice, ...]
x_test = x_test[local_slice, ...].clone()
elif tensor_parallel_mode == "row":
y_ref = y_ref[local_slice, ...]
dy_ref = dy_ref[local_slice, ...]
dy_test = dy_test[local_slice, ...].clone()
x_test.requires_grad_()
# Implementation with fusible operation
recipe = make_recipe(quantization)
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
model = te_ops.Sequential(
te_ops.Linear(
in_features,
out_features,
bias=bias,
device=device,
dtype=dtype,
tensor_parallel_mode=tensor_parallel_mode,
tensor_parallel_group=process_group,
sequence_parallel=sequence_parallel,
),
)
with torch.no_grad():
model[0].weight.copy_(w_test)
if bias:
model[0].bias.copy_(b_test)
del w_test
del b_test
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y_test = model(x_test)
y_test.backward(dy_test)
# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, dx_ref, **tols)
torch.testing.assert_close(dw_test, dw_ref, **tols)
if bias:
db_test = model[0].bias.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(db_test, db_ref, **tols)
def _test_fp8_scale_update(
*,
amax_history_len: int = 31,
amax_compute_algo: str = "max",
margin: float = 2,
local_weight_shape: tuple[int, int] = (32, 32),
batch_size: int = 32,
dtype: torch.dtype = torch.float32,
device: torch.device = "cuda",
tensor_parallel_mode: str = "column",
) -> None:
# Distributed process group
process_group = world_group()
rank = torch.distributed.get_rank(process_group)
world_size = torch.distributed.get_world_size(process_group)
# Tensor dimensions
local_out_features, local_in_features = local_weight_shape
out_features, in_features = local_out_features, local_in_features
if tensor_parallel_mode == "column":
out_features *= world_size
elif tensor_parallel_mode == "row":
in_features *= world_size
in_shape = [batch_size, in_features]
out_shape = [batch_size, out_features]
# Random data
reset_rng()
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
)
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
def ref_amax_and_scale(
ref: torch.Tensor,
stage: str,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Expected absmax and FP8 scale"""
amax = ref.abs().amax()
max_val = {
"forward": 448.0,
"backward": 57344.0,
}[stage]
scale = (max_val / amax) / (2**margin)
amax = amax.to(dtype=torch.float32, device="cpu")
scale = scale.to(dtype=torch.float32, device="cpu")
return amax, scale
# Compute expected amaxes and FP8 scales
x_amax_ref, x_scale_ref = ref_amax_and_scale(x_ref, "forward")
w_amax_ref, w_scale_ref = ref_amax_and_scale(w_ref, "forward")
dy_amax_ref, dy_scale_ref = ref_amax_and_scale(dy_ref, "backward")
# Convert to distributed tensors
with torch.no_grad():
if tensor_parallel_mode == "column":
local_out_features = out_features // world_size
local_slice = slice(
rank * local_out_features,
(rank + 1) * local_out_features,
)
w_ref = w_ref[local_slice, :]
w_test = w_test[local_slice, :]
dy_ref = dy_ref[..., local_slice]
dy_test = dy_test[..., local_slice].clone()
elif tensor_parallel_mode == "row":
local_in_features = in_features // world_size
local_slice = slice(
rank * local_in_features,
(rank + 1) * local_in_features,
)
w_ref = w_ref[:, local_slice]
w_test = w_test[:, local_slice]
x_ref = x_ref[..., local_slice]
x_test = x_test[..., local_slice].clone()
x_test.requires_grad_()
# Initialize fusible operation
op = te_ops.BasicLinear(
in_features,
out_features,
device=device,
dtype=dtype,
tensor_parallel_mode=tensor_parallel_mode,
tensor_parallel_group=process_group,
)
with torch.no_grad():
op.weight.copy_(w_test)
del w_test
# Forward and backward pass
fp8_format = transformer_engine.common.recipe.Format.HYBRID
recipe = transformer_engine.common.recipe.DelayedScaling(
margin=margin,
interval=1,
fp8_format=fp8_format,
amax_history_len=amax_history_len,
amax_compute_algo=amax_compute_algo,
)
with te.fp8_autocast(fp8_recipe=recipe):
y_test = op(x_test)
y_test.backward(dy_test)
# Check results
x_quantizer = op.get_quantizer("forward", 0)
w_quantizer = op.get_quantizer("forward", 1)
dy_quantizer = op.get_quantizer("backward", 0)
x_scale_test = x_quantizer.scale.to(dtype=torch.float32, device="cpu").reshape([])
w_scale_test = w_quantizer.scale.to(dtype=torch.float32, device="cpu").reshape([])
dy_scale_test = dy_quantizer.scale.to(dtype=torch.float32, device="cpu").reshape([])
torch.testing.assert_close(x_scale_test, x_scale_ref)
torch.testing.assert_close(w_scale_test, w_scale_ref)
torch.testing.assert_close(dy_scale_test, dy_scale_ref)
def run_parallel_tests() -> None:
"""Run parallel tests"""
# Distributed process group
process_group = world_group()
rank = torch.distributed.get_rank(process_group)
world_size = torch.distributed.get_world_size(process_group)
# Collective communication ops
if rank == 0:
print(f"Running _test_all_reduce")
_test_all_reduce()
if rank == 0:
print(f"Running _test_all_gather")
_test_all_gather()
if rank == 0:
print(f"Running _test_reduce_scatter")
_test_reduce_scatter()
# Basic linear op
for config in itertools.product(
quantization_list,
("column", "row"),
(False, True),
):
if rank == 0:
print(f"Running _test_basic_linear with {config=}")
quantization, tensor_parallel_mode, sequence_parallel = config
_test_basic_linear(
quantization=quantization,
tensor_parallel_mode=tensor_parallel_mode,
sequence_parallel=sequence_parallel,
)
# Linear op
for config in itertools.product(
quantization_list,
("column", "row"),
):
if rank == 0:
print(f"Running _test_linear with {config=}")
quantization, tensor_parallel_mode = config
dtype = torch.bfloat16 if is_bf16_compatible() else torch.float32
_test_linear(
bias=True, # bias=False is tested in _test_basic_linear
dtype=dtype,
quantization=quantization,
tensor_parallel_mode=tensor_parallel_mode,
)
# FP8 scale update
if fp8_available:
if rank == 0:
print(f"Running _test_fp8_scale_update")
_test_fp8_scale_update()
# Parallel job sizes
_world_sizes = [torch.cuda.device_count()]
if 1 not in _world_sizes:
_world_sizes.append(1)
if torch.cuda.device_count() >= 2 and 2 not in _world_sizes:
_world_sizes.append(2)
@pytest.mark.parametrize("world_size", _world_sizes)
def test_distributed_fuser_ops(world_size: int) -> None:
"""Launch parallel job that runs parallel tests"""
python_exe = pathlib.Path(sys.executable).resolve()
current_file = pathlib.Path(__file__).resolve()
command = [
python_exe,
"-m",
"torch.distributed.run",
f"--nproc_per_node={world_size}",
current_file,
"--parallel",
]
result = subprocess.run(
command,
check=True,
)
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--parallel", action="store_true", help="Run parallel tests")
args = parser.parse_args()
if args.parallel:
run_parallel_tests()
if __name__ == "__main__":
main()
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from __future__ import annotations
import argparse
import dataclasses
import functools
import itertools
import os
import pathlib
import subprocess
import sys
import pytest
import torch
import transformer_engine
import transformer_engine.pytorch as te
import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops._common import is_float8_tensor
from transformer_engine.pytorch.ops.fused import (
UserbuffersBackwardLinear,
UserbuffersForwardLinear,
)
from transformer_engine.pytorch.utils import is_bf16_compatible
# Import utility functions
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import dtype_tols, str_to_dtype
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
# Check if there are multiple GPUs
if torch.cuda.device_count() < 2:
pytest.skip("Userbuffers requires at least 2 GPUs.")
@dataclasses.dataclass
class ModelConfig:
"""Tensor dimensions in Transformer model"""
sequence_length: int
batch_size: int
num_heads: int
head_dim: int
dtype: torch.dtype
fp8: bool
@property
def hidden_size(self):
return self.num_heads * self.head_dim
@functools.cache
def launcher() -> str:
"""Launcher for current parallel job"""
if "OMPI_COMM_WORLD_SIZE" in os.environ:
return "ompi"
if "TORCHELASTIC_RUN_ID" in os.environ:
return "torchrun"
raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`")
@functools.cache
def world_group() -> torch.distributed.ProcessGroup:
"""Get NCCL process group, initializing if needed"""
# Get launch config from environment
if launcher() == "ompi":
# OpenMPI
world_size = int(os.getenv("OMPI_COMM_WORLD_SIZE"))
rank = int(os.getenv("OMPI_COMM_WORLD_RANK"))
local_size = int(os.getenv("OMPI_COMM_WORLD_LOCAL_SIZE"))
local_rank = int(os.getenv("OMPI_COMM_WORLD_LOCAL_RANK"))
elif launcher() == "torchrun":
# torchrun
world_size = int(os.getenv("WORLD_SIZE"))
rank = int(os.getenv("RANK"))
local_size = int(os.getenv("LOCAL_WORLD_SIZE"))
local_rank = int(os.getenv("LOCAL_RANK"))
else:
raise RuntimeError("Unexpected launcher ({launcher()})")
# Construct communicator
assert local_size == world_size
torch.cuda.set_device(local_rank)
group = torch.distributed.init_process_group(
"nccl",
init_method="file:///tmp/rdzv",
world_size=world_size,
rank=rank,
device_id=torch.device(f"cuda:{local_rank}"),
)
return group
def reset_rng(seed: int = 1234) -> None:
"""Reset random number generators"""
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@torch.no_grad()
def make_reference_and_test_tensors(
shape: int | Iterable[int],
ref_dtype: torch.dtype = torch.float64,
ref_device: torch.device = "cpu",
test_dtype: torch.dtype = torch.float32,
test_device: torch.device = "cuda",
test_is_fp8: bool = False,
requires_grad: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Construct tensors with the same values
The reference tensor is intended for use in plain PyTorch
operations in high precision. The test tensor is intended for use
in Transformer Engine operations.
"""
# Random data
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
# Make copy of tensor
if test_is_fp8:
test = Float8Tensor.to_float8(ref)
else:
test = ref.to(device=test_device, dtype=test_dtype)
if test.data_ptr() == ref.data_ptr():
test = test.clone()
# Make sure reference and test tensors represent exact same values
ref.copy_(test)
# Return reference and test tensors
ref.requires_grad_(requires_grad)
test.requires_grad_(requires_grad)
return ref, test
def _test_linear(
*,
model_config: ModelConfig,
bias: bool = False,
device: torch.device = "cuda",
tensor_parallel_mode: str = "column",
sequence_parallel: bool = True,
weight_requires_grad: bool = True,
) -> None:
dtype = model_config.dtype
fp8_compute = model_config.fp8
# Distributed process group
process_group = world_group()
rank = torch.distributed.get_rank(process_group)
world_size = torch.distributed.get_world_size(process_group)
# Tensor dimensions
out_features = model_config.hidden_size
in_features = model_config.hidden_size
batch_size = model_config.sequence_length * model_config.batch_size
in_shape = [batch_size, in_features]
out_shape = [batch_size, out_features]
# Random data
reset_rng()
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8_compute,
)
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8_compute,
)
b_ref, b_test = None, None
if bias:
if tensor_parallel_mode == "row":
bias_shape = [world_size, out_features]
else:
bias_shape = [out_features]
b_ref, b_test = make_reference_and_test_tensors(
bias_shape,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8_compute,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = torch.nn.functional.linear(x_ref, w_ref)
if bias:
if tensor_parallel_mode == "row":
y_ref += b_ref.sum(dim=0)
else:
y_ref += b_ref
y_ref.backward(dy_ref)
# Convert to distributed tensors
with torch.no_grad():
dw_ref = w_ref.grad
db_ref = b_ref.grad if bias else None
dx_ref = x_ref.grad
if tensor_parallel_mode == "column":
local_out_features = out_features // world_size
local_slice = slice(
rank * local_out_features,
(rank + 1) * local_out_features,
)
w_ref = w_ref[local_slice, :]
dw_ref = dw_ref[local_slice, :]
w_test = w_test[local_slice, :]
if bias:
b_ref = b_ref[local_slice]
db_ref = db_ref[local_slice]
b_test = b_test[local_slice]
y_ref = y_ref[..., local_slice]
dy_ref = dy_ref[..., local_slice]
dy_test = dy_test[..., local_slice].clone()
elif tensor_parallel_mode == "row":
local_in_features = in_features // world_size
local_slice = slice(
rank * local_in_features,
(rank + 1) * local_in_features,
)
w_ref = w_ref[:, local_slice]
dw_ref = dw_ref[:, local_slice]
w_test = w_test[:, local_slice]
if bias:
b_ref = b_ref[rank, :]
db_ref = db_ref[rank, :]
b_test = b_test[rank, :]
x_ref = x_ref[..., local_slice]
dx_ref = dx_ref[..., local_slice]
x_test = x_test[..., local_slice].clone()
if sequence_parallel:
local_batch_size = batch_size // world_size
local_slice = slice(
rank * local_batch_size,
(rank + 1) * local_batch_size,
)
if tensor_parallel_mode == "column":
x_ref = x_ref[local_slice, ...]
dx_ref = dx_ref[local_slice, ...]
x_test = x_test[local_slice, ...].clone()
elif tensor_parallel_mode == "row":
y_ref = y_ref[local_slice, ...]
dy_ref = dy_ref[local_slice, ...]
dy_test = dy_test[local_slice, ...].clone()
x_test.requires_grad_()
# Implementation with fusible operation
with te.fp8_model_init(enabled=fp8_compute):
ops = []
linear_op = None
bias_op = None
if tensor_parallel_mode == "column":
userbuffers_options = {}
if not weight_requires_grad:
if fp8_compute:
userbuffers_options["comm_name"] = "fc1"
else:
# There is a correctness bug with overlapping
# dgrad reduce-scatter with dgrad GEMM. Fall back
# to overlapping dgrad reduce-scatter with wgrad
# GEMM, even though wgrad isn't needed.
userbuffers_options["comm_name"] = "qkv"
else:
userbuffers_options["comm_name"] = "qkv"
linear_op = te_ops.BasicLinear(
in_features,
out_features,
device=device,
dtype=dtype,
tensor_parallel_mode=tensor_parallel_mode,
tensor_parallel_group=process_group,
sequence_parallel=sequence_parallel,
userbuffers_options=userbuffers_options,
)
ops.append(linear_op)
if bias:
bias_op = te_ops.Bias(
out_features // world_size,
device=device,
dtype=dtype,
)
ops.append(bias_op)
elif tensor_parallel_mode == "row":
userbuffers_options = dict(comm_name="proj")
linear_op = te_ops.BasicLinear(
in_features // world_size,
out_features,
device=device,
dtype=dtype,
userbuffers_options=userbuffers_options,
)
ops.append(linear_op)
if bias:
bias_op = te_ops.Bias(out_features, device=device, dtype=dtype)
ops.append(bias_op)
ops.append(te_ops.ReduceScatter(process_group))
model = te_ops.Sequential(*ops)
with torch.no_grad():
linear_op.weight.copy_(w_test)
linear_op.weight.requires_grad_(requires_grad=weight_requires_grad)
if bias:
bias_op.bias.copy_(b_test)
del w_test
del b_test
with te.fp8_autocast(enabled=fp8_compute):
y_test = model(x_test)
y_test.backward(dy_test)
# Check that forward operations have been fused
forward_ops = model._module_groups[0]._forward_ops
backward_ops = model._module_groups[0]._backward_ops
assert len(forward_ops) == 1
assert len(backward_ops) == 1
assert isinstance(forward_ops[0][0], UserbuffersForwardLinear)
assert isinstance(backward_ops[0][0], UserbuffersBackwardLinear)
# Expected numerical error
tols = dtype_tols(dtype)
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if fp8_compute:
tols = dtype_tols(
model[0].weight._fp8_dtype
if is_float8_tensor(model[0].weight)
else tex.DType.kFloat8E4M3
)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, dx_ref, **tols)
if weight_requires_grad:
dw_test = linear_op.weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(dw_test, dw_ref, **tols)
if bias:
db_test = bias_op.bias.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(db_test, db_ref, **tols)
def run_parallel_tests(model_config: ModelConfig) -> None:
"""Run parallel tests"""
# Distributed process group
process_group = world_group()
rank = torch.distributed.get_rank(process_group)
world_size = torch.distributed.get_world_size(process_group)
# Linear op
for test_config in itertools.product(
(False, True), # bias
("column", "row"), # tensor_parallel_mode
(False, True), # weight_requires_grad
):
if rank == 0:
print(f"Running _test_linear with {test_config=}")
bias, tensor_parallel_mode, weight_requires_grad = test_config
_test_linear(
model_config=model_config,
bias=bias,
tensor_parallel_mode=tensor_parallel_mode,
weight_requires_grad=weight_requires_grad,
)
# Parallel job sizes
_world_sizes = []
if torch.cuda.device_count() > 1:
_world_sizes.append(torch.cuda.device_count())
@pytest.mark.parametrize("world_size", _world_sizes)
@pytest.mark.parametrize("fp8", (False, True))
def test_fuser_ops_with_userbuffers(
*,
world_size: int,
dtype: torch.dtype = torch.bfloat16,
fp8: bool,
) -> None:
"""Launch parallel job and run tests"""
# Skip invalid configurations
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
# Parallel job launcher
command = []
if tex.ubuf_built_with_mpi():
python_exe = pathlib.Path(sys.executable).resolve()
command.extend(("mpirun", "-np", str(world_size), "--oversubscribe", "--quiet", python_exe))
else:
command.extend(("torchrun", f"--nproc_per_node={world_size}"))
# Script invocation
command.extend(
(
_current_file,
"--parallel",
"--batch-size",
str(world_size),
"--num-heads",
str(world_size),
"--dtype",
str(dtype),
)
)
if fp8:
command.append("--fp8")
# Environment
env = dict(os.environ)
if not tex.device_supports_multicast():
env["UB_SKIPMC"] = "1"
env["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
env["PYTORCH_JIT"] = "0"
env["NVTE_TORCH_COMPILE"] = "0"
env["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0"
# Launch parallel job
result = subprocess.run(command, check=True, env=env)
def main() -> None:
# Parse command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--parallel", action="store_true", help="Run parallel tests")
parser.add_argument("--sequence-length", type=int, default=32)
parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--num-heads", type=int, default=16)
parser.add_argument("--head-dim", type=int, default=32)
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--fp8", action="store_true")
args = parser.parse_args()
# Run parallel tests if needed
if args.parallel:
# Model config
model_config = ModelConfig(
sequence_length=args.sequence_length,
batch_size=args.batch_size,
num_heads=args.num_heads,
head_dim=args.head_dim,
dtype=str_to_dtype(args.dtype),
fp8=args.fp8,
)
# Initialize Userbuffers
group = world_group() # Initialize NCCL
bootstrap_backend = "mpi" if launcher() == "ompi" else "nccl"
userbuffer_configs = {
"fc1_dgrad": {"method": "pipeline"}, # Overlap dgrad RS with dgrad GEMM
}
te.module.base.initialize_ub(
[
model_config.sequence_length * model_config.batch_size,
model_config.num_heads * model_config.head_dim,
],
torch.distributed.get_world_size(group),
use_fp8=model_config.fp8,
dtype=model_config.dtype,
bootstrap_backend=bootstrap_backend,
ub_cfgs=userbuffer_configs,
)
# Run tests
run_parallel_tests(model_config)
# Clean up
te.module.base.destroy_ub()
if __name__ == "__main__":
main()
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import subprocess
from pathlib import Path
import pytest
import torch
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
"""
Distributed numerics tests
These tests test the numerical corectness of the TransformerEngine layers.
Tests are parametrized by the layer and fp8 precision.
One test consists of running multiple configurations from file run_numerics.py
Such design is due to the fact the initialization of one test is long
- 2 processes need to start and load torch and TE. Multiple configurations
are run in one test - this reduces the initialization overhead.
"""
if torch.cuda.device_count() < 2:
pytest.skip("Distributed training needs at least 2 GPUs.")
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
TEST_ROOT = Path(__file__).parent.resolve()
NUM_PROCS: int = min(4, torch.cuda.device_count())
LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"]
def _run_test(quantization):
test_path = TEST_ROOT / "run_numerics.py"
test_cmd = LAUNCH_CMD + [str(test_path)]
if quantization is not None:
test_cmd += ["--quantization", quantization]
result = subprocess.run(test_cmd, env=os.environ, check=False)
assert result.returncode == 0
all_boolean = [True, False]
@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs"])
def test_distributed(quantization):
if quantization == "fp8" and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "fp8_cs" and not fp8_available:
pytest.skip(fp8_available)
if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
_run_test(quantization)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import pytest
import subprocess
from pathlib import Path
from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import torch
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
NUM_PROCS: int = torch.cuda.device_count()
def _run_test(fp_init, sharding_dims):
test_path = Path(__file__).parent.resolve() / "run_fsdp2_model.py"
test_cmd = ["torchrun", f"--nproc_per_node={NUM_PROCS}", str(test_path)]
if fp_init:
test_cmd += ["--fp8-init"]
if len(sharding_dims) == 1:
test_cmd += ["--sharding-dims", str(sharding_dims[0])]
elif len(sharding_dims) == 2:
test_cmd += ["--sharding-dims", str(sharding_dims[0]), str(sharding_dims[1])]
else:
assert False
result = subprocess.run(test_cmd, env=os.environ, check=True)
@pytest.mark.skipif(NUM_PROCS < 4, reason="Requires 4+ GPUs")
@pytest.mark.skipif(NUM_PROCS % 2 != 0, reason="Requires even number of GPUs")
@pytest.mark.skipif(not torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+")
@pytest.mark.parametrize("sharding_dims", ([NUM_PROCS], [2, NUM_PROCS // 2]))
@pytest.mark.parametrize("fp8_init", (False, True))
def test_distributed(fp8_init, sharding_dims):
# Skip invalid configurations
if torch.cuda.device_count() < 4:
pytest.skip("FSDP2 test requires at least 4 GPUs")
if fp8_init and not fp8_available:
pytest.skip(reason_for_no_fp8)
_run_test(fp8_init, sharding_dims)
def test_dummy() -> None:
"""Dummy test
pytest returns exit code 5 if all tests are skipped.
"""
pass
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os, sys, logging
from contextlib import nullcontext
import torch
import torch.distributed as dist
from transformer_engine.pytorch.attention import DotProductAttention
from transformer_engine.pytorch.attention import get_cu_seqlens_on_cp_rank
import transformer_engine_torch as tex
from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn
from transformer_engine.pytorch.fp8 import fp8_autocast
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer
from transformer_engine.common.recipe import DelayedScaling
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}
def run_dpa_with_cp(
dtype="bf16",
model=None,
qkv_format="bshd",
kernel_backend="FlashAttention",
cp_comm_type="p2p",
fp8_mha=False,
):
"""Test DotProductAttention module with context parallelism"""
# args are passed as strings
fp8_mha = fp8_mha == "True"
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if kernel_backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
config = model_configs_flash_attn[model]
if kernel_backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
config = model_configs_fused_attn[model]
assert config.attn_mask_type in [
"causal",
"no_mask",
], f"{config.attn_mask_type} is an unsupported attention mask type!"
if qkv_format == "thd":
if "causal" in config.attn_mask_type:
config.attn_mask_type = "padding_causal"
else:
config.attn_mask_type = "padding"
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
if dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()
else:
device_count = torch.cuda.device_count()
device = rank % device_count
torch.cuda.set_device(device)
print(f"[INFO] world_size:{world_size}, rank:{rank}")
dist.init_process_group(backend="nccl", world_size=world_size, rank=rank)
# create flash attn comm group for CP
cp_comm_ranks = range(world_size)
assert rank in cp_comm_ranks
cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl")
if cp_comm_type == "a2a+p2p":
assert (
world_size % 2 == 0
), "Assuming CP size for A2A is 2, and CP size for P2P is (world_size // 2)!"
cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)]
cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)]
cp_comm_sub_groups = []
for sub_ranks in cp_comm_sub_ranks:
sub_group = dist.new_group(sub_ranks, backend="nccl")
if rank in sub_ranks:
cp_comm_sub_groups.append(sub_group)
if dtype == "fp8":
fp8_recipe = DelayedScaling(fp8_dpa=True, fp8_mha=fp8_mha)
# instantiate core attn module
core_attn = DotProductAttention(
config.num_heads,
config.head_dim_qk,
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
qkv_format=qkv_format,
attn_mask_type=config.attn_mask_type,
window_size=config.window_size,
)
core_attn = core_attn.cuda()
# create flash attn inputs
if qkv_format == "bshd":
q_input_shape = (
config.batch_size,
config.max_seqlen_q,
config.num_heads,
config.head_dim_qk,
)
kv_input_shape = (
config.batch_size,
config.max_seqlen_kv,
config.num_gqa_groups,
config.head_dim_qk,
)
attn_output_shape = (
config.batch_size,
config.max_seqlen_q,
config.num_heads * config.head_dim_qk,
)
cu_seqlens_q = None
cu_seqlens_kv = None
cu_seqlens_q_padded = None
cu_seqlens_kv_padded = None
elif qkv_format == "sbhd":
q_input_shape = (
config.max_seqlen_q,
config.batch_size,
config.num_heads,
config.head_dim_qk,
)
kv_input_shape = (
config.max_seqlen_kv,
config.batch_size,
config.num_gqa_groups,
config.head_dim_qk,
)
attn_output_shape = (
config.max_seqlen_q,
config.batch_size,
config.num_heads * config.head_dim_qk,
)
cu_seqlens_q = None
cu_seqlens_kv = None
cu_seqlens_q_padded = None
cu_seqlens_kv_padded = None
elif qkv_format == "thd":
q_input_shape = (
config.batch_size * config.max_seqlen_q,
config.num_heads,
config.head_dim_qk,
)
kv_input_shape = (
config.batch_size * config.max_seqlen_q,
config.num_gqa_groups,
config.head_dim_qk,
)
attn_output_shape = (
config.batch_size * config.max_seqlen_q,
config.num_heads * config.head_dim_qk,
)
seqlens_q = torch.randint(0, config.max_seqlen_q + 1, [config.batch_size]).to(torch.int32)
seqlens_q_padded = (seqlens_q + 2 * world_size - 1) // (world_size * 2) * (world_size * 2)
cu_seqlens_q_padded = torch.cat(
[
torch.zeros([1], dtype=torch.int32),
seqlens_q_padded.cumsum(0, dtype=torch.int32),
torch.tensor([q_input_shape[0]], dtype=torch.int32),
]
).cuda()
cu_seqlens_q = torch.clone(cu_seqlens_q_padded)
if kernel_backend == "FusedAttention":
cu_seqlens_q[1:-1] = seqlens_q.cumsum(0, dtype=torch.int32).cuda()
cu_seqlens_q[-1] = cu_seqlens_q[-2]
cu_seqlens_kv = cu_seqlens_q
cu_seqlens_kv_padded = cu_seqlens_q_padded
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
q = torch.randn(q_input_shape, dtype=dtypes[dtype]).cuda()
k = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda()
v = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda()
dout = torch.randn(attn_output_shape, dtype=dtypes[dtype]).cuda()
dout_quantizer = Float8Quantizer(
fp8_dtype=tex.DType.kFloat8E5M2,
scale=torch.tensor([1], dtype=torch.float32).cuda(),
amax=torch.tensor([0], dtype=torch.float32).cuda(),
)
# create flash attention bias
if config.attn_bias_type not in ["no_bias", "alibi"]:
attn_bias_shape = (1, 1, config.max_seqlen_q, config.max_seqlen_kv)
bias = torch.randn(*attn_bias_shape, dtype=dtypes[dtype]).cuda()
else:
bias = None
# run core_attn without CP
for x in [q, k, v]:
x.requires_grad = True
if dtype == "fp8":
fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group)
else:
fp8_context = nullcontext()
with fp8_context:
out = core_attn(
q,
k,
v,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
)
if fp8_mha:
dout_fp8 = dout_quantizer(dout)
out.backward(dout_fp8)
else:
out.backward(dout)
# run core_attn wit CP
q_, k_, v_, dout_, *rest = [
x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias])
]
bias_ = rest[0] if len(rest) else None
if qkv_format == "bshd" or qkv_format == "sbhd":
seq_dim = qkv_format.index("s")
q_, k_, v_, dout_ = [
x.view(
*x.shape[:seq_dim],
2 * world_size,
x.shape[seq_dim] // (2 * world_size),
*x.shape[(seq_dim + 1) :],
)
for x in [q_, k_, v_, dout_]
]
seq_idx = torch.tensor([rank, 2 * world_size - rank - 1], device=q_.device)
q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]]
q_, k_, v_, dout_ = [
x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) for x in [q_, k_, v_, dout_]
]
elif qkv_format == "thd":
seq_idx_q = tex.thd_get_partitioned_indices(
cu_seqlens_q_padded, q_.shape[0], world_size, rank
)
seq_idx_kv = tex.thd_get_partitioned_indices(
cu_seqlens_kv_padded, k_.shape[0], world_size, rank
)
q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]]
k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]]
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]]
if bias_ is not None:
bias_ = bias_.view(
*bias_.shape[:-2], 2 * world_size, bias_.shape[-2] // (2 * world_size), bias_.shape[-1]
)
bias_ = bias_.index_select(2, seq_idx)
bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1])
core_attn.set_context_parallel_group(
cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group,
cp_comm_ranks,
torch.cuda.Stream(),
cp_comm_type,
)
if dtype == "fp8":
core_attn.reset_fp8_meta_tensors()
fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group)
else:
fp8_context = nullcontext()
with fp8_context:
out_ = core_attn(
q_,
k_,
v_,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias_,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
)
if fp8_mha:
dout_fp8_ = dout_quantizer(dout_)
out_.backward(dout_fp8_)
else:
out_.backward(dout_)
for x in [out_, q_.grad, k_.grad, v_.grad]:
assert torch.all(~torch.isnan(x))
assert torch.all(~torch.isinf(x))
# compare results with and without CP
if qkv_format == "bshd" or qkv_format == "sbhd":
dq, dk, dv, out = [
x.view(
*x.shape[:seq_dim],
2 * world_size,
x.shape[seq_dim] // (2 * world_size),
*x.shape[(seq_dim + 1) :],
)
for x in [q.grad, k.grad, v.grad, out]
]
dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]]
dq_, dk_, dv_, out_ = [
x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim] // 2, *x.shape[(seq_dim + 1) :])
for x in [q_.grad, k_.grad, v_.grad, out_]
]
elif qkv_format == "thd":
dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [q.grad, out]]
dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [k.grad, v.grad]]
dq_, dk_, dv_, out_ = [q_.grad, k_.grad, v_.grad, out_]
cu_seqlens_q_padded = cu_seqlens_q_padded // world_size
cu_seqlens_q = get_cu_seqlens_on_cp_rank(
cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True
)
cu_pads_q = cu_seqlens_q_padded - cu_seqlens_q
num_pads_q = cu_pads_q[1:] - cu_pads_q[:-1]
for x in [dq, out, dq_, out_]:
assert torch.count_nonzero(x[cu_seqlens_q_padded[-1] :]).item() == 0
for b in range(config.batch_size):
assert (
num_pads_q[b] == 0
or torch.count_nonzero(
x[(cu_seqlens_q_padded[b + 1] - num_pads_q[b]) : cu_seqlens_q_padded[b + 1]]
).item()
== 0
)
cu_seqlens_kv_padded = cu_seqlens_kv_padded // world_size
cu_seqlens_kv = get_cu_seqlens_on_cp_rank(
cu_seqlens_kv, cu_seqlens_kv_padded, world_size, rank, True, True
)
cu_pads_kv = cu_seqlens_kv_padded - cu_seqlens_kv
num_pads_kv = cu_pads_kv[1:] - cu_pads_kv[:-1]
for x in [dk, dv, dk_, dv_]:
assert torch.count_nonzero(x[cu_seqlens_kv_padded[-1] :]).item() == 0
for b in range(config.batch_size):
assert (
num_pads_kv[b] == 0
or torch.count_nonzero(
x[
(cu_seqlens_kv_padded[b + 1] - num_pads_kv[b]) : cu_seqlens_kv_padded[
b + 1
]
]
).item()
== 0
)
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
if dtype == "bf16":
if config.num_heads == config.num_gqa_groups:
tols = dict(atol=2.5e-2, rtol=2.5e-2)
else:
tols = dict(atol=3.5e-2, rtol=3.5e-2)
elif dtype == "fp16":
tols = dict(atol=5e-3, rtol=5e-3)
elif dtype == "fp8":
tols = dict(atol=5e-1, rtol=5e-1)
rmse_tol = 0.1
else:
assert False, f"{dtype} is an unsupported dtype!"
def _rmse(a, b):
return torch.sqrt((a - b).square().mean()).item()
def _error(a, b):
if dtype != "fp8":
torch.testing.assert_close(a, b, **tols)
else:
try:
torch.testing.assert_close(a, b, **tols)
except Exception as e:
logging.debug(e)
rmse = _rmse(a, b)
rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item())
assert (
rmse < rmse_tol * rmse_range
), "RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
rmse, rmse_tol * rmse_range, rmse_tol, rmse_range
)
if qkv_format == "bshd":
for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]):
_error(a[:, 0], b[:, 0])
_error(a[:, 1], b[:, 1])
elif qkv_format == "sbhd":
for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]):
_error(a[0], b[0])
_error(a[1], b[1])
elif qkv_format == "thd":
for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]):
_error(a, b)
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
dist.destroy_process_group()
def main(**kwargs):
run_dpa_with_cp(**kwargs)
if __name__ == "__main__":
kwargs = dict(arg.split("=") for arg in sys.argv[2:])
main(**kwargs)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import functools
import logging
import math
import os
from importlib.metadata import version
from typing import Any, Dict, List, Tuple, Union, Optional
from contextlib import contextmanager
import pytest
import torch
from transformer_engine.common import recipe
from transformer_engine.pytorch import TransformerLayer, fp8_autocast, fp8_model_init
from transformer_engine.pytorch.attention import (
DotProductAttention,
MultiheadAttention,
_attention_backends,
)
from transformer_engine.pytorch.dot_product_attention.utils import (
FlashAttentionUtils,
get_attention_backend,
check_set_window_size,
AttentionParams,
)
from transformer_engine.pytorch.dot_product_attention.rope import RotaryPositionEmbedding
from transformer_engine.pytorch.constants import TE_DType
import transformer_engine.pytorch.cpp_extensions as ext
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
AttnBiasType,
AttnMaskType,
FusedAttnBackend,
QKVLayout,
fused_attn_bwd,
fused_attn_fwd,
)
from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
import transformer_engine.pytorch.fp8 as fp8
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
from transformer_engine.pytorch.utils import (
get_device_compute_capability,
init_method_normal,
scaled_init_method_normal,
is_bf16_compatible,
)
from transformer_engine.pytorch.utils import get_cudnn_version
import transformer_engine_torch as tex
from transformer_engine_torch import NVTE_Fused_Attn_Backend
from transformer_engine.pytorch.tensor.quantized_tensor import (
QuantizedTensor,
Quantizer,
prepare_for_saving,
restore_from_saved,
)
# Only run FP8 tests on H100
fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available()
# Initialize RNG state
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
def reset_rng_states() -> None:
"""Revert back to initial RNG state"""
torch.set_rng_state(_cpu_rng_state)
torch.cuda.set_rng_state(_cuda_rng_state)
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
yield
fp8.FP8GlobalStateManager.reset()
class ModelConfig:
def __init__(
self,
batch_size: int,
num_heads: int,
num_gqa_groups: int,
head_dim_qk: int,
max_seqlen_q: int,
max_seqlen_kv: int,
dropout_p: float,
attn_mask_type: str,
attn_bias_type: str,
head_dim_v: int = None,
alibi_type: str = "none",
num_layers: int = 1,
bias_shape: str = "1hss",
window_size: Tuple[int, int] = (-1, -1),
):
self.batch_size = batch_size
self.num_heads = num_heads
self.num_gqa_groups = num_gqa_groups
self.head_dim_qk = head_dim_qk
self.head_dim_v = head_dim_qk if head_dim_v is None else head_dim_v
self.hidden_size = num_heads * head_dim_qk
self.hidden_size_kv = num_gqa_groups * self.head_dim_v
self.max_seqlen_q = max_seqlen_q
self.max_seqlen_kv = max_seqlen_kv
self.dropout_p = dropout_p
self.attn_mask_type = attn_mask_type
self.attn_bias_type = attn_bias_type
self.alibi_type = alibi_type
self.attn_type = "self" if (max_seqlen_q == max_seqlen_kv) else "cross"
self.num_layers = num_layers
self.bias_shape = bias_shape
self.window_size = window_size
@contextmanager
def logging_context(highest_level=logging.WARNING):
previous_level = logging.root.manager.disable
logging.disable(highest_level)
try:
yield
finally:
logging.disable(previous_level)
def _get_attention_backends(
config: ModelConfig,
qkv_dtype: torch.dtype,
qkv_layout: str,
window_size: Tuple[int, int] = (-1, -1),
pad_between_seqs: bool = False,
context_parallel: bool = False,
deterministic: bool = False,
fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None,
) -> Tuple[List, List]:
"""Check if what attention backends support a model configuration"""
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_UNFUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
alibi_slopes_shape = None
if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
if config.bias_shape == "1hss":
alibi_slopes_shape = [config.num_heads]
if config.bias_shape == "bhss":
alibi_slopes_shape = [config.batch_size, config.num_heads]
core_attention_bias_shape = (
config.bias_shape if config.attn_bias_type == "post_scale_bias" else None
)
core_attention_bias_requires_grad = False
# d=256 is supported by cuDNN 9.0+ for inference but not training
if (
config.attn_bias_type == "post_scale_bias"
and config.head_dim_qk <= 128
and config.head_dim_v <= 128
):
core_attention_bias_requires_grad = True
fused_attn_backends = []
available_backends = None
fused_attention_backend = None
def test():
attention_params = AttentionParams(
qkv_dtype=qkv_dtype,
qkv_layout=qkv_layout,
batch_size=config.batch_size,
num_heads=config.num_heads,
num_gqa_groups=config.num_gqa_groups,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
head_dim_qk=config.head_dim_qk,
head_dim_v=config.head_dim_v,
attn_mask_type=config.attn_mask_type,
window_size=window_size,
alibi_slopes_shape=alibi_slopes_shape,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias_shape=core_attention_bias_shape,
core_attention_bias_requires_grad=core_attention_bias_requires_grad,
pad_between_seqs=pad_between_seqs,
attention_dropout=config.dropout_p,
context_parallel=context_parallel,
deterministic=deterministic,
fp8=fp8,
fp8_meta=fp8_meta,
)
(
use_flash_attention,
use_fused_attention,
fused_attention_backend,
use_unfused_attention,
available_backends,
) = get_attention_backend(attention_params)
# Set attention.py _attention_backends var using return value
# from get_attention_backend()
_attention_backends["use_flash_attention"] = use_flash_attention
_attention_backends["use_fused_attention"] = use_fused_attention
_attention_backends["fused_attention_backend"] = fused_attention_backend
_attention_backends["use_unfused_attention"] = use_unfused_attention
_attention_backends["backend_selection_requires_update"] = False
return available_backends, fused_attention_backend
backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"}
with logging_context():
for i in range(3):
os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
_attention_backends["backend_selection_requires_update"] = True
available_backends, fused_attention_backend = test()
if fused_attention_backend == FusedAttnBackend[backends[i]]:
fused_attn_backends.append(fused_attention_backend)
return available_backends, fused_attn_backends
model_configs_base = {
# test: b, h, hg, d, sq, skv, p, mask, bias # attn , backend
"base_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), # self , 0
"base_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"), # cross, 0
"base_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), # self , 1
"base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), # cross, 1
"base_3_0": ModelConfig(8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias"), # inference
"base_3_1": ModelConfig(8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"), # inference
}
param_types = [torch.float16]
if is_bf16_compatible(): # bf16 requires sm_80 or higher
param_types.append(torch.bfloat16)
param_types_lean = [torch.bfloat16]
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_base])
@pytest.mark.parametrize("model", model_configs_base.keys())
@pytest.mark.parametrize("ckpt_attn", [False])
@pytest.mark.parametrize("workspace_opt", [True, False])
@pytest.mark.parametrize("qkv_layout", [None])
@pytest.mark.parametrize("swa", [False])
@pytest.mark.parametrize("pad_between_seqs", [False])
def test_dot_product_attention(
dtype, model_configs, model, ckpt_attn, workspace_opt, qkv_layout, swa, pad_between_seqs
):
"""Test DotProductAttention module"""
# Get configs
tols = dict(atol=1e-3, rtol=1e-3)
if dtype == torch.bfloat16:
tols = dict(atol=1.5e-2, rtol=1.5e-2)
config = model_configs[model]
is_mla = config.head_dim_qk != config.head_dim_v
is_mqa_gqa = config.num_heads != config.num_gqa_groups
if qkv_layout is None:
if config.attn_type == "self":
qkv_layout = "sb3hd" if not is_mla and not is_mqa_gqa else "sbhd_sbhd_sbhd"
else:
qkv_layout = "bshd_bs2hd" if not is_mla and not is_mqa_gqa else "bshd_bshd_bshd"
if "3" in qkv_layout and config.attn_type == "cross":
pytest.skip("No need to test this layout for cross attention")
if config.window_size == (-1, -1) and swa:
config.window_size = [2, 2]
config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
available_backends, fused_attn_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=pad_between_seqs,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
# FlashAttention does not support pad_between_seqs, but _run_dot_product_attention
# mannually pads and unpads the input and output of FlashAttention for testing purposes
if (
pad_between_seqs
and FlashAttentionUtils.is_installed
and not (
config.max_seqlen_q != config.max_seqlen_kv
and config.attn_mask_type in ["causal", "padding_causal"]
)
and (config.window_size[0] == -1 or FlashAttentionUtils.v2_3_plus)
):
flash_attn_supported = True
# Skip if only unfused backend is supported
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.")
is_training = config.head_dim_qk <= 128 and config.head_dim_v <= 128
# UnfusedDotProductAttention backend
if unfused_attn_supported:
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
dtype,
config,
"UnfusedDotProductAttention",
ckpt_attn,
qkv_layout,
workspace_opt,
pad_between_seqs,
is_training,
)
# FusedAttention backend
if fused_attn_supported:
if len(fused_attn_backends) == 1:
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype,
config,
"FusedAttention",
ckpt_attn,
qkv_layout,
workspace_opt,
pad_between_seqs,
is_training,
)
if len(fused_attn_backends) == 2:
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype,
config,
"FusedAttention",
ckpt_attn,
qkv_layout,
workspace_opt,
pad_between_seqs,
is_training,
)
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention(
dtype,
config,
"FusedAttention",
ckpt_attn,
qkv_layout,
workspace_opt,
pad_between_seqs,
is_training,
)
# FlashAttention backend
if flash_attn_supported:
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype,
config,
"FlashAttention",
ckpt_attn,
qkv_layout,
workspace_opt,
pad_between_seqs,
is_training,
)
if unfused_attn_supported and flash_attn_supported:
logging.info("[test_dot_product_attention]: unfused attn vs flash attn")
torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
for i, _ in enumerate(flash_attn_bwd):
torch.testing.assert_close(unfused_attn_bwd[i], flash_attn_bwd[i], **tols)
if unfused_attn_supported and fused_attn_supported:
logging.info("[test_dot_product_attention]: unfused attn vs fused attn")
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
for i, _ in enumerate(unfused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols)
if fused_attn_supported and flash_attn_supported:
logging.info("[test_dot_product_attention]: fused attn vs flash attn")
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
for i, _ in enumerate(flash_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
if fused_attn_supported and len(fused_attn_backends) == 2:
logging.info("[test_dot_product_attention]: fused attn backend 0 vs 1")
torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols)
for i, _ in enumerate(fused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols)
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_base])
@pytest.mark.parametrize("model", ["base_1_1", "base_2_1"])
def test_dpa_checkpoint(dtype, model_configs, model):
"""Test DotProductAttention module with checkpointing"""
test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)
model_configs_mla = {
# test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend
"mla_1_0": ModelConfig(
8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias", head_dim_v=128
), # self , 0
"mla_1_1": ModelConfig(
4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128
), # cross, 0
"mla_2_0": ModelConfig(
2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias", head_dim_v=64
), # self , 1
"mla_2_1": ModelConfig(
1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=64
), # cross, 1
"mla_3_0": ModelConfig(
8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=64
), # inference
"mla_3_1": ModelConfig(
8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=128
), # inference
}
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_mla])
@pytest.mark.parametrize("model", model_configs_mla.keys())
def test_dpa_mla(dtype, model_configs, model):
"""Test DotProductAttention module with Multi-Latent Attention (MLA)"""
test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)
model_configs_mask = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"mask_1_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"),
"mask_1_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"mask_1_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"),
"mask_2_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
"mask_2_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
"mask_2_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"),
"mask_3_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"mask_3_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias"),
"mask_3_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
"mask_4_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"mask_4_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"mask_4_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
"mask_5_0": ModelConfig(
2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
),
"mask_5_1": ModelConfig(
2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
),
"mask_5_2": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"
),
"mask_6_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal", "no_bias"),
"mask_6_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal", "no_bias"),
"mask_7_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal_bottom_right", "no_bias"),
"mask_7_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal_bottom_right", "no_bias"),
"mask_8_0": ModelConfig(2, 24, 24, 128, 1, 2048, 0.0, "padding", "no_bias"),
"mask_8_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "padding", "no_bias"),
"mask_9_0": ModelConfig(2, 24, 24, 128, 1, 2048, 0.0, "padding_causal", "no_bias"),
"mask_9_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "padding_causal", "no_bias"),
"mask_10_0": ModelConfig(
2, 24, 24, 128, 1, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
),
"mask_10_1": ModelConfig(
2, 16, 16, 256, 1, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
),
}
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_mask])
@pytest.mark.parametrize("model", model_configs_mask.keys())
def test_dpa_mask(dtype, model_configs, model):
"""Test DotProductAttention module with different mask types"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
model_configs_bias = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"bias_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"),
"bias_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "post_scale_bias"),
"bias_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "post_scale_bias"),
"bias_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "post_scale_bias"),
"bias_1_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "alibi"), # skipped
"bias_1_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "alibi"), # skipped
"bias_2_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"), # skipped
"bias_2_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "padding", "post_scale_bias"), # skipped
"bias_2_2": ModelConfig(
4, 24, 24, 128, 2048, 2048, 0.0, "padding", "post_scale_bias"
), # skipped
"bias_2_3": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding", "post_scale_bias"
), # skipped
"bias_2_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding", "alibi"), # skipped
"bias_2_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "alibi"), # skipped
"bias_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"),
"bias_3_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal", "post_scale_bias"),
"bias_3_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"),
"bias_3_3": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "causal", "post_scale_bias"
), # skipped
"bias_3_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "alibi"),
"bias_3_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "alibi"), # skipped
"bias_4_0": ModelConfig(
4, 16, 16, 64, 128, 128, 0.0, "padding_causal", "post_scale_bias"
), # skipped
"bias_4_1": ModelConfig(
2, 16, 16, 64, 128, 256, 0.0, "padding_causal", "post_scale_bias"
), # skipped
"bias_4_2": ModelConfig(
4, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "post_scale_bias"
), # skipped
"bias_4_3": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias"
), # skipped
"bias_4_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "alibi"), # skipped
"bias_4_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "alibi"), # skipped
}
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias])
@pytest.mark.parametrize("model", model_configs_bias.keys())
def test_dpa_bias(dtype, model_configs, model):
"""Test DotProductAttention module with different bias types"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
model_configs_bias_shapes = {
# test: b, h, hg, d, sq, skv, p,
"bias_1_0": ModelConfig(
4,
16,
16,
64,
128,
128,
0.0,
# mask, bias, bias_shape,
"no_mask",
"post_scale_bias",
bias_shape="11ss",
),
"bias_1_1": ModelConfig(
2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias", bias_shape="1hss"
),
"bias_1_2": ModelConfig(
4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "post_scale_bias", bias_shape="b1ss"
),
"bias_1_3": ModelConfig(
2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "post_scale_bias", bias_shape="bhss"
),
"bias_1_4": ModelConfig(
4, 24, 24, 128, 2048, 2048, 0.0, "causal", "alibi", bias_shape="1hss", alibi_type="custom"
),
"bias_1_5": ModelConfig(
2, 24, 24, 128, 2048, 2048, 0.0, "causal", "alibi", bias_shape="bhss", alibi_type="custom"
),
}
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias_shapes])
@pytest.mark.parametrize("model", model_configs_bias_shapes.keys())
def test_dpa_bias_shapes(dtype, model_configs, model):
"""Test DotProductAttention module with different bias types and shapes"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
model_configs_swa = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"swa_1_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"),
"swa_1_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"),
"swa_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"),
"swa_2_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"swa_2_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"),
"swa_3_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
"swa_3_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
"swa_3_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"),
"swa_4_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"swa_4_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "padding", "no_bias"),
"swa_4_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
"swa_5_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"swa_5_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"swa_5_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
"swa_6_1": ModelConfig(
2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
),
"swa_6_2": ModelConfig(
2, 24, 4, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
),
"swa_6_3": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"
),
}
@pytest.mark.skipif(not FlashAttentionUtils.v2_3_plus, reason="Flash-attn 2.3+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_swa])
@pytest.mark.parametrize("model", model_configs_swa.keys())
def test_dpa_sliding_window(dtype, model_configs, model):
"""Test DotProductAttention module with sliding window attention"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, True, False)
model_configs_alibi_slopes = {
# test: b, h, hg, d, sq, skv, p, mask, bias, alibi_type
"alibi_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "causal", "alibi", alibi_type="vanilla"),
"alibi_1_1": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "causal", "alibi", alibi_type="vanilla"),
"alibi_2_0": ModelConfig(
2, 24, 24, 128, 1024, 1024, 0.0, "causal", "alibi", alibi_type="custom"
),
"alibi_2_1": ModelConfig(
1, 24, 24, 128, 1024, 2048, 0.0, "causal", "alibi", alibi_type="custom"
),
}
@pytest.mark.skipif(not FlashAttentionUtils.v2_3_plus, reason="Flash-attn 2.3+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_alibi_slopes])
@pytest.mark.parametrize("model", model_configs_alibi_slopes.keys())
def test_dpa_alibi_slopes(dtype, model_configs, model):
"""Test DotProductAttention module with ALiBi slopes"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
qkv_layouts = [
"sb3hd",
"sbh3d",
"sbhd_sb2hd",
"sbhd_sbh2d",
"sbhd_sbhd_sbhd",
"bs3hd",
"bsh3d",
"bshd_bs2hd",
"bshd_bsh2d",
"bshd_bshd_bshd",
]
model_configs_layout = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"layout_0_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
"layout_0_1": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"),
"layout_0_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"),
"layout_0_3": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding_causal", "post_scale_bias"),
"layout_1_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"layout_1_1": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"),
"layout_1_2": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
"layout_1_3": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias"),
"layout_2_0": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"),
"layout_2_1": ModelConfig(2, 24, 24, 256, 2048, 2048, 0.0, "causal", "post_scale_bias"),
}
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 5), reason="cuDNN 8.9.5+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_layout])
@pytest.mark.parametrize("model", model_configs_layout.keys())
@pytest.mark.parametrize("qkv_layout", qkv_layouts)
def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
"""Test DotProductAttention module with different QKV layouts"""
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)
qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"]
model_configs_layout_thd = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"layout_0_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"layout_0_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias"),
"layout_0_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
"layout_1_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"layout_1_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"layout_1_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
"layout_2_0": ModelConfig(
2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
),
"layout_2_1": ModelConfig(
2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
),
"layout_2_2": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"
),
"layout_3_0": ModelConfig(
2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias", window_size=(4, 4)
),
"layout_3_1": ModelConfig(
2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias", window_size=(4, 4)
),
"layout_3_2": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias", window_size=(4, 4)
),
"layout_4_0": ModelConfig(
2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias", window_size=(4, 0)
),
"layout_4_1": ModelConfig(
2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias", window_size=(4, 0)
),
"layout_4_2": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias", window_size=(4, 0)
),
"layout_5_0": ModelConfig(
2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias", window_size=(4, 0)
),
"layout_5_1": ModelConfig(
2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias", window_size=(4, 0)
),
"layout_5_2": ModelConfig(
2,
24,
24,
128,
2048,
4096,
0.0,
"padding_causal_bottom_right",
"no_bias",
window_size=(4, 0),
),
}
@pytest.mark.skipif(get_cudnn_version() < (9, 0, 0), reason="cuDNN 9.0.0+ is required.")
@pytest.mark.skipif(
get_device_compute_capability() < (9, 0), reason="THD is only supported on Hopper+."
)
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_layout_thd])
@pytest.mark.parametrize("model", model_configs_layout_thd.keys())
@pytest.mark.parametrize("qkv_layout", qkv_layouts_thd)
def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout):
"""Test DotProductAttention module with different QKV layouts"""
config = model_configs[model]
if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
pytest.skip("qkv_layout not applicable for MQA/GQA")
logging.info("[test_dpa_qkv_layout_thd]: pad_between_seqs = True")
pad_between_seqs = True
test_dot_product_attention(
dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
)
if get_cudnn_version() >= (9, 3, 0):
logging.info("[test_dpa_qkv_layout_thd]: pad_between_seqs = False")
# cuDNN 9.3.0+ is required to run pad_between_seqs = False/True in the same run
pad_between_seqs = False
test_dot_product_attention(
dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
)
def _run_dot_product_attention(
dtype: torch.dtype,
config: ModelConfig,
backend: str,
ckpt_attn: bool,
qkv_layout: str,
workspace_opt: bool,
pad_between_seqs: bool,
is_training: bool,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Run DotProductAttention module with one forward pass and one backward pass"""
# Set RNG and environment varables
reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0"
_attention_backends["backend_selection_requires_update"] = True
# Create seqlens
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
if "padding" in config.attn_mask_type or qkv_format == "thd":
if config.attn_type == "self":
seqlens_q = torch.randint(
1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
)
seqlens_kv = seqlens_q
if config.attn_type == "cross":
if config.max_seqlen_q > 1:
seqlens_q = torch.randint(
1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
)
else:
seqlens_q = torch.ones([config.batch_size], dtype=torch.int32, device="cuda")
seqlens_kv = torch.randint(
1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
)
else:
seqlens_q = torch.full(
[config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
)
seqlens_kv = torch.full(
[config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
)
cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
seqlens_q_after_pad = seqlens_q.clone()
seqlens_kv_after_pad = seqlens_kv.clone()
cu_seqlens_q_after_pad = cu_seqlens_q.clone()
cu_seqlens_kv_after_pad = cu_seqlens_kv.clone()
pad_len = [0] * config.batch_size
if pad_between_seqs:
max_pad_len = 3
pad_len = torch.randint(0, max_pad_len + 1, [config.batch_size], device="cuda") # 3
seqlens_q_after_pad = seqlens_q + pad_len
seqlens_kv_after_pad = seqlens_kv + pad_len
cu_seqlens_q_after_pad[1:] = torch.cumsum(seqlens_q_after_pad, dim=0)
cu_seqlens_kv_after_pad[1:] = torch.cumsum(seqlens_kv_after_pad, dim=0)
# Create attention mask if padding
attention_mask = None
if "padding" in config.attn_mask_type:
if config.attn_type == "self":
attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
for i in range(config.batch_size):
attention_mask_q = torch.cat(
[
attention_mask_q,
torch.Tensor(
[False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i])
)
.to(dtype=torch.bool)
.unsqueeze(0)
.unsqueeze(0)
.unsqueeze(0),
],
dim=0,
)
attention_mask = attention_mask_q.to(device="cuda")
if config.attn_type == "cross":
attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
attention_mask_kv = torch.Tensor([]).to(dtype=torch.bool)
for i in range(config.batch_size):
attention_mask_q = torch.cat(
[
attention_mask_q,
torch.Tensor(
[False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i])
)
.to(dtype=torch.bool)
.unsqueeze(0)
.unsqueeze(0)
.unsqueeze(0),
],
dim=0,
)
attention_mask_kv = torch.cat(
[
attention_mask_kv,
torch.Tensor(
[False] * seqlens_kv[i]
+ [True] * (config.max_seqlen_kv - seqlens_kv[i])
)
.to(dtype=torch.bool)
.unsqueeze(0)
.unsqueeze(0)
.unsqueeze(0),
],
dim=0,
)
attention_mask = (
attention_mask_q.to(device="cuda"),
attention_mask_kv.to(device="cuda"),
)
alibi_slopes = None
if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
if config.bias_shape == "1hss":
alibi_slopes = (
torch.randn(config.num_heads).abs().to(dtype=torch.float32, device="cuda")
)
if config.bias_shape == "bhss":
alibi_slopes = (
torch.randn(config.batch_size, config.num_heads)
.abs()
.to(dtype=torch.float32, device="cuda")
)
# Create input tensors
dim_to_num = {
"b": config.batch_size,
"sq": config.max_seqlen_q,
"skv": config.max_seqlen_kv,
"h": config.num_heads,
"hg": config.num_gqa_groups,
"dqk": config.head_dim_qk,
"dv": config.head_dim_v,
"t": cu_seqlens_q_after_pad[-1],
"tg": cu_seqlens_kv_after_pad[-1],
"3": 3,
"2": 2,
"1": 1,
}
inp = []
inp_orig = []
for i, layout in enumerate(qkv_layout.split("_")):
layout = "_".join(layout)
if i == 0:
layout = layout.replace("s", "sq")
else:
layout = layout.replace("s", "skv")
layout = layout.replace("h", "hg")
layout = layout.replace("t", "tg")
if i == 2:
layout = layout.replace("d", "dv")
else:
layout = layout.replace("d", "dqk")
tensor_shape = [dim_to_num[j] for j in layout.split("_")]
tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
tensor_orig = tensor
if qkv_format == "thd" and pad_between_seqs:
tensor_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
if layout in ["t_h_dqk", "t_3_h_dqk", "t_h_3_dqk"]:
for i in range(1, config.batch_size + 1):
valid_range = (
cu_seqlens_q_after_pad[i - 1],
cu_seqlens_q_after_pad[i] - pad_len[i - 1],
)
pad_range = (
cu_seqlens_q_after_pad[i] - pad_len[i - 1],
cu_seqlens_q_after_pad[i],
)
tensor[pad_range[0] : pad_range[1]] = 0.0
tensor_orig = torch.cat(
[tensor_orig, tensor[valid_range[0] : valid_range[1]]], dim=0
)
if layout in ["tg_hg_dqk", "tg_2_hg_dqk", "tg_hg_2_dqk", "tg_hg_dv"]:
for i in range(1, config.batch_size + 1):
valid_range = (
cu_seqlens_kv_after_pad[i - 1],
cu_seqlens_kv_after_pad[i] - pad_len[i - 1],
)
pad_range = (
cu_seqlens_kv_after_pad[i] - pad_len[i - 1],
cu_seqlens_kv_after_pad[i],
)
tensor[pad_range[0] : pad_range[1]] = 0.0
tensor_orig = torch.cat(
[tensor_orig, tensor[valid_range[0] : valid_range[1]]], dim=0
)
tensor_count = 1
split_dim = 0
for dim, l in enumerate(layout.split("_")):
if l.isdigit():
tensor_count = int(l)
split_dim = dim
break
tensors = torch.split(tensor, 1, dim=split_dim) if split_dim != 0 else [tensor]
tensors_orig = (
torch.split(tensor_orig, 1, dim=split_dim) if split_dim != 0 else [tensor_orig]
)
for j in range(tensor_count):
if split_dim != 0:
inp.append(tensors[j].squeeze(split_dim))
inp_orig.append(tensors_orig[j].squeeze(split_dim))
else:
inp.append(tensors[j])
inp_orig.append(tensors_orig[j])
for i in range(3):
inp[i].requires_grad = True
inp_orig[i].requires_grad = True
# Create output gradient
qkv_format_kv = "_".join(qkv_format)
qkv_format_kv = qkv_format_kv.replace("s", "sq")
qkv_format_kv = qkv_format_kv.replace("d", "dv")
out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")]
out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
out_grad = 0.001 * torch.randint(0, 200, out_grad_shape_new, dtype=dtype, device="cuda")
out_grad_orig = out_grad
if qkv_format == "thd" and pad_between_seqs:
out_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
if qkv_format_kv == "t_h_dv":
for i in range(1, config.batch_size + 1):
valid_range = (
cu_seqlens_q_after_pad[i - 1],
cu_seqlens_q_after_pad[i] - pad_len[i - 1],
)
pad_range = (cu_seqlens_q_after_pad[i] - pad_len[i - 1], cu_seqlens_q_after_pad[i])
out_grad[pad_range[0] : pad_range[1]] = 0.0
out_grad_orig = torch.cat(
[out_grad_orig, out_grad[valid_range[0] : valid_range[1]]], dim=0
)
# Create bias
if config.attn_bias_type in ["no_bias", "alibi"]:
bias = None
if config.attn_bias_type == "post_scale_bias":
shape = "_".join(config.bias_shape)
shape = shape.replace("_s_s", "_sq_skv")
tensor_shape = [dim_to_num[j] for j in shape.split("_")]
bias = torch.randn(tensor_shape, dtype=dtype, device="cuda")
if config.bias_shape != "1hss":
bias.requires_grad = False
# Create RNG
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
# Set up model
block = DotProductAttention(
config.num_heads,
(config.head_dim_qk, config.head_dim_v),
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
qkv_format=qkv_format,
attn_mask_type=config.attn_mask_type,
sequence_parallel=False,
tp_size=1,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
tp_group=None,
layer_number=1,
attention_type=config.attn_type,
).to(dtype=dtype, device="cuda")
# Run a forward and backward pass
if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
q = inp_orig[0]
k = inp_orig[1]
v = inp_orig[2]
d_out = out_grad_orig
if backend == "FusedAttention":
q = inp[0]
k = inp[1]
v = inp[2]
d_out = out_grad
out = block(
q,
k,
v,
window_size=config.window_size,
attention_mask=attention_mask,
qkv_format=qkv_format,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=cu_seqlens_q_after_pad if backend == "FusedAttention" else None,
cu_seqlens_kv_padded=cu_seqlens_kv_after_pad if backend == "FusedAttention" else None,
attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=ckpt_attn,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias,
alibi_slopes=alibi_slopes,
fast_zero_fill=True,
)
if is_training:
out.backward(d_out)
if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
if is_training:
return out, (q.grad, k.grad, v.grad)
else:
return out, (None, None, None)
if backend == "FusedAttention":
if qkv_format == "thd" and pad_between_seqs:
out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
if is_training:
q_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
k_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
v_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
for i in range(1, config.batch_size + 1):
valid_range_q = (
cu_seqlens_q_after_pad[i - 1],
cu_seqlens_q_after_pad[i] - pad_len[i - 1],
)
valid_range_kv = (
cu_seqlens_kv_after_pad[i - 1],
cu_seqlens_kv_after_pad[i] - pad_len[i - 1],
)
out_orig = torch.cat([out_orig, out[valid_range_q[0] : valid_range_q[1]]], dim=0)
if is_training:
q_grad_orig = torch.cat(
[q_grad_orig, q.grad[valid_range_q[0] : valid_range_q[1]]], dim=0
)
k_grad_orig = torch.cat(
[k_grad_orig, k.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
)
v_grad_orig = torch.cat(
[v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
)
if is_training:
return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig)
else:
return out_orig, (None, None, None)
else:
if is_training:
return out, (q.grad, k.grad, v.grad)
else:
return out, (None, None, None)
model_configs_te_layer = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"te_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"),
"te_1_1": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"),
"te_1_2": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"),
"te_2_0": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"),
"te_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"),
"te_2_2": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"te_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "alibi"),
"te_3_1": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "alibi"),
}
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", model_configs_te_layer.keys())
@pytest.mark.parametrize("ckpt_attn", [False])
@pytest.mark.parametrize("qkv_format", ["sbhd"])
@pytest.mark.parametrize("fused_qkv_params", [False])
@pytest.mark.parametrize("RoPE", [False])
def test_transformer_layer(
dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE
):
"""Test TransformerLayer module"""
# Get configs
config = model_configs[model]
tols = dict(atol=5e-2, rtol=5e-2)
workspace_opt = True
# Test backend availability
available_backends, fused_attn_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout="sbh3d" if fused_qkv_params else "sb3hd",
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
# Skip if only unfused backend is supported
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.")
# UnfusedDotProductAttention backend
if unfused_attn_supported:
unfused_attn_fwd, unfused_attn_bwd = _run_transformer_layer(
dtype,
config,
"UnfusedDotProductAttention",
ckpt_attn,
qkv_format,
workspace_opt,
fused_qkv_params,
RoPE,
)
# FusedAttention backend
if fused_attn_supported:
fused_attn_fwd, fused_attn_bwd = _run_transformer_layer(
dtype,
config,
"FusedAttention",
ckpt_attn,
qkv_format,
workspace_opt,
fused_qkv_params,
RoPE,
)
# FlashAttention backend
if flash_attn_supported:
flash_attn_fwd, flash_attn_bwd = _run_transformer_layer(
dtype,
config,
"FlashAttention",
ckpt_attn,
qkv_format,
workspace_opt,
fused_qkv_params,
RoPE,
)
if unfused_attn_supported and fused_attn_supported:
logging.info("[test_transformer_layer]: unfused attn vs fused attn")
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
torch.testing.assert_close(fused_attn_bwd, unfused_attn_bwd, **tols)
if unfused_attn_supported and flash_attn_supported:
logging.info("[test_transformer_layer]: unfused attn vs flash attn")
torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
torch.testing.assert_close(flash_attn_bwd, unfused_attn_bwd, **tols)
if fused_attn_supported and flash_attn_supported:
logging.info("[test_transformer_layer]: fused attn vs flash attn")
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols)
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", ["te_1_2", "te_2_0"])
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd"])
def test_te_layer_misc(dtype, model_configs, model, qkv_format):
"""Test TransformerLayer module with miscellaneous settings"""
ckpt_attn = True
fused_qkv_params = True
RoPE = True
test_transformer_layer(
dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE
)
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", ["te_2_0", "te_2_1", "te_2_2"])
def test_te_layer_mqa_gqa(dtype, model_configs, model):
"""Test TransformerLayer module with MQA/GQA"""
def find_factors(x):
f = []
for i in range(2, x + 1):
if x % i == 0:
f.append(i)
return f
ckpt_attn = True
qkv_format = "bshd"
fused_qkv_params = True
RoPE = True
config = model_configs[model]
num_querys_per_gqa_group = find_factors(config.num_heads)
for num_q_per_gqa_group in num_querys_per_gqa_group:
config.num_gqa_groups = config.num_heads // num_q_per_gqa_group
test_transformer_layer(
dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE
)
def _run_transformer_layer(
dtype: torch.dtype,
config: ModelConfig,
backend: str,
ckpt_attn: bool,
qkv_format: str,
workspace_opt: bool,
fused_qkv_params: bool,
RoPE: bool,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Run TransformerLayer module with one forward pass and one backward pass"""
# Set RNG and environment variables
reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
# Create input tensor
inp = torch.randn(
config.max_seqlen_q,
config.batch_size,
config.hidden_size,
dtype=dtype,
device="cuda",
requires_grad=True,
)
# In case the format to be tested is batch-first, need to transpose the
# input tensor.
if qkv_format == "bshd":
inp = inp.transpose(0, 1)
# Create seqlens
if "padding" in config.attn_mask_type:
seqlens_q = torch.randint(
1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
)
else:
seqlens_q = torch.full(
[config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
)
# Create attention mask if padding
attention_mask = None
if "padding" in config.attn_mask_type:
attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
for i in range(config.batch_size):
attention_mask_q = torch.cat(
[
attention_mask_q,
torch.Tensor(
[False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i])
)
.to(torch.bool)
.unsqueeze(0)
.unsqueeze(0)
.unsqueeze(0),
],
dim=0,
)
attention_mask = attention_mask_q.to(device="cuda")
sigma = 0.02
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
layer_number = 1
drop_path_rate = 0.0
drop_path_rates = [rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)]
# Create bias
bias = None
if config.attn_bias_type == "post_scale_bias":
bias = torch.randn(
1,
config.num_heads,
config.max_seqlen_q,
config.max_seqlen_kv,
dtype=dtype,
device="cuda",
)
# Create RoPE
rotary_pos_emb = None
if RoPE:
PE = RotaryPositionEmbedding(dim=config.head_dim_qk)
rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda")
# Set up model
block = TransformerLayer(
config.hidden_size,
4 * config.hidden_size,
config.num_heads,
num_gqa_groups=config.num_gqa_groups,
layernorm_epsilon=1e-5,
hidden_dropout=0.0,
attention_dropout=config.dropout_p,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
layer_number=layer_number,
kv_channels=config.head_dim_qk,
self_attn_mask_type=config.attn_mask_type,
tp_group=None,
tp_size=1,
params_dtype=dtype,
get_rng_state_tracker=None,
fuse_wgrad_accumulation=False,
seq_length=config.max_seqlen_q,
micro_batch_size=config.batch_size,
sequence_parallel=False,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
layer_type="encoder",
drop_path_rate=drop_path_rates[layer_number - 1],
set_parallel_mode=True,
fuse_qkv_params=fused_qkv_params,
zero_centered_gamma=False,
qkv_weight_interleaved=False,
ub_tp_comm_overlap=False,
bias=True,
attn_input_format=qkv_format,
).to(dtype=dtype, device="cuda")
# Create ALiBi slopes
alibi_slopes = None
if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
alibi_slopes = torch.randn(config.num_heads).abs().to(dtype=torch.float32, device="cuda")
# Run a forward and backward pass
out = block(
inp,
attention_mask=attention_mask,
self_attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=False,
rotary_pos_emb=rotary_pos_emb,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias,
alibi_slopes=alibi_slopes,
)
loss = out.sum()
loss.backward()
return out, inp.grad
model_configs_fp8_vs_f16 = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"fp8_9": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"fp8_10": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"fp8_11": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "no_mask", "no_bias"),
"fp8_12": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"fp8_13": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"fp8_14": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"),
"fp8_15": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "padding", "no_bias"),
"fp8_16": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "padding", "no_bias"),
"fp8_17": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "padding", "no_bias"),
"fp8_18": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"fp8_19": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"fp8_20": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "padding_causal", "no_bias"),
}
param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16]
qkv_layout_fp8_vs_f16 = ["sbh3d", "bshd_bshd_bshd", "sbhd_sbhd_sbhd"]
qkv_format_fp8_vs_f16 = ["bshd", "sbhd"]
def _rmse(a, b):
return math.sqrt((torch.pow((a - b), 2) / a.numel()).sum())
def _error(a, b, name_a, name_b, atol, rtol, rmse_tol):
logging.debug(name_a + " min {:.6f} max {:.6f}".format(a.min().item(), a.max().item()))
logging.debug(name_b + " min {:.6f} max {:.6f}".format(b.min().item(), b.max().item()))
try:
if a.dtype != b.dtype:
a = a.to(b.dtype)
torch.testing.assert_close(a, b, atol=atol, rtol=rtol)
except Exception as e:
logging.debug(e)
rmse = _rmse(a, b)
logging.debug(name_a + " vs " + name_b + " RMSE: {:.6f}".format(rmse))
rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item())
assert rmse < rmse_tol * rmse_range, (
name_a
+ " vs "
+ name_b
+ " RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
rmse, rmse_tol * rmse_range, rmse_tol, rmse_range
)
)
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
@pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16)
@pytest.mark.parametrize("input_layernorm", [True, False])
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
@pytest.mark.parametrize("RoPE", [True, False])
@pytest.mark.parametrize("is_training", [True, False])
def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training):
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
config = model_configs_fp8_vs_f16[model]
if ("padding" in config.attn_mask_type or config.head_dim_qk != 128) and get_cudnn_version() < (
9,
7,
0,
):
pytest.skip("FP8 with padding or head_dim != 128 is not supported for cuDNN < 9.7")
if (
FlashAttentionUtils.v3_is_installed
and not is_training
and "padding" not in config.attn_mask_type
):
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm, RoPE, is_training
)
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
dtype, config, True, qkv_format, input_layernorm, RoPE, is_training
)
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
dtype, config, False, qkv_format, input_layernorm, RoPE, is_training
)
atol = 5e-1
rtol = 5e-1
rmse_tol = 0.15
logging.debug("========== {:^25s} ==========".format("forward output"))
if (
FlashAttentionUtils.v3_is_installed
and not is_training
and "padding" not in config.attn_mask_type
):
_error(
flash_attn_fwd_fp8,
fused_attn_fwd_f16,
"flash_attn_fwd_fp8",
"fused_attn_fwd_f16",
atol,
rtol,
rmse_tol,
)
_error(
fused_attn_fwd_fp8,
fused_attn_fwd_f16,
"fused_attn_fwd_fp8",
"fused_attn_fwd_f16",
atol,
rtol,
rmse_tol,
)
if is_training:
for i in range(len(param_names[:1])):
logging.debug("========== {:^25s} ==========".format(param_names[i]))
_error(
fused_attn_bwd_fp8[i],
fused_attn_bwd_f16[i],
f"fused_attn_bwd_fp8[{i}]",
f"fused_attn_bwd_f16[{i}]",
atol,
rtol,
rmse_tol,
)
def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoPE, is_training):
reset_rng_states()
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
fp8_recipe = recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
fp8_dpa=fp8_mha,
fp8_mha=fp8_mha,
)
with fp8_model_init(enabled=fp8_mha, recipe=fp8_recipe):
rotary_pos_emb = None
if RoPE:
PE = RotaryPositionEmbedding(dim=config.head_dim_qk)
rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda")
mha = MultiheadAttention(
hidden_size=config.hidden_size,
num_attention_heads=config.num_heads,
kv_channels=config.head_dim_qk,
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
layer_number=1,
bias=True,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
params_dtype=dtype,
input_layernorm=input_layernorm,
fuse_qkv_params=True,
attention_type="self",
qkv_weight_interleaved=True,
qkv_format=qkv_format,
).to(dtype=dtype, device="cuda")
if not is_training:
mha = mha.eval()
if "padding" in config.attn_mask_type or qkv_format == "thd":
if config.attn_type == "self":
seqlens_q = torch.randint(
1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
)
seqlens_kv = seqlens_q
if config.attn_type == "cross":
seqlens_q = torch.randint(
1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
)
seqlens_kv = torch.randint(
1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
)
else:
seqlens_q = torch.full(
[config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
)
seqlens_kv = torch.full(
[config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
)
cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
dim_to_num = {
"b": config.batch_size,
"sq": config.max_seqlen_q,
"skv": config.max_seqlen_kv,
"h": config.num_heads,
"hg": config.num_gqa_groups,
"d": config.head_dim_qk,
"t": cu_seqlens_q[-1],
"tg": cu_seqlens_kv[-1],
"3": 3,
"2": 2,
"1": 1,
}
layout = "_".join(qkv_format)
layout = layout.replace("s", "sq")
tensor_shape = [dim_to_num[j] for j in layout.split("_")]
tensor = 0.01 * torch.randint(-100, 100, tensor_shape, dtype=dtype, device="cuda")
hidden_states = tensor.view(*tensor.shape[:-2], -1)
if is_training:
hidden_states.requires_grad = True
tensor = 0.01 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
out_grad = tensor.view(*tensor.shape[:-2], -1)
with fp8_autocast(enabled=fp8_mha, fp8_recipe=fp8_recipe):
out = mha(
hidden_states,
attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=False,
core_attention_bias_type=config.attn_bias_type,
is_first_microbatch=None,
rotary_pos_emb=rotary_pos_emb,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
)
if is_training:
out.backward(out_grad)
param_names = []
param_names.append("hidden_states.grad")
params = []
params.append(hidden_states)
for name, param in mha.named_parameters():
if param.requires_grad:
param_names.append(name + ".grad")
params.append(param)
if is_training:
return out, param_names, tuple(x.grad for x in params)
return out, param_names, tuple(None for x in params)
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
@pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16)
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
@pytest.mark.parametrize("is_training", [True, False])
def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
config = model_configs_fp8_vs_f16[model]
# TODO(cyang): think of another way to verify dropout results
# test cuDNN FP8 dropout
# 1. we modify the config here to not affect mha_fp8_vs_f16 tests
# 2. there is no other backend that implements dropout the same way as cuDNN FP8, and as an
# indirect verification method, we create Q/K/V as all 1s and check if O is all 1s
# 3. we avoid running FP16/BF16 kernels as they do not have dropout support on Blackwell
# if "padding" not in config.attn_mask_type and "causal" not in config.attn_mask_type:
# if get_device_compute_capability() >= (10, 0):
# config.dropout_p = 0.1
if ("padding" in config.attn_mask_type or config.head_dim_qk != 128) and get_cudnn_version() < (
9,
7,
0,
):
pytest.skip("FP8 with padding or head_dim != 128 is not supported for cuDNN < 9.7")
if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
pytest.skip("qkv_layout not applicable for MQA/GQA")
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
if (
FlashAttentionUtils.v3_is_installed
and not is_training
and "padding" not in config.attn_mask_type
):
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True")
flash_attn_fwd_fp8, flash_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
dtype, config, True, qkv_layout, is_training
)
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True")
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
dtype, config, True, qkv_layout, is_training
)
if config.dropout_p == 0.0:
# test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False")
fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16(
dtype, config, False, qkv_layout, is_training
)
atol = 5e-1
rtol = 5e-2
rmse_tol = 0.11
bwd_names = ["dq", "dk", "dv"]
logging.debug("========== {:^25s} ==========".format("forward output"))
if (
FlashAttentionUtils.v3_is_installed
and not is_training
and "padding" not in config.attn_mask_type
):
_error(
flash_attn_fwd_fp8,
fused_attn_fwd_f16,
"flash_attn_fwd_fp8",
"fused_attn_fwd_f16",
atol,
rtol,
rmse_tol,
)
if config.dropout_p != 0.0:
# test cuDNN FP8 dropout
assert torch.all(
fused_attn_fwd_fp8 == 1
), "fused_attn_fwd_fp8 must be all 1s when Q/K/V are all 1s."
else:
_error(
fused_attn_fwd_fp8,
fused_attn_fwd_f16,
"fused_attn_fwd_fp8",
"fused_attn_fwd_f16",
atol,
rtol,
rmse_tol,
)
if is_training:
for i, _ in enumerate(fused_attn_bwd_f16):
logging.debug("========== {:^25s} ==========".format(bwd_names[i]))
_error(
fused_attn_bwd_fp8[i],
fused_attn_bwd_f16[i],
f"fused_attn_bwd_fp8[{i}]",
f"fused_attn_bwd_f16[{i}]",
atol,
rtol,
rmse_tol,
)
def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training):
reset_rng_states()
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
fp8_recipe = recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
fp8_dpa=fp8_dpa,
)
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
with fp8_model_init(enabled=fp8_dpa):
dpa = DotProductAttention(
config.num_heads,
config.head_dim_qk,
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
sequence_parallel=False,
tp_size=1,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
tp_group=None,
layer_number=1,
attention_type="self",
qkv_format=qkv_format,
).to(dtype=dtype, device="cuda")
if not is_training:
dpa = dpa.eval()
if "padding" in config.attn_mask_type or qkv_format == "thd":
if config.attn_type == "self":
seqlens_q = torch.randint(
1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
)
seqlens_kv = seqlens_q
if config.attn_type == "cross":
seqlens_q = torch.randint(
1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
)
seqlens_kv = torch.randint(
1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
)
else:
seqlens_q = torch.full(
[config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
)
seqlens_kv = torch.full(
[config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
)
cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
dim_to_num = {
"b": config.batch_size,
"sq": config.max_seqlen_q,
"skv": config.max_seqlen_kv,
"h": config.num_heads,
"hg": config.num_gqa_groups,
"d": config.head_dim_qk,
"t": cu_seqlens_q[-1],
"tg": cu_seqlens_kv[-1],
"3": 3,
"2": 2,
"1": 1,
}
inp = []
for i, layout in enumerate(qkv_layout.split("_")):
layout = "_".join(layout)
if i == 0:
layout = layout.replace("s", "sq")
else:
layout = layout.replace("s", "skv")
layout = layout.replace("h", "hg")
layout = layout.replace("t", "tg")
tensor_shape = [dim_to_num[j] for j in layout.split("_")]
if config.dropout_p == 0.0:
tensor = torch.randn(tensor_shape, dtype=dtype, device="cuda")
else:
# test cuDNN FP8 dropout
tensor = torch.ones(tensor_shape, dtype=dtype, device="cuda")
tensor_count = 1
split_dim = 0
for dim, l in enumerate(layout.split("_")):
if l.isdigit():
tensor_count = int(l)
split_dim = dim
break
tensors = torch.split(tensor, 1, dim=split_dim) if split_dim != 0 else [tensor]
for j in range(tensor_count):
if split_dim != 0:
inp.append(tensors[j].squeeze(split_dim))
else:
inp.append(tensors[j])
for i in range(3):
inp[i].requires_grad = True
qkv_format_kv = "_".join(qkv_format)
qkv_format_kv = qkv_format_kv.replace("s", "sq")
out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")]
out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
out_grad = torch.randn(out_grad_shape_new, dtype=dtype, device="cuda")
with fp8_autocast(enabled=fp8_dpa, fp8_recipe=fp8_recipe):
out = dpa(
inp[0],
inp[1],
inp[2],
qkv_format=qkv_format,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=False,
core_attention_bias_type=config.attn_bias_type,
)
if is_training:
out.backward(out_grad)
if is_training:
return out, (inp[0].grad, inp[1].grad, inp[2].grad)
return out, (None, None, None)
model_configs_fp8 = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"fp8_1": ModelConfig(1, 1, 1, 64, 512, 512, 0.0, "no_mask", "no_bias"),
"fp8_2": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"),
"fp8_3": ModelConfig(1, 1, 1, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"fp8_4": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"fp8_5": ModelConfig(1, 1, 1, 64, 512, 512, 0.0, "causal", "no_bias"),
"fp8_6": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "causal", "no_bias"),
"fp8_7": ModelConfig(1, 1, 1, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"fp8_8": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"),
}
param_types_fp8 = [torch.float16, torch.bfloat16]
cudnn_frontend_version = int(os.getenv("NVTE_FUSED_ATTN_FE_VER", "1"))
models_v0 = ["fp8_1", "fp8_2", "fp8_5", "fp8_6"]
models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"]
@pytest.mark.skipif(
(
get_cudnn_version() < (8, 9, 3)
if cudnn_frontend_version == 0
else get_cudnn_version() < (9, 2, 1)
),
reason=f"""cuDNN {"8.9.3" if cudnn_frontend_version == 0 else "9.2.1"}+ is required.""",
)
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
@pytest.mark.parametrize("dtype", param_types_fp8)
@pytest.mark.parametrize("model", models_v1 if cudnn_frontend_version == 1 else models_v0)
def test_custom_mha_fp8_vs_f16(dtype, model):
"""Test FP8 dot product attention implementations based on cuDNN frontend
v0.9 and v1.0+. Each test compares results from a custom implementation of
an FP8 MHA module, i.e. Custom_MHA_FP8(), to results from an F16 MHA
implementation, i.e. transformer_engine.pytorch.attention.MultiHeadAttention.
Both paths take F16 input and output. QKV layout is t3hd or bs3hd"""
config = model_configs_fp8[model]
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention")
unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedAttention")
atol = 5e-1
rtol = 5e-1
rmse_tol = 0.13
_error(
fused_attn_fwd_fp8,
unfused_attn_fwd_f16,
"fused_attn_fwd_fp8",
"unfused_attn_fwd_f16",
atol,
rtol,
rmse_tol,
)
_error(
fused_attn_bwd_fp8,
unfused_attn_bwd_f16,
"fused_attn_bwd_fp8",
"unfused_attn_bwd_f16",
atol,
rtol,
rmse_tol,
)
def _run_custom_mha_fp8(dtype, config, backend):
"""Run Custom_MHA_FP8 with FP8 FusedAttention backend. Both input and output
are in F16. QKV GEMM, DPA, and projection GEMM are calculated in FP8."""
reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
inp = 0.0001 * torch.randint(
-100,
100,
(config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim_qk),
dtype=dtype,
device="cuda",
requires_grad=True,
)
seqlens = torch.full([config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda")
cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
out_grad = 0.01 * torch.randn(
config.batch_size * config.max_seqlen_q,
config.num_heads * config.head_dim_qk,
dtype=dtype,
device="cuda",
)
torch.save(out_grad, "out_grad.pt")
fp8_recipe = recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
)
mha = Custom_MHA_FP8(config).to(dtype=dtype, device="cuda")
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
out = mha(inp, cu_seqlens, config.max_seqlen_q)
out.backward(out_grad)
out = torch.load("out.pt")
dqkv = torch.load("dqkv.pt")
return (
out.view(config.batch_size, config.max_seqlen_q, -1),
dqkv.view(
config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim_qk
).contiguous(),
)
def _run_ref_mha_f16(dtype, config, backend):
"""Run reference F16 FusedAttention. Both input and output
are in F16. QKV GEMM, DPA, and projection GEMM are also in F16."""
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
inp = torch.load("qkv.pt").to(device="cuda")
inp.requires_grad = True
seqlens = torch.full([config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda")
cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
out_grad = (
torch.load("out_grad.pt").to(device="cuda").view(config.batch_size, config.max_seqlen_q, -1)
)
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
block = DotProductAttention(
config.num_heads,
config.head_dim_qk,
attention_dropout=config.dropout_p,
sequence_parallel=False,
tp_size=1,
get_rng_state_tracker=get_dummy_cuda_rng_tracker,
tp_group=None,
layer_number=1,
attention_type="self",
qkv_format="bshd",
).to(dtype=dtype, device="cuda")
q = inp[:, :, 0, :, :]
k = inp[:, :, 1, :, :]
v = inp[:, :, 2, :, :]
out = block(q, k, v, attn_mask_type=config.attn_mask_type)
out.backward(out_grad)
return out, inp.grad
_CUBLASLT_WORKSPACE_SIZE_BYTES = 33_554_432 # 32MiB
_2X_ACC_FPROP = False
_2X_ACC_DGRAD = False
_2X_ACC_WGRAD = False
META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT
META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1
META_O = tex.FP8FwdTensors.GEMM2_INPUT
META_DO = tex.FP8BwdTensors.GRAD_INPUT2
META_S = tex.FP8FwdTensors.GEMM3_OUTPUT
META_DP = tex.FP8BwdTensors.GRAD_INPUT3
class _custom_mha_fp8(torch.autograd.Function):
@staticmethod
def forward(
ctx,
inp: torch.Tensor,
qkv_weight: torch.Tensor,
qkv_bias: torch.Tensor,
cu_seqlens: torch.Tensor,
num_heads: int,
p_dropout: float,
max_s: int,
fast_zero_fill: bool,
fp8_meta: Dict[str, Any],
workspace: torch.Tensor,
is_training: bool,
mask_type: str,
quantizers: list[Quantizer],
) -> torch.Tensor:
qkv_dtype = inp.dtype
assert inp.dim() == 2
in_features = qkv_weight.shape[-1]
h = num_heads
d = in_features // h
b = cu_seqlens.numel() - 1
input_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
qkv_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT]
qkv_weight_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
o_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
dO_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
dQKV_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1]
s_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT2]
dP_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT3]
inp_fp8 = input_quantizer(inp)
qkv_weight_fp8 = qkv_weight_quantizer(qkv_weight)
qkv, *_ = ext.general_gemm(
qkv_weight_fp8,
inp_fp8,
workspace,
bias=qkv_bias,
out_dtype=qkv_weight_fp8.dtype,
quantization_params=qkv_quantizer,
use_split_accumulator=_2X_ACC_FPROP,
)
qkv = qkv.view(-1, 3, h, d)
qkv_fp16 = qkv.dequantize().view(b, max_s, 3, h, d).contiguous()
torch.save(qkv_fp16, "qkv.pt")
if cudnn_frontend_version == 1:
qkv = qkv.view(b, max_s, 3, h, d) # bs3hd
# FMHA
q_data = qkv._data[:, :, 0, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 0, :, :]
k_data = qkv._data[:, :, 1, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 1, :, :]
v_data = qkv._data[:, :, 2, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 2, :, :]
q = qkv.make_like(tensor=qkv, data=q_data, shape=q_data.shape)
k = qkv.make_like(tensor=qkv, data=k_data, shape=k_data.shape)
v = qkv.make_like(tensor=qkv, data=v_data, shape=v_data.shape)
out, aux_ctx_tensors = fused_attn_fwd(
is_training,
max_s,
max_s,
cu_seqlens,
cu_seqlens,
q,
k,
v,
qkv_dtype,
FusedAttnBackend["FP8"],
attn_scale=None,
dropout=p_dropout,
fast_zero_fill=fast_zero_fill,
qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd",
attn_bias_type="no_bias",
attn_mask_type=mask_type if cudnn_frontend_version == 1 else "padding",
rng_gen=None,
o_quantizer=o_quantizer,
s_quantizer=s_quantizer,
)
tensors_to_save, tensor_objects = prepare_for_saving(
q, k, v, inp_fp8, qkv_weight_fp8, workspace, out
)
ctx.save_for_backward(*tensors_to_save)
ctx.tensor_objects = tensor_objects
ctx.aux_ctx_tensors = aux_ctx_tensors
ctx.qkv_dtype = qkv_dtype
ctx.fp8_meta = fp8_meta
ctx.cu_seqlens = cu_seqlens
ctx.p_dropout = p_dropout
ctx.max_s = max_s
ctx.fast_zero_fill = fast_zero_fill
ctx.hidden_size = in_features
ctx.num_heads = num_heads
ctx.mask_type = mask_type
ctx.dtype = inp.dtype
ctx.dQKV_quantizer = dQKV_quantizer
ctx.dO_quantizer = dO_quantizer
ctx.dP_quantizer = dP_quantizer
ctx.S_quantizer = s_quantizer
out = out.view(-1, in_features) # (bs)(hd)
out_fp16 = out.dequantize()
torch.save(out_fp16, "out.pt") # (bs)(hd)
return out_fp16
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
with torch.cuda.nvtx.range("_DPA"):
saved_tensors = ctx.saved_tensors
(q, k, v, inp_fp8, qkv_weight_fp8, workspace, out) = restore_from_saved(
ctx.tensor_objects, saved_tensors
)
proj_dgrad = ctx.dO_quantizer(grad_output)
fp8_dtype_backward = fp8.get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
dq, dk, dv, *rest = fused_attn_bwd(
ctx.max_s,
ctx.max_s,
ctx.cu_seqlens,
ctx.cu_seqlens,
q,
k,
v,
out,
proj_dgrad.view_as(out),
ctx.qkv_dtype,
fp8_dtype_backward,
ctx.aux_ctx_tensors,
FusedAttnBackend["FP8"],
None,
None,
ctx.S_quantizer,
ctx.dP_quantizer,
ctx.dQKV_quantizer,
attn_scale=None,
dropout=ctx.p_dropout,
fast_zero_fill=ctx.fast_zero_fill,
qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd",
attn_bias_type="no_bias",
attn_mask_type=ctx.mask_type if cudnn_frontend_version == 1 else "padding",
)
dim = 2 if cudnn_frontend_version == 1 else 1
dqkv = torch.Tensor().to(device=dq._data.device, dtype=dq._data.dtype)
dqkv_shape = list(dq._data.shape)
dqkv_shape.insert(dim, 3)
dqkv_stride = list(dq._data.stride())
dqkv_stride.insert(dim, int(dqkv_stride[-3] / 3))
dqkv.set_(
dq._data.untyped_storage(), dq._data.storage_offset(), dqkv_shape, dqkv_stride
) # bs3hd
dqkv_c = dqkv.view(-1, 3 * ctx.hidden_size)
dqkv_c = dq.make_like(tensor=dq, data=dqkv_c, shape=dqkv_c.shape)
dqkv_c_fp16 = dqkv_c.dequantize()
torch.save(dqkv_c_fp16, "dqkv.pt")
qkv_bgrad, dqkv = ext.bgrad_quantize(dqkv_c_fp16, ctx.dQKV_quantizer)
dqkv_c._transpose = None
dqkv_c._create_transpose()
# QKV DGRAD
qkv_dgrad, *_ = ext.general_gemm(
qkv_weight_fp8,
dqkv_c,
workspace,
ctx.dtype,
use_split_accumulator=_2X_ACC_DGRAD,
layout="NN",
)
# QKV WGRAD
qkv_wgrad, *_ = ext.general_gemm(
inp_fp8,
dqkv,
workspace,
ctx.dtype,
use_split_accumulator=_2X_ACC_WGRAD,
layout="NT",
)
return (
qkv_dgrad,
qkv_wgrad,
qkv_bgrad,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
class Custom_MHA_FP8(TransformerEngineBaseModule):
def __init__(self, config, params_dtype: torch.dtype = torch.float32):
super().__init__()
self.p_dropout = config.dropout_p
self.h = config.num_heads
self.hidden_size = config.hidden_size
self.head_dim = config.head_dim_qk
self.fast_zero_fill = True
self.mask_type = config.attn_mask_type
self.qkv_weight = torch.nn.Parameter(
torch.empty(
self.hidden_size * 3,
self.hidden_size,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
self.qkv_bias = torch.nn.Parameter(
torch.empty(
self.hidden_size * 3,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
with torch.no_grad():
self.qkv_bias.zero_()
self.qkv_weight.fill_(1.0)
self.workspace = torch.empty(
_CUBLASLT_WORKSPACE_SIZE_BYTES, dtype=torch.int8, device="cuda"
)
def forward(
self,
inp: torch.Tensor,
cu_seqlens,
max_s,
) -> torch.Tensor:
with self.prepare_forward(inp, num_gemms=3) as inp:
out = _custom_mha_fp8.apply(
inp,
self.qkv_weight,
self.qkv_bias,
cu_seqlens,
self.h,
self.p_dropout,
max_s,
self.fast_zero_fill,
self.fp8_meta,
self.workspace,
self.training,
self.mask_type,
self.quantizers,
)
return out
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import subprocess
import pytest
import torch
from transformer_engine.pytorch.utils import (
get_device_compute_capability,
get_cudnn_version,
)
from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils
from test_fused_attn import ModelConfig
model_configs_flash_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_1_2": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # MHA
"cp_1_3": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512)
), # MHA
"cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_2_2": ModelConfig(
2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # GQA
"cp_2_3": ModelConfig(
2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512)
), # GQA
}
def get_bash_arguments(num_gpus_per_node, **kwargs):
args = [
"python3",
"-m",
"torch.distributed.launch",
"--nproc-per-node=" + str(num_gpus_per_node),
]
te_path = os.getenv("TE_PATH", "/opt/transformerengine")
script_path = os.path.join(te_path, "tests/pytorch/fused_attn/run_fused_attn_with_cp.py")
args.append(script_path)
for k, v in kwargs.items():
args.append(f"{k}={v}")
return args
@pytest.mark.skipif(not FlashAttentionUtils.v2_plus, reason="Flash-attn 2.0+ is required.")
@pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
@pytest.mark.parametrize("dtype", ["bf16", "fp16"])
@pytest.mark.parametrize("model", model_configs_flash_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"])
def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2
if num_gpus > torch.cuda.device_count():
pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}")
config = model_configs_flash_attn[model]
if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("CP implementation with KV P2P does not support sliding window yet!")
if cp_comm_type == "all_gather" and qkv_format == "thd":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with KV all-gather does not support bias yet!")
if "a2a" in cp_comm_type and qkv_format == "thd":
pytest.skip("CP implementation with QKVO A2A does not support THD format yet!")
if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with QKVO A2A does not support bias yet!")
if "a2a" in cp_comm_type and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0):
pytest.skip(
f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and"
f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!"
)
subprocess.run(
get_bash_arguments(
num_gpus_per_node=num_gpus,
dtype=dtype,
model=model,
qkv_format=qkv_format,
kernel_backend="FlashAttention",
cp_comm_type=cp_comm_type,
),
check=True,
)
model_configs_fused_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
"cp_1_2": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA
"cp_1_3": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # MHA
"cp_1_4": ModelConfig(
2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # MHA
"cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
"cp_2_2": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA
"cp_2_3": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA
"cp_2_4": ModelConfig(
2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0)
), # GQA
}
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.")
@pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
@pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"])
@pytest.mark.parametrize("model", model_configs_fused_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"])
@pytest.mark.parametrize("fp8_mha", [False, True])
def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha):
num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2
if num_gpus > torch.cuda.device_count():
pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}")
if qkv_format == "thd" and get_device_compute_capability() < (9, 0):
pytest.skip("THD format is only supported on sm90+!")
if cp_comm_type == "all_gather" and get_cudnn_version() < (9, 3, 0):
pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!")
if dtype == "fp8" and get_device_compute_capability() < (9, 0):
pytest.skip("FP8 attention is only supported on sm90+!")
config = model_configs_fused_attn[model]
if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias":
pytest.skip("THD format does not support post_scale_bias yet!")
if qkv_format == "thd" and cp_comm_type == "all_gather":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if qkv_format == "thd" and "a2a" in cp_comm_type:
pytest.skip("CP implementation with QKVO A2A does not support THD format yet!")
if dtype == "fp8" and cp_comm_type == "all_gather":
pytest.skip(
"CP implementation with KV all-gather does not support FP8 + context parallelism yet!"
)
if dtype == "fp8" and qkv_format == "thd":
pytest.skip("FP8 attention cannot work with THD format yet!")
if dtype == "fp8" and config.attn_bias_type != "no_bias":
pytest.skip("FP8 attention cannot work with bias yet!")
if dtype == "fp8" and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("FP8 attention cannot work with sliding window yet!")
if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("CP implementation with KV P2P does not support sliding window yet!")
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with KV all-gather does not support bias yet!")
if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with QKVO A2A does not support bias yet!")
if "a2a" in cp_comm_type and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0):
pytest.skip(
f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and"
f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!"
)
if dtype != "fp8" and fp8_mha:
pytest.skip("Only fp8 works with fp8_mha=True!")
subprocess.run(
get_bash_arguments(
num_gpus_per_node=num_gpus,
dtype=dtype,
model=model,
qkv_format=qkv_format,
kernel_backend="FusedAttention",
cp_comm_type=cp_comm_type,
fp8_mha=fp8_mha,
),
check=True,
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch
import transformer_engine_torch as tex
from transformer_engine.pytorch.constants import TE_DType_To_Torch
# compute amax and scale
def _ref_compute_amax_scale(x, quant_dtype, eps, pow_2_scales):
x_fp32 = x.to(torch.float32)
amax = torch.amax(torch.abs(x_fp32)).view(1)
assert amax.dtype == torch.float, "amax must be a float tensor."
fp8_max = torch.finfo(quant_dtype).max
# Clamping amax to avoid division by small numbers
amax = torch.max(amax, torch.tensor(eps))
# Compute scale factor
scale = torch.div(fp8_max, amax)
# Note frexp doesn't give back inf for exponent with an inf input
# We take care of inf before pow_2_scales
# option1: set scale to fp32 max when scale is inf
scale = torch.where(scale == torch.inf, torch.finfo(torch.float32).max, scale)
# option2: when scale is inf, set scale to 1
scale = torch.where(scale == torch.inf, 1.0, scale)
if pow_2_scales:
# Calculate rounded down exponent
_, exp = torch.frexp(scale)
# Positive numbers are always returned as mant, exp with
# a mantissa in [0.5, 1.0). Because a normal float has a mantissa with
# hidden bit in [1.0, 2.0), the exponent will be off by exactly one because
# of the shift. Subnormal and zero cases need not be considered because
# the smallest possible result of fp8_max / amax is still normal.
exp = exp - 1
# No subnormals and zero.
assert (exp > -127).all()
# TODO: If/when adding a URM option an option is to cap to 126
# rather than allowing the full range of FP32 (2 - 2^23) x 2^127
# addresses cases where adding a mantissa overflows into inf scales.
# Not necessary currently without additional scale smudging options.
unity = torch.tensor([1.0], device=exp.device)
torch.ldexp(unity, exp, out=scale)
# Case where amax is inf. The frexp, ldexp logic changes 0.0 scales
# Return 0.0 for 0.0 scale for consistency with non-pow2 scale
# calculation.
scale = torch.where(amax == float("inf"), 0.0, scale)
# Handle overflow cases for amax zero causing NaN
scale = torch.where(amax == 0, 1.0, scale)
# Compute scale_inv
scale_inv = torch.reciprocal(scale)
return scale, scale_inv, amax
def _multi_dim_transpose(tensor):
# Get the number of dimensions
dims = list(range(len(tensor.shape)))
if len(dims) <= 1:
return tensor
# circular shift of shapes
new_order = []
new_order.append(dims[-1])
for i in range(len(dims) - 1):
new_order.append(dims[i])
# Permute the tensor according to the new order
output_tensor = tensor.permute(new_order).contiguous()
return output_tensor
# current scaling reference quantization
def ref_per_tensor_cs_cast(
tensor: torch.Tensor,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
return_transpose: bool = False,
force_pow_2_scales: bool = False,
amax_epsilon: float = 0.0,
) -> torch.Tensor:
quant_dtype_torch = TE_DType_To_Torch[fp8_dtype]
scale, scale_inv, _ = _ref_compute_amax_scale(
tensor,
quant_dtype_torch,
amax_epsilon,
force_pow_2_scales,
)
qx = (tensor.float() * scale).to(quant_dtype_torch)
sx = scale_inv
qx_t = None
sx_t = None
if tensor.shape == torch.Size([]):
qx = qx.view([])
if return_transpose:
qx_t = _multi_dim_transpose(qx)
sx_t = sx
return qx, sx, qx_t, sx_t
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import torch
from contextlib import nullcontext
import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
# Check if FP8 supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
SIZE = 512
models = {
"linear": te.Linear,
"layernorm_mlp": te.LayerNormMLP,
"layernorm_linear": te.LayerNormLinear,
}
def _get_input():
return torch.empty((128, SIZE, SIZE)).cuda()
def _measure_memory_between_forward_and_backward(model_cls, fp8, cpu_offload):
input_layer = model_cls(SIZE, SIZE)
hidden_layer = model_cls(SIZE, SIZE)
output_layer = model_cls(SIZE, SIZE)
input = _get_input()
if cpu_offload:
offload_context, sync_function = te.get_cpu_offload_context(
enabled=True,
num_layers=2,
model_layers=3,
offload_activations=True,
offload_weights=False,
)
else:
offload_context = nullcontext()
sync_function = lambda x: x
with te.fp8_autocast(enabled=fp8), offload_context:
out = input_layer(input)
out = sync_function(out)
with te.fp8_autocast(enabled=fp8), offload_context:
out = hidden_layer(out)
out = sync_function(out)
with te.fp8_autocast(enabled=fp8), offload_context:
out = output_layer(out)
out = sync_function(out)
max_mem_used = torch.cuda.memory_allocated() / 1024**2
out.sum().backward()
del input_layer
del hidden_layer
del output_layer
del input
del out
torch.cuda.synchronize()
return max_mem_used
@pytest.mark.parametrize("fp8", [True, False])
@pytest.mark.parametrize("model_key", models.keys())
def test_cpu_offload(fp8, model_key) -> None:
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
model_cls = models[model_key]
without_offloading = _measure_memory_between_forward_and_backward(model_cls, fp8, False)
with_offloading = _measure_memory_between_forward_and_backward(model_cls, fp8, True)
assert with_offloading < without_offloading
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from dataclasses import dataclass
import itertools
from typing import Iterable, List, Tuple, Union
import pytest
import torch
from transformer_engine.pytorch import (
DotProductAttention,
LayerNormLinear,
LayerNormMLP,
Linear,
MultiheadAttention,
TransformerLayer,
fp8_autocast,
fp8_model_init,
make_graphed_callables,
)
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.common import recipe
# Check if FP8 is supported.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
# Record initial RNG state.
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
@dataclass
class ModelConfig:
"""Data tensor dimensions within Transformer model"""
sequence_length: int
batch_size: int
hidden_size: int
num_heads: int
kv_channels: int
model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)}
fp8_recipes = [
recipe.DelayedScaling(),
recipe.MXFP8BlockScaling(),
]
# Supported data types
dtypes: List[torch.dtype] = [torch.float32, torch.float16]
if is_bf16_compatible(): # bf16 requires sm_80 or higher
dtypes.append(torch.bfloat16)
def reset_rng_states() -> None:
"""Revert to initial RNG state."""
torch.set_rng_state(_cpu_rng_state)
torch.cuda.set_rng_state(_cuda_rng_state)
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
yield
FP8GlobalStateManager.reset()
def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool:
"""Check that two lists of tensors match exactly."""
assert len(l1) == len(l2), "Unequal number of outputs."
failure_message = "Output mismatches in:"
failed_tensors = []
for i, (t1, t2) in enumerate(zip(l1, l2)):
if not torch.equal(t1, t2):
failure_message += "\n "
if names is None:
failure_message += f"tensor at idx={i}"
else:
failure_message += names[i]
failed_tensors.append((t1, t2))
if failed_tensors:
print(failure_message)
t1, t2 = failed_tensors[0]
torch.testing.assert_close(t1, t2, rtol=0, atol=0)
def generate_data(
model_config: ModelConfig,
dtype: torch.dtype,
warmup: bool = False,
requires_grad: bool = True,
) -> torch.Tensor:
"""Generate synthetic data."""
gen_func = torch.ones if warmup else torch.randn
return gen_func(
model_config.sequence_length,
model_config.batch_size,
model_config.hidden_size,
device="cuda",
requires_grad=requires_grad,
dtype=dtype,
)
def get_outputs(
model: torch.nn.Module,
output: Union[torch.Tensor, Iterable[torch.Tensor]],
) -> List[torch.Tensor]:
"""Return grads and params for comparsion."""
values = []
for param in model.parameters():
values.append(param)
if param.grad is not None:
values.append(param.grad)
if isinstance(output, torch.Tensor):
values.append(output)
else:
values.extend(output)
return values
class _Sequential(torch.nn.Sequential):
"""Sequential model that forwards keyword arguments to modules"""
def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor:
x = input_
for module in self:
x = module(x, **kwargs)
return x
# Supported modules
_test_cuda_graphs_modules: List[str] = [
"transformer",
"layernorm_mlp",
"layernorm_linear",
"linear",
"mha",
"linear_op",
]
def _test_cuda_graphs(
*,
graph_mode: str,
module: str,
model_config: ModelConfig,
num_layers: int,
dtype: torch.dtype,
fp8: bool,
fp8_params: bool,
fp8_weight_caching: bool,
fp8_recipe: recipe.Recipe,
) -> List[torch.Tensor]:
"""Helper function for CUDA graph test."""
reset_rng_states()
FP8GlobalStateManager.reset()
# Operation-based API does not support FP8 weight caching.
if module == "linear_op":
fp8_weight_caching = False
# Create modules.
with fp8_model_init(enabled=fp8_params, recipe=fp8_recipe):
if module == "transformer":
modules = [
TransformerLayer(
model_config.hidden_size,
model_config.hidden_size,
model_config.num_heads,
hidden_dropout=0.0,
attention_dropout=0.0,
fuse_qkv_params=True,
params_dtype=dtype,
)
for _ in range(num_layers)
]
elif module == "layernorm_mlp":
modules = [
LayerNormMLP(
model_config.hidden_size,
model_config.hidden_size,
params_dtype=dtype,
)
for _ in range(num_layers)
]
elif module == "layernorm_linear":
modules = [
LayerNormLinear(
model_config.hidden_size,
model_config.hidden_size,
params_dtype=dtype,
)
for _ in range(num_layers)
]
elif module == "mha":
modules = [
MultiheadAttention(
model_config.hidden_size,
model_config.num_heads,
attention_dropout=0.0,
params_dtype=dtype,
fuse_qkv_params=True,
)
for _ in range(num_layers)
]
elif module == "linear":
modules = [
Linear(
model_config.hidden_size,
model_config.hidden_size,
device="cuda",
params_dtype=dtype,
)
for _ in range(num_layers)
]
elif module == "linear_op":
modules = [
te_ops.Sequential(
te_ops.Linear(
model_config.hidden_size,
model_config.hidden_size,
dtype=dtype,
),
)
for _ in range(num_layers)
]
else:
raise ValueError(f"Unknown module type ({module})")
# Initialize gradient buffers.
for module in modules:
for param in module.parameters():
param.grad = torch.empty_like(param)
# Generate model and wrap API to return graphed version.
if graph_mode == "full":
# Graph entire model at once.
model = torch.nn.Sequential(*modules)
model = make_graphed_callables(
model,
(generate_data(model_config, dtype, warmup=True),),
num_warmup_iters=10,
fp8_enabled=fp8,
fp8_weight_caching=fp8_weight_caching,
fp8_recipe=fp8_recipe,
)
elif graph_mode == "individual":
# Graph individual modules.
modules = [
make_graphed_callables(
module,
(generate_data(model_config, dtype, warmup=True),),
num_warmup_iters=10,
fp8_enabled=fp8,
fp8_weight_caching=fp8_weight_caching,
fp8_recipe=fp8_recipe,
)
for module in modules
]
model = _Sequential(*modules)
else:
model = _Sequential(*modules)
# Optimizer.
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
# Training steps.
for _ in range(3):
optimizer.zero_grad(set_to_none=False)
for grad_accumulation_step in range(2):
input_ = generate_data(model_config, dtype)
grad_output = generate_data(model_config, dtype, requires_grad=False)
with fp8_autocast(enabled=fp8, fp8_recipe=fp8_recipe):
kwargs = {}
if fp8_weight_caching:
kwargs["is_first_microbatch"] = grad_accumulation_step == 0
output = model(input_, **kwargs)
output.backward(grad_output)
optimizer.step()
return get_outputs(model, output)
@pytest.mark.parametrize("module", _test_cuda_graphs_modules)
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("fp8", (False, True))
@pytest.mark.parametrize("fp8_params", (False, True))
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
def test_make_graphed_callables(
*,
module: str,
model_config: str = "small",
num_layers: int = 3,
dtype: torch.dtype,
fp8: bool,
fp8_params: bool,
fp8_recipe: recipe.Recipe,
fp8_weight_caching: bool = False,
) -> None:
# Skip invalid configurations.
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_params and not fp8:
pytest.skip("FP8 needed for FP8 parameters.")
if fp8_weight_caching and not fp8:
pytest.skip("FP8 needed for FP8 parameters.")
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
# Run model with different CUDA graph settings.
model_config = model_configs[model_config]
kwargs = dict(
module=module,
model_config=model_config,
num_layers=num_layers,
dtype=dtype,
fp8=fp8,
fp8_params=fp8_params,
fp8_weight_caching=fp8_weight_caching,
fp8_recipe=fp8_recipe,
)
outputs = _test_cuda_graphs(graph_mode="none", **kwargs)
graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs)
graph_outputs_mode2 = _test_cuda_graphs(graph_mode="individual", **kwargs)
# Check that results match.
assert_all_equal(outputs, graph_outputs_mode1)
assert_all_equal(outputs, graph_outputs_mode2)
_test_make_graphed_callables_with_fp8_weight_caching_modules = [
"transformer",
"layernorm_mlp",
"layernorm_linear",
"linear",
"mha",
]
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize(
"module",
_test_make_graphed_callables_with_fp8_weight_caching_modules,
)
@pytest.mark.parametrize("fp8_params", (False, True))
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
def test_make_graphed_callables_with_fp8_weight_caching(
*,
module: str,
fp8_params: bool,
fp8_recipe: recipe.Recipe,
) -> None:
test_make_graphed_callables(
module=module,
dtype=torch.float32,
fp8=True,
fp8_params=fp8_params,
fp8_recipe=fp8_recipe,
fp8_weight_caching=True,
)
def generate_data_for_dot_product_attention(
model_config: ModelConfig,
dtype: torch.dtype,
warmup: bool = False,
) -> List[torch.Tensor]:
"""Generate synthetic data for dot product attention."""
gen_func = torch.ones if warmup else torch.randn
return [
gen_func(
model_config.sequence_length,
model_config.batch_size,
model_config.num_heads,
model_config.kv_channels,
device="cuda",
requires_grad=True,
dtype=dtype,
)
for _ in range(3)
]
def _test_cuda_graphs_with_dot_product_attention(
*,
with_graph: bool,
model_config: ModelConfig,
dtype: torch.dtype,
) -> List[torch.Tensor]:
"""Helper function for CUDA graph test."""
reset_rng_states()
FP8GlobalStateManager.reset()
# Create dot product attention module.
assert model_config.hidden_size % model_config.num_heads == 0
model = DotProductAttention(
model_config.num_heads,
model_config.kv_channels,
attention_dropout=0.0,
)
# Graph model if needed.
if with_graph:
model = make_graphed_callables(
model,
generate_data_for_dot_product_attention(model_config, dtype, warmup=True),
num_warmup_iters=10,
fp8_enabled=False,
)
# Forward and backward passes.
for _ in range(3):
inputs = generate_data_for_dot_product_attention(model_config, dtype)
grad_output = generate_data(model_config, dtype, requires_grad=False)
output = model(*inputs)
output.backward(grad_output)
return get_outputs(model, output)
@pytest.mark.parametrize("dtype", dtypes)
def test_make_graphed_callables_with_dot_product_attention(
*,
model_config: str = "small",
dtype: torch.dtype,
) -> None:
"""Test CUDA graphs with dot product attention."""
model_config = model_configs[model_config]
kwargs = dict(model_config=model_config, dtype=dtype)
outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=False, **kwargs)
graph_outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=True, **kwargs)
assert_all_equal(outputs, graph_outputs)
def _test_cuda_graphs_with_kwargs(
*,
with_graph: bool,
model_config: ModelConfig,
dtype: torch.dtype,
) -> List[torch.Tensor]:
"""Helper function for CUDA graph test with keyword arguments."""
reset_rng_states()
# Initialize model.
model = TransformerLayer(
model_config.hidden_size,
model_config.hidden_size,
model_config.num_heads,
hidden_dropout=0.0,
attention_dropout=0.0,
self_attn_mask_type="arbitrary",
fuse_qkv_params=True,
params_dtype=dtype,
)
# Initialize gradient buffers.
for param in model.parameters():
param.grad = torch.empty_like(param)
# Make graphed version of model if needed.
if with_graph:
attn_mask = torch.zeros(
(
model_config.batch_size,
1,
model_config.sequence_length,
model_config.sequence_length,
),
dtype=torch.bool,
device="cuda",
)
model = make_graphed_callables(
model,
(generate_data(model_config, dtype, warmup=True),),
sample_kwargs=dict(attention_mask=attn_mask),
allow_unused_input=True,
)
# Optimizer.
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
# Training loop.
for _ in range(3):
optimizer.zero_grad(set_to_none=False)
for grad_accumulation_step in range(2):
input_ = generate_data(model_config, dtype)
grad_output = generate_data(model_config, dtype, requires_grad=False)
attn_mask = torch.randint(
2,
(
model_config.batch_size,
1,
model_config.sequence_length,
model_config.sequence_length,
),
dtype=torch.bool,
device="cuda",
)
output = model(input_, attention_mask=attn_mask)
output.backward(grad_output)
optimizer.step()
return get_outputs(model, output)
def test_make_graphed_callables_with_kwargs(
*,
model_config: str = "small",
dtype: torch.dtype = torch.float32,
) -> None:
"""Test CUDA graphs with keyword arguments."""
model_config = model_configs[model_config]
kwargs = dict(model_config=model_config, dtype=dtype)
outputs = _test_cuda_graphs_with_kwargs(with_graph=False, **kwargs)
graph_outputs = _test_cuda_graphs_with_kwargs(with_graph=True, **kwargs)
assert_all_equal(outputs, graph_outputs)
def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
*,
with_graph: bool,
model_config: ModelConfig,
dtype: torch.dtype,
) -> List[torch.Tensor]:
"""Simulate Megatron-LM interleaved pipeline parallelism."""
reset_rng_states()
# Pipeline parallel configuration.
num_layers = 2
num_microbatches = 3
layer_order = [1, 2, 1, 2, -2, -1, 1, 2, -2, -1, -2, -1]
# Initialize model.
model = torch.nn.ModuleList(
[
Linear(
model_config.hidden_size,
model_config.hidden_size,
params_dtype=dtype,
)
for _ in range(num_layers)
]
)
# Initialize gradient buffers.
for param in model.parameters():
param.grad = torch.empty_like(param)
# Make graphed version of model if needed.
layer_forwards = {
(i % num_layers, i // num_layers): model[i % num_layers]
for i in range(num_layers * num_microbatches)
}
if with_graph:
sample_args = tuple(
(generate_data(model_config, dtype, warmup=True),)
for _ in range(num_layers * num_microbatches)
)
layer_forwards = make_graphed_callables(
tuple(model),
sample_args,
allow_unused_input=True,
_order=layer_order,
)
layer_forwards = {
(i // num_microbatches, i % num_microbatches): forward
for i, forward in enumerate(layer_forwards)
}
# Optimizer.
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
# Training loop.
for _ in range(3):
optimizer.zero_grad(set_to_none=False)
# Generate data.
inputs = {}
grad_outputs = {}
for layer_idx in range(num_layers):
for microbatch_idx in range(num_microbatches):
x = generate_data(model_config, dtype)
dy = generate_data(model_config, dtype, requires_grad=False)
idxs = (layer_idx, microbatch_idx)
inputs[idxs] = x
grad_outputs[idxs] = dy
# Cache for layer outputs.
outputs = {}
def forward(layer_idx: int, microbatch_idx: int):
"""Helper function for forward steps"""
idxs = (layer_idx, microbatch_idx)
outputs[idxs] = layer_forwards[idxs](inputs[idxs])
def backward(layer_idx: int, microbatch_idx: int):
"""Helper function for backward steps"""
outputs[layer_idx, microbatch_idx].backward(grad_outputs[layer_idx, microbatch_idx])
# Forward and backward steps.
forward(0, 0)
forward(1, 0)
forward(0, 1)
forward(1, 1)
backward(1, 0)
backward(0, 0)
forward(0, 2)
forward(1, 2)
backward(1, 1)
backward(0, 1)
backward(1, 2)
backward(0, 2)
# Optimizer step.
optimizer.step()
outputs = [y for _, y in sorted(outputs.items())]
return get_outputs(model, outputs)
def test_make_graphed_callables_with_interleaved_pipeline_parallelism(
*,
model_config: str = "small",
dtype: torch.dtype = torch.float16,
) -> None:
"""Test CUDA graphs with Megatron-LM interleaved pipeline parallelism."""
model_config = model_configs[model_config]
kwargs = dict(model_config=model_config, dtype=dtype)
outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
with_graph=False,
**kwargs,
)
graph_outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
with_graph=True,
**kwargs,
)
assert_all_equal(outputs, graph_outputs)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import torch
import torch.distributed as dist
import transformer_engine.pytorch as te
_core_modules = [
te.LayerNorm,
te.RMSNorm,
te.Linear,
te.LayerNormLinear,
te.LayerNormMLP,
]
_composed_modules = [
te.MultiheadAttention,
te.TransformerLayer,
]
batch_size = 32
seq_length = 2048
num_heads = 16
head_dim = 64
dtype = torch.bfloat16
class TestDeferredInit:
@staticmethod
def get_module_args(module):
hidden_size = num_heads * head_dim
args = (hidden_size,)
kwargs = {"params_dtype": dtype, "device": "meta"}
if module in [te.Linear, te.LayerNormLinear, te.LayerNormMLP]:
ffn_hidden_size = 2 * hidden_size
args += (ffn_hidden_size,)
kwargs["bias"] = True
if module == te.LayerNormMLP:
kwargs["seq_length"] = seq_length
elif module == te.MultiheadAttention:
args += (num_heads,)
kwargs["fuse_qkv_params"] = True
elif module == te.TransformerLayer:
args += (3 * hidden_size, num_heads)
kwargs["fuse_qkv_params"] = True
kwargs["seq_length"] = seq_length
return args, kwargs
@pytest.mark.parametrize("module_type", _core_modules + _composed_modules)
def test_zero_memory_init(
self,
module_type: torch.nn.Module,
) -> None:
"""Test deferred initialization via device='meta'."""
# This should not allocate any memory on CUDA device until we call reset_parameters() later.
args, kwargs = TestDeferredInit.get_module_args(module_type)
module = module_type(*args, **kwargs)
assert torch.cuda.memory_allocated(device=0) == 0.0, (
f"Initializing {module_type.__name__} with device='meta' prematurely allocated "
"memory on CUDA device"
)
del module
@pytest.mark.parametrize("module_type", _core_modules)
def test_reset_parameters(
self,
module_type: torch.nn.Module,
) -> None:
"""Test parameter reset for core modules that have been initialized with device='meta'."""
# Core modules own their own parameters so calling reset_parameters() here should
# materialize them on CUDA device.
args, kwargs = TestDeferredInit.get_module_args(module_type)
module = module_type(*args, **kwargs)
with torch.no_grad():
module.reset_parameters()
assert torch.cuda.memory_allocated(device=0) > 0.0, (
f"{module_type.__name__}.reset_parameters() failed to materialize parameters "
"on CUDA device"
)
del module
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pathlib
import os
import torch
import pytest
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.common.recipe import Float8CurrentScaling
from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp8_torch_dtype
# read env variable NVTE_TEST_FLOAT8_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR to override the default tensor dump directory
TENSOR_DUMP_DIR = pathlib.Path(__file__).resolve().parent.parent.parent / "tensor_dumps"
tensor_dump_dir_env = os.getenv("NVTE_TEST_FLOAT8_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR")
if tensor_dump_dir_env is not None:
TENSOR_DUMP_DIR = pathlib.Path(tensor_dump_dir_env)
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
class GetRecipes:
@staticmethod
def none():
return None
@staticmethod
def fp8_per_tensor_current_scaling_default():
# return default configs
return Float8CurrentScaling()
# base class for validating current_scaling x linear layer
class TestFP8RecipeLinearBase:
@staticmethod
def _prepare_data(
batch_size, hidden_size, out_size, use_bias=True, seed=0, dtype=torch.float32
):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
x = torch.randn((batch_size, hidden_size), dtype=dtype, device="cuda")
w = torch.randn((out_size, hidden_size), dtype=dtype, device="cuda")
bias = torch.randn((out_size), dtype=dtype, device="cuda") if use_bias else None
gradient = torch.randn((batch_size, out_size), dtype=dtype, device="cuda")
return x, w, bias, gradient
@staticmethod
def _shard_tensor(x, world_size, axis):
split_size = x.size()[axis] // world_size
split_tensor = torch.split(x, split_size, axis)
out = []
for tensor in split_tensor:
out.append(tensor.detach().clone().requires_grad_(x.requires_grad))
return out
@staticmethod
def _gather_tensor(local, world_size, tp_group, concat_dim):
out_list = [torch.zeros_like(local) for _ in range(world_size)]
torch.distributed.all_gather(out_list, local, tp_group)
return torch.cat(out_list, dim=concat_dim)
@staticmethod
def _all_reduce_tensor(local, world_size, tp_group):
if world_size == 1:
return local
handle = torch.distributed.all_reduce(local, group=tp_group, async_op=False)
return local
@staticmethod
def _get_sum_abs_error(a, b):
return torch.sum(torch.abs(a - b))
@staticmethod
def _get_mean_abs_relative_error(a, b):
return torch.mean(torch.abs((a - b) / b))
@staticmethod
def _load_golden_tensor_values(a, b):
return torch.sum(torch.abs(a - b))
@staticmethod
def _check_golden_tensor_dumps(dump_dir, get_recipe, dims, input_dtype, use_bias):
recipe = get_recipe()
batch_size, hidden_size, out_size = dims
fp8_type_x = get_fp8_torch_dtype(recipe, fprop_tensor=True)
fp8_type_w = get_fp8_torch_dtype(recipe, fprop_tensor=True)
fp8_type_g = get_fp8_torch_dtype(recipe, fprop_tensor=False)
# Expected tensor names based on the naming template
scaling_type = ( # Assuming the scaling type is PER_TENSOR for this example
"ScalingType.PER_TENSOR"
)
current_seed = torch.initial_seed() # Get the current seed
expected_tensor_names = {
"y": f"y_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt",
"dgrad": f"dgrad_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt",
"wgrad": f"wgrad_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt",
"bgrad": f"bgrad_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt",
}
if not use_bias:
expected_tensor_names.pop("bgrad")
# Check if all expected tensors are in the tensor dumps directory
tensor_map = {}
for tensor_key, tensor_name in expected_tensor_names.items():
tensor_path = dump_dir / tensor_name
if not os.path.exists(tensor_path):
print(f"Missing tensor: {tensor_name}")
return None
# Load the tensor
tensor_map[tensor_key] = torch.load(tensor_path)
return tensor_map
@classmethod
def run_linear_preprocess_parallel(
cls,
x,
w,
bias,
gradient,
parallel_mode=None,
sequence_parallel=False,
tp_size=1,
rank=0,
):
if tp_size > 1:
if parallel_mode == "column":
# split w in N dim, which should be axis 0
w = cls._shard_tensor(w, tp_size, 0)[rank]
bias = cls._shard_tensor(bias, tp_size, 0)[rank] if bias is not None else None
# split gradient in N dim, which should be axis 1
gradient = cls._shard_tensor(gradient, tp_size, 1)[rank]
if sequence_parallel:
# split x in M dim, which should be axis 0
x = cls._shard_tensor(x, tp_size, 0)[rank]
# row parallel, split x in k dim, which should be axis 1, split w in k dim, should be axis 1
if parallel_mode == "row":
# split x in K dim, which should be axis 1
x = cls._shard_tensor(x, tp_size, 1)[rank]
# split w in K dim, which should be axis 1
w = cls._shard_tensor(w, tp_size, 1)[rank]
if sequence_parallel:
# split gradient in M dim, which should be axis 0
gradient = cls._shard_tensor(gradient, tp_size, 0)[rank]
return x, w, bias, gradient
@classmethod
def run_linear_postprocess_parallel(
cls,
y_q,
dgrad,
wgrad,
bgrad,
parallel_mode,
sequence_parallel,
tp_size,
tp_group,
):
if tp_size > 1:
if parallel_mode == "column":
# gather y_q in N dim, which should be axis 1
y_q = cls._gather_tensor(y_q, tp_size, tp_group, 1)
# gather wgrad in N dim, which should be axis 0
wgrad = cls._gather_tensor(wgrad, tp_size, tp_group, 0)
# gather bgrad in N dim, which should be axis 0
bgrad = (
cls._gather_tensor(bgrad, tp_size, tp_group, 0) if bgrad is not None else None
)
if sequence_parallel:
# gather dgrad in M dim, which should be axis 0
dgrad = cls._gather_tensor(dgrad, tp_size, tp_group, 0)
if parallel_mode == "row":
# gather dgrad in K dim, which should be axis 1
dgrad = cls._gather_tensor(dgrad, tp_size, tp_group, 1)
# gather wgrad in K dim, which should be axis 1
wgrad = cls._gather_tensor(wgrad, tp_size, tp_group, 1)
if sequence_parallel:
# gather y_q in M dim, which should be axis 0
y_q = cls._gather_tensor(y_q, tp_size, tp_group, 0)
# we need to sum bias gradient when using TP + SP
bgrad = (
cls._all_reduce_tensor(bgrad, tp_size, tp_group)
if bgrad is not None
else None
)
return y_q, dgrad, wgrad, bgrad
@classmethod
def run_linear_one_step(
cls, layer, x, gradient, is_first_microbatch=None, fuse_wgrad_accumulation=False
):
# reset gradients
layer.zero_grad()
x.grad = None
# Forward pass
if isinstance(layer, te.Linear):
# Kitchen Linear
y_q = layer.forward(x, is_first_microbatch=is_first_microbatch)
else:
# the default torch.nn.Linear
y_q = layer(x)
# Backward pass
y_q.backward(gradient)
# Collect gradients
dgrad = x.grad
bgrad = (
layer._parameters["bias"].grad
if layer._parameters.get("bias", None) is not None
else None
)
assert "weight" in layer._parameters
if fuse_wgrad_accumulation:
wgrad = layer._parameters["weight"].main_grad
assert layer._parameters["weight"].grad is None
else:
wgrad = layer._parameters["weight"].grad
return y_q, dgrad, wgrad, bgrad
@classmethod
def run_linear_multiple_steps(
cls,
layer,
x,
gradient,
run_num_steps,
enable_weight_cache,
fuse_wgrad_accumulation=False,
):
"""
Run multiple steps of linear layer and collect results.
"""
y_q_list, dgrad_list, wgrad_list = [], [], []
bgrad_list = [] if layer._parameters.get("bias", None) is not None else None
for i in range(run_num_steps):
x_i = (x + i).clone().detach().requires_grad_(True)
# run_linear_one_step
y_q, dgrad, wgrad, bgrad = cls.run_linear_one_step(
layer,
x_i,
gradient,
is_first_microbatch=(i == 0) if enable_weight_cache else None,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
)
# Collect results
y_q_list.append(y_q.detach().clone())
dgrad_list.append(dgrad.detach().clone())
wgrad_list.append(wgrad.detach().clone())
if bgrad_list is not None and bgrad is not None:
bgrad_list.append(bgrad.detach().clone())
@classmethod
def run_linear(
cls,
x,
w,
bias,
gradient,
parallel_mode=None,
sequence_parallel=False,
tp_group=None,
tp_size=1,
rank=0,
run_num_steps=1,
enable_weight_cache=False,
fuse_wgrad_accumulation=False,
):
"""
If Model parallel, split inputs for a given rank and return the gathered output and gradients, so that they can be compared with
the reference single GPU run.
"""
# clone inputs and move to current device
# w has shape [N, K], x has shape [M, K], gradient has shape [M, N]
x = x.clone().detach().requires_grad_(True).to("cuda")
w = w.clone().detach().to("cuda")
gradient = gradient.clone().detach().to("cuda")
bias = bias.clone().detach().to("cuda") if bias is not None else None
in_features = x.shape[1]
out_features = w.shape[0]
# If Model parallel: split inputs for a given rank
x, w, bias, gradient = cls.run_linear_preprocess_parallel(
x, w, bias, gradient, parallel_mode, sequence_parallel, tp_size, rank
)
# set data types
params_dtype = x.dtype
# Create linear layer and copy weights
layer = te.Linear(
in_features,
out_features,
bias=bias is not None,
params_dtype=params_dtype,
parallel_mode=parallel_mode,
sequence_parallel=sequence_parallel,
tp_group=tp_group,
tp_size=tp_size,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
)
layer = layer.to("cuda")
with torch.no_grad():
layer.weight.copy_(w)
if bias is not None:
layer.bias.copy_(bias)
if fuse_wgrad_accumulation:
assert (
run_num_steps > 1
), "Fused weight gradient accumulation requires run_num_steps > 1"
layer.weight.main_grad = torch.zeros_like(layer.weight)
# Run one step or multiple steps
if run_num_steps == 1:
y_q, dgrad, wgrad, bgrad = cls.run_linear_one_step(layer, x, gradient)
else:
y_q, dgrad, wgrad, bgrad = cls.run_linear_multiple_steps(
layer,
x,
gradient,
run_num_steps,
enable_weight_cache,
fuse_wgrad_accumulation,
)
# If Model parallel: gather output and gradients from all ranks
y_q, dgrad, wgrad, bgrad = cls.run_linear_postprocess_parallel(
y_q,
dgrad,
wgrad,
bgrad,
parallel_mode,
sequence_parallel,
tp_size,
tp_group,
)
return y_q, dgrad, wgrad, bgrad
def compare_recipe(
self,
recipe1,
recipe2,
batch_size,
hidden_size,
out_size,
use_bias,
seed,
dtype,
y_error=0.0,
dgrad_error=0.0,
wgrad_error=0.0,
bgrad_error=0.0,
recipe1_golden_tensors=None,
recipe2_golden_tensors=None,
):
x, w, bias, gradient = self._prepare_data(
batch_size, hidden_size, out_size, use_bias, seed=seed, dtype=dtype
)
# recipe1
using_fp8_recipe = recipe1 != GetRecipes.none
if using_fp8_recipe:
with fp8_autocast(enabled=True, fp8_recipe=recipe1()):
y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_linear(x, w, bias, gradient)
else:
y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_linear(x, w, bias, gradient)
# recipe2
using_fp8_recipe = recipe2 != GetRecipes.none
if using_fp8_recipe:
with fp8_autocast(enabled=True, fp8_recipe=recipe2()):
y_q, dgrad, wgrad, bgrad = self.run_linear(x, w, bias, gradient)
else:
y_q, dgrad, wgrad, bgrad = self.run_linear(x, w, bias, gradient)
# Compare results (mean abs relative error)
assert (
self._get_mean_abs_relative_error(y_q, y_q_ref).item() < y_error
), "y and y_ref has too large mean abs relative error"
assert (
self._get_mean_abs_relative_error(dgrad, dgrad_ref) < dgrad_error
), "dgrad and dgrad_ref has too large mean abs relative error"
assert (
self._get_mean_abs_relative_error(wgrad, wgrad_ref).item() < wgrad_error
), "wgrad and wgrad_ref has too large mean abs relative error"
if use_bias:
assert (
self._get_mean_abs_relative_error(bgrad, bgrad_ref).item() < bgrad_error
), "bgrad and bgrad_ref has too large mean abs relative error"
# enforce zero tolerance check when we can find golden tensor value dump
if recipe2_golden_tensors is not None:
torch.testing.assert_close(
y_q.float(), recipe2_golden_tensors["y"].float(), atol=0, rtol=0.0
)
torch.testing.assert_close(dgrad, recipe2_golden_tensors["dgrad"], atol=0.0, rtol=0.0)
torch.testing.assert_close(wgrad, recipe2_golden_tensors["wgrad"], atol=0.0, rtol=0.0)
if use_bias:
torch.testing.assert_close(
bgrad, recipe2_golden_tensors["bgrad"], atol=0.0, rtol=0.0
)
class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase):
@staticmethod
def _check_golden_tensor_dumps(
dump_dir, get_recipe, dims, input_dtype, use_bias, normalization
):
recipe = get_recipe()
batch_size, hidden_size, out_size = dims
fp8_type_x = get_fp8_torch_dtype(recipe, fprop_tensor=True)
fp8_type_w = get_fp8_torch_dtype(recipe, fprop_tensor=True)
fp8_type_g = get_fp8_torch_dtype(recipe, fprop_tensor=False)
# Expected tensor names based on the naming template
scaling_type = ( # Assuming the scaling type is PER_TENSOR for this example
"ScalingType.PER_TENSOR"
)
current_seed = torch.initial_seed() # Get the current seed
expected_tensor_names = {
"y": f"y_{normalization}_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt",
"ln_out": f"ln_out_{normalization}_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt",
"dgrad": f"dgrad_{normalization}_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt",
"wgrad": f"wgrad_{normalization}_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt",
"bgrad": f"bgrad_{normalization}_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt",
}
if not use_bias:
expected_tensor_names.pop("bgrad")
# Check if all expected tensors are in the tensor dumps directory
tensor_map = {}
for tensor_key, tensor_name in expected_tensor_names.items():
tensor_path = dump_dir / tensor_name
if not os.path.exists(tensor_path):
print(f"Missing tensor: {tensor_name}")
return None
# Load the tensor
tensor_map[tensor_key] = torch.load(tensor_path)
return tensor_map
@classmethod
def run_linear_one_step(cls, layer, x, gradient, is_first_microbatch=None):
# reset gradients
layer.zero_grad()
x.grad = None
# Forward pass
y_q, ln_out = layer.forward(x, is_first_microbatch=is_first_microbatch)
# Backward pass
y_q.backward(gradient)
# Collect gradients
dgrad = x.grad
parameters = layer._parameters
# bias and weight gradients
bgrad = parameters["bias"].grad if parameters.get("bias", None) is not None else None
assert "weight" in parameters
wgrad = parameters["weight"].grad
return y_q, ln_out, dgrad, wgrad, bgrad
@classmethod
def run_linear_multiple_steps(
cls, layer, x, gradient, run_num_steps, enable_weight_cache, fuse_wgrad_accumulation=False
):
# raise error, no test case for multiple steps for now
raise NotImplementedError("LayerNormLinear does not support test multiple steps for now")
@classmethod
def run_layernorm_linear(
cls,
x,
w,
bias,
gradient,
parallel_mode=None,
sequence_parallel=False,
tp_group=None,
tp_size=1,
rank=0,
run_num_steps=1,
enable_weight_cache=False,
LayerNormLinearClass=te.LayerNormLinear,
normalization="LayerNorm",
):
"""
If Model parallel, split inputs for a given rank and return the gathered output and gradients, so that they can be compared with
the reference single GPU run.
"""
# clone inputs and move to current device
# w has shape [N, K], x has shape [M, K], gradient has shape [M, N]
x = x.clone().detach().requires_grad_(True).to("cuda")
w = w.clone().detach().to("cuda")
gradient = gradient.clone().detach().to("cuda")
bias = bias.clone().detach().to("cuda") if bias is not None else None
in_features = x.shape[1]
out_features = w.shape[0]
# If Model parallel: split inputs for a given rank
x, w, bias, gradient = cls.run_linear_preprocess_parallel(
x, w, bias, gradient, parallel_mode, sequence_parallel, tp_size, rank
)
# set data types
params_dtype = x.dtype
# Create linear layer and copy weights
layer = LayerNormLinearClass(
in_features,
out_features,
bias=bias is not None,
params_dtype=params_dtype,
parallel_mode=parallel_mode,
sequence_parallel=sequence_parallel,
tp_group=tp_group,
tp_size=tp_size,
normalization=normalization,
return_layernorm_output=True,
)
layer = layer.to("cuda")
# Copy weights
# kitchen_linear has different parameter names
with torch.no_grad():
layer.weight.copy_(w)
if bias is not None:
layer.bias.copy_(bias)
# Run one step
y_q, ln_out, dgrad, wgrad, bgrad = cls.run_linear_one_step(layer, x, gradient)
# If Model parallel: gather output and gradients from all ranks
y_q, dgrad, wgrad, bgrad = cls.run_linear_postprocess_parallel(
y_q,
dgrad,
wgrad,
bgrad,
parallel_mode,
sequence_parallel,
tp_size,
tp_group,
)
return y_q, ln_out, dgrad, wgrad, bgrad
def compare_recipe(
self,
recipe1,
recipe2,
batch_size,
hidden_size,
out_size,
use_bias,
seed,
dtype,
y_error=0.0,
ln_out_error=0.0,
dgrad_error=0.0,
wgrad_error=0.0,
bgrad_error=0.0,
normalization="LayerNorm",
LayerNormLinearClass1=te.LayerNormLinear,
LayerNormLinearClass2=te.LayerNormLinear,
recipe1_golden_tensors=None,
recipe2_golden_tensors=None,
):
x, w, bias, gradient = self._prepare_data(
batch_size, hidden_size, out_size, use_bias, seed=seed, dtype=dtype
)
# recipe1
using_fp8_recipe = recipe1 != GetRecipes.none
if using_fp8_recipe:
with fp8_autocast(enabled=True, fp8_recipe=recipe1()):
y_q_ref, ln_out_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_layernorm_linear(
x,
w,
bias,
gradient,
normalization=normalization,
LayerNormLinearClass=LayerNormLinearClass1,
)
else:
y_q_ref, ln_out_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_layernorm_linear(
x,
w,
bias,
gradient,
normalization=normalization,
LayerNormLinearClass=LayerNormLinearClass1,
)
# recipe2
using_fp8_recipe = recipe2 != GetRecipes.none
if using_fp8_recipe:
with fp8_autocast(enabled=True, fp8_recipe=recipe2()):
y_q, ln_out, dgrad, wgrad, bgrad = self.run_layernorm_linear(
x,
w,
bias,
gradient,
normalization=normalization,
LayerNormLinearClass=LayerNormLinearClass2,
)
else:
y_q, ln_out, dgrad, wgrad, bgrad = self.run_layernorm_linear(
x,
w,
bias,
gradient,
normalization=normalization,
LayerNormLinearClass=LayerNormLinearClass2,
)
# Compare results (mean abs relative error)
assert (
self._get_mean_abs_relative_error(y_q, y_q_ref).item() < y_error
), "y and y_ref has too large mean abs relative error"
assert (
self._get_mean_abs_relative_error(ln_out, ln_out_ref).item() < ln_out_error
), "ln_out and ln_out_ref has too large mean abs relative error"
assert (
self._get_mean_abs_relative_error(dgrad, dgrad_ref) < dgrad_error
), "dgrad and dgrad_ref has too large mean abs relative error"
assert (
self._get_mean_abs_relative_error(wgrad, wgrad_ref).item() < wgrad_error
), "wgrad and wgrad_ref has too large mean abs relative error"
if use_bias:
assert (
self._get_mean_abs_relative_error(bgrad, bgrad_ref).item() < bgrad_error
), "bgrad and bgrad_ref has too large mean abs relative error"
# enforce zero tolerance check when we can find golden tensor value dump
if recipe2_golden_tensors is not None:
torch.testing.assert_close(
y_q.float(), recipe2_golden_tensors["y"].float(), atol=0, rtol=0.0
)
torch.testing.assert_close(ln_out, recipe2_golden_tensors["ln_out"], atol=0.0, rtol=0.0)
torch.testing.assert_close(dgrad, recipe2_golden_tensors["dgrad"], atol=0.0, rtol=0.0)
torch.testing.assert_close(wgrad, recipe2_golden_tensors["wgrad"], atol=0.0, rtol=0.0)
if use_bias:
torch.testing.assert_close(
bgrad, recipe2_golden_tensors["bgrad"], atol=0.0, rtol=0.0
)
# FP8 per tesnor current scaling
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
class TestFP8CurrentScalingRecipeLinear(TestFP8RecipeLinearBase):
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@pytest.mark.parametrize(
"batch_size, hidden_size, out_size",
[
(16, 256, 128),
],
)
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"])
@pytest.mark.parametrize(
"recipe1, recipe2",
[
(GetRecipes.none, GetRecipes.fp8_per_tensor_current_scaling_default),
],
)
def test_fp8_current_scaling_with_linear_module(
self,
recipe1,
recipe2,
batch_size,
hidden_size,
out_size,
dtype,
use_bias=True,
):
fp8_zero_tolerance_tensor_dumps_recipe2 = None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
# if we cannot get all four tensors, then still set the tensor dump to None
tensor_map = self._check_golden_tensor_dumps(
TENSOR_DUMP_DIR, recipe2, (batch_size, hidden_size, out_size), dtype, use_bias
)
if tensor_map is not None:
fp8_zero_tolerance_tensor_dumps_recipe2 = tensor_map
self.compare_recipe(
recipe1,
recipe2,
batch_size,
hidden_size,
out_size,
use_bias,
seed=torch.initial_seed(),
dtype=dtype,
y_error=0.5,
dgrad_error=1,
wgrad_error=1,
bgrad_error=0.5,
recipe1_golden_tensors=None,
recipe2_golden_tensors=fp8_zero_tolerance_tensor_dumps_recipe2,
)
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
class TestFP8CurrentScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBase):
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@pytest.mark.parametrize(
"batch_size, hidden_size, out_size",
[
(16, 256, 128),
],
)
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"])
@pytest.mark.parametrize(
"recipe1, recipe2",
[
(GetRecipes.none, GetRecipes.fp8_per_tensor_current_scaling_default),
],
)
def test_fp8_current_scaling_with_layernorm_linear_module(
self,
recipe1,
recipe2,
batch_size,
hidden_size,
out_size,
dtype,
use_bias=True,
):
fp8_zero_tolerance_tensor_dumps_recipe2 = None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
# if we cannot get all four tensors, then still set the tensor dump to None
tensor_map = self._check_golden_tensor_dumps(
TENSOR_DUMP_DIR,
recipe2,
(batch_size, hidden_size, out_size),
dtype,
use_bias,
"LayerNorm",
)
if tensor_map is not None:
fp8_zero_tolerance_tensor_dumps_recipe2 = tensor_map
self.compare_recipe(
recipe1,
recipe2,
batch_size,
hidden_size,
out_size,
use_bias,
seed=torch.initial_seed(),
dtype=dtype,
y_error=0.5,
ln_out_error=0.5,
dgrad_error=1,
wgrad_error=1,
bgrad_error=0.5,
recipe1_golden_tensors=None,
recipe2_golden_tensors=fp8_zero_tolerance_tensor_dumps_recipe2,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment