Commit 0a5016b1 authored by wenjh's avatar wenjh
Browse files

Merge nv release_v2.9


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents 063ef88d 70f53666
...@@ -29,6 +29,7 @@ using namespace __hip_internal; ...@@ -29,6 +29,7 @@ using namespace __hip_internal;
typedef uint16_t hip_bfloat16x2 __attribute__((ext_vector_type(2))); typedef uint16_t hip_bfloat16x2 __attribute__((ext_vector_type(2)));
#else #else
#if !defined(__CUDACC_RTC__) #if !defined(__CUDACC_RTC__)
#include <cassert>
#include <cstdint> #include <cstdint>
#else #else
// Importing C++ standard headers is a pain with NVRTC // Importing C++ standard headers is a pain with NVRTC
......
...@@ -2739,10 +2739,13 @@ def fused_attn_bwd( ...@@ -2739,10 +2739,13 @@ def fused_attn_bwd(
assert bias is None assert bias is None
bias = jnp.zeros(0, dtype=qkv[0].dtype) bias = jnp.zeros(0, dtype=qkv[0].dtype)
if 100 in get_all_device_compute_capability(): # TODO(KshitijLakhani): Add a check for cuDNN version when determinism does get supported on
# sm100+
compute_capabilities = get_all_device_compute_capability()
if any(x >= 100 for x in compute_capabilities):
assert not ( assert not (
attn_bias_type != AttnBiasType.NO_BIAS and dropout_probability != 0 attn_bias_type != AttnBiasType.NO_BIAS and dropout_probability != 0
), "For sm100, bprop kernel support for dropout + determinism (bias) is not supported" ), "For sm100+, bprop kernel support for dropout + determinism (bias) is not supported"
fused_config = _FusedAttnConfig( fused_config = _FusedAttnConfig(
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
......
...@@ -44,7 +44,6 @@ from ..quantize import ( ...@@ -44,7 +44,6 @@ from ..quantize import (
noop_quantizer_set, noop_quantizer_set,
is_fp8_gemm_with_all_layouts_supported, is_fp8_gemm_with_all_layouts_supported,
apply_padding_to_scale_inv, apply_padding_to_scale_inv,
should_use_rht,
) )
from .misc import get_padded_spec, is_all_reduce_in_float32 from .misc import get_padded_spec, is_all_reduce_in_float32
from ..sharding import ( from ..sharding import (
...@@ -169,16 +168,13 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ ...@@ -169,16 +168,13 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_
assert not isinstance(lhs_q, ScaledTensor2x) assert not isinstance(lhs_q, ScaledTensor2x)
assert not isinstance(rhs_q, ScaledTensor2x) assert not isinstance(rhs_q, ScaledTensor2x)
def uses_rht(q: AbstractBaseTensor) -> bool: def has_rht_applied(q: AbstractBaseTensor) -> bool:
return isinstance(q, ScaledTensor1x) and should_use_rht( return isinstance(q, ScaledTensor1x) and q.has_rht_applied
q.scaling_mode, is_colwise=q.is_colwise
)
# TODO(jberchtold): Move RHT usage check to a bool flag on the ScaledTensor class assert has_rht_applied(lhs_q) == has_rht_applied(rhs_q), (
assert uses_rht(lhs_q) == uses_rht(rhs_q), ( "With NVFP4_1D_SCALING, if one operand is quantized with RHT, the other must be quantized"
"With NVFP4_1D_SCALING, if one operand is colwise quantized, the other must be colwise" " with RHT as well. This is to ensure the RHT is applied to both and will cancel out in the"
" quantized as well. This is to ensure the RHT is applied to both and will cancel out in" " GEMM."
" the GEMM."
) )
return lhs_q, rhs_q return lhs_q, rhs_q
......
...@@ -31,7 +31,7 @@ from .misc import ( ...@@ -31,7 +31,7 @@ from .misc import (
from ..sharding import ( from ..sharding import (
all_reduce_max_along_all_axes_except_PP, all_reduce_max_along_all_axes_except_PP,
all_reduce_sum_along_dp_fsdp, all_reduce_sum_along_dp_fsdp,
num_of_devices, get_num_devices_in_mesh,
) )
from ..quantize import ( from ..quantize import (
ScaledTensor2x, ScaledTensor2x,
...@@ -45,7 +45,6 @@ from ..quantize import ( ...@@ -45,7 +45,6 @@ from ..quantize import (
compute_scale_from_amax, compute_scale_from_amax,
NoScaleTensor, NoScaleTensor,
get_rht_matrix, get_rht_matrix,
should_use_rht,
) )
...@@ -108,17 +107,18 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -108,17 +107,18 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
"sr_rng_state must be a uint32 array when stochastic_rounding is True but" "sr_rng_state must be a uint32 array when stochastic_rounding is True but"
f" received {sr_rng_state_aval}" f" received {sr_rng_state_aval}"
) )
if is_outer: if is_outer and get_num_devices_in_mesh() > 1:
assert ( assert (
sr_rng_state_aval.shape[0] == num_of_devices() sr_rng_state_aval.shape[0] == get_num_devices_in_mesh()
and sr_rng_state_aval.shape[1] == 4 and sr_rng_state_aval.shape[1] == 4
), ( ), (
"sr_rng_state must be of shape (num_devices, 4) when stochastic_rounding is" "sr_rng_state must be of shape (num_devices, 4) when stochastic_rounding is"
f" True and is_outer is True but received {sr_rng_state_aval.shape}" f" True and is_outer is True but received {sr_rng_state_aval.shape}"
) )
else: else:
assert sr_rng_state_aval.shape == (4,), ( # We cannot assert the shape is exactly (4,) here because if the quantized data is not perfectly sharded across all devices then we will have extra rng state here. For example, this could occur when the weights are not sharded when using data parallelism. However, this is okay because the extra rng state will simply not be used and each device still has a unique rng state.
"Sharded sr_rng_state must be of shape (4,) per device when" assert sr_rng_state_aval.size >= 4, (
"Sharded sr_rng_state must have at least 4 elements per device when"
f" stochastic_rounding is True but received {sr_rng_state_aval.shape}" f" stochastic_rounding is True but received {sr_rng_state_aval.shape}"
) )
...@@ -552,8 +552,13 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -552,8 +552,13 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
desc="BaseDBiasQuantizePrimitive.colwise_scale_inv", desc="BaseDBiasQuantizePrimitive.colwise_scale_inv",
) )
# TODO(jberchtold): Assert the sr_rng state is sharded along all mesh axes arg_shardings = list(arg_i.sharding for arg_i in arg_infos)
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) arg_shardings[3] = NamedSharding(
mesh,
PartitionSpec(tuple(x for x in x_spec if x is not None), None),
desc="BaseDBiasQuantizePrimitive.sr_rng_state",
)
arg_shardings = tuple(arg_shardings)
out_shardings = ( out_shardings = (
out_sharding, out_sharding,
colwise_out_sharding, colwise_out_sharding,
...@@ -564,6 +569,9 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -564,6 +569,9 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
) )
def sharded_impl(x, scale, amax, sr_rng_state, post_rht_amax, rht_matrix): def sharded_impl(x, scale, amax, sr_rng_state, post_rht_amax, rht_matrix):
if sr_rng_state.size > 4:
# See comment in abstract method for explanation of why we cannot assert exact shape
sr_rng_state = sr_rng_state.flatten()[:4]
( (
local_x, local_x,
local_colwise_x, local_colwise_x,
...@@ -754,9 +762,10 @@ def _quantize_dbias_impl( ...@@ -754,9 +762,10 @@ def _quantize_dbias_impl(
# If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE, # If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE,
# fall back on the native-JAX quantize implementation # fall back on the native-JAX quantize implementation
PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive
is_unsupported = ( is_unsupported = quantizer.q_layout == QuantizeLayout.COLWISE and not (
quantizer.q_layout == QuantizeLayout.COLWISE quantizer.scaling_mode == ScalingMode.NVFP4_1D_SCALING
and quantizer.scaling_mode != ScalingMode.NVFP4_1D_SCALING and hasattr(quantizer, "use_rht")
and quantizer.use_rht
) )
if is_unsupported or not PrimitiveClass.enabled(): if is_unsupported or not PrimitiveClass.enabled():
if is_dbias: if is_dbias:
...@@ -792,7 +801,7 @@ def _quantize_dbias_impl( ...@@ -792,7 +801,7 @@ def _quantize_dbias_impl(
rht_matrix = jnp.empty((1, 1), jnp.bfloat16) rht_matrix = jnp.empty((1, 1), jnp.bfloat16)
amax = x.amax amax = x.amax
if should_use_rht(quantizer.scaling_mode, q_layout=quantizer.q_layout): if hasattr(quantizer, "use_rht") and quantizer.use_rht:
use_rht = True use_rht = True
rht_matrix = get_rht_matrix() rht_matrix = get_rht_matrix()
...@@ -861,7 +870,11 @@ def _quantize_dbias_impl( ...@@ -861,7 +870,11 @@ def _quantize_dbias_impl(
x.data, x.data,
scale, scale,
amax, amax,
sr_rng_state if sr_rng_state is not None else jnp.empty((num_of_devices(), 1), jnp.uint32), (
sr_rng_state
if sr_rng_state is not None
else jnp.empty((get_num_devices_in_mesh(), 1), jnp.uint32)
),
post_rht_amax if post_rht_amax is not None else jnp.zeros((1,), jnp.float32), post_rht_amax if post_rht_amax is not None else jnp.zeros((1,), jnp.float32),
rht_matrix, rht_matrix,
out_dtype=quantizer.q_dtype, out_dtype=quantizer.q_dtype,
...@@ -902,6 +915,7 @@ def _quantize_dbias_impl( ...@@ -902,6 +915,7 @@ def _quantize_dbias_impl(
q_layout=quantizer.q_layout, q_layout=quantizer.q_layout,
data_layout=quantizer.get_data_layout(), data_layout=quantizer.get_data_layout(),
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
colwise_has_rht_applied=use_rht,
) )
return out, dbias.astype(dq_dtype) return out, dbias.astype(dq_dtype)
......
...@@ -22,7 +22,8 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DTy ...@@ -22,7 +22,8 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DTy
auto backend = nvte_get_fused_attn_backend( auto backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout,
bias_type, mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads, bias_type, mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads,
q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right,
false);
return backend; return backend;
} }
...@@ -179,17 +180,18 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -179,17 +180,18 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, is_training, ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); softmax_type, window_size_left, window_size_right, query_workspace_tensor.data(),
nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
nvte_fused_attn_fwd_kvpacked( nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), q_tensor.data(), kv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(),
dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false,
dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_right, query_workspace_tensor.data(), nullptr); window_size_left, window_size_right, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_fwd( nvte_fused_attn_fwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
...@@ -197,8 +199,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -197,8 +199,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(),
ragged_offset_tensor.data(), dummy_page_table_tensor.data(), ragged_offset_tensor.data(), dummy_page_table_tensor.data(),
dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen,
kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, kv_max_seqlen, is_training, false, scaling_factor, dropout_probability, qkv_layout,
mask_type, softmax_type, window_size_left, window_size_right, bias_type, mask_type, softmax_type, window_size_left, window_size_right,
query_workspace_tensor.data(), nullptr); query_workspace_tensor.data(), nullptr);
} else { } else {
NVTE_ERROR("Unsupported QKVLayout."); NVTE_ERROR("Unsupported QKVLayout.");
...@@ -276,7 +278,8 @@ static void FusedAttnForwardImpl( ...@@ -276,7 +278,8 @@ static void FusedAttnForwardImpl(
auto backend = nvte_get_fused_attn_backend( auto backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups,
q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right,
false);
nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
/* Auxiliary tensors (to be propagated to the backward pass later) */ /* Auxiliary tensors (to be propagated to the backward pass later) */
...@@ -294,7 +297,7 @@ static void FusedAttnForwardImpl( ...@@ -294,7 +297,7 @@ static void FusedAttnForwardImpl(
nvte_fused_attn_fwd_qkvpacked( nvte_fused_attn_fwd_qkvpacked(
qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), s_tensor.data(), qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), rng_state_tensor.data(), q_max_seqlen, is_training, q_seq_offsets_tensor.data(), rng_state_tensor.data(), q_max_seqlen, is_training, false,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, workspace_tensor.data(), stream); window_size_left, window_size_right, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
...@@ -308,8 +311,8 @@ static void FusedAttnForwardImpl( ...@@ -308,8 +311,8 @@ static void FusedAttnForwardImpl(
s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(),
dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(),
q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, q_max_seqlen, kv_max_seqlen, is_training, false, scaling_factor, dropout_probability,
bias_type, mask_type, softmax_type, window_size_left, window_size_right, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right,
workspace_tensor.data(), stream); workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
...@@ -323,7 +326,7 @@ static void FusedAttnForwardImpl( ...@@ -323,7 +326,7 @@ static void FusedAttnForwardImpl(
dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, workspace_tensor.data(), stream); window_size_right, workspace_tensor.data(), stream);
} else { } else {
...@@ -542,7 +545,8 @@ static void FusedAttnBackwardImpl( ...@@ -542,7 +545,8 @@ static void FusedAttnBackwardImpl(
auto backend = nvte_get_fused_attn_backend( auto backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups,
q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right,
false);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads, PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads,
bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend,
softmax_aux, rng_state, bias); softmax_aux, rng_state, bias);
......
...@@ -15,7 +15,7 @@ import jax ...@@ -15,7 +15,7 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from .scaling_modes import ScalingMode from .scaling_modes import ScalingMode
from .hadamard import apply_rht, should_use_rht from .hadamard import apply_rht
__all__ = ["ScalingModeToDequantizerMap"] __all__ = ["ScalingModeToDequantizerMap"]
...@@ -171,7 +171,9 @@ class NVFP4Dequantizer(Dequantizer): ...@@ -171,7 +171,9 @@ class NVFP4Dequantizer(Dequantizer):
""" """
@staticmethod @staticmethod
def _dequantize_func(data, scale_inv, amax, dq_dtype, scaling_mode, is_colwise, flatten_axis): def _dequantize_func(
data, scale_inv, amax, dq_dtype, scaling_mode, is_colwise, flatten_axis, has_rht_applied
):
"""Dequantize a tensor using block scaling. """Dequantize a tensor using block scaling.
Args: Args:
...@@ -182,6 +184,7 @@ class NVFP4Dequantizer(Dequantizer): ...@@ -182,6 +184,7 @@ class NVFP4Dequantizer(Dequantizer):
scaling_mode: The scaling mode used for quantization scaling_mode: The scaling mode used for quantization
is_colwise: Whether the scaling is column-wise is_colwise: Whether the scaling is column-wise
flatten_axis: The axis along which the tensor could be flattened to 2D flatten_axis: The axis along which the tensor could be flattened to 2D
has_rht_applied: Whether the quantization has RHT applied and we need to apply the inverse RHT to dequantize
Returns: Returns:
The dequantized tensor The dequantized tensor
...@@ -223,8 +226,7 @@ class NVFP4Dequantizer(Dequantizer): ...@@ -223,8 +226,7 @@ class NVFP4Dequantizer(Dequantizer):
out = jnp.asarray(data * scale_inv, dq_dtype).reshape(data_shape) out = jnp.asarray(data * scale_inv, dq_dtype).reshape(data_shape)
# Apply inverse of RHT if needed # Apply inverse of RHT if needed
use_rht = should_use_rht(scaling_mode, is_colwise=is_colwise) if has_rht_applied:
if use_rht:
out = apply_rht(out, inverse=True) out = apply_rht(out, inverse=True)
return out return out
...@@ -247,6 +249,7 @@ class NVFP4Dequantizer(Dequantizer): ...@@ -247,6 +249,7 @@ class NVFP4Dequantizer(Dequantizer):
scaled_tensor.scaling_mode, scaled_tensor.scaling_mode,
scaled_tensor.is_colwise, scaled_tensor.is_colwise,
scaled_tensor.flatten_axis, scaled_tensor.flatten_axis,
scaled_tensor.has_rht_applied,
) )
......
...@@ -4,32 +4,6 @@ ...@@ -4,32 +4,6 @@
"""Randomized Hadamard Transform (RHT) utilities for JAX.""" """Randomized Hadamard Transform (RHT) utilities for JAX."""
import jax.numpy as jnp import jax.numpy as jnp
from .scaling_modes import ScalingMode
def should_use_rht(scaling_mode, is_colwise=None, q_layout=None) -> bool:
"""Determine if RHT (Randomized Hadamard Transform) should be used.
Args:
scaling_mode: The scaling mode of the tensor.
is_colwise: Whether the tensor is column-wise. Only one of is_colwise or q_layout should be provided.
q_layout: The quantization layout of the tensor. Only one of is_colwise or q_layout should be provided.
Returns:
bool: True if RHT should be used, False otherwise.
"""
# Delayed import to avoid circular dependencies
from .quantizer import QuantizeLayout
assert (is_colwise is None) != (
q_layout is None
), "Exactly one of is_colwise or q_layout must be provided."
if q_layout is not None:
is_colwise = q_layout in {QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE}
return scaling_mode == ScalingMode.NVFP4_1D_SCALING and is_colwise
def get_wgrad_sign_vector() -> list[int]: def get_wgrad_sign_vector() -> list[int]:
"""Get a fixed sign vector for the RHT used in NVFP4 weight gradient quantization.""" """Get a fixed sign vector for the RHT used in NVFP4 weight gradient quantization."""
......
...@@ -12,6 +12,7 @@ from abc import ABC, abstractmethod ...@@ -12,6 +12,7 @@ from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
import hashlib
from typing import Optional, Tuple, Dict, Union, Sequence, Type, List from typing import Optional, Tuple, Dict, Union, Sequence, Type, List
from functools import reduce, lru_cache from functools import reduce, lru_cache
import operator import operator
...@@ -35,7 +36,7 @@ from transformer_engine.common.recipe import ( ...@@ -35,7 +36,7 @@ from transformer_engine.common.recipe import (
from transformer_engine.jax.sharding import ( from transformer_engine.jax.sharding import (
global_shard_guard, global_shard_guard,
MeshResource, MeshResource,
num_of_devices, get_num_devices_in_mesh,
get_all_mesh_axes, get_all_mesh_axes,
with_sharding_constraint, with_sharding_constraint,
) )
...@@ -561,29 +562,87 @@ class BlockScalingQuantizeConfig(BaseQuantizeConfig): ...@@ -561,29 +562,87 @@ class BlockScalingQuantizeConfig(BaseQuantizeConfig):
return QuantizeMeta() return QuantizeMeta()
@dataclass
class NVFP4ScalingQuantizeConfig(BaseQuantizeConfig): class NVFP4ScalingQuantizeConfig(BaseQuantizeConfig):
"""Configuration class for NVFP4 scaling recipe. """Configuration class for NVFP4 scaling recipe.
This class provides specific initialization and finalization for NVFP4 scaling quantization mode. This class provides specific initialization and finalization for NVFP4 scaling quantization mode.
""" """
DISABLE_STOCHASTIC_ROUNDING: bool = False
DISABLE_RHT: bool = False
DISABLE_2D_QUANTIZATION: bool = False
def initialize_from_recipe(self, fp8_recipe: Recipe) -> None: def initialize_from_recipe(self, fp8_recipe: Recipe) -> None:
"""Initialize block scaling FP8 configuration. """Initialize block scaling NVFP4 configuration.
Args: Args:
fp8_recipe: The FP8 recipe to use for initialization fp8_recipe: The quantization recipe to use for initialization
""" """
assert isinstance(fp8_recipe, NVFP4BlockScaling)
self.INITIALIZED = True self.INITIALIZED = True
self.FWD_DTYPE, self.BWD_DTYPE = _format2dtypes(fp8_recipe.fp4_format) self.FWD_DTYPE, self.BWD_DTYPE = _format2dtypes(fp8_recipe.fp4_format)
self.AMAX_HISTORY_LEN = 0 self.AMAX_HISTORY_LEN = 0
self.DISABLE_STOCHASTIC_ROUNDING = fp8_recipe.disable_stochastic_rounding
self.DISABLE_RHT = fp8_recipe.disable_rht
self.DISABLE_2D_QUANTIZATION = fp8_recipe.disable_2d_quantization
def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode: def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode:
"""Gets the scaling mode for a specific tensor's usage type.""" """Gets the scaling mode for a specific tensor's usage type."""
if tensor_source == TensorSource.KERNEL: if (not self.DISABLE_2D_QUANTIZATION) and tensor_source == TensorSource.KERNEL:
return ScalingMode.NVFP4_2D_SCALING return ScalingMode.NVFP4_2D_SCALING
# for x and grad # for x and grad
return ScalingMode.NVFP4_1D_SCALING return ScalingMode.NVFP4_1D_SCALING
def _make_rht_quantize_meta(self, q_layout, tensor_source: TensorSource) -> QuantizeMeta:
"""Create the quantization metadata for RHT if applicable."""
# Imported here to prevent circular import
from transformer_engine.jax.quantize import QuantizeLayout
use_rht = self.get_scaling_mode(
tensor_source
) == ScalingMode.NVFP4_1D_SCALING and q_layout in {
QuantizeLayout.ROWWISE_COLWISE,
QuantizeLayout.COLWISE,
}
if self.DISABLE_RHT:
use_rht = False
return QuantizeMeta(use_rht=use_rht)
def _make_stochastic_rounding_rng_state(
self, module, tensor_source: TensorSource, quantizer_name: str
) -> jnp.ndarray:
"""Create the stochastic rounding rng state if applicable."""
if self.DISABLE_STOCHASTIC_ROUNDING:
return QuantizeMeta()
if tensor_source != TensorSource.DGRAD:
# Only DGRAD uses stochastic rounding
return QuantizeMeta()
sr_jax_rng = module.make_rng("sr_rng")
# Get a unique key for this quantizer
# Use hashlib to get a deterministic hash value for quantizer_name
quantizer_hash = (
int(hashlib.sha256(quantizer_name.encode("utf-8")).hexdigest(), 16)
% jnp.iinfo(jnp.int32).max
)
sr_jax_rng = jax.jit(jax.random.fold_in)(sr_jax_rng, quantizer_hash)
# Generate 4 random uint32 values from the JAX PRNG key
shape = (4,)
if get_num_devices_in_mesh() > 1:
shape = (get_num_devices_in_mesh(), 4)
sr_jax_rng_state = jax.random.randint(
sr_jax_rng, shape, 0, jnp.iinfo(jnp.int32).max, dtype=jnp.int32
).view(jnp.uint32)
sr_jax_rng_state = with_sharding_constraint(
sr_jax_rng_state, jax.sharding.PartitionSpec(get_all_mesh_axes(), None)
)
return QuantizeMeta(stochastic_rounding_rng_state=sr_jax_rng_state)
def get_quantize_flax_meta( def get_quantize_flax_meta(
self, self,
module, module,
...@@ -603,27 +662,14 @@ class NVFP4ScalingQuantizeConfig(BaseQuantizeConfig): ...@@ -603,27 +662,14 @@ class NVFP4ScalingQuantizeConfig(BaseQuantizeConfig):
Returns: Returns:
The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed. The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed.
""" """
if tensor_source != TensorSource.DGRAD: # Imported here to prevent circular import
# Only DGRAD uses stochastic rounding from transformer_engine.jax.quantize import QuantizeLayout
return QuantizeMeta()
# TODO(jberchtold): This assumes SR is always enabled for NVFP4. Use flag from recipe to toggle it.
sr_jax_rng = module.make_rng("sr_rng")
# Get a unique key for this quantizer
sr_jax_rng = jax.jit(jax.random.fold_in)(
sr_jax_rng, hash(quantizer_name) % jnp.iinfo(jnp.int32).max
)
# Generate 4 random uint32 values from the JAX PRNG key return QuantizeMeta.merge(
sr_jax_rng_state = jax.random.randint( self._make_rht_quantize_meta(QuantizeLayout.ROWWISE_COLWISE, tensor_source),
sr_jax_rng, (num_of_devices(), 4), 0, jnp.iinfo(jnp.int32).max, dtype=jnp.int32 self._make_stochastic_rounding_rng_state(module, tensor_source, quantizer_name),
).view(jnp.uint32)
sr_jax_rng_state = with_sharding_constraint(
sr_jax_rng_state, jax.sharding.PartitionSpec(get_all_mesh_axes(), None)
) )
return QuantizeMeta(stochastic_rounding_rng_state=sr_jax_rng_state)
_QUANTIZE_CONFIG = NoOpQuantizeConfig() _QUANTIZE_CONFIG = NoOpQuantizeConfig()
......
...@@ -26,6 +26,26 @@ class QuantizeMeta: ...@@ -26,6 +26,26 @@ class QuantizeMeta:
""" """
@staticmethod
def merge(a: "QuantizeMeta", b: "QuantizeMeta") -> "QuantizeMeta":
"""Merge two QuantizeMeta instances.
Args:
a (QuantizeMeta): The first QuantizeMeta instance.
b (QuantizeMeta): The second QuantizeMeta instance.
Returns:
QuantizeMeta: A new QuantizeMeta instance with merged metadata.
"""
assert isinstance(a, QuantizeMeta)
assert isinstance(b, QuantizeMeta)
for key in b.get_kwargs_dictionary().keys():
if key in a.get_kwargs_dictionary():
assert (
a.get_kwargs_dictionary()[key] == b.get_kwargs_dictionary()[key]
), f"Conflict in merging QuantizeMeta: {key} has different values."
return QuantizeMeta(**{**a.get_kwargs_dictionary(), **b.get_kwargs_dictionary()})
def __init__(self, **kwargs): def __init__(self, **kwargs):
self._kwargs = kwargs self._kwargs = kwargs
......
...@@ -19,7 +19,7 @@ from transformer_engine_jax import QuantizeLayout ...@@ -19,7 +19,7 @@ from transformer_engine_jax import QuantizeLayout
from transformer_engine.common import recipe from transformer_engine.common import recipe
from .scaling_modes import ScalingMode from .scaling_modes import ScalingMode
from .hadamard import apply_rht, should_use_rht from .hadamard import apply_rht
from .tensor import ( from .tensor import (
ScaledTensor, ScaledTensor,
ScaledTensor1x, ScaledTensor1x,
...@@ -590,11 +590,13 @@ class NVFP4Quantizer(Quantizer): ...@@ -590,11 +590,13 @@ class NVFP4Quantizer(Quantizer):
q_layout: Quantization axis q_layout: Quantization axis
data_layout: Data layout string (default: "NT") data_layout: Data layout string (default: "NT")
stochastic_rounding_rng_state: RNG state for stochastic rounding, must be of shape (4,) and dtype uint32. If None, stochastic rounding is disabled. stochastic_rounding_rng_state: RNG state for stochastic rounding, must be of shape (4,) and dtype uint32. If None, stochastic rounding is disabled.
use_rht: Whether to apply Randomized Hadamard Transform (RHT) before quantization.
""" """
scaling_mode: ScalingMode = ScalingMode.NVFP4_1D_SCALING scaling_mode: ScalingMode = ScalingMode.NVFP4_1D_SCALING
q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE
data_layout: str = "NT" data_layout: str = "NT"
use_rht: bool = False
stochastic_rounding_rng_state: Optional[jnp.ndarray] = None stochastic_rounding_rng_state: Optional[jnp.ndarray] = None
def __post_init__(self): def __post_init__(self):
...@@ -603,6 +605,30 @@ class NVFP4Quantizer(Quantizer): ...@@ -603,6 +605,30 @@ class NVFP4Quantizer(Quantizer):
), "NVFP4 quantization must use a q_dtype of float4_e2m1fn" ), "NVFP4 quantization must use a q_dtype of float4_e2m1fn"
assert self.scaling_mode.is_nvfp4_scaling, "NVFP4Quantizer must use NVFP4 scaling modes" assert self.scaling_mode.is_nvfp4_scaling, "NVFP4Quantizer must use NVFP4 scaling modes"
def tree_flatten(self):
"""Flatten the quantizer for JAX tree operations.
Returns:
Tuple of (children, aux_data) for tree operations
"""
children = (self.stochastic_rounding_rng_state,)
aux_data = (self.q_dtype, self.scaling_mode, self.q_layout, self.data_layout, self.use_rht)
return (children, aux_data)
@classmethod
def tree_unflatten(cls, aux_data, children):
"""Reconstruct a quantizer from its flattened representation.
Args:
aux_data: Auxiliary data containing quantizer parameters
children: Unused children data
Returns:
A reconstructed Quantizer instance
"""
stochastic_rounding_rng_state = children[0]
return cls(*aux_data, stochastic_rounding_rng_state=stochastic_rounding_rng_state)
def _apply_stochastic_rounding(self, x): def _apply_stochastic_rounding(self, x):
assert ( assert (
self.stochastic_rounding_rng_state is not None self.stochastic_rounding_rng_state is not None
...@@ -688,8 +714,9 @@ class NVFP4Quantizer(Quantizer): ...@@ -688,8 +714,9 @@ class NVFP4Quantizer(Quantizer):
flatten_axis = x.ndim - flatten_axis flatten_axis = x.ndim - flatten_axis
x_shape = x.shape x_shape = x.shape
if should_use_rht(self.scaling_mode, is_colwise=is_colwise): # We currently only have a single flag 'use_rht' on the quantizer. To avoid an unused rowwise flag, we assume RHT is only used for colwise quantization for now.
# We only apply RHT for 1D colwise nvfp4 use_rht = self.use_rht and is_colwise and self.scaling_mode == ScalingMode.NVFP4_1D_SCALING
if use_rht:
x = apply_rht(x) x = apply_rht(x)
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
...@@ -790,6 +817,7 @@ class NVFP4Quantizer(Quantizer): ...@@ -790,6 +817,7 @@ class NVFP4Quantizer(Quantizer):
scaling_mode=self.scaling_mode, scaling_mode=self.scaling_mode,
dq_dtype=dq_dtype, dq_dtype=dq_dtype,
flatten_axis=rowwise_flatten_axis, flatten_axis=rowwise_flatten_axis,
has_rht_applied=use_rht,
) )
......
...@@ -175,6 +175,7 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor): ...@@ -175,6 +175,7 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor):
is_colwise: Whether the tensor uses column-wise quantization is_colwise: Whether the tensor uses column-wise quantization
data_layout: The data_layout specification for the tensor data_layout: The data_layout specification for the tensor
flatten_axis: The quantization axis for the tensor flatten_axis: The quantization axis for the tensor
has_rht_applied: Whether the tensor had the Randomized Hadamard Transform (RHT) applied during quantization
""" """
scale_inv: jnp.ndarray scale_inv: jnp.ndarray
...@@ -184,6 +185,7 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor): ...@@ -184,6 +185,7 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor):
is_colwise: bool is_colwise: bool
data_layout: str data_layout: str
flatten_axis: int flatten_axis: int
has_rht_applied: bool
def __post_init__(self): def __post_init__(self):
"""Validates and adjusts the scale_inv shape after initialization. """Validates and adjusts the scale_inv shape after initialization.
...@@ -243,6 +245,7 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor): ...@@ -243,6 +245,7 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor):
self.is_colwise, self.is_colwise,
self.data_layout, self.data_layout,
self.flatten_axis, self.flatten_axis,
self.has_rht_applied,
) )
return (children, aux_data) return (children, aux_data)
...@@ -314,6 +317,7 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor): ...@@ -314,6 +317,7 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor):
is_colwise=self.is_colwise, is_colwise=self.is_colwise,
data_layout=self.data_layout, data_layout=self.data_layout,
flatten_axis=self.flatten_axis, flatten_axis=self.flatten_axis,
has_rht_applied=self.has_rht_applied,
) )
...@@ -354,6 +358,7 @@ class GroupedScaledTensor1x(ScaledTensor1x): ...@@ -354,6 +358,7 @@ class GroupedScaledTensor1x(ScaledTensor1x):
self.group_sizes = group_sizes self.group_sizes = group_sizes
self.original_shape = original_shape self.original_shape = original_shape
self.group_axis = group_axis self.group_axis = group_axis
# TODO(Phuong):Handle RHT for grouped quantization once grouped quantization supports NVFP4
super().__init__( super().__init__(
data=data, data=data,
scale_inv=scale_inv, scale_inv=scale_inv,
...@@ -364,6 +369,7 @@ class GroupedScaledTensor1x(ScaledTensor1x): ...@@ -364,6 +369,7 @@ class GroupedScaledTensor1x(ScaledTensor1x):
is_colwise=is_colwise, is_colwise=is_colwise,
data_layout=data_layout, data_layout=data_layout,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
has_rht_applied=False,
) )
def __post_init__(self): def __post_init__(self):
...@@ -515,6 +521,7 @@ class ScaledTensorFactory: ...@@ -515,6 +521,7 @@ class ScaledTensorFactory:
group_sizes=None, group_sizes=None,
original_shape=None, original_shape=None,
group_axis=0, group_axis=0,
has_rht_applied=False,
): ):
"""Creates a single-scale quantized tensor. """Creates a single-scale quantized tensor.
...@@ -530,6 +537,7 @@ class ScaledTensorFactory: ...@@ -530,6 +537,7 @@ class ScaledTensorFactory:
group_sizes: Array of ints containing the size of each group (default: None) group_sizes: Array of ints containing the size of each group (default: None)
original_shape: The original shape of the tensor before grouping (default: None) original_shape: The original shape of the tensor before grouping (default: None)
group_axis: The axis along which grouping is performed (default: 0) group_axis: The axis along which grouping is performed (default: 0)
has_rht_applied: Whether the tensor had the Randomized Hadamard Transform (RHT) applied during quantization (default: False)
Returns: Returns:
A ScaledTensor1x or GroupedScaledTensor1x instance depending on whether group_sizes is provided A ScaledTensor1x or GroupedScaledTensor1x instance depending on whether group_sizes is provided
...@@ -593,6 +601,7 @@ class ScaledTensorFactory: ...@@ -593,6 +601,7 @@ class ScaledTensorFactory:
is_colwise=is_colwise, is_colwise=is_colwise,
data_layout=data_layout, data_layout=data_layout,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
has_rht_applied=has_rht_applied,
) )
@staticmethod @staticmethod
...@@ -610,6 +619,8 @@ class ScaledTensorFactory: ...@@ -610,6 +619,8 @@ class ScaledTensorFactory:
group_sizes=None, group_sizes=None,
original_shape=None, original_shape=None,
group_axis=0, group_axis=0,
rowwise_has_rht_applied=False,
colwise_has_rht_applied=False,
): ):
"""Creates a double-scale quantized tensor. """Creates a double-scale quantized tensor.
...@@ -626,6 +637,8 @@ class ScaledTensorFactory: ...@@ -626,6 +637,8 @@ class ScaledTensorFactory:
group_sizes: Array containing the size of each group (default: None) group_sizes: Array containing the size of each group (default: None)
original_shape: The original shape of the tensor before grouping (default: None) original_shape: The original shape of the tensor before grouping (default: None)
group_axis: The axis along which grouping is performed (default: 0) group_axis: The axis along which grouping is performed (default: 0)
rowwise_has_rht_applied: Whether the row-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False)
colwise_has_rht_applied: Whether the column-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False)
Returns: Returns:
A ScaledTensor2x instance A ScaledTensor2x instance
...@@ -648,6 +661,7 @@ class ScaledTensorFactory: ...@@ -648,6 +661,7 @@ class ScaledTensorFactory:
group_sizes=group_sizes, group_sizes=group_sizes,
original_shape=original_shape, original_shape=original_shape,
group_axis=group_axis, group_axis=group_axis,
has_rht_applied=rowwise_has_rht_applied,
) )
colwise_tensor = ScaledTensorFactory.create_1x( colwise_tensor = ScaledTensorFactory.create_1x(
colwise_data, colwise_data,
...@@ -661,6 +675,7 @@ class ScaledTensorFactory: ...@@ -661,6 +675,7 @@ class ScaledTensorFactory:
group_sizes=group_sizes, group_sizes=group_sizes,
original_shape=original_shape, original_shape=original_shape,
group_axis=group_axis, group_axis=group_axis,
has_rht_applied=colwise_has_rht_applied,
) )
return ScaledTensor2x(rowwise_tensor, colwise_tensor) return ScaledTensor2x(rowwise_tensor, colwise_tensor)
...@@ -680,6 +695,8 @@ class ScaledTensorFactory: ...@@ -680,6 +695,8 @@ class ScaledTensorFactory:
group_sizes: jnp.ndarray = None, group_sizes: jnp.ndarray = None,
original_shape: Tuple[int] = None, original_shape: Tuple[int] = None,
group_axis: int = 0, group_axis: int = 0,
rowwise_has_rht_applied: bool = False,
colwise_has_rht_applied: bool = False,
): ):
"""Creates a scaled tensor based on the quantization axis. """Creates a scaled tensor based on the quantization axis.
...@@ -696,10 +713,14 @@ class ScaledTensorFactory: ...@@ -696,10 +713,14 @@ class ScaledTensorFactory:
group_sizes: Array containing the size of each group (default: None) group_sizes: Array containing the size of each group (default: None)
original_shape: The original shape of the tensor before grouping (default: None) original_shape: The original shape of the tensor before grouping (default: None)
group_axis: The axis along which grouping is performed (default: 0) group_axis: The axis along which grouping is performed (default: 0)
rowwise_has_rht_applied: Whether the row-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False)
colwise_has_rht_applied: Whether the col-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False)
Returns: Returns:
Either a ScaledTensor1x or ScaledTensor2x instance depending on q_layout Either a ScaledTensor1x or ScaledTensor2x instance depending on q_layout
""" """
assert not rowwise_has_rht_applied, "RHT is not supported for rowwise quantization yet"
if q_layout == QuantizeLayout.ROWWISE_COLWISE: if q_layout == QuantizeLayout.ROWWISE_COLWISE:
return ScaledTensorFactory.create_2x( return ScaledTensorFactory.create_2x(
data, data,
...@@ -715,6 +736,8 @@ class ScaledTensorFactory: ...@@ -715,6 +736,8 @@ class ScaledTensorFactory:
group_sizes=group_sizes, group_sizes=group_sizes,
original_shape=original_shape, original_shape=original_shape,
group_axis=group_axis, group_axis=group_axis,
rowwise_has_rht_applied=rowwise_has_rht_applied,
colwise_has_rht_applied=colwise_has_rht_applied,
) )
is_colwise = q_layout == QuantizeLayout.COLWISE is_colwise = q_layout == QuantizeLayout.COLWISE
...@@ -731,6 +754,7 @@ class ScaledTensorFactory: ...@@ -731,6 +754,7 @@ class ScaledTensorFactory:
group_sizes=group_sizes, group_sizes=group_sizes,
original_shape=original_shape, original_shape=original_shape,
group_axis=group_axis, group_axis=group_axis,
has_rht_applied=colwise_has_rht_applied,
) )
return ScaledTensorFactory.create_1x( return ScaledTensorFactory.create_1x(
...@@ -745,6 +769,7 @@ class ScaledTensorFactory: ...@@ -745,6 +769,7 @@ class ScaledTensorFactory:
group_sizes=group_sizes, group_sizes=group_sizes,
original_shape=original_shape, original_shape=original_shape,
group_axis=group_axis, group_axis=group_axis,
has_rht_applied=rowwise_has_rht_applied,
) )
......
...@@ -54,6 +54,26 @@ os.environ["NVTE_PROJECT_BUILDING"] = "1" ...@@ -54,6 +54,26 @@ os.environ["NVTE_PROJECT_BUILDING"] = "1"
CMakeBuildExtension = get_build_ext(BuildExtension, True) CMakeBuildExtension = get_build_ext(BuildExtension, True)
def get_cuda_major_version() -> int:
"""Get CUDA major version using Jax backend."""
assert (
jax._src.lib.cuda_versions is not None
), "GPU backend is required to build TE jax extensions."
# Jax currently does not have any stable/public method to get cuda version.
# Try using internal function and default to cuda12 if not found.
try:
cuda_version = jax._src.lib.cuda_versions.cuda_runtime_get_version()
cuda_major_version = cuda_version // 1000
except AttributeError:
cuda_version = os.getenv("CUDA_VERSION", "12")
cuda_major_version = int(cuda_version.split(".")[0])
assert cuda_major_version in (12, 13), f"Unsupported cuda version {cuda_version}."
return cuda_major_version
if __name__ == "__main__": if __name__ == "__main__":
"""Main entry point for JAX extension installation. """Main entry point for JAX extension installation.
...@@ -93,15 +113,23 @@ if __name__ == "__main__": ...@@ -93,15 +113,23 @@ if __name__ == "__main__":
) )
] ]
# Setup version and requirements.
# Having the framework extension depend on the core lib allows
# us to detect CUDA version dynamically during compilation and
# choose the correct wheel for te core lib.
__version__ = te_version()
te_core = f"transformer_engine_cu{get_cuda_major_version()}=={__version__}"
install_requires = install_requirements() + [te_core]
# Configure package # Configure package
setuptools.setup( setuptools.setup(
name="transformer_engine_jax", name="transformer_engine_jax",
version=te_version(), version=__version__,
description="Transformer acceleration library - Jax Lib", description="Transformer acceleration library - Jax Lib",
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension}, cmdclass={"build_ext": CMakeBuildExtension},
python_requires=f">={min_python_version_str()}", python_requires=f">={min_python_version_str()}",
install_requires=install_requirements(), install_requires=install_requires,
tests_require=test_requirements(), tests_require=test_requirements(),
) )
if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")):
......
...@@ -238,6 +238,19 @@ def num_of_devices(): ...@@ -238,6 +238,19 @@ def num_of_devices():
return len(jax.devices()) return len(jax.devices())
def get_num_devices_in_mesh(mesh=None):
"""
Get the number of devices in the given mesh.
If the mesh is None, it would be replaced
by the global mesh.
"""
if mesh is None:
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
if mesh.empty:
return 1
return np.prod(list(mesh.shape.values()))
def get_mesh_axis_size(axis, mesh=None): def get_mesh_axis_size(axis, mesh=None):
""" """
Get the axis size of the given mesh. Get the axis size of the given mesh.
......
...@@ -59,6 +59,8 @@ from transformer_engine.pytorch.attention.dot_product_attention.utils import ( ...@@ -59,6 +59,8 @@ from transformer_engine.pytorch.attention.dot_product_attention.utils import (
combine_and_quantize, combine_and_quantize,
combine_and_dequantize, combine_and_dequantize,
print_quantizers, print_quantizers,
ConvertTHDtoBSHD,
ConvertBSHDtoTHD,
) )
from transformer_engine.pytorch.attention.dot_product_attention.utils import ( from transformer_engine.pytorch.attention.dot_product_attention.utils import (
AttentionLogging as attn_log, AttentionLogging as attn_log,
...@@ -203,6 +205,7 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -203,6 +205,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
attention_dropout_ctx: Optional[Callable] = nullcontext, attention_dropout_ctx: Optional[Callable] = nullcontext,
layer_number: Optional[int] = None, layer_number: Optional[int] = None,
softmax_type: str = "vanilla", softmax_type: str = "vanilla",
return_max_logit: Optional[bool] = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -211,6 +214,7 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -211,6 +214,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
self.attention_dropout_ctx = attention_dropout_ctx self.attention_dropout_ctx = attention_dropout_ctx
self.layer_number = layer_number self.layer_number = layer_number
self.softmax_type = softmax_type self.softmax_type = softmax_type
self.return_max_logit = return_max_logit
def mask_func(x, y): def mask_func(x, y):
return ( return (
...@@ -219,6 +223,7 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -219,6 +223,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
else attention_mask_func(x, y) else attention_mask_func(x, y)
) )
self.mask_func = mask_func
self.scale_mask_softmax = FusedScaleMaskSoftmax(mask_func) self.scale_mask_softmax = FusedScaleMaskSoftmax(mask_func)
# Dropout. Note that for a single iteration, this layer will generate # Dropout. Note that for a single iteration, this layer will generate
...@@ -240,6 +245,8 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -240,6 +245,8 @@ class UnfusedDotProductAttention(torch.nn.Module):
qkv_layout: str = "sbh3d", qkv_layout: str = "sbh3d",
cu_seqlens_q: Optional[torch.Tensor] = None, # pylint: disable=unused-argument cu_seqlens_q: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
cu_seqlens_kv: Optional[torch.Tensor] = None, # pylint: disable=unused-argument cu_seqlens_kv: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
max_seqlen_q: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
max_seqlen_kv: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
attn_mask_type: str = "causal", attn_mask_type: str = "causal",
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
window_size: Optional[Tuple[int, int]] = None, window_size: Optional[Tuple[int, int]] = None,
...@@ -263,6 +270,9 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -263,6 +270,9 @@ class UnfusedDotProductAttention(torch.nn.Module):
if inference_params is not None and inference_params.is_paged: if inference_params is not None and inference_params.is_paged:
key_layer, value_layer = inference_params.convert_paged_to_nonpaged(self.layer_number) key_layer, value_layer = inference_params.convert_paged_to_nonpaged(self.layer_number)
# convert to sbhd
# training: bshd, thd
# inference: bshd, sbhd_2bshd, thd_2bshd
if qkv_format == "bshd": if qkv_format == "bshd":
# convert to sbhd and use sbhd implementation for now # convert to sbhd and use sbhd implementation for now
query_layer, key_layer, value_layer = [ query_layer, key_layer, value_layer = [
...@@ -271,9 +281,8 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -271,9 +281,8 @@ class UnfusedDotProductAttention(torch.nn.Module):
if qkv_format == "sbhd_2bshd": if qkv_format == "sbhd_2bshd":
key_layer, value_layer = [x.transpose(0, 1) for x in [key_layer, value_layer]] key_layer, value_layer = [x.transpose(0, 1) for x in [key_layer, value_layer]]
total_tokens, batch_size = None, None
if qkv_format == "thd_2bshd": if qkv_format == "thd_2bshd":
total_tokens, batch_size = query_layer.shape[0], key_layer.shape[0] batch_size = key_layer.shape[0]
query_layer = tex.convert_thd_to_bshd( query_layer = tex.convert_thd_to_bshd(
query_layer, query_layer,
cu_seqlens_q, cu_seqlens_q,
...@@ -283,6 +292,26 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -283,6 +292,26 @@ class UnfusedDotProductAttention(torch.nn.Module):
query_layer, key_layer, value_layer = [ query_layer, key_layer, value_layer = [
x.transpose(0, 1) for x in [query_layer, key_layer, value_layer] x.transpose(0, 1) for x in [query_layer, key_layer, value_layer]
] ]
if qkv_format == "thd":
assert cu_seqlens_q is not None and cu_seqlens_kv is not None
assert max_seqlen_q is not None and max_seqlen_kv is not None
query_layer = ConvertTHDtoBSHD.apply(
query_layer,
cu_seqlens_q,
max_seqlen_q,
)
key_layer, value_layer = [
ConvertTHDtoBSHD.apply(
x,
cu_seqlens_kv,
max_seqlen_kv,
)
for x in [key_layer, value_layer]
]
query_layer, key_layer, value_layer = [
x.transpose(0, 1).contiguous() for x in [query_layer, key_layer, value_layer]
]
batch_size, max_seqlen_q, max_seqlen_kv = ( batch_size, max_seqlen_q, max_seqlen_kv = (
query_layer.shape[1], query_layer.shape[1],
query_layer.shape[0], query_layer.shape[0],
...@@ -428,6 +457,15 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -428,6 +457,15 @@ class UnfusedDotProductAttention(torch.nn.Module):
matmul_result, None, None, dP_quantizer, "dP_quantizer", None matmul_result, None, None, dP_quantizer, "dP_quantizer", None
) )
# max attention score
max_logit = None
if self.return_max_logit:
# matmul_result [b, np, sq, dk], max_logit [np]
max_logit = matmul_result
if attn_mask_type != "no_mask":
max_logit = self.mask_func(matmul_result, attention_mask)
max_logit = torch.amax(max_logit, dim=(0, 2, 3))
# add attention sink to the last column: [b, np, sq, sk+1] # add attention sink to the last column: [b, np, sq, sk+1]
if self.softmax_type != "vanilla": if self.softmax_type != "vanilla":
matmul_result = torch.cat( matmul_result = torch.cat(
...@@ -508,14 +546,13 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -508,14 +546,13 @@ class UnfusedDotProductAttention(torch.nn.Module):
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
# [b, sq, np, hn] --> [tq, np, hn] # [b, sq, np, hn] --> [tq, np, hn]
context_layer = tex.convert_bshd_to_thd( context_layer = ConvertBSHDtoTHD.apply(
context_layer, context_layer,
cu_seqlens_q, cu_seqlens_q,
total_tokens,
) )
# [tq, np, hn] --> [tq, hp] # [tq, np, hn] --> [tq, hp]
context_layer = context_layer.view(total_tokens, -1) context_layer = context_layer.view(context_layer.shape[0], -1)
if fp8: if fp8:
# quantize and dequantize O to emulate FP8 # quantize and dequantize O to emulate FP8
...@@ -531,6 +568,9 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -531,6 +568,9 @@ class UnfusedDotProductAttention(torch.nn.Module):
if fp8_output: if fp8_output:
context_layer = O_quantizer(context_layer) context_layer = O_quantizer(context_layer)
if self.return_max_logit:
return context_layer, max_logit
return context_layer return context_layer
...@@ -1069,6 +1109,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1069,6 +1109,7 @@ class FusedAttnFunc(torch.autograd.Function):
softmax_offset, softmax_offset,
fp8_output, fp8_output,
layer_number, layer_number,
return_max_logit,
): ):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
...@@ -1104,6 +1145,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1104,6 +1145,7 @@ class FusedAttnFunc(torch.autograd.Function):
# FP8 attention: torch.float16 or torch.bfloat16 # FP8 attention: torch.float16 or torch.bfloat16
out_nominal_dtype = q.dtype out_nominal_dtype = q.dtype
max_logit = None
if fp8: if fp8:
fused_attention_backend = FusedAttnBackend["FP8"] fused_attention_backend = FusedAttnBackend["FP8"]
...@@ -1131,7 +1173,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1131,7 +1173,7 @@ class FusedAttnFunc(torch.autograd.Function):
# DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16 # DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16
# fp8_dtype = tex.DType.kFloat8E4M3 # fp8_dtype = tex.DType.kFloat8E4M3
# Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16 # Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16
out_, aux_ctx_tensors = fused_attn_fwd( out_, aux_ctx_tensors, *_ = fused_attn_fwd(
is_training, is_training,
max_seqlen_q, max_seqlen_q,
max_seqlen_kv, max_seqlen_kv,
...@@ -1207,7 +1249,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1207,7 +1249,7 @@ class FusedAttnFunc(torch.autograd.Function):
qkvo_tensors = (q, k, v, out) qkvo_tensors = (q, k, v, out)
else: else:
# q, k, v, out_: torch.Tensor; dtype = torch.float16 or torch.bfloat16 # q, k, v, out_: torch.Tensor; dtype = torch.float16 or torch.bfloat16
out_, aux_ctx_tensors = fused_attn_fwd( out_, aux_ctx_tensors, *max_logit = fused_attn_fwd(
is_training, is_training,
max_seqlen_q, max_seqlen_q,
max_seqlen_kv, max_seqlen_kv,
...@@ -1235,6 +1277,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1235,6 +1277,7 @@ class FusedAttnFunc(torch.autograd.Function):
window_size, window_size,
rng_gen, rng_gen,
softmax_offset, softmax_offset,
return_max_logit,
) )
out = out_ out = out_
out_ret = out_ out_ret = out_
...@@ -1329,10 +1372,12 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1329,10 +1372,12 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.use_FAv2_bwd = use_FAv2_bwd ctx.use_FAv2_bwd = use_FAv2_bwd
ctx.deterministic = deterministic ctx.deterministic = deterministic
if return_max_logit:
return out_ret, *max_logit
return out_ret return out_ret
@staticmethod @staticmethod
def backward(ctx, d_out): def backward(ctx, d_out, *_args):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
# d_out is expected to be in FP8 if is_output_fp8=True, # d_out is expected to be in FP8 if is_output_fp8=True,
...@@ -1576,6 +1621,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1576,6 +1621,7 @@ class FusedAttnFunc(torch.autograd.Function):
d_softmax_offset, d_softmax_offset,
None, None,
None, None,
None,
) )
...@@ -1616,6 +1662,7 @@ class FusedAttention(torch.nn.Module): ...@@ -1616,6 +1662,7 @@ class FusedAttention(torch.nn.Module):
layer_number: Optional[int] = None, layer_number: Optional[int] = None,
deterministic: bool = False, deterministic: bool = False,
softmax_type: str = "vanilla", softmax_type: str = "vanilla",
return_max_logit: Optional[bool] = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -1629,6 +1676,7 @@ class FusedAttention(torch.nn.Module): ...@@ -1629,6 +1676,7 @@ class FusedAttention(torch.nn.Module):
self.layer_number = 1 if layer_number is None else layer_number self.layer_number = 1 if layer_number is None else layer_number
self.deterministic = deterministic self.deterministic = deterministic
self.softmax_type = softmax_type self.softmax_type = softmax_type
self.return_max_logit = return_max_logit
def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
""" """
...@@ -1848,6 +1896,7 @@ class FusedAttention(torch.nn.Module): ...@@ -1848,6 +1896,7 @@ class FusedAttention(torch.nn.Module):
softmax_offset=softmax_offset, softmax_offset=softmax_offset,
fp8_output=fp8_output, fp8_output=fp8_output,
layer_number=self.layer_number, layer_number=self.layer_number,
return_max_logit=self.return_max_logit,
) )
else: else:
with self.attention_dropout_ctx(): with self.attention_dropout_ctx():
...@@ -1883,7 +1932,11 @@ class FusedAttention(torch.nn.Module): ...@@ -1883,7 +1932,11 @@ class FusedAttention(torch.nn.Module):
softmax_offset, softmax_offset,
fp8_output, fp8_output,
self.layer_number, self.layer_number,
self.return_max_logit,
) )
if self.return_max_logit:
# ...hd -> ...(hd)
return output[0].view(*output[0].shape[:-2], -1), output[1]
# ...hd -> ...(hd) # ...hd -> ...(hd)
return output.view(*output.shape[:-2], -1) return output.view(*output.shape[:-2], -1)
...@@ -255,6 +255,12 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -255,6 +255,12 @@ class DotProductAttention(TransformerEngineBaseModule):
where alpha is a learnable parameter in shape [h]. where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention 'off-by-one' and 'learnable' softmax types are also called sink attention
('zero sink' and 'learnable sink'). ('zero sink' and 'learnable sink').
return_max_logit: Optional[bool], default = `False`
If true, returns the maximum attention score that can be used in a Muon optimizer to
rescale the Q and K projection weights (see `Muon is Scalable for LLM Training
<https://arxiv.org/pdf/2502.16982>`_).
max_logit = max(S), where S = mask(Q*K^T*softmax_scale + bias) in shape [b, h, s_q, s_kv],
and max_logit is in shape [h].
Parallelism parameters Parallelism parameters
---------------------- ----------------------
...@@ -311,6 +317,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -311,6 +317,7 @@ class DotProductAttention(TransformerEngineBaseModule):
cp_comm_type: str = "p2p", cp_comm_type: str = "p2p",
softmax_scale: Optional[float] = None, softmax_scale: Optional[float] = None,
softmax_type: str = "vanilla", softmax_type: str = "vanilla",
return_max_logit: Optional[bool] = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -394,6 +401,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -394,6 +401,7 @@ class DotProductAttention(TransformerEngineBaseModule):
self.attention_type = attention_type self.attention_type = attention_type
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.return_max_logit = return_max_logit
self.softmax_type = softmax_type self.softmax_type = softmax_type
if self.softmax_type == "vanilla": if self.softmax_type == "vanilla":
...@@ -431,6 +439,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -431,6 +439,7 @@ class DotProductAttention(TransformerEngineBaseModule):
deterministic=self.deterministic, deterministic=self.deterministic,
**attn_kwargs, **attn_kwargs,
softmax_type=self.softmax_type, softmax_type=self.softmax_type,
return_max_logit=self.return_max_logit,
) )
self.unfused_attention = UnfusedDotProductAttention( self.unfused_attention = UnfusedDotProductAttention(
...@@ -439,6 +448,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -439,6 +448,7 @@ class DotProductAttention(TransformerEngineBaseModule):
**attn_kwargs, **attn_kwargs,
layer_number=layer_number, layer_number=layer_number,
softmax_type=self.softmax_type, softmax_type=self.softmax_type,
return_max_logit=self.return_max_logit,
) )
def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
...@@ -1303,6 +1313,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1303,6 +1313,7 @@ class DotProductAttention(TransformerEngineBaseModule):
fp8_meta=self.fp8_meta, fp8_meta=self.fp8_meta,
inference_params=inference_params, inference_params=inference_params,
softmax_type=self.softmax_type, softmax_type=self.softmax_type,
return_max_logit=self.return_max_logit,
) )
global _attention_backends global _attention_backends
if is_in_onnx_export_mode(): if is_in_onnx_export_mode():
...@@ -1502,6 +1513,8 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1502,6 +1513,8 @@ class DotProductAttention(TransformerEngineBaseModule):
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
attention_mask=attention_mask, attention_mask=attention_mask,
window_size=window_size, window_size=window_size,
...@@ -1523,6 +1536,8 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1523,6 +1536,8 @@ class DotProductAttention(TransformerEngineBaseModule):
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
attention_mask=attention_mask, attention_mask=attention_mask,
window_size=window_size, window_size=window_size,
......
...@@ -139,6 +139,7 @@ def fused_attn_fwd( ...@@ -139,6 +139,7 @@ def fused_attn_fwd(
window_size: Tuple[int, int] = (-1, -1), window_size: Tuple[int, int] = (-1, -1),
rng_gen: torch.Generator = None, rng_gen: torch.Generator = None,
softmax_offset: torch.Tensor = None, softmax_offset: torch.Tensor = None,
return_max_logit: bool = False,
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention FWD for separate QKV input. """Fused Attention FWD for separate QKV input.
...@@ -216,6 +217,8 @@ def fused_attn_fwd( ...@@ -216,6 +217,8 @@ def fused_attn_fwd(
softmax_offset: torch.Tensor, default = None softmax_offset: torch.Tensor, default = None
softmax offset tensor in shape [1, h_q, 1, 1]. softmax offset tensor in shape [1, h_q, 1, 1].
See softmax_type in DotProductAttention for details. See softmax_type in DotProductAttention for details.
return_max_logit: bool, default = False
whether to return the maximum attention score
Returns Returns
---------- ----------
...@@ -246,6 +249,7 @@ def fused_attn_fwd( ...@@ -246,6 +249,7 @@ def fused_attn_fwd(
rng_state: torch.Tensor, optional, if backend is not F16_max512_seqlen rng_state: torch.Tensor, optional, if backend is not F16_max512_seqlen
state of the random number generator; state of the random number generator;
[seed, offset], dtype uint64 [seed, offset], dtype uint64
max_logit: if return_max_logit = True, shape [h] and same data type as O; otherwise None
""" """
if attn_scale is None: if attn_scale is None:
...@@ -315,8 +319,22 @@ def fused_attn_fwd( ...@@ -315,8 +319,22 @@ def fused_attn_fwd(
softmax_offset, softmax_offset,
rng_gen, rng_gen,
rng_elts_per_thread, rng_elts_per_thread,
return_max_logit,
) )
if return_max_logit:
qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0]
# thd: output_tensors: out [tq, h, d], Max [tq, h, 1], Sum_Exp [tq, h, 1]
# bshd: output_tensors: out [b, sq, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1]
# sbhd: output_tensors: out [sq, b, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1]
stats = output_tensors[1] + torch.log(output_tensors[2])
amax_dims = (0, 2) if qkv_format == "thd" else (0, 2, 3)
# Max -> max_logit [h]
max_logit = torch.amax(output_tensors[1], dim=amax_dims).to(dtype=output_tensors[0].dtype)
aux_ctx_tensors = [stats]
aux_ctx_tensors.extend(output_tensors[3:])
return output_tensors[0], aux_ctx_tensors, max_logit
# out, aux_ctx_tensors # out, aux_ctx_tensors
return output_tensors[0], output_tensors[1:] return output_tensors[0], output_tensors[1:]
......
...@@ -76,7 +76,7 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( ...@@ -76,7 +76,7 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right); int64_t window_size_right, bool return_max_logit);
std::pair<TensorWrapper, py::object> quantizer_helper(py::handle quantizer, std::pair<TensorWrapper, py::object> quantizer_helper(py::handle quantizer,
const std::vector<size_t> &shape, DType dtype, const std::vector<size_t> &shape, DType dtype,
...@@ -94,7 +94,7 @@ std::vector<py::object> fused_attn_fwd( ...@@ -94,7 +94,7 @@ std::vector<py::object> fused_attn_fwd(
const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v, const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias, py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias,
const std::optional<at::Tensor> SoftmaxOffset, const std::optional<at::Generator> rng_gen, const std::optional<at::Tensor> SoftmaxOffset, const std::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread); size_t rng_elts_per_thread, bool return_max_logit);
std::vector<py::object> fused_attn_bwd( std::vector<py::object> fused_attn_bwd(
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment