Commit 0a5016b1 authored by wenjh's avatar wenjh
Browse files

Merge nv release_v2.9


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents 063ef88d 70f53666
......@@ -29,6 +29,7 @@ using namespace __hip_internal;
typedef uint16_t hip_bfloat16x2 __attribute__((ext_vector_type(2)));
#else
#if !defined(__CUDACC_RTC__)
#include <cassert>
#include <cstdint>
#else
// Importing C++ standard headers is a pain with NVRTC
......
......@@ -2739,10 +2739,13 @@ def fused_attn_bwd(
assert bias is None
bias = jnp.zeros(0, dtype=qkv[0].dtype)
if 100 in get_all_device_compute_capability():
# TODO(KshitijLakhani): Add a check for cuDNN version when determinism does get supported on
# sm100+
compute_capabilities = get_all_device_compute_capability()
if any(x >= 100 for x in compute_capabilities):
assert not (
attn_bias_type != AttnBiasType.NO_BIAS and dropout_probability != 0
), "For sm100, bprop kernel support for dropout + determinism (bias) is not supported"
), "For sm100+, bprop kernel support for dropout + determinism (bias) is not supported"
fused_config = _FusedAttnConfig(
attn_bias_type=attn_bias_type,
......
......@@ -44,7 +44,6 @@ from ..quantize import (
noop_quantizer_set,
is_fp8_gemm_with_all_layouts_supported,
apply_padding_to_scale_inv,
should_use_rht,
)
from .misc import get_padded_spec, is_all_reduce_in_float32
from ..sharding import (
......@@ -169,16 +168,13 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_
assert not isinstance(lhs_q, ScaledTensor2x)
assert not isinstance(rhs_q, ScaledTensor2x)
def uses_rht(q: AbstractBaseTensor) -> bool:
return isinstance(q, ScaledTensor1x) and should_use_rht(
q.scaling_mode, is_colwise=q.is_colwise
)
def has_rht_applied(q: AbstractBaseTensor) -> bool:
return isinstance(q, ScaledTensor1x) and q.has_rht_applied
# TODO(jberchtold): Move RHT usage check to a bool flag on the ScaledTensor class
assert uses_rht(lhs_q) == uses_rht(rhs_q), (
"With NVFP4_1D_SCALING, if one operand is colwise quantized, the other must be colwise"
" quantized as well. This is to ensure the RHT is applied to both and will cancel out in"
" the GEMM."
assert has_rht_applied(lhs_q) == has_rht_applied(rhs_q), (
"With NVFP4_1D_SCALING, if one operand is quantized with RHT, the other must be quantized"
" with RHT as well. This is to ensure the RHT is applied to both and will cancel out in the"
" GEMM."
)
return lhs_q, rhs_q
......
......@@ -31,7 +31,7 @@ from .misc import (
from ..sharding import (
all_reduce_max_along_all_axes_except_PP,
all_reduce_sum_along_dp_fsdp,
num_of_devices,
get_num_devices_in_mesh,
)
from ..quantize import (
ScaledTensor2x,
......@@ -45,7 +45,6 @@ from ..quantize import (
compute_scale_from_amax,
NoScaleTensor,
get_rht_matrix,
should_use_rht,
)
......@@ -108,17 +107,18 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
"sr_rng_state must be a uint32 array when stochastic_rounding is True but"
f" received {sr_rng_state_aval}"
)
if is_outer:
if is_outer and get_num_devices_in_mesh() > 1:
assert (
sr_rng_state_aval.shape[0] == num_of_devices()
sr_rng_state_aval.shape[0] == get_num_devices_in_mesh()
and sr_rng_state_aval.shape[1] == 4
), (
"sr_rng_state must be of shape (num_devices, 4) when stochastic_rounding is"
f" True and is_outer is True but received {sr_rng_state_aval.shape}"
)
else:
assert sr_rng_state_aval.shape == (4,), (
"Sharded sr_rng_state must be of shape (4,) per device when"
# We cannot assert the shape is exactly (4,) here because if the quantized data is not perfectly sharded across all devices then we will have extra rng state here. For example, this could occur when the weights are not sharded when using data parallelism. However, this is okay because the extra rng state will simply not be used and each device still has a unique rng state.
assert sr_rng_state_aval.size >= 4, (
"Sharded sr_rng_state must have at least 4 elements per device when"
f" stochastic_rounding is True but received {sr_rng_state_aval.shape}"
)
......@@ -552,8 +552,13 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
desc="BaseDBiasQuantizePrimitive.colwise_scale_inv",
)
# TODO(jberchtold): Assert the sr_rng state is sharded along all mesh axes
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
arg_shardings = list(arg_i.sharding for arg_i in arg_infos)
arg_shardings[3] = NamedSharding(
mesh,
PartitionSpec(tuple(x for x in x_spec if x is not None), None),
desc="BaseDBiasQuantizePrimitive.sr_rng_state",
)
arg_shardings = tuple(arg_shardings)
out_shardings = (
out_sharding,
colwise_out_sharding,
......@@ -564,6 +569,9 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
)
def sharded_impl(x, scale, amax, sr_rng_state, post_rht_amax, rht_matrix):
if sr_rng_state.size > 4:
# See comment in abstract method for explanation of why we cannot assert exact shape
sr_rng_state = sr_rng_state.flatten()[:4]
(
local_x,
local_colwise_x,
......@@ -754,9 +762,10 @@ def _quantize_dbias_impl(
# If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE,
# fall back on the native-JAX quantize implementation
PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive
is_unsupported = (
quantizer.q_layout == QuantizeLayout.COLWISE
and quantizer.scaling_mode != ScalingMode.NVFP4_1D_SCALING
is_unsupported = quantizer.q_layout == QuantizeLayout.COLWISE and not (
quantizer.scaling_mode == ScalingMode.NVFP4_1D_SCALING
and hasattr(quantizer, "use_rht")
and quantizer.use_rht
)
if is_unsupported or not PrimitiveClass.enabled():
if is_dbias:
......@@ -792,7 +801,7 @@ def _quantize_dbias_impl(
rht_matrix = jnp.empty((1, 1), jnp.bfloat16)
amax = x.amax
if should_use_rht(quantizer.scaling_mode, q_layout=quantizer.q_layout):
if hasattr(quantizer, "use_rht") and quantizer.use_rht:
use_rht = True
rht_matrix = get_rht_matrix()
......@@ -861,7 +870,11 @@ def _quantize_dbias_impl(
x.data,
scale,
amax,
sr_rng_state if sr_rng_state is not None else jnp.empty((num_of_devices(), 1), jnp.uint32),
(
sr_rng_state
if sr_rng_state is not None
else jnp.empty((get_num_devices_in_mesh(), 1), jnp.uint32)
),
post_rht_amax if post_rht_amax is not None else jnp.zeros((1,), jnp.float32),
rht_matrix,
out_dtype=quantizer.q_dtype,
......@@ -902,6 +915,7 @@ def _quantize_dbias_impl(
q_layout=quantizer.q_layout,
data_layout=quantizer.get_data_layout(),
flatten_axis=flatten_axis,
colwise_has_rht_applied=use_rht,
)
return out, dbias.astype(dq_dtype)
......
......@@ -22,7 +22,8 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DTy
auto backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout,
bias_type, mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads,
q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right);
q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right,
false);
return backend;
}
......@@ -179,17 +180,18 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, query_workspace_tensor.data(), nullptr);
false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
softmax_type, window_size_left, window_size_right, query_workspace_tensor.data(),
nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(),
dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, query_workspace_tensor.data(), nullptr);
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_fwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
......@@ -197,8 +199,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(),
ragged_offset_tensor.data(), dummy_page_table_tensor.data(),
dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen,
kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type,
mask_type, softmax_type, window_size_left, window_size_right,
kv_max_seqlen, is_training, false, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, softmax_type, window_size_left, window_size_right,
query_workspace_tensor.data(), nullptr);
} else {
NVTE_ERROR("Unsupported QKVLayout.");
......@@ -276,7 +278,8 @@ static void FusedAttnForwardImpl(
auto backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups,
q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right);
q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right,
false);
nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
/* Auxiliary tensors (to be propagated to the backward pass later) */
......@@ -294,7 +297,7 @@ static void FusedAttnForwardImpl(
nvte_fused_attn_fwd_qkvpacked(
qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), rng_state_tensor.data(), q_max_seqlen, is_training,
q_seq_offsets_tensor.data(), rng_state_tensor.data(), q_max_seqlen, is_training, false,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
......@@ -308,8 +311,8 @@ static void FusedAttnForwardImpl(
s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(),
dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(),
q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, softmax_type, window_size_left, window_size_right,
q_max_seqlen, kv_max_seqlen, is_training, false, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right,
workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
......@@ -323,7 +326,7 @@ static void FusedAttnForwardImpl(
dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor,
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, workspace_tensor.data(), stream);
} else {
......@@ -542,7 +545,8 @@ static void FusedAttnBackwardImpl(
auto backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups,
q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right);
q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right,
false);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads,
bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend,
softmax_aux, rng_state, bias);
......
......@@ -15,7 +15,7 @@ import jax
import jax.numpy as jnp
from .scaling_modes import ScalingMode
from .hadamard import apply_rht, should_use_rht
from .hadamard import apply_rht
__all__ = ["ScalingModeToDequantizerMap"]
......@@ -171,7 +171,9 @@ class NVFP4Dequantizer(Dequantizer):
"""
@staticmethod
def _dequantize_func(data, scale_inv, amax, dq_dtype, scaling_mode, is_colwise, flatten_axis):
def _dequantize_func(
data, scale_inv, amax, dq_dtype, scaling_mode, is_colwise, flatten_axis, has_rht_applied
):
"""Dequantize a tensor using block scaling.
Args:
......@@ -182,6 +184,7 @@ class NVFP4Dequantizer(Dequantizer):
scaling_mode: The scaling mode used for quantization
is_colwise: Whether the scaling is column-wise
flatten_axis: The axis along which the tensor could be flattened to 2D
has_rht_applied: Whether the quantization has RHT applied and we need to apply the inverse RHT to dequantize
Returns:
The dequantized tensor
......@@ -223,8 +226,7 @@ class NVFP4Dequantizer(Dequantizer):
out = jnp.asarray(data * scale_inv, dq_dtype).reshape(data_shape)
# Apply inverse of RHT if needed
use_rht = should_use_rht(scaling_mode, is_colwise=is_colwise)
if use_rht:
if has_rht_applied:
out = apply_rht(out, inverse=True)
return out
......@@ -247,6 +249,7 @@ class NVFP4Dequantizer(Dequantizer):
scaled_tensor.scaling_mode,
scaled_tensor.is_colwise,
scaled_tensor.flatten_axis,
scaled_tensor.has_rht_applied,
)
......
......@@ -4,32 +4,6 @@
"""Randomized Hadamard Transform (RHT) utilities for JAX."""
import jax.numpy as jnp
from .scaling_modes import ScalingMode
def should_use_rht(scaling_mode, is_colwise=None, q_layout=None) -> bool:
"""Determine if RHT (Randomized Hadamard Transform) should be used.
Args:
scaling_mode: The scaling mode of the tensor.
is_colwise: Whether the tensor is column-wise. Only one of is_colwise or q_layout should be provided.
q_layout: The quantization layout of the tensor. Only one of is_colwise or q_layout should be provided.
Returns:
bool: True if RHT should be used, False otherwise.
"""
# Delayed import to avoid circular dependencies
from .quantizer import QuantizeLayout
assert (is_colwise is None) != (
q_layout is None
), "Exactly one of is_colwise or q_layout must be provided."
if q_layout is not None:
is_colwise = q_layout in {QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE}
return scaling_mode == ScalingMode.NVFP4_1D_SCALING and is_colwise
def get_wgrad_sign_vector() -> list[int]:
"""Get a fixed sign vector for the RHT used in NVFP4 weight gradient quantization."""
......
......@@ -12,6 +12,7 @@ from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
import hashlib
from typing import Optional, Tuple, Dict, Union, Sequence, Type, List
from functools import reduce, lru_cache
import operator
......@@ -35,7 +36,7 @@ from transformer_engine.common.recipe import (
from transformer_engine.jax.sharding import (
global_shard_guard,
MeshResource,
num_of_devices,
get_num_devices_in_mesh,
get_all_mesh_axes,
with_sharding_constraint,
)
......@@ -561,29 +562,87 @@ class BlockScalingQuantizeConfig(BaseQuantizeConfig):
return QuantizeMeta()
@dataclass
class NVFP4ScalingQuantizeConfig(BaseQuantizeConfig):
"""Configuration class for NVFP4 scaling recipe.
This class provides specific initialization and finalization for NVFP4 scaling quantization mode.
"""
DISABLE_STOCHASTIC_ROUNDING: bool = False
DISABLE_RHT: bool = False
DISABLE_2D_QUANTIZATION: bool = False
def initialize_from_recipe(self, fp8_recipe: Recipe) -> None:
"""Initialize block scaling FP8 configuration.
"""Initialize block scaling NVFP4 configuration.
Args:
fp8_recipe: The FP8 recipe to use for initialization
fp8_recipe: The quantization recipe to use for initialization
"""
assert isinstance(fp8_recipe, NVFP4BlockScaling)
self.INITIALIZED = True
self.FWD_DTYPE, self.BWD_DTYPE = _format2dtypes(fp8_recipe.fp4_format)
self.AMAX_HISTORY_LEN = 0
self.DISABLE_STOCHASTIC_ROUNDING = fp8_recipe.disable_stochastic_rounding
self.DISABLE_RHT = fp8_recipe.disable_rht
self.DISABLE_2D_QUANTIZATION = fp8_recipe.disable_2d_quantization
def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode:
"""Gets the scaling mode for a specific tensor's usage type."""
if tensor_source == TensorSource.KERNEL:
if (not self.DISABLE_2D_QUANTIZATION) and tensor_source == TensorSource.KERNEL:
return ScalingMode.NVFP4_2D_SCALING
# for x and grad
return ScalingMode.NVFP4_1D_SCALING
def _make_rht_quantize_meta(self, q_layout, tensor_source: TensorSource) -> QuantizeMeta:
"""Create the quantization metadata for RHT if applicable."""
# Imported here to prevent circular import
from transformer_engine.jax.quantize import QuantizeLayout
use_rht = self.get_scaling_mode(
tensor_source
) == ScalingMode.NVFP4_1D_SCALING and q_layout in {
QuantizeLayout.ROWWISE_COLWISE,
QuantizeLayout.COLWISE,
}
if self.DISABLE_RHT:
use_rht = False
return QuantizeMeta(use_rht=use_rht)
def _make_stochastic_rounding_rng_state(
self, module, tensor_source: TensorSource, quantizer_name: str
) -> jnp.ndarray:
"""Create the stochastic rounding rng state if applicable."""
if self.DISABLE_STOCHASTIC_ROUNDING:
return QuantizeMeta()
if tensor_source != TensorSource.DGRAD:
# Only DGRAD uses stochastic rounding
return QuantizeMeta()
sr_jax_rng = module.make_rng("sr_rng")
# Get a unique key for this quantizer
# Use hashlib to get a deterministic hash value for quantizer_name
quantizer_hash = (
int(hashlib.sha256(quantizer_name.encode("utf-8")).hexdigest(), 16)
% jnp.iinfo(jnp.int32).max
)
sr_jax_rng = jax.jit(jax.random.fold_in)(sr_jax_rng, quantizer_hash)
# Generate 4 random uint32 values from the JAX PRNG key
shape = (4,)
if get_num_devices_in_mesh() > 1:
shape = (get_num_devices_in_mesh(), 4)
sr_jax_rng_state = jax.random.randint(
sr_jax_rng, shape, 0, jnp.iinfo(jnp.int32).max, dtype=jnp.int32
).view(jnp.uint32)
sr_jax_rng_state = with_sharding_constraint(
sr_jax_rng_state, jax.sharding.PartitionSpec(get_all_mesh_axes(), None)
)
return QuantizeMeta(stochastic_rounding_rng_state=sr_jax_rng_state)
def get_quantize_flax_meta(
self,
module,
......@@ -603,27 +662,14 @@ class NVFP4ScalingQuantizeConfig(BaseQuantizeConfig):
Returns:
The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed.
"""
if tensor_source != TensorSource.DGRAD:
# Only DGRAD uses stochastic rounding
return QuantizeMeta()
# TODO(jberchtold): This assumes SR is always enabled for NVFP4. Use flag from recipe to toggle it.
sr_jax_rng = module.make_rng("sr_rng")
# Get a unique key for this quantizer
sr_jax_rng = jax.jit(jax.random.fold_in)(
sr_jax_rng, hash(quantizer_name) % jnp.iinfo(jnp.int32).max
)
# Imported here to prevent circular import
from transformer_engine.jax.quantize import QuantizeLayout
# Generate 4 random uint32 values from the JAX PRNG key
sr_jax_rng_state = jax.random.randint(
sr_jax_rng, (num_of_devices(), 4), 0, jnp.iinfo(jnp.int32).max, dtype=jnp.int32
).view(jnp.uint32)
sr_jax_rng_state = with_sharding_constraint(
sr_jax_rng_state, jax.sharding.PartitionSpec(get_all_mesh_axes(), None)
return QuantizeMeta.merge(
self._make_rht_quantize_meta(QuantizeLayout.ROWWISE_COLWISE, tensor_source),
self._make_stochastic_rounding_rng_state(module, tensor_source, quantizer_name),
)
return QuantizeMeta(stochastic_rounding_rng_state=sr_jax_rng_state)
_QUANTIZE_CONFIG = NoOpQuantizeConfig()
......
......@@ -26,6 +26,26 @@ class QuantizeMeta:
"""
@staticmethod
def merge(a: "QuantizeMeta", b: "QuantizeMeta") -> "QuantizeMeta":
"""Merge two QuantizeMeta instances.
Args:
a (QuantizeMeta): The first QuantizeMeta instance.
b (QuantizeMeta): The second QuantizeMeta instance.
Returns:
QuantizeMeta: A new QuantizeMeta instance with merged metadata.
"""
assert isinstance(a, QuantizeMeta)
assert isinstance(b, QuantizeMeta)
for key in b.get_kwargs_dictionary().keys():
if key in a.get_kwargs_dictionary():
assert (
a.get_kwargs_dictionary()[key] == b.get_kwargs_dictionary()[key]
), f"Conflict in merging QuantizeMeta: {key} has different values."
return QuantizeMeta(**{**a.get_kwargs_dictionary(), **b.get_kwargs_dictionary()})
def __init__(self, **kwargs):
self._kwargs = kwargs
......
......@@ -19,7 +19,7 @@ from transformer_engine_jax import QuantizeLayout
from transformer_engine.common import recipe
from .scaling_modes import ScalingMode
from .hadamard import apply_rht, should_use_rht
from .hadamard import apply_rht
from .tensor import (
ScaledTensor,
ScaledTensor1x,
......@@ -590,11 +590,13 @@ class NVFP4Quantizer(Quantizer):
q_layout: Quantization axis
data_layout: Data layout string (default: "NT")
stochastic_rounding_rng_state: RNG state for stochastic rounding, must be of shape (4,) and dtype uint32. If None, stochastic rounding is disabled.
use_rht: Whether to apply Randomized Hadamard Transform (RHT) before quantization.
"""
scaling_mode: ScalingMode = ScalingMode.NVFP4_1D_SCALING
q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE
data_layout: str = "NT"
use_rht: bool = False
stochastic_rounding_rng_state: Optional[jnp.ndarray] = None
def __post_init__(self):
......@@ -603,6 +605,30 @@ class NVFP4Quantizer(Quantizer):
), "NVFP4 quantization must use a q_dtype of float4_e2m1fn"
assert self.scaling_mode.is_nvfp4_scaling, "NVFP4Quantizer must use NVFP4 scaling modes"
def tree_flatten(self):
"""Flatten the quantizer for JAX tree operations.
Returns:
Tuple of (children, aux_data) for tree operations
"""
children = (self.stochastic_rounding_rng_state,)
aux_data = (self.q_dtype, self.scaling_mode, self.q_layout, self.data_layout, self.use_rht)
return (children, aux_data)
@classmethod
def tree_unflatten(cls, aux_data, children):
"""Reconstruct a quantizer from its flattened representation.
Args:
aux_data: Auxiliary data containing quantizer parameters
children: Unused children data
Returns:
A reconstructed Quantizer instance
"""
stochastic_rounding_rng_state = children[0]
return cls(*aux_data, stochastic_rounding_rng_state=stochastic_rounding_rng_state)
def _apply_stochastic_rounding(self, x):
assert (
self.stochastic_rounding_rng_state is not None
......@@ -688,8 +714,9 @@ class NVFP4Quantizer(Quantizer):
flatten_axis = x.ndim - flatten_axis
x_shape = x.shape
if should_use_rht(self.scaling_mode, is_colwise=is_colwise):
# We only apply RHT for 1D colwise nvfp4
# We currently only have a single flag 'use_rht' on the quantizer. To avoid an unused rowwise flag, we assume RHT is only used for colwise quantization for now.
use_rht = self.use_rht and is_colwise and self.scaling_mode == ScalingMode.NVFP4_1D_SCALING
if use_rht:
x = apply_rht(x)
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
......@@ -790,6 +817,7 @@ class NVFP4Quantizer(Quantizer):
scaling_mode=self.scaling_mode,
dq_dtype=dq_dtype,
flatten_axis=rowwise_flatten_axis,
has_rht_applied=use_rht,
)
......
......@@ -175,6 +175,7 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor):
is_colwise: Whether the tensor uses column-wise quantization
data_layout: The data_layout specification for the tensor
flatten_axis: The quantization axis for the tensor
has_rht_applied: Whether the tensor had the Randomized Hadamard Transform (RHT) applied during quantization
"""
scale_inv: jnp.ndarray
......@@ -184,6 +185,7 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor):
is_colwise: bool
data_layout: str
flatten_axis: int
has_rht_applied: bool
def __post_init__(self):
"""Validates and adjusts the scale_inv shape after initialization.
......@@ -243,6 +245,7 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor):
self.is_colwise,
self.data_layout,
self.flatten_axis,
self.has_rht_applied,
)
return (children, aux_data)
......@@ -314,6 +317,7 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor):
is_colwise=self.is_colwise,
data_layout=self.data_layout,
flatten_axis=self.flatten_axis,
has_rht_applied=self.has_rht_applied,
)
......@@ -354,6 +358,7 @@ class GroupedScaledTensor1x(ScaledTensor1x):
self.group_sizes = group_sizes
self.original_shape = original_shape
self.group_axis = group_axis
# TODO(Phuong):Handle RHT for grouped quantization once grouped quantization supports NVFP4
super().__init__(
data=data,
scale_inv=scale_inv,
......@@ -364,6 +369,7 @@ class GroupedScaledTensor1x(ScaledTensor1x):
is_colwise=is_colwise,
data_layout=data_layout,
flatten_axis=flatten_axis,
has_rht_applied=False,
)
def __post_init__(self):
......@@ -515,6 +521,7 @@ class ScaledTensorFactory:
group_sizes=None,
original_shape=None,
group_axis=0,
has_rht_applied=False,
):
"""Creates a single-scale quantized tensor.
......@@ -530,6 +537,7 @@ class ScaledTensorFactory:
group_sizes: Array of ints containing the size of each group (default: None)
original_shape: The original shape of the tensor before grouping (default: None)
group_axis: The axis along which grouping is performed (default: 0)
has_rht_applied: Whether the tensor had the Randomized Hadamard Transform (RHT) applied during quantization (default: False)
Returns:
A ScaledTensor1x or GroupedScaledTensor1x instance depending on whether group_sizes is provided
......@@ -593,6 +601,7 @@ class ScaledTensorFactory:
is_colwise=is_colwise,
data_layout=data_layout,
flatten_axis=flatten_axis,
has_rht_applied=has_rht_applied,
)
@staticmethod
......@@ -610,6 +619,8 @@ class ScaledTensorFactory:
group_sizes=None,
original_shape=None,
group_axis=0,
rowwise_has_rht_applied=False,
colwise_has_rht_applied=False,
):
"""Creates a double-scale quantized tensor.
......@@ -626,6 +637,8 @@ class ScaledTensorFactory:
group_sizes: Array containing the size of each group (default: None)
original_shape: The original shape of the tensor before grouping (default: None)
group_axis: The axis along which grouping is performed (default: 0)
rowwise_has_rht_applied: Whether the row-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False)
colwise_has_rht_applied: Whether the column-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False)
Returns:
A ScaledTensor2x instance
......@@ -648,6 +661,7 @@ class ScaledTensorFactory:
group_sizes=group_sizes,
original_shape=original_shape,
group_axis=group_axis,
has_rht_applied=rowwise_has_rht_applied,
)
colwise_tensor = ScaledTensorFactory.create_1x(
colwise_data,
......@@ -661,6 +675,7 @@ class ScaledTensorFactory:
group_sizes=group_sizes,
original_shape=original_shape,
group_axis=group_axis,
has_rht_applied=colwise_has_rht_applied,
)
return ScaledTensor2x(rowwise_tensor, colwise_tensor)
......@@ -680,6 +695,8 @@ class ScaledTensorFactory:
group_sizes: jnp.ndarray = None,
original_shape: Tuple[int] = None,
group_axis: int = 0,
rowwise_has_rht_applied: bool = False,
colwise_has_rht_applied: bool = False,
):
"""Creates a scaled tensor based on the quantization axis.
......@@ -696,10 +713,14 @@ class ScaledTensorFactory:
group_sizes: Array containing the size of each group (default: None)
original_shape: The original shape of the tensor before grouping (default: None)
group_axis: The axis along which grouping is performed (default: 0)
rowwise_has_rht_applied: Whether the row-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False)
colwise_has_rht_applied: Whether the col-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False)
Returns:
Either a ScaledTensor1x or ScaledTensor2x instance depending on q_layout
"""
assert not rowwise_has_rht_applied, "RHT is not supported for rowwise quantization yet"
if q_layout == QuantizeLayout.ROWWISE_COLWISE:
return ScaledTensorFactory.create_2x(
data,
......@@ -715,6 +736,8 @@ class ScaledTensorFactory:
group_sizes=group_sizes,
original_shape=original_shape,
group_axis=group_axis,
rowwise_has_rht_applied=rowwise_has_rht_applied,
colwise_has_rht_applied=colwise_has_rht_applied,
)
is_colwise = q_layout == QuantizeLayout.COLWISE
......@@ -731,6 +754,7 @@ class ScaledTensorFactory:
group_sizes=group_sizes,
original_shape=original_shape,
group_axis=group_axis,
has_rht_applied=colwise_has_rht_applied,
)
return ScaledTensorFactory.create_1x(
......@@ -745,6 +769,7 @@ class ScaledTensorFactory:
group_sizes=group_sizes,
original_shape=original_shape,
group_axis=group_axis,
has_rht_applied=rowwise_has_rht_applied,
)
......
......@@ -54,6 +54,26 @@ os.environ["NVTE_PROJECT_BUILDING"] = "1"
CMakeBuildExtension = get_build_ext(BuildExtension, True)
def get_cuda_major_version() -> int:
"""Get CUDA major version using Jax backend."""
assert (
jax._src.lib.cuda_versions is not None
), "GPU backend is required to build TE jax extensions."
# Jax currently does not have any stable/public method to get cuda version.
# Try using internal function and default to cuda12 if not found.
try:
cuda_version = jax._src.lib.cuda_versions.cuda_runtime_get_version()
cuda_major_version = cuda_version // 1000
except AttributeError:
cuda_version = os.getenv("CUDA_VERSION", "12")
cuda_major_version = int(cuda_version.split(".")[0])
assert cuda_major_version in (12, 13), f"Unsupported cuda version {cuda_version}."
return cuda_major_version
if __name__ == "__main__":
"""Main entry point for JAX extension installation.
......@@ -93,15 +113,23 @@ if __name__ == "__main__":
)
]
# Setup version and requirements.
# Having the framework extension depend on the core lib allows
# us to detect CUDA version dynamically during compilation and
# choose the correct wheel for te core lib.
__version__ = te_version()
te_core = f"transformer_engine_cu{get_cuda_major_version()}=={__version__}"
install_requires = install_requirements() + [te_core]
# Configure package
setuptools.setup(
name="transformer_engine_jax",
version=te_version(),
version=__version__,
description="Transformer acceleration library - Jax Lib",
ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension},
python_requires=f">={min_python_version_str()}",
install_requires=install_requirements(),
install_requires=install_requires,
tests_require=test_requirements(),
)
if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")):
......
......@@ -238,6 +238,19 @@ def num_of_devices():
return len(jax.devices())
def get_num_devices_in_mesh(mesh=None):
"""
Get the number of devices in the given mesh.
If the mesh is None, it would be replaced
by the global mesh.
"""
if mesh is None:
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
if mesh.empty:
return 1
return np.prod(list(mesh.shape.values()))
def get_mesh_axis_size(axis, mesh=None):
"""
Get the axis size of the given mesh.
......
......@@ -59,6 +59,8 @@ from transformer_engine.pytorch.attention.dot_product_attention.utils import (
combine_and_quantize,
combine_and_dequantize,
print_quantizers,
ConvertTHDtoBSHD,
ConvertBSHDtoTHD,
)
from transformer_engine.pytorch.attention.dot_product_attention.utils import (
AttentionLogging as attn_log,
......@@ -203,6 +205,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
attention_dropout_ctx: Optional[Callable] = nullcontext,
layer_number: Optional[int] = None,
softmax_type: str = "vanilla",
return_max_logit: Optional[bool] = False,
) -> None:
super().__init__()
......@@ -211,6 +214,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
self.attention_dropout_ctx = attention_dropout_ctx
self.layer_number = layer_number
self.softmax_type = softmax_type
self.return_max_logit = return_max_logit
def mask_func(x, y):
return (
......@@ -219,6 +223,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
else attention_mask_func(x, y)
)
self.mask_func = mask_func
self.scale_mask_softmax = FusedScaleMaskSoftmax(mask_func)
# Dropout. Note that for a single iteration, this layer will generate
......@@ -240,6 +245,8 @@ class UnfusedDotProductAttention(torch.nn.Module):
qkv_layout: str = "sbh3d",
cu_seqlens_q: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
cu_seqlens_kv: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
max_seqlen_q: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
max_seqlen_kv: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
attn_mask_type: str = "causal",
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
window_size: Optional[Tuple[int, int]] = None,
......@@ -263,6 +270,9 @@ class UnfusedDotProductAttention(torch.nn.Module):
if inference_params is not None and inference_params.is_paged:
key_layer, value_layer = inference_params.convert_paged_to_nonpaged(self.layer_number)
# convert to sbhd
# training: bshd, thd
# inference: bshd, sbhd_2bshd, thd_2bshd
if qkv_format == "bshd":
# convert to sbhd and use sbhd implementation for now
query_layer, key_layer, value_layer = [
......@@ -271,9 +281,8 @@ class UnfusedDotProductAttention(torch.nn.Module):
if qkv_format == "sbhd_2bshd":
key_layer, value_layer = [x.transpose(0, 1) for x in [key_layer, value_layer]]
total_tokens, batch_size = None, None
if qkv_format == "thd_2bshd":
total_tokens, batch_size = query_layer.shape[0], key_layer.shape[0]
batch_size = key_layer.shape[0]
query_layer = tex.convert_thd_to_bshd(
query_layer,
cu_seqlens_q,
......@@ -283,6 +292,26 @@ class UnfusedDotProductAttention(torch.nn.Module):
query_layer, key_layer, value_layer = [
x.transpose(0, 1) for x in [query_layer, key_layer, value_layer]
]
if qkv_format == "thd":
assert cu_seqlens_q is not None and cu_seqlens_kv is not None
assert max_seqlen_q is not None and max_seqlen_kv is not None
query_layer = ConvertTHDtoBSHD.apply(
query_layer,
cu_seqlens_q,
max_seqlen_q,
)
key_layer, value_layer = [
ConvertTHDtoBSHD.apply(
x,
cu_seqlens_kv,
max_seqlen_kv,
)
for x in [key_layer, value_layer]
]
query_layer, key_layer, value_layer = [
x.transpose(0, 1).contiguous() for x in [query_layer, key_layer, value_layer]
]
batch_size, max_seqlen_q, max_seqlen_kv = (
query_layer.shape[1],
query_layer.shape[0],
......@@ -428,6 +457,15 @@ class UnfusedDotProductAttention(torch.nn.Module):
matmul_result, None, None, dP_quantizer, "dP_quantizer", None
)
# max attention score
max_logit = None
if self.return_max_logit:
# matmul_result [b, np, sq, dk], max_logit [np]
max_logit = matmul_result
if attn_mask_type != "no_mask":
max_logit = self.mask_func(matmul_result, attention_mask)
max_logit = torch.amax(max_logit, dim=(0, 2, 3))
# add attention sink to the last column: [b, np, sq, sk+1]
if self.softmax_type != "vanilla":
matmul_result = torch.cat(
......@@ -508,14 +546,13 @@ class UnfusedDotProductAttention(torch.nn.Module):
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
# [b, sq, np, hn] --> [tq, np, hn]
context_layer = tex.convert_bshd_to_thd(
context_layer = ConvertBSHDtoTHD.apply(
context_layer,
cu_seqlens_q,
total_tokens,
)
# [tq, np, hn] --> [tq, hp]
context_layer = context_layer.view(total_tokens, -1)
context_layer = context_layer.view(context_layer.shape[0], -1)
if fp8:
# quantize and dequantize O to emulate FP8
......@@ -531,6 +568,9 @@ class UnfusedDotProductAttention(torch.nn.Module):
if fp8_output:
context_layer = O_quantizer(context_layer)
if self.return_max_logit:
return context_layer, max_logit
return context_layer
......@@ -1069,6 +1109,7 @@ class FusedAttnFunc(torch.autograd.Function):
softmax_offset,
fp8_output,
layer_number,
return_max_logit,
):
# pylint: disable=missing-function-docstring
......@@ -1104,6 +1145,7 @@ class FusedAttnFunc(torch.autograd.Function):
# FP8 attention: torch.float16 or torch.bfloat16
out_nominal_dtype = q.dtype
max_logit = None
if fp8:
fused_attention_backend = FusedAttnBackend["FP8"]
......@@ -1131,7 +1173,7 @@ class FusedAttnFunc(torch.autograd.Function):
# DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16
# fp8_dtype = tex.DType.kFloat8E4M3
# Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16
out_, aux_ctx_tensors = fused_attn_fwd(
out_, aux_ctx_tensors, *_ = fused_attn_fwd(
is_training,
max_seqlen_q,
max_seqlen_kv,
......@@ -1207,7 +1249,7 @@ class FusedAttnFunc(torch.autograd.Function):
qkvo_tensors = (q, k, v, out)
else:
# q, k, v, out_: torch.Tensor; dtype = torch.float16 or torch.bfloat16
out_, aux_ctx_tensors = fused_attn_fwd(
out_, aux_ctx_tensors, *max_logit = fused_attn_fwd(
is_training,
max_seqlen_q,
max_seqlen_kv,
......@@ -1235,6 +1277,7 @@ class FusedAttnFunc(torch.autograd.Function):
window_size,
rng_gen,
softmax_offset,
return_max_logit,
)
out = out_
out_ret = out_
......@@ -1329,10 +1372,12 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.use_FAv2_bwd = use_FAv2_bwd
ctx.deterministic = deterministic
if return_max_logit:
return out_ret, *max_logit
return out_ret
@staticmethod
def backward(ctx, d_out):
def backward(ctx, d_out, *_args):
# pylint: disable=missing-function-docstring
# d_out is expected to be in FP8 if is_output_fp8=True,
......@@ -1576,6 +1621,7 @@ class FusedAttnFunc(torch.autograd.Function):
d_softmax_offset,
None,
None,
None,
)
......@@ -1616,6 +1662,7 @@ class FusedAttention(torch.nn.Module):
layer_number: Optional[int] = None,
deterministic: bool = False,
softmax_type: str = "vanilla",
return_max_logit: Optional[bool] = False,
) -> None:
super().__init__()
......@@ -1629,6 +1676,7 @@ class FusedAttention(torch.nn.Module):
self.layer_number = 1 if layer_number is None else layer_number
self.deterministic = deterministic
self.softmax_type = softmax_type
self.return_max_logit = return_max_logit
def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
"""
......@@ -1848,6 +1896,7 @@ class FusedAttention(torch.nn.Module):
softmax_offset=softmax_offset,
fp8_output=fp8_output,
layer_number=self.layer_number,
return_max_logit=self.return_max_logit,
)
else:
with self.attention_dropout_ctx():
......@@ -1883,7 +1932,11 @@ class FusedAttention(torch.nn.Module):
softmax_offset,
fp8_output,
self.layer_number,
self.return_max_logit,
)
if self.return_max_logit:
# ...hd -> ...(hd)
return output[0].view(*output[0].shape[:-2], -1), output[1]
# ...hd -> ...(hd)
return output.view(*output.shape[:-2], -1)
......@@ -617,6 +617,7 @@ def cp_p2p_fwd_fused_attn(
rank,
step,
cp_size,
return_max_logit,
q_part,
k_part,
v_part,
......@@ -693,7 +694,7 @@ def cp_p2p_fwd_fused_attn(
fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step
fp8_meta_kwargs["o_quantizer"] = O_quantizer_per_step
out_per_step, aux_ctx_tensors = fused_attn_fwd(
out_per_step, aux_ctx_tensors, *max_logit = fused_attn_fwd(
is_training,
max_seqlen_q_,
max_seqlen_kv_,
......@@ -713,6 +714,7 @@ def cp_p2p_fwd_fused_attn(
cu_seqlens_q_padded=cu_seqlens_q_padded_,
cu_seqlens_kv_padded=cu_seqlens_kv_padded_,
**fp8_meta_kwargs,
return_max_logit=return_max_logit,
)
if fp8:
......@@ -721,7 +723,9 @@ def cp_p2p_fwd_fused_attn(
softmax_lse_per_step, rng_states, *rest = aux_ctx_tensors
attn_bias = rest[0] if len(rest) > 0 else None
return out_per_step, softmax_lse_per_step, rng_states, attn_bias
if return_max_logit:
return out_per_step, softmax_lse_per_step, rng_states, attn_bias, *max_logit
return out_per_step, softmax_lse_per_step, rng_states, attn_bias, None
def cp_p2p_fwd_flash_attn(
......@@ -1086,6 +1090,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
attn_bias,
deterministic,
use_fused_attention,
return_max_logit,
fp8,
fp8_meta,
cp_group,
......@@ -1156,6 +1161,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
amax_per_step = None
S_quantizer_per_step = [None for _ in range(cp_size)]
O_quantizer_per_step = [None for _ in range(cp_size)]
max_logit_per_step = [None for _ in range(cp_size)]
max_logit = None
assert isinstance(k, q.__class__) and isinstance(
v, q.__class__
......@@ -1244,6 +1251,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
q_f16 = q
if use_fused_attention:
fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"]
if return_max_logit:
max_logit_per_step = [
torch.empty(q.shape[-2], dtype=q.dtype, device=q.device) for _ in range(cp_size)
]
# split qkv to two halves and prepare for load balancing
assert qkv_format == "thd" or (
......@@ -1418,6 +1429,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
rank,
i,
cp_size,
return_max_logit,
]
else:
flash_attn_inputs = [
......@@ -1462,6 +1474,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step[i],
rng_states[i],
attn_biases[i],
max_logit_per_step[i],
) = cp_p2p_fwd_fused_attn(
*fused_attn_inputs, *prepare_outputs, section
)
......@@ -1488,6 +1501,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step[i],
rng_states[i],
attn_biases[i],
max_logit_per_step[i],
) = cp_p2p_fwd_fused_attn(
*fused_attn_inputs, *prepare_outputs, section
)
......@@ -1514,6 +1528,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step[i],
rng_states[i],
attn_biases[i],
max_logit_per_step[i],
) = cp_p2p_fwd_fused_attn(
*fused_attn_inputs, *prepare_outputs, section
)
......@@ -1541,6 +1556,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step[i],
rng_states[i],
attn_biases[i],
max_logit_per_step[i],
) = cp_p2p_fwd_fused_attn(*fused_attn_inputs, *prepare_outputs, section)
else:
out_per_step[i], softmax_lse_per_step[i], rng_states[i] = (
......@@ -1600,11 +1616,20 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse.view(*softmax_lse.shape[:-1], 2, -1),
softmax_lse_per_step[i - 1],
)
if return_max_logit:
if i == 1:
max_logit = torch.clone(max_logit_per_step[0])
else:
max_logit = torch.maximum(max_logit, max_logit_per_step[i - 1])
if i < cp_size:
flash_attn_streams[(i - 1) % 2].record_event(fwd_results_correction_done)
torch.cuda.current_stream().wait_stream(flash_attn_streams[1])
if return_max_logit:
torch.distributed.all_reduce(
max_logit, op=torch.distributed.ReduceOp.MAX, group=cp_group
)
second_half_lse_seqlen = None
if causal and rank < (cp_size - 1):
......@@ -1682,6 +1707,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
elif qkv_format == "sbhd":
# [s*b, h, d] -> [s, b, h, d]
out = out.view(-1, ctx.batch_size, *out.shape[-2:])
if return_max_logit:
max_logit = flash_attn_a2a_communicate_softmax_offset(
max_logit, 0, cp_size_a2a, cp_group_a2a, cp_stream, False
)
elif not use_fused_attention:
out = out.view(-1, *out.shape[-2:])
......@@ -1811,10 +1840,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
nvtx_range_pop(f"{nvtx_label}")
if return_max_logit:
return out_ret, max_logit
return out_ret
@staticmethod
def backward(ctx, dout):
def backward(ctx, dout, *_args):
# pylint: disable=missing-function-docstring
# add NVTX range
......@@ -2522,6 +2553,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -2577,6 +2609,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
attn_bias,
deterministic,
use_fused_attention,
return_max_logit,
window_size,
cp_group,
cp_stream,
......@@ -2682,6 +2715,8 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
softmax_lse_per_step = [None, None]
rng_states = [None, None]
out = torch.empty_like(q)
max_logit_per_step = [None, None]
max_logit = None
for i in range(len(local_seq_chunk_ids) + 1):
if i < len(local_seq_chunk_ids):
......@@ -2712,7 +2747,11 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
# [s_range, b, h, d] -> [b, s_range, h, d] or [s_range, b, h, d]
k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]]
if use_fused_attention:
out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = fused_attn_fwd(
(
out_per_step[i],
[softmax_lse_per_step[i], rng_states[i]],
*max_logit_,
) = fused_attn_fwd(
is_training,
max_seqlen_q,
max_seqlen_kv_,
......@@ -2732,7 +2771,10 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i],
window_size=window_size_per_step[i],
return_max_logit=return_max_logit,
)
if return_max_logit:
max_logit_per_step[i] = max_logit_[0]
else:
fa_forward_args_thd = get_fa_args(
True,
......@@ -2767,14 +2809,22 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
if not use_flash_attn_3:
rng_states[i] = fa_outputs[3]
if return_max_logit and i == 0:
max_logit = torch.clone(max_logit_per_step[0])
if i > 0:
with torch.cuda.stream(flash_attn_streams[i - 1]):
if qkv_format == "bshd":
out[:, i - 1].copy_(out_per_step[i - 1])
elif qkv_format == "sbhd":
out[i - 1].copy_(out_per_step[i - 1])
if return_max_logit:
max_logit = torch.maximum(max_logit, max_logit_per_step[i - 1])
torch.cuda.current_stream().wait_stream(cp_stream)
if return_max_logit:
torch.distributed.all_reduce(
max_logit, op=torch.distributed.ReduceOp.MAX, group=cp_group
)
if use_fused_attention:
if qkv_format == "bshd":
......@@ -2811,10 +2861,12 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
ctx.use_fused_attention = use_fused_attention
ctx.use_flash_attn_3 = use_flash_attn_3
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.forward")
if return_max_logit:
return out, max_logit
return out
@staticmethod
def backward(ctx, dout):
def backward(ctx, dout, *_args):
# pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.backward")
cp_size = get_distributed_world_size(ctx.cp_group)
......@@ -3035,6 +3087,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -3065,6 +3118,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
attn_bias,
deterministic,
use_fused_attention,
return_max_logit,
window_size,
fp8,
fp8_meta,
......@@ -3158,6 +3212,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
fp8_recipe = fp8_meta["local_recipes"][0]
fwd_nominal_dtype = q.dtype
fused_attn_backend = None
max_logit = None
QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = (
dpa_utils.get_attention_quantizers(fp8, quantizers)
......@@ -3203,7 +3258,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype)
for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part])
]
out_, aux_ctx_tensors = fused_attn_fwd(
out_, aux_ctx_tensors, *max_logit = fused_attn_fwd(
is_training,
max_seqlen_q,
max_seqlen_kv,
......@@ -3226,6 +3281,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
**fp8_meta_kwargs,
softmax_type=softmax_type,
softmax_offset=softmax_offset,
return_max_logit=return_max_logit,
)
if isinstance(out_, Float8Tensor):
out_fp8 = out_
......@@ -3276,6 +3332,10 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
out_ = flash_attn_a2a_communicate(
out_, chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, False
)
if return_max_logit:
max_logit = flash_attn_a2a_communicate_softmax_offset(
*max_logit, 0, cp_size, cp_group, cp_stream, False
)
if use_fused_attention:
if qkv_format == "bshd":
......@@ -3362,10 +3422,12 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
ctx.S_quantizer = S_quantizer.copy()
ctx.S_quantizer.scale = S_quantizer.scale.clone()
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward")
if return_max_logit:
return out_ret, max_logit
return out_ret
@staticmethod
def backward(ctx, dout):
def backward(ctx, dout, *_args):
# pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward")
cp_size = get_distributed_world_size(ctx.cp_group)
......@@ -3599,6 +3661,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
None,
None,
None,
None,
d_softmax_offset,
None,
)
......@@ -3637,6 +3700,7 @@ def attn_forward_func_with_cp(
softmax_offset=None,
fp8_output=False,
layer_number=1,
return_max_logit=False,
) -> torch.Tensor:
"""
Attention implementation with context parallelism (CP). CP partitions tensors along the sequence
......@@ -3784,6 +3848,7 @@ def attn_forward_func_with_cp(
attn_bias,
deterministic,
use_fused_attention,
return_max_logit,
]
if cp_comm_type in ["p2p", "a2a+p2p"]:
......
......@@ -255,6 +255,12 @@ class DotProductAttention(TransformerEngineBaseModule):
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention
('zero sink' and 'learnable sink').
return_max_logit: Optional[bool], default = `False`
If true, returns the maximum attention score that can be used in a Muon optimizer to
rescale the Q and K projection weights (see `Muon is Scalable for LLM Training
<https://arxiv.org/pdf/2502.16982>`_).
max_logit = max(S), where S = mask(Q*K^T*softmax_scale + bias) in shape [b, h, s_q, s_kv],
and max_logit is in shape [h].
Parallelism parameters
----------------------
......@@ -311,6 +317,7 @@ class DotProductAttention(TransformerEngineBaseModule):
cp_comm_type: str = "p2p",
softmax_scale: Optional[float] = None,
softmax_type: str = "vanilla",
return_max_logit: Optional[bool] = False,
) -> None:
super().__init__()
......@@ -394,6 +401,7 @@ class DotProductAttention(TransformerEngineBaseModule):
self.attention_type = attention_type
self.attention_dropout = attention_dropout
self.return_max_logit = return_max_logit
self.softmax_type = softmax_type
if self.softmax_type == "vanilla":
......@@ -431,6 +439,7 @@ class DotProductAttention(TransformerEngineBaseModule):
deterministic=self.deterministic,
**attn_kwargs,
softmax_type=self.softmax_type,
return_max_logit=self.return_max_logit,
)
self.unfused_attention = UnfusedDotProductAttention(
......@@ -439,6 +448,7 @@ class DotProductAttention(TransformerEngineBaseModule):
**attn_kwargs,
layer_number=layer_number,
softmax_type=self.softmax_type,
return_max_logit=self.return_max_logit,
)
def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
......@@ -1303,6 +1313,7 @@ class DotProductAttention(TransformerEngineBaseModule):
fp8_meta=self.fp8_meta,
inference_params=inference_params,
softmax_type=self.softmax_type,
return_max_logit=self.return_max_logit,
)
global _attention_backends
if is_in_onnx_export_mode():
......@@ -1502,6 +1513,8 @@ class DotProductAttention(TransformerEngineBaseModule):
qkv_layout=qkv_layout,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
window_size=window_size,
......@@ -1523,6 +1536,8 @@ class DotProductAttention(TransformerEngineBaseModule):
qkv_layout=qkv_layout,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
window_size=window_size,
......
......@@ -229,6 +229,8 @@ class AttentionParams:
Inference-related parameters. See InferenceParams for details.
softmax_type: str, default = "vanilla"
The type of softmax operation. See DotProductAttention for details.
return_max_logit: bool, default = `False`
Whether to output max_logit.
"""
qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor
......@@ -257,6 +259,7 @@ class AttentionParams:
fp8_meta: Union[Dict[str, Any], None] = None
inference_params: Optional[InferenceParams] = None
softmax_type: str = "vanilla"
return_max_logit: bool = False
def __eq__(self, other):
"""
......@@ -330,6 +333,7 @@ def get_attention_backend(
fp8_meta = attention_params.fp8_meta
inference_params = attention_params.inference_params
softmax_type = attention_params.softmax_type
return_max_logit = attention_params.return_max_logit
# Run config
logger = logging.getLogger("DotProductAttention")
......@@ -473,14 +477,54 @@ def get_attention_backend(
if device_compute_capability < (10, 0):
logger.debug("Disabling FusedAttention for FP8 current scaling on arch < sm100")
use_fused_attention = False
elif cudnn_version < (9, 14, 0):
logger.debug("Disabling FusedAttention for FP8 current scaling with cuDNN < 9.14.0")
use_fused_attention = False
# TODO(cyanguwa): Modify the min cuDNN version supporting FP8 current scaling
# determinism for Blackwell
else:
if cudnn_version < (9, 14, 0):
logger.debug(
"Disabling FusedAttention for FP8 current scaling with cuDNN < 9.14.0"
)
use_fused_attention = False
else:
if deterministic and cudnn_version < (9, 18, 0):
logger.debug(
"Disabling FusedAttention for FP8 current scaling requiring determinism"
" with cuDNN < 9.18.0"
)
use_fused_attention = False
# TODO: rocm fused attention backends does not support fp8 yet
if IS_HIP_EXTENSION and use_fused_attention:
logger.debug("Disabling ROCm FusedAttention as it does not support FP8")
use_fused_attention = False
if device_compute_capability == (12, 0):
if use_flash_attention:
logger.debug(
"Disabling FlashAttention as FP8 is not supported"
" for compute capability = sm120"
)
if use_fused_attention:
logger.debug(
"Disabling FusedAttention as FP8 is not supported"
" for compute capability = sm120"
)
use_flash_attention = False
use_fused_attention = False
# Filter: Return max_logit
if return_max_logit:
if use_flash_attention:
use_flash_attention = False
logger.debug("Disabling FlashAttention for max_logit")
if use_fused_attention and qkv_format == "thd":
use_fused_attention = False
logger.debug("Disabling FusedAttention for max_logit with qkv_format = thd")
if fp8 and fp8_meta["recipe"].fp8_dpa:
use_flash_attention = False
use_fused_attention = False
use_unfused_attention = False
logger.debug("Disabling all backends for max_logit with FP8 attention")
# Filter: KV cache
# backend | precision | KV cache | architecture | qkv_format | page_size
# ---------------------------------------------------------------------------------------
......@@ -539,7 +583,7 @@ def get_attention_backend(
if use_flash_attention_2 and FlashAttentionUtils.is_installed:
logger.debug("Disabling FlashAttention 2 as it does not support MLA.")
use_flash_attention_2 = False
qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "")
if use_fused_attention and qkv_layout_group != "hd_hd_hd":
logger.debug(
......@@ -547,6 +591,19 @@ def get_attention_backend(
qkv_layout,
)
use_fused_attention = False
if (
device_compute_capability == (12, 0)
and (head_dim_qk > 128 or head_dim_qk % 8 != 0)
and is_training
):
if use_fused_attention:
logger.debug(
"Disabling FusedAttention as MLA for backward pass is not supported for compute"
" capability = sm120 for a head_dim_qk > 128 or head_dim_qk %%8 != 0. Found:"
" head_dim_qk = %s",
head_dim_qk,
)
use_fused_attention = False
else:
if use_fused_attention and head_dim_qk != head_dim_v:
logger.debug("Disabling FusedAttention as it does not support MLA in rocm backend.")
......@@ -621,6 +678,13 @@ def get_attention_backend(
"padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]"
)
use_flash_attention = False
if device_compute_capability == (12, 0):
if use_fused_attention:
logger.debug(
"Disabling FusedAttention as qkv_format = thd is"
" not supported for compute capability = sm120"
)
use_fused_attention = False
if IS_HIP_EXTENSION and use_fused_attention and pad_between_seqs:
logger.debug(
"Disabling rocm fused attn for qkv_format = thd when there is "
......@@ -929,6 +993,7 @@ def get_attention_backend(
head_dim_v,
window_size[0],
window_size[1],
return_max_logit,
)
if fused_attention_backend == FusedAttnBackend["No_Backend"]:
logger.debug("Disabling FusedAttention as no backend supports the provided input")
......@@ -999,7 +1064,7 @@ def get_attention_backend(
logger.debug("Disabling FusedAttention for determinism reasons with post_scale_bias")
use_fused_attention = False
fused_attention_backend = None
if is_training and device_compute_capability >= (10, 0) and cudnn_version <= (9, 14, 0):
if is_training and device_compute_capability >= (10, 0):
logger.debug("Disabling FusedAttention for determinism reasons on Blackwell")
use_fused_attention = False
fused_attention_backend = None
......@@ -1665,6 +1730,78 @@ class UnpackTensor(torch.autograd.Function):
return None, None, _pack_tensor(indices, grad_output)
class ConvertTHDtoBSHD(torch.autograd.Function):
"""
Convert a tensor from qkv_format = thd to qkv_format = bshd.
"""
@staticmethod
def forward(ctx, thd_tensor, cu_seqlens, max_seqlen):
# pylint: disable=missing-function-docstring
batch_size = cu_seqlens.shape[0] - 1
if not thd_tensor.is_contiguous():
thd_tensor = thd_tensor.contiguous()
bshd_tensor = tex.convert_thd_to_bshd(
thd_tensor,
cu_seqlens,
batch_size,
max_seqlen,
)
ctx.save_for_backward(cu_seqlens)
ctx.num_tokens = thd_tensor.shape[0]
return bshd_tensor
@staticmethod
def backward(ctx, bshd_tensor):
# pylint: disable=missing-function-docstring
(cu_seqlens,) = ctx.saved_tensors
if not bshd_tensor.is_contiguous():
bshd_tensor = bshd_tensor.contiguous()
thd_tensor = tex.convert_bshd_to_thd(
bshd_tensor,
cu_seqlens,
ctx.num_tokens,
)
return thd_tensor, None, None
class ConvertBSHDtoTHD(torch.autograd.Function):
"""
Convert a tensor from qkv_format = bshd to qkv_format = thd.
"""
@staticmethod
def forward(ctx, bshd_tensor, cu_seqlens):
# pylint: disable=missing-function-docstring
num_tokens = cu_seqlens[-1]
max_seqlen = bshd_tensor.shape[1]
if not bshd_tensor.is_contiguous():
bshd_tensor = bshd_tensor.contiguous()
thd_tensor = tex.convert_bshd_to_thd(
bshd_tensor,
cu_seqlens,
num_tokens,
)
ctx.save_for_backward(cu_seqlens)
ctx.max_seqlen = max_seqlen
return thd_tensor
@staticmethod
def backward(ctx, thd_tensor):
# pylint: disable=missing-function-docstring
(cu_seqlens,) = ctx.saved_tensors
batch_size = cu_seqlens.shape[0] - 1
if not thd_tensor.is_contiguous():
thd_tensor = thd_tensor.contiguous()
bshd_tensor = tex.convert_thd_to_bshd(
thd_tensor,
cu_seqlens,
batch_size,
ctx.max_seqlen,
)
return bshd_tensor, None
def get_qkv_format(
qkv_layout: str = "bshd_bshd_bshd",
inference_params: InferenceParams = None,
......
......@@ -139,6 +139,7 @@ def fused_attn_fwd(
window_size: Tuple[int, int] = (-1, -1),
rng_gen: torch.Generator = None,
softmax_offset: torch.Tensor = None,
return_max_logit: bool = False,
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention FWD for separate QKV input.
......@@ -216,6 +217,8 @@ def fused_attn_fwd(
softmax_offset: torch.Tensor, default = None
softmax offset tensor in shape [1, h_q, 1, 1].
See softmax_type in DotProductAttention for details.
return_max_logit: bool, default = False
whether to return the maximum attention score
Returns
----------
......@@ -246,6 +249,7 @@ def fused_attn_fwd(
rng_state: torch.Tensor, optional, if backend is not F16_max512_seqlen
state of the random number generator;
[seed, offset], dtype uint64
max_logit: if return_max_logit = True, shape [h] and same data type as O; otherwise None
"""
if attn_scale is None:
......@@ -315,8 +319,22 @@ def fused_attn_fwd(
softmax_offset,
rng_gen,
rng_elts_per_thread,
return_max_logit,
)
if return_max_logit:
qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0]
# thd: output_tensors: out [tq, h, d], Max [tq, h, 1], Sum_Exp [tq, h, 1]
# bshd: output_tensors: out [b, sq, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1]
# sbhd: output_tensors: out [sq, b, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1]
stats = output_tensors[1] + torch.log(output_tensors[2])
amax_dims = (0, 2) if qkv_format == "thd" else (0, 2, 3)
# Max -> max_logit [h]
max_logit = torch.amax(output_tensors[1], dim=amax_dims).to(dtype=output_tensors[0].dtype)
aux_ctx_tensors = [stats]
aux_ctx_tensors.extend(output_tensors[3:])
return output_tensors[0], aux_ctx_tensors, max_logit
# out, aux_ctx_tensors
return output_tensors[0], output_tensors[1:]
......
......@@ -76,7 +76,7 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right);
int64_t window_size_right, bool return_max_logit);
std::pair<TensorWrapper, py::object> quantizer_helper(py::handle quantizer,
const std::vector<size_t> &shape, DType dtype,
......@@ -94,7 +94,7 @@ std::vector<py::object> fused_attn_fwd(
const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias,
const std::optional<at::Tensor> SoftmaxOffset, const std::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread);
size_t rng_elts_per_thread, bool return_max_logit);
std::vector<py::object> fused_attn_bwd(
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
......
......@@ -45,14 +45,15 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right) {
int64_t window_size_right, bool return_max_logit) {
#ifdef __HIP_PLATFORM_AMD__
return NVTE_Fused_Attn_Backend::NVTE_No_Backend;
#else
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout,
bias_type, attn_mask_type, softmax_type, p_dropout, num_attn_heads, num_gqa_groups,
max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right);
max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right,
return_max_logit);
return fused_attention_backend;
#endif
}
......@@ -110,7 +111,7 @@ std::vector<py::object> fused_attn_fwd(
const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias,
const std::optional<at::Tensor> SoftmaxOffset, const std::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread) {
size_t rng_elts_per_thread, bool return_max_logit) {
#ifdef __HIP_PLATFORM_AMD__
assert(false);
#else
......@@ -235,8 +236,9 @@ std::vector<py::object> fused_attn_fwd(
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(),
te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0],
window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream());
return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
softmax_type, window_size[0], window_size[1], workspace.data(),
at::cuda::getCurrentCUDAStream());
});
// allocate memory for workspace and auxiliary output tensors
......@@ -256,7 +258,9 @@ std::vector<py::object> fused_attn_fwd(
};
// allocate memory for nvte_aux_tensor_pack.tensors
// f16_max512 : S [b, h, sq, skv]
// f16_arbitrary: S [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1]
// f16_arbitrary:
// return_max_logit=false: S [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1]
// return_max_logit=true: Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1]
// fp8 : M [b, h, sq, 1], ZInv [b, h, sq, 1], rng_state [2]
size_t i = 0;
at::Tensor output_tensor;
......@@ -265,8 +269,8 @@ std::vector<py::object> fused_attn_fwd(
allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])),
static_cast<DType>(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false);
set_tensor_param(i++, output_tensor);
// fp8 has an additional softmax stats tensor, ZInv
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// fp8 has an additional softmax stats tensor, ZInv; return_max_logit=true has an additional Sum_Exp tensor
if (return_max_logit || qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
output_tensor =
allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])),
static_cast<DType>(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false);
......@@ -292,8 +296,9 @@ std::vector<py::object> fused_attn_fwd(
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(),
te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0],
window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream());
return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
softmax_type, window_size[0], window_size[1], workspace.data(),
at::cuda::getCurrentCUDAStream());
});
// destroy tensor wrappers, but not allocated memory
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment