Commit 9df0c4a3 authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main'

parents 0d874a4e f122b07d
......@@ -44,6 +44,7 @@ __all__ = ["act_lu", "dact_lu", "quantize_dact_dbias"]
ActivationEnum = {
("gelu",): NVTE_Activation_Type.GELU,
("gelu", "linear"): NVTE_Activation_Type.GEGLU,
("sigmoid", "linear"): NVTE_Activation_Type.GLU,
("silu",): NVTE_Activation_Type.SILU,
("silu", "linear"): NVTE_Activation_Type.SWIGLU,
("relu",): NVTE_Activation_Type.RELU,
......
......@@ -70,6 +70,7 @@ __all__ = [
"is_training",
"max_segments_per_seq",
"window_size",
"bottom_right_diagonal",
"context_parallel_load_balanced",
"cp_axis",
"cp_striped_window_size",
......@@ -91,6 +92,7 @@ class _FusedAttnConfig:
is_training: bool
max_segments_per_seq: int
window_size: Tuple[int, int]
bottom_right_diagonal: bool
context_parallel_load_balanced: bool
cp_axis: str
cp_striped_window_size: Tuple[int, int] # Only for CP + Ring P2P + THD + SWA
......@@ -144,6 +146,7 @@ class FusedAttnHelper:
self.head_dim_v,
self.window_size[0],
self.window_size[1],
not self.is_non_deterministic_allowed(),
)
@staticmethod
......@@ -370,6 +373,11 @@ class FusedAttnFwdPrimitive(BasePrimitive):
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
bottom_right_diagonal = config.attn_mask_type in [
AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK,
AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
]
# do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to
# prepare for the active fused-attn backend
input_batch = reduce(operator.mul, batch_shape)
......@@ -394,6 +402,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
config.max_segments_per_seq,
config.window_size[0],
config.window_size[1],
bottom_right_diagonal,
)
wkspace_aval = q_aval.update(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
......@@ -502,6 +511,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
window_size_left=window_size_left,
window_size_right=window_size_right,
bottom_right_diagonal=config.bottom_right_diagonal,
softmax_type=int(config.softmax_type.value),
)
......@@ -812,6 +822,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
config.max_segments_per_seq,
config.window_size[0],
config.window_size[1],
config.bottom_right_diagonal,
)
dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype)
......@@ -947,6 +958,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
window_size_left=window_size_left,
window_size_right=window_size_right,
bottom_right_diagonal=config.bottom_right_diagonal,
softmax_type=int(config.softmax_type.value),
)
......@@ -1356,9 +1368,10 @@ class _FusedAttnCPWithAllGatherHelper:
def get_step_config(self) -> _FusedAttnConfig:
"""Returns a _FusedAttnConfig for single CP step call to fused attention."""
adjusted_mask = self.get_adjusted_mask()
return _FusedAttnConfig(
attn_bias_type=self.config.attn_bias_type,
attn_mask_type=self.get_adjusted_mask(),
attn_mask_type=adjusted_mask,
softmax_type=self.config.softmax_type,
qkv_layout=self.config.qkv_layout,
scaling_factor=self.config.scaling_factor,
......@@ -1366,6 +1379,7 @@ class _FusedAttnCPWithAllGatherHelper:
is_training=self.config.is_training,
max_segments_per_seq=self.config.max_segments_per_seq,
window_size=self.config.window_size,
bottom_right_diagonal=adjusted_mask.is_bottom_right(),
context_parallel_load_balanced=self.config.context_parallel_load_balanced,
cp_axis=self.config.cp_axis,
cp_striped_window_size=None,
......@@ -1374,9 +1388,10 @@ class _FusedAttnCPWithAllGatherHelper:
def get_step_config_for_striped(self, max_seqlen, cp_size) -> _FusedAttnConfig:
"""Returns a _FusedAttnConfig for single CP step call (made via a striped AG primitive) to fused attention."""
adjusted_mask = self.get_adjusted_mask()
return _FusedAttnConfig(
attn_bias_type=self.config.attn_bias_type,
attn_mask_type=self.get_adjusted_mask(),
attn_mask_type=adjusted_mask,
softmax_type=self.config.softmax_type,
qkv_layout=self.config.qkv_layout,
scaling_factor=self.config.scaling_factor,
......@@ -1384,6 +1399,7 @@ class _FusedAttnCPWithAllGatherHelper:
is_training=self.config.is_training,
max_segments_per_seq=self.get_adjusted_max_segments_per_seq(max_seqlen, cp_size),
window_size=self.config.window_size,
bottom_right_diagonal=adjusted_mask.is_bottom_right(),
context_parallel_load_balanced=self.config.context_parallel_load_balanced,
cp_axis=self.config.cp_axis,
cp_striped_window_size=None,
......@@ -2429,6 +2445,7 @@ class _FusedAttnCPWithP2PHelper:
is_training=self.config.is_training,
max_segments_per_seq=self.config.max_segments_per_seq,
window_size=self.config.window_size,
bottom_right_diagonal=attn_mask_type.is_bottom_right(),
context_parallel_load_balanced=self.config.context_parallel_load_balanced,
cp_axis=self.config.cp_axis,
cp_striped_window_size=None,
......@@ -3417,6 +3434,7 @@ def fused_attn_fwd(
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
window_size=(-1, -1) if window_size is None else window_size,
bottom_right_diagonal=attn_mask_type.is_bottom_right(),
context_parallel_load_balanced=context_parallel_causal_load_balanced,
cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
cp_striped_window_size=None,
......@@ -3563,13 +3581,21 @@ def fused_attn_bwd(
softmax_offset, (None, HEAD_AXES, None, None)
)
# TODO(KshitijLakhani): Add a check for cuDNN version when determinism does get supported on
# sm100+
compute_capabilities = get_all_device_compute_capability()
if any(x >= 100 for x in compute_capabilities):
assert not (
attn_bias_type != AttnBiasType.NO_BIAS and dropout_probability != 0
), "For sm100+, bprop kernel support for dropout + determinism (bias) is not supported"
if any(x >= 100 for x in compute_capabilities) and is_training:
assert (
FusedAttnHelper.is_non_deterministic_allowed()
and get_cudnn_version() >= (9, 7, 0)
and (attn_bias_type == AttnBiasType.NO_BIAS or dropout_probability == 0.0)
) or (
not FusedAttnHelper.is_non_deterministic_allowed()
and get_cudnn_version() >= (9, 18, 1)
and attn_bias_type == AttnBiasType.NO_BIAS
and dropout_probability == 0.0
), (
"For sm100+, non-deterministic bprop (cuDNN 9.7+) does not support bias with dropout,"
" and deterministic bprop (cuDNN 9.18.1+) does not support bias or dropout"
)
fused_config = _FusedAttnConfig(
attn_bias_type=attn_bias_type,
......@@ -3581,6 +3607,7 @@ def fused_attn_bwd(
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
window_size=(-1, -1) if window_size is None else window_size,
bottom_right_diagonal=attn_mask_type.is_bottom_right(),
context_parallel_load_balanced=context_parallel_causal_load_balanced,
cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
cp_striped_window_size=None,
......
......@@ -113,7 +113,7 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left,
int64_t window_size_right);
int64_t window_size_right, bool deterministic);
pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
......@@ -121,7 +121,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout,
DType dtype, bool is_training, size_t max_segments_per_seq, int64_t window_size_left,
int64_t window_size_right);
int64_t window_size_right, bool bottom_right_diagonal);
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
......@@ -129,7 +129,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout,
DType dtype, bool is_training, bool deterministic, size_t max_segments_per_seq,
int64_t window_size_left, int64_t window_size_right);
int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal);
// GEMM
XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler);
......@@ -143,6 +143,9 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(RHTAmaxCalculationInitializeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(RHTAmaxCalculationHandler);
// Inspect
XLA_FFI_DECLARE_HANDLER_SYMBOL(InspectHandler);
// Cudnn helpers
XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler);
......
......@@ -109,6 +109,9 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
case NVTE_Activation_Type::GEGLU:
nvte_geglu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::GLU:
nvte_glu(input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SILU:
nvte_silu(input_tensor.data(), output_tensor.data(), stream);
break;
......@@ -427,6 +430,9 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
case NVTE_Activation_Type::GEGLU:
nvte_dgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::GLU:
nvte_dglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SWIGLU:
nvte_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
......
......@@ -5,8 +5,6 @@
************************************************************************/
#include <cuda_runtime.h>
#include <iostream>
#include "../extensions.h"
#include "transformer_engine/cast.h"
#include "transformer_engine/hadamard_transform.h"
......
......@@ -16,12 +16,12 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left,
int64_t window_size_right) {
int64_t window_size_right, bool deterministic) {
auto backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout,
bias_type, mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads,
q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right,
false, false);
false, false, deterministic);
return backend;
}
......@@ -144,7 +144,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout,
DType dtype, bool is_training, size_t max_segments_per_seq, int64_t window_size_left,
int64_t window_size_right) {
int64_t window_size_right, bool bottom_right_diagonal) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
......@@ -192,7 +192,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
ragged_offset_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, query_workspace_tensor.data(), nullptr);
window_size_left, window_size_right, bottom_right_diagonal, query_workspace_tensor.data(),
nullptr);
}
nvte_tensor_pack_destroy(&aux_output_tensors);
......@@ -237,7 +238,7 @@ static void FusedAttnForwardImpl(
size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout,
DType dtype, DType wkspace_dtype, bool is_training, bool deterministic,
int64_t window_size_left, int64_t window_size_right) {
int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal) {
FUSED_ATTN_IMPL_COMMON_BLOCK;
/* Input tensors */
......@@ -266,7 +267,7 @@ static void FusedAttnForwardImpl(
is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups,
q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right,
false, false);
false, false, deterministic);
nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
/* Auxiliary tensors (to be propagated to the backward pass later) */
......@@ -328,7 +329,7 @@ static void FusedAttnForwardImpl(
k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, workspace_tensor.data(), stream);
window_size_left, window_size_right, bottom_right_diagonal, workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_output_tensors);
}
......@@ -346,6 +347,7 @@ static void FusedAttnForwardImpl(
size_t max_segments_per_seq = get_attr_value<int64_t>(attrs, "max_segments_per_seq"); \
auto window_size_left = get_attr_value<int64_t>(attrs, "window_size_left"); \
auto window_size_right = get_attr_value<int64_t>(attrs, "window_size_right"); \
bool bottom_right_diagonal = get_attr_value<bool>(attrs, "bottom_right_diagonal"); \
float scaling_factor = get_attr_value<double>(attrs, "scaling_factor"); \
float dropout_probability = get_attr_value<double>(attrs, "dropout_probability"); \
NVTE_Bias_Type bias_type = \
......@@ -384,7 +386,7 @@ Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Ty
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads,
qk_head_dim, v_head_dim, max_segments_per_seq, wkspace_size, scaling_factor,
dropout_probability, bias_type, mask_type, softmax_type, qkv_layout, dtype, wkspace_dtype,
is_training, deterministic, window_size_left, window_size_right);
is_training, deterministic, window_size_left, window_size_right, bottom_right_diagonal);
return ffi_with_cuda_error_check();
}
......@@ -415,7 +417,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout,
DType dtype, bool is_training, bool deterministic, size_t max_segments_per_seq,
int64_t window_size_left, int64_t window_size_right) {
int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype);
......@@ -467,17 +469,18 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
nvte_fused_attn_bwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, deterministic, false, query_workspace_tensor.data(), nullptr);
dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, bottom_right_diagonal, deterministic, false,
query_workspace_tensor.data(), nullptr);
}
nvte_tensor_pack_destroy(&aux_input_tensors);
......@@ -496,7 +499,7 @@ static void FusedAttnBackwardImpl(
size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout,
DType dtype, DType wkspace_dtype, bool is_training, bool deterministic,
int64_t window_size_left, int64_t window_size_right) {
int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal) {
FUSED_ATTN_IMPL_COMMON_BLOCK;
/* Input tensors */
......@@ -522,7 +525,7 @@ static void FusedAttnBackwardImpl(
is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups,
q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right,
false, false);
false, false, deterministic);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads,
bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend,
softmax_aux, rng_state, bias, softmax_offset);
......@@ -593,16 +596,17 @@ static void FusedAttnBackwardImpl(
}
}
nvte_fused_attn_bwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), dbias_tensor.data(),
dsoftmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, deterministic, false, workspace_tensor.data(), stream);
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), dsoftmax_offset_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen,
kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type,
mask_type, softmax_type, window_size_left, window_size_right,
bottom_right_diagonal, deterministic, false, workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_input_tensors);
}
......@@ -631,7 +635,7 @@ Error_Type FusedAttnBackwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_T
q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads, qk_head_dim, v_head_dim,
max_segments_per_seq, wkspace_size, scaling_factor, dropout_probability, bias_type, mask_type,
softmax_type, qkv_layout, dtype, wkspace_dtype, is_training, deterministic, window_size_left,
window_size_right);
window_size_right, bottom_right_diagonal);
return ffi_with_cuda_error_check();
}
......
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_runtime.h>
#include <fstream>
#include <iostream>
#include "../extensions.h"
#include "xla/ffi/api/c_api.h"
namespace transformer_engine {
namespace jax {
Error_Type InspectFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type min_buf,
Buffer_Type max_buf, Buffer_Type mean_buf, Buffer_Type std_buf,
Result_Type output_buf) {
NVTE_CHECK(input_buf.untyped_data() != nullptr, "Input must be provided for inspect operation");
NVTE_CHECK(output_buf->untyped_data() != nullptr,
"Output must be provided for inspect operation");
NVTE_CHECK(input_buf.untyped_data() == output_buf->untyped_data(),
"Input and output must point to the same buffer for inspect operation");
std::vector<uint8_t> input_data(input_buf.size_bytes());
NVTE_CHECK_CUDA(cudaMemcpyAsync(input_data.data(), input_buf.untyped_data(),
input_buf.size_bytes(), cudaMemcpyDeviceToHost, stream));
float min_val{}, max_val{}, mean_val{}, std_val{};
NVTE_CHECK_CUDA(cudaMemcpyAsync(&min_val, min_buf.untyped_data(), sizeof(float),
cudaMemcpyDeviceToHost, stream));
NVTE_CHECK_CUDA(cudaMemcpyAsync(&max_val, max_buf.untyped_data(), sizeof(float),
cudaMemcpyDeviceToHost, stream));
NVTE_CHECK_CUDA(cudaMemcpyAsync(&mean_val, mean_buf.untyped_data(), sizeof(float),
cudaMemcpyDeviceToHost, stream));
NVTE_CHECK_CUDA(cudaMemcpyAsync(&std_val, std_buf.untyped_data(), sizeof(float),
cudaMemcpyDeviceToHost, stream));
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
int device;
NVTE_CHECK_CUDA(cudaGetDevice(&device));
// Write the tensor data to a file as a binary blob
std::string filename = "my_tensor_gpu" + std::to_string(device) + ".bin";
std::ofstream file(filename, std::ios::binary);
NVTE_CHECK(file.is_open(), "Failed to create file: ", filename);
file.write(reinterpret_cast<const char *>(input_data.data()), input_data.size());
file.close();
// Write out a metadata file
std::string meta_filename = "my_tensor_gpu" + std::to_string(device) + "_meta.json";
std::ofstream meta_file(meta_filename);
NVTE_CHECK(meta_file.is_open(), "Failed to create file: ", meta_filename);
meta_file << "{";
meta_file << "\"shape\": [";
for (size_t i = 0; i < input_buf.dimensions().size(); ++i) {
meta_file << input_buf.dimensions()[i];
if (i < input_buf.dimensions().size() - 1) {
meta_file << ", ";
}
}
meta_file << "], ";
meta_file << "\"dtype\": " << static_cast<int>(input_buf.element_type());
meta_file << ", \"min\": " << min_val;
meta_file << ", \"max\": " << max_val;
meta_file << ", \"mean\": " << mean_val;
meta_file << ", \"std\": " << std_val;
meta_file << "}";
meta_file.close();
// Log the tensor metadata to the console
printf("[gpu%d]: Tensor data written to %s (shape: [", device, filename.c_str());
for (size_t i = 0; i < input_buf.dimensions().size(); ++i) {
printf("%zu", static_cast<size_t>(input_buf.dimensions()[i]));
if (i < input_buf.dimensions().size() - 1) {
printf(", ");
}
}
printf("], dtype: %d", static_cast<int>(input_buf.element_type()));
printf(", min: %f, max: %f, mean: %f, std: %f)\n", min_val, max_val, mean_val, std_val);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(InspectHandler, InspectFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // min
.Arg<Buffer_Type>() // max
.Arg<Buffer_Type>() // mean
.Arg<Buffer_Type>() // std
.Ret<Buffer_Type>() // output
);
} // namespace jax
} // namespace transformer_engine
......@@ -81,6 +81,9 @@ pybind11::dict Registrations() {
pybind11::arg("initialize") = EncapsulateFFI(RHTAmaxCalculationInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(RHTAmaxCalculationHandler));
dict["te_inspect_ffi"] =
pybind11::dict(pybind11::arg("execute") = EncapsulateFFI(InspectHandler));
return dict;
}
......@@ -150,6 +153,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
pybind11::enum_<NVTE_Activation_Type>(m, "NVTE_Activation_Type", pybind11::module_local())
.value("GELU", NVTE_Activation_Type::GELU)
.value("GEGLU", NVTE_Activation_Type::GEGLU)
.value("GLU", NVTE_Activation_Type::GLU)
.value("SILU", NVTE_Activation_Type::SILU)
.value("SWIGLU", NVTE_Activation_Type::SWIGLU)
.value("RELU", NVTE_Activation_Type::RELU)
......
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""EXPERIMENTAL debugging utilities for Transformer Engine JAX.
This API is experimental and may change or be removed without deprecation in future releases.
"""
__all__ = [
"experimental",
]
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""EXPERIMENTAL debugging utilities for Transformer Engine JAX.
This API is experimental and may change or be removed without deprecation in future releases.
"""
from .inspect import inspect_array, load_array_dump
__all__ = [
"inspect_array",
"load_array_dump",
]
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Experimental JAX array inspection utilities."""
from functools import partial
import jax
import jax.numpy as jnp
from jax import ffi
from transformer_engine.jax.cpp_extensions.base import BasePrimitive, register_primitive
__all__ = ["inspect_array", "load_array_dump"]
class InspectPrimitive(BasePrimitive):
"""
No-op used for inspect array values.
"""
name = "te_inspect_ffi"
multiple_results = False
impl_static_args = ()
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
x_aval,
x_min_aval,
x_max_aval,
x_mean_aval,
x_std_aval,
):
"""
inspect abstract
"""
assert (
x_min_aval.shape == () and x_min_aval.dtype == jnp.float32
), "x_min must be a scalar with dtype float32"
assert (
x_max_aval.shape == () and x_max_aval.dtype == jnp.float32
), "x_max must be a scalar with dtype float32"
assert (
x_mean_aval.shape == () and x_mean_aval.dtype == jnp.float32
), "x_mean must be a scalar with dtype float32"
assert (
x_std_aval.shape == () and x_std_aval.dtype == jnp.float32
), "x_std must be a scalar with dtype float32"
return x_aval
@staticmethod
def lowering(
ctx,
x,
x_min,
x_max,
x_mean,
x_std,
):
"""
inspect lowering rules
"""
return ffi.ffi_lowering(
InspectPrimitive.name,
operand_output_aliases={0: 0}, # donate input buffer to output buffer
)(
ctx,
x,
x_min,
x_max,
x_mean,
x_std,
)
@staticmethod
def impl(
x,
x_min,
x_max,
x_mean,
x_std,
):
"""
inspect implementation
"""
assert InspectPrimitive.inner_primitive is not None
(x) = InspectPrimitive.inner_primitive.bind(
x,
x_min,
x_max,
x_mean,
x_std,
)
return x
register_primitive(InspectPrimitive)
def _inspect_array_inner(x: jnp.ndarray) -> jnp.ndarray:
assert InspectPrimitive.outer_primitive is not None, (
"InspectPrimitive FFI is not registered. Please ensure the C++ extension is properly built"
" and registered."
)
return InspectPrimitive.outer_primitive.bind(
x,
jnp.min(x).astype(jnp.float32),
jnp.max(x).astype(jnp.float32),
jnp.mean(x.astype(jnp.float32)),
jnp.std(x.astype(jnp.float32)),
)
@partial(jax.custom_vjp, nondiff_argnums=())
def _inspect(
x,
):
""" """
output, _ = _inspect_fwd_rule(
x,
)
return output
def _inspect_fwd_rule(
x,
):
""""""
ctx = ()
x = _inspect_array_inner(x)
return x, ctx
def _inspect_bwd_rule(
ctx,
grad,
):
""""""
del ctx
return (grad,)
_inspect.defvjp(_inspect_fwd_rule, _inspect_bwd_rule)
def inspect_array(x: jnp.ndarray, name: str) -> jnp.ndarray:
"""Utility function to inspect JAX arrays by printing their name, shape, dtype, and statistics.
Args:
x (jnp.ndarray): The JAX array to inspect.
name (str): The name of the array for identification in the output.
"""
del name # Name is currently unused, but can be included in the future for more informative output
return _inspect(x)
def load_array_dump(filename: str, shape: tuple, dtype: jnp.dtype) -> jnp.ndarray:
"""Utility function to load a JAX array from a dumped binary file.
Args:
filename (str): The path to the binary file containing the array data.
shape (tuple): The shape of the array to be loaded.
dtype (jnp.dtype): The data type of the array to be loaded.
Returns:
jnp.ndarray: The loaded JAX array.
"""
with open(filename, "rb") as f:
data = f.read()
array = jnp.frombuffer(data, dtype=dtype).reshape(shape)
return array
......@@ -52,7 +52,7 @@ def token_dispatch(
Optional[jnp.ndarray],
jnp.ndarray,
Optional[jnp.ndarray],
Optional[jnp.ndarray],
jnp.ndarray,
]:
"""
Dispatch tokens to experts based on routing map.
......@@ -101,9 +101,11 @@ def token_dispatch(
pad_offsets : Optional[jnp.ndarray]
Per-expert cumulative padding offsets of shape [num_experts] when using padding,
None otherwise. Pass this to token_combine when unpadding is needed.
target_tokens_per_expert : Optional[jnp.ndarray]
Aligned token counts per expert of shape [num_experts] when using padding,
None otherwise.
tokens_per_expert : jnp.ndarray
Token counts per expert of shape [num_experts]:
- Without padding: actual token counts (sum of routing_map columns)
- With padding: aligned token counts (ceil(actual / align_size) * align_size)
This gives the effective number of tokens per expert in the output buffer.
Note
----
......@@ -151,10 +153,10 @@ def _token_dispatch(
Optional[jnp.ndarray],
jnp.ndarray,
Optional[jnp.ndarray],
Optional[jnp.ndarray],
jnp.ndarray,
]:
"""Internal token_dispatch with custom VJP."""
(output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert), _ = (
(output, permuted_probs, row_id_map, pad_offsets, tokens_per_expert), _ = (
_token_dispatch_fwd_rule(
inp,
routing_map,
......@@ -165,7 +167,7 @@ def _token_dispatch(
use_padding,
)
)
return output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert
return output, permuted_probs, row_id_map, pad_offsets, tokens_per_expert
def _token_dispatch_fwd_rule(
......@@ -182,7 +184,7 @@ def _token_dispatch_fwd_rule(
Optional[jnp.ndarray],
jnp.ndarray,
Optional[jnp.ndarray],
Optional[jnp.ndarray],
jnp.ndarray,
],
Tuple[jnp.ndarray, Optional[jnp.ndarray], int, int, int, bool],
]:
......@@ -212,11 +214,11 @@ def _token_dispatch_fwd_rule(
with_probs = probs is not None
if use_padding:
# Compute tokens_per_expert internally from routing_map
# This can be a traced value since output shape uses worst_case_out_tokens
# Compute tokens_per_expert from routing_map (actual counts)
# This is well-optimized by XLA as a simple column-wise reduction
tokens_per_expert = jnp.sum(routing_map, axis=0).astype(jnp.int32)
if use_padding:
# Calculate aligned token counts per expert
target_tokens_per_expert = (jnp.ceil(tokens_per_expert / align_size) * align_size).astype(
jnp.int32
......@@ -242,10 +244,12 @@ def _token_dispatch_fwd_rule(
hidden_size,
align_size=align_size,
)
# Return aligned counts when using padding
out_tokens_per_expert = target_tokens_per_expert
else:
# No padding
pad_offsets = None
target_tokens_per_expert = None
output, permuted_probs = permute_with_mask_map(
inp,
......@@ -257,14 +261,20 @@ def _token_dispatch_fwd_rule(
hidden_size,
)
# Return actual counts when not using padding
out_tokens_per_expert = tokens_per_expert
# Return (primals, residuals)
# out_tokens_per_expert is:
# - target_tokens_per_expert (aligned) when using padding
# - tokens_per_expert (actual) when not using padding
residuals = (row_id_map, pad_offsets, num_tokens, num_experts, hidden_size, with_probs)
return (
output,
permuted_probs,
row_id_map,
pad_offsets,
target_tokens_per_expert,
out_tokens_per_expert,
), residuals
......@@ -571,7 +581,7 @@ def sort_chunks_by_index(
return _sort_chunks_by_index(inp, split_sizes, sorted_indices)
@partial(jax.custom_vjp, nondiff_argnums=(1, 2))
@jax.custom_vjp
def _sort_chunks_by_index(
inp: jnp.ndarray,
split_sizes: jnp.ndarray,
......@@ -586,7 +596,7 @@ def _sort_chunks_by_index_fwd_rule(
inp: jnp.ndarray,
split_sizes: jnp.ndarray,
sorted_indices: jnp.ndarray,
) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], Tuple[jnp.ndarray, int, int]]:
) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, int, int]]:
"""Forward pass rule for sort_chunks_by_index."""
# Validate input dimensions
assert inp.ndim in [2, 3], f"inp must be 2D or 3D, got {inp.ndim}D"
......@@ -608,18 +618,17 @@ def _sort_chunks_by_index_fwd_rule(
)
# Return (primals, residuals)
residuals = (row_id_map, num_tokens, hidden_size)
# Include split_sizes and sorted_indices in residuals since we removed nondiff_argnums
residuals = (row_id_map, split_sizes, sorted_indices, num_tokens, hidden_size)
return (output, row_id_map), residuals
def _sort_chunks_by_index_bwd_rule(
_split_sizes: jnp.ndarray,
_sorted_indices: jnp.ndarray,
residuals: Tuple[jnp.ndarray, int, int],
residuals: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, int, int],
g: Tuple[jnp.ndarray, jnp.ndarray],
) -> Tuple[jnp.ndarray]:
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Backward pass rule for sort_chunks_by_index."""
row_id_map, num_tokens, hidden_size = residuals
row_id_map, split_sizes, sorted_indices, num_tokens, hidden_size = residuals
output_grad, _ = g
# Backward: reverse the sort
......@@ -632,7 +641,12 @@ def _sort_chunks_by_index_bwd_rule(
is_forward=False,
)
return (inp_grad,)
# Return gradients for all inputs: (inp, split_sizes, sorted_indices)
# split_sizes and sorted_indices are integer arrays, so their gradients are zeros
split_sizes_grad = jnp.zeros_like(split_sizes, dtype=split_sizes.dtype)
sorted_indices_grad = jnp.zeros_like(sorted_indices, dtype=sorted_indices.dtype)
return (inp_grad, split_sizes_grad, sorted_indices_grad)
_sort_chunks_by_index.defvjp(_sort_chunks_by_index_fwd_rule, _sort_chunks_by_index_bwd_rule)
......@@ -65,8 +65,6 @@ class RowIdMapPass1Primitive(BasePrimitive):
@staticmethod
def abstract(routing_map_aval, *, num_tokens, num_experts, block_size):
"""Shape/dtype inference for pass 1."""
del block_size # Only affects grid, not output shape
assert routing_map_aval.shape == (
num_tokens,
num_experts,
......@@ -75,7 +73,7 @@ class RowIdMapPass1Primitive(BasePrimitive):
row_id_map_shape = (num_tokens, num_experts * 2 + 1)
workspace_shape = (
num_experts,
triton.cdiv(num_tokens, DEFAULT_BLOCK_SIZE),
triton.cdiv(num_tokens, block_size),
)
return (
......@@ -134,9 +132,10 @@ class RowIdMapPass1Primitive(BasePrimitive):
desc="RowIdMapPass1.row_id_map_sharding",
)
# Workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE))
# Second dim depends on num_tokens, so it must be sharded on the same axis as tokens
workspace_sharding = NamedSharding(
mesh,
PartitionSpec(None, None),
PartitionSpec(None, routing_map_spec[0]),
desc="RowIdMapPass1.workspace_sharding",
)
return [row_id_map_sharding, workspace_sharding]
......@@ -156,9 +155,11 @@ class RowIdMapPass1Primitive(BasePrimitive):
PartitionSpec(routing_map_spec[0], None),
desc="RowIdMapPass1.row_id_map_sharding",
)
# Workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE))
# Second dim depends on num_tokens, so it must be sharded on the same axis as tokens
workspace_sharding = NamedSharding(
mesh,
PartitionSpec(None, None),
PartitionSpec(None, routing_map_spec[0]),
desc="RowIdMapPass1.workspace_sharding",
)
out_shardings = [row_id_map_sharding, workspace_sharding]
......@@ -186,7 +187,8 @@ class RowIdMapPass1Primitive(BasePrimitive):
# Note: row_id_cols != experts since it's num_experts * 2 + 1
row_id_map_spec = (f"{prefix}_tokens", f"{prefix}_row_id_cols")
# workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE))
workspace_spec = (f"{prefix}_experts", f"{prefix}_ws_blocks")
# Second dim depends on num_tokens, so use same factor to ensure same sharding
workspace_spec = (f"{prefix}_experts", f"{prefix}_tokens")
return SdyShardingRule((input_spec,), (row_id_map_spec, workspace_spec))
......@@ -208,10 +210,9 @@ class RowIdMapPass2Primitive(BasePrimitive):
def abstract(row_id_map_aval, workspace_aval, *, num_tokens, num_experts, block_size):
"""Shape/dtype inference for pass 2 (in-place operation)."""
del row_id_map_aval, workspace_aval
del block_size
row_id_map_shape = (num_tokens, num_experts * 2 + 1)
workspace_shape = (num_experts, triton.cdiv(num_tokens, DEFAULT_BLOCK_SIZE))
workspace_shape = (num_experts, triton.cdiv(num_tokens, block_size))
return (
jax.core.ShapedArray(row_id_map_shape, jnp.int32),
......@@ -270,9 +271,11 @@ class RowIdMapPass2Primitive(BasePrimitive):
PartitionSpec(*row_id_map_spec),
desc="RowIdMapPass2.row_id_map_sharding",
)
# Workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE))
# Second dim depends on num_tokens, so it must be sharded on the same axis as tokens
workspace_sharding = NamedSharding(
mesh,
PartitionSpec(None, None),
PartitionSpec(None, row_id_map_spec[0]),
desc="RowIdMapPass2.workspace_sharding",
)
return [row_id_map_sharding, workspace_sharding]
......@@ -292,9 +295,11 @@ class RowIdMapPass2Primitive(BasePrimitive):
PartitionSpec(*row_id_map_spec),
desc="RowIdMapPass2.row_id_map_sharding",
)
# Workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE))
# Second dim depends on num_tokens, so it must be sharded on the same axis as tokens
workspace_sharding = NamedSharding(
mesh,
PartitionSpec(None, None),
PartitionSpec(None, row_id_map_spec[0]),
desc="RowIdMapPass2.workspace_sharding",
)
out_shardings = [row_id_map_sharding, workspace_sharding]
......@@ -317,7 +322,9 @@ class RowIdMapPass2Primitive(BasePrimitive):
del num_tokens, num_experts, block_size, mesh, value_types, result_types
prefix = "RowIdMapPass2"
row_id_map_spec = (f"{prefix}_tokens", f"{prefix}_cols")
workspace_spec = (f"{prefix}_ws_experts", f"{prefix}_ws_blocks")
# workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE))
# Second dim depends on num_tokens, so use same factor to ensure same sharding
workspace_spec = (f"{prefix}_ws_experts", f"{prefix}_tokens")
return SdyShardingRule((row_id_map_spec, workspace_spec), (row_id_map_spec, workspace_spec))
......
......@@ -36,6 +36,8 @@ import warnings
from typing import Any, Callable, Mapping
import zlib
from packaging import version
from jax import core
import jax
import jax.numpy as jnp
......@@ -274,13 +276,16 @@ def compile_triton(
return _TRITON_KERNEL_CACHE[cache_key]
# Compile kernel
cuda_option_kwargs = {}
if version.parse(_TRITON_VERSION) < version.parse("3.6.0"):
cuda_option_kwargs["cluster_dims"] = (1, 1, 1)
options = cb.CUDAOptions(
num_warps=num_warps,
num_stages=num_stages,
num_ctas=num_ctas,
cluster_dims=(1, 1, 1),
debug=False,
enable_fp_fusion=enable_fp_fusion,
**cuda_option_kwargs,
)
# Mark constants as constexpr in signature
......@@ -303,8 +308,6 @@ def compile_triton(
# Create kernel object for JAX
# From jax/jaxlib/gpu/triton_kernels.cc:
from packaging import version
if version.parse(jax.__version__) >= version.parse("0.8.2"):
kernel = gpu_triton.TritonKernel(
compiled.name, # arg0: kernel_name (str)
......
......@@ -166,6 +166,11 @@ class FP8EmulationFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout):
# pylint: disable=missing-function-docstring
if is_in_onnx_export_mode():
return FP8EmulationFunc.onnx_forward(
tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout
)
if quantizer_name == "QKV_quantizer":
query_layer, key_layer, value_layer = [
x.contiguous() for x in [tensor1, tensor2, tensor3]
......@@ -204,6 +209,47 @@ class FP8EmulationFunc(torch.autograd.Function):
tensors = grad1, grad2, grad3
return tensors[0], tensors[1], tensors[2], None, None, None
@staticmethod
def onnx_forward(tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout=None):
"""
ONNX-compatible forward for FP8 emulation using operations with defined ONNX translations.
"""
# pylint: disable=unused-argument
is_qkv_quantizer = quantizer_name == "QKV_quantizer"
assert isinstance(
quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
), "ONNX FP8 emulation path supports only Float8 quantizers."
if is_qkv_quantizer:
# Flatten + concatenate + quantize + split. Equivalent to combine_and_quantize Case 3.
orig_dtype = tensor1.dtype
shapes = [tensor1.shape, tensor2.shape, tensor3.shape]
numels = [tensor1.numel(), tensor2.numel(), tensor3.numel()]
# Flatten and concatenate
combined = torch.cat(
[tensor1.reshape(-1), tensor2.reshape(-1), tensor3.reshape(-1)], dim=0
)
# Quantize + dequantize combined tensor using quantizer's ONNX methods
combined_fp8 = quantizer.onnx_quantize(combined)
out = quantizer.onnx_dequantize(combined_fp8).to(orig_dtype)
# Split back
out1 = out[: numels[0]].reshape(shapes[0])
out2 = out[numels[0] : numels[0] + numels[1]].reshape(shapes[1])
out3 = out[numels[0] + numels[1] :].reshape(shapes[2])
return out1, out2, out3
if quantizer_name in ["S_quantizer", "O_quantizer"]:
# Emulate FP8 on single tensor using quantizer's ONNX methods
orig_dtype = tensor1.dtype
t_fp8 = quantizer.onnx_quantize(tensor1)
out = quantizer.onnx_dequantize(t_fp8).to(orig_dtype)
return out, tensor2, tensor3
# Pass-through
return tensor1, tensor2, tensor3
class UnfusedDotProductAttention(torch.nn.Module):
"""Parallel attention w/o QKV and Proj Gemms
......@@ -263,6 +309,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
attn_mask_type: str = "causal",
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
window_size: Optional[Tuple[int, int]] = None,
bottom_right_diagonal: Optional[bool] = None,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
......@@ -348,6 +395,11 @@ class UnfusedDotProductAttention(torch.nn.Module):
attention_mask=attention_mask,
window_size=window_size,
attention_type=self.attention_type,
bottom_right_alignment=(
attn_mask_type not in ["causal", "padding_causal"]
if bottom_right_diagonal is None
else bottom_right_diagonal
),
)
)
......@@ -451,7 +503,11 @@ class UnfusedDotProductAttention(torch.nn.Module):
actual_seqlens_q=actual_seqlens_q if "padding" in attn_mask_type else None,
actual_seqlens_kv=actual_seqlens_kv if "padding" in attn_mask_type else None,
alibi_slopes=alibi_slopes,
bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
bottom_right_alignment=(
attn_mask_type not in ["causal", "padding_causal"]
if bottom_right_diagonal is None
else bottom_right_diagonal
),
)
matmul_result = torch.baddbmm(
matmul_result,
......@@ -1112,6 +1168,7 @@ class FusedAttnFunc(torch.autograd.Function):
attn_mask_type,
softmax_type,
window_size,
bottom_right_diagonal,
rng_gen,
fused_attention_backend,
use_FAv2_bwd,
......@@ -1215,6 +1272,7 @@ class FusedAttnFunc(torch.autograd.Function):
attn_mask_type,
softmax_type,
window_size,
bottom_right_diagonal,
rng_gen,
softmax_offset,
cuda_graph=is_graph_capturing(),
......@@ -1292,6 +1350,7 @@ class FusedAttnFunc(torch.autograd.Function):
attn_mask_type,
softmax_type,
window_size,
bottom_right_diagonal,
rng_gen,
softmax_offset,
return_max_logit,
......@@ -1379,6 +1438,7 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.attn_mask_type = attn_mask_type
ctx.softmax_type = softmax_type
ctx.window_size = window_size
ctx.bottom_right_diagonal = bottom_right_diagonal
ctx.fused_attention_backend = (
fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
)
......@@ -1529,6 +1589,7 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.attn_mask_type,
ctx.softmax_type,
ctx.window_size,
ctx.bottom_right_diagonal,
ctx.deterministic,
is_graph_capturing(),
)
......@@ -1594,6 +1655,7 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.attn_mask_type,
ctx.softmax_type,
ctx.window_size,
ctx.bottom_right_diagonal,
ctx.deterministic,
is_graph_capturing(),
)
......@@ -1633,6 +1695,7 @@ class FusedAttnFunc(torch.autograd.Function):
None,
None,
None,
None,
d_softmax_offset,
None,
None,
......@@ -1730,6 +1793,7 @@ class FusedAttention(torch.nn.Module):
attn_mask_type: str = "causal",
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
window_size: Optional[Tuple[int, int]] = None,
bottom_right_diagonal: Optional[bool] = None,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
......@@ -1937,6 +2001,7 @@ class FusedAttention(torch.nn.Module):
attn_mask_type,
self.softmax_type,
window_size,
bottom_right_diagonal,
None, # rng_gen
fused_attention_backend,
use_FAv2_bwd,
......
......@@ -4026,28 +4026,30 @@ def attn_forward_func_with_cp(
assert not sliding_window_attn or cp_comm_type in [
"a2a",
"all_gather",
], "Context parallelism does not support sliding window attention with {cp_comm_type=}!"
], f"Context parallelism does not support sliding window attention with {cp_comm_type=}!"
enable_mla = k.shape[-1] != v.shape[-1]
assert not enable_mla or cp_comm_type in [
"p2p",
"a2a+p2p",
], "Context parallelism does not support MLA with {cp_comm_type=}!"
], f"Context parallelism does not support MLA with {cp_comm_type=}!"
if fp8 and fp8_meta is not None:
if fp8_meta["recipe"].fp8_dpa:
assert (
softmax_type == "vanilla"
), "Context parallelism does not support {softmax_type=} with FP8 attention!"
), f"Context parallelism does not support {softmax_type=} with FP8 attention!"
assert (
softmax_type == "vanilla" or use_fused_attention
), "Context parallelism only supports {softmax_type=} with FusedAttention backend!"
), f"Context parallelism only supports {softmax_type=} with FusedAttention backend!"
assert (
softmax_type == "vanilla" or cp_comm_type == "a2a"
), "Context parallelism only supports {softmax_type=} with cp_comm_type = 'a2a'!"
assert (
softmax_type == "vanilla" or qkv_format != "thd"
), "Context parallelism does not support {softmax_type=} with qkv_format = 'thd'!"
), f"Context parallelism only supports {softmax_type=} with cp_comm_type = 'a2a'!"
if get_cudnn_version() < (9, 18, 0):
assert softmax_type == "vanilla" or qkv_format != "thd", (
f"Before cuDNN 9.18.0, context parallelism does not support {softmax_type=} with"
" qkv_format = 'thd'!"
)
args = [
is_training,
......
......@@ -228,6 +228,11 @@ class DotProductAttention(TransformerEngineBaseModule):
map to ``window_size = (-1, 0)`` and Transformer Engine distinguishes them based on
``attn_mask_type``. Similar to :attr:`attn_mask_type`, ``window_size`` can
be overridden by :attr:`window_size` in ``forward`` as well.
bottom_right_diagonal: Optional[bool], default = `None`
Align sliding window and ALiBi diagonal to the top left (`False`)
or bottom right (`True`) corner of the softmax matrix in the encoder.
If `None`, it will be set to `False` for `attn_mask_type` =
{'causal', 'padding_causal'} and `True` for other mask types.
attention_type : str, default = "self"
type of attention, either ``"self"`` and ``"cross"``.
layer_number : int, default = None
......@@ -324,6 +329,7 @@ class DotProductAttention(TransformerEngineBaseModule):
qkv_format: str = "sbhd",
attn_mask_type: str = "causal",
window_size: Optional[Tuple[int, int]] = None,
bottom_right_diagonal: Optional[bool] = None,
sequence_parallel: bool = False,
tp_size: int = 1,
get_rng_state_tracker: Optional[Callable] = None,
......@@ -350,6 +356,7 @@ class DotProductAttention(TransformerEngineBaseModule):
attn_mask_type = "padding_causal"
self.attn_mask_type = attn_mask_type
self.window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size)
self.bottom_right_diagonal = bottom_right_diagonal
if tp_group is None:
self.tp_size = tp_size
if tp_size == 1:
......@@ -676,9 +683,9 @@ class DotProductAttention(TransformerEngineBaseModule):
# assume attention uses the same fp8_group as GEMMs
fp8_group = FP8GlobalStateManager.get_fp8_group()
self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
self.fp8 = FP8GlobalStateManager.is_fp8_enabled()
self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
self.fast_setattr("fp8_parameters", FP8GlobalStateManager.with_fp8_parameters())
self.fast_setattr("fp8", FP8GlobalStateManager.is_fp8_enabled())
self.fast_setattr("fp8_calibration", FP8GlobalStateManager.is_fp8_calibration())
fp8_enabled = self.fp8 or self.fp8_calibration
self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration
if self.fp8_parameters or fp8_enabled:
......@@ -703,7 +710,7 @@ class DotProductAttention(TransformerEngineBaseModule):
)
else:
# If fp8 isn't enabled, turn off and return.
self.fp8_initialized = False
self.fast_setattr("fp8_initialized", False)
return
if self.fp8_parameters and not self.fp8_initialized:
......@@ -721,7 +728,7 @@ class DotProductAttention(TransformerEngineBaseModule):
# Allocate scales and amaxes
self.init_fp8_meta_tensors(fp8_recipes)
self.fp8_initialized = True
self.fast_setattr("fp8_initialized", True)
self.fp8_meta["recipe"] = fp8_recipe_dpa
if fp8_recipe != fp8_recipe_dpa:
......@@ -811,6 +818,7 @@ class DotProductAttention(TransformerEngineBaseModule):
max_seqlen_kv: int = None,
attn_mask_type: Optional[str] = None,
window_size: Optional[Tuple[int, int]] = None,
bottom_right_diagonal: Optional[bool] = None,
checkpoint_core_attention: bool = False,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
......@@ -963,6 +971,16 @@ class DotProductAttention(TransformerEngineBaseModule):
causal masks are aligned to the bottom right corner.
window_size: Optional[Tuple[int, int]], default = None
Sliding window size for local attention.
bottom_right_diagonal: Optional[bool], default = None
Align sliding window and ALiBi diagonal to the top left (`False`)
or bottom right (`True`) corner of the softmax matrix in the encoder.
If `None`, it will be set to `False` for `attn_mask_type` =
{'causal', 'padding_causal'} and `True` for other mask types.
Note: This parameter will be automatically overridden based on the
`attn_mask_type` - it will be forced to `False` for 'causal' and
'padding_causal' mask types, and forced to `True` for mask types
containing 'bottom_right' (e.g., 'causal_bottom_right',
'padding_causal_bottom_right'), regardless of the explicitly passed value.
checkpoint_core_attention : bool, default = False
If true, forward activations for attention are recomputed
during the backward pass in order to save memory that would
......@@ -1000,7 +1018,7 @@ class DotProductAttention(TransformerEngineBaseModule):
cases. It is ignored for other backends and when context parallelism is enabled.
"""
with self.prepare_forward(
with self.prepare_forward_ctx(
query_layer,
num_gemms=3,
allow_non_contiguous=True,
......@@ -1081,6 +1099,15 @@ class DotProductAttention(TransformerEngineBaseModule):
if window_size is None:
window_size = self.window_size
window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size)
if bottom_right_diagonal is None:
bottom_right_diagonal = self.bottom_right_diagonal
if attn_mask_type in {"causal", "padding_causal"}:
bottom_right_diagonal = False
if bottom_right_diagonal is None or attn_mask_type in {
"causal_bottom_right",
"padding_causal_bottom_right",
}:
bottom_right_diagonal = True
# checks for qkv_format
if qkv_format is None:
......@@ -1144,8 +1171,11 @@ class DotProductAttention(TransformerEngineBaseModule):
assert "padding" in attn_mask_type, "KV caching requires padding mask!"
if attn_mask_type == "padding_causal":
attn_mask_type = attn_mask_type + "_bottom_right"
# since attention mask is changed, set `bottom_right_diagonal` to True
bottom_right_diagonal = True
self.attention_type = "cross"
if self.attention_type != "cross":
self.fast_setattr("attention_type", "cross")
self.flash_attention.attention_type = self.attention_type
self.fused_attention.attention_type = self.attention_type
self.unfused_attention.attention_type = self.attention_type
......@@ -1256,7 +1286,6 @@ class DotProductAttention(TransformerEngineBaseModule):
if self.layer_number == 1:
_alibi_cache["_alibi_slopes_require_update"] = True
_alibi_cache["_alibi_bias_require_update"] = True
bottom_right_alignment = (attn_mask_type not in ["causal", "padding_causal"],)
if core_attention_bias_type == "alibi":
assert (
core_attention_bias is None
......@@ -1265,7 +1294,7 @@ class DotProductAttention(TransformerEngineBaseModule):
_alibi_cache["_num_heads"] != query_layer.shape[-2]
or _alibi_cache["_max_seqlen_q"] != max_seqlen_q
or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv
or _alibi_cache["_bottom_right_alignment"] != bottom_right_alignment
or _alibi_cache["_bottom_right_alignment"] != bottom_right_diagonal
or _alibi_cache["_alibi_slopes"] is None
):
_alibi_cache["_alibi_slopes_require_update"] = True
......@@ -1322,6 +1351,7 @@ class DotProductAttention(TransformerEngineBaseModule):
head_dim_v=head_dim_v,
attn_mask_type=attn_mask_type,
window_size=window_size,
bottom_right_diagonal=bottom_right_diagonal,
alibi_slopes_shape=alibi_slopes.shape if alibi_slopes is not None else None,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias_shape=core_attention_bias_shape,
......@@ -1445,9 +1475,7 @@ class DotProductAttention(TransformerEngineBaseModule):
if use_fused_attention:
fu_core_attention_bias_type = core_attention_bias_type
fu_core_attention_bias = core_attention_bias
if core_attention_bias_type == "alibi" and (
alibi_slopes is not None or max_seqlen_q != max_seqlen_kv
):
if core_attention_bias_type == "alibi" and (alibi_slopes is not None):
fu_core_attention_bias_type = "post_scale_bias"
_, fu_core_attention_bias = dpa_utils.get_alibi(
_alibi_cache,
......@@ -1456,7 +1484,7 @@ class DotProductAttention(TransformerEngineBaseModule):
max_seqlen_kv,
alibi_slopes=alibi_slopes,
bias_dtype=query_layer.dtype,
bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
bottom_right_alignment=bottom_right_diagonal,
)
if checkpoint_core_attention:
return self._checkpointed_attention_forward(
......@@ -1474,6 +1502,7 @@ class DotProductAttention(TransformerEngineBaseModule):
attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
window_size=window_size,
bottom_right_diagonal=bottom_right_diagonal,
fused_attention_backend=fused_attention_backend,
core_attention_bias_type=fu_core_attention_bias_type,
core_attention_bias=fu_core_attention_bias,
......@@ -1504,6 +1533,7 @@ class DotProductAttention(TransformerEngineBaseModule):
attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
window_size=window_size,
bottom_right_diagonal=bottom_right_diagonal,
fused_attention_backend=fused_attention_backend,
core_attention_bias_type=fu_core_attention_bias_type,
core_attention_bias=fu_core_attention_bias,
......@@ -1522,7 +1552,9 @@ class DotProductAttention(TransformerEngineBaseModule):
)
if use_unfused_attention:
allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1"
allow_emulation = (
os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" or is_in_onnx_export_mode()
)
if checkpoint_core_attention:
return self._checkpointed_attention_forward(
self.unfused_attention,
......@@ -1538,6 +1570,7 @@ class DotProductAttention(TransformerEngineBaseModule):
attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
window_size=window_size,
bottom_right_diagonal=bottom_right_diagonal,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes,
......@@ -1561,6 +1594,7 @@ class DotProductAttention(TransformerEngineBaseModule):
attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
window_size=window_size,
bottom_right_diagonal=bottom_right_diagonal,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes,
......
......@@ -200,6 +200,9 @@ class AttentionParams:
`causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`}
window_size : Tuple[int, int], default = None
Sliding window attention size.
bottom_right_diagonal: bool, default = `None`
Whether to align sliding window and ALiBi diagonal to the bottom right corner
of the softmax matrix.
alibi_slopes_shape : Optional[Union[torch.Size, List]], default = None
Tensor shape of :attr:`alibi_slopes` in `DotProductAttention`.
core_attention_bias_type : str, default = no_bias
......@@ -249,6 +252,7 @@ class AttentionParams:
head_dim_v: int = 64
attn_mask_type: str = "no_mask"
window_size: Union[Tuple[int, int], None] = None
bottom_right_diagonal: bool = True
alibi_slopes_shape: Union[torch.Size, List, None] = None
core_attention_bias_type: str = "no_bias"
core_attention_bias_shape: str = "1hss"
......@@ -325,6 +329,7 @@ def get_attention_backend(
head_dim_v = attention_params.head_dim_v
attn_mask_type = attention_params.attn_mask_type
window_size = attention_params.window_size
bottom_right_diagonal = attention_params.bottom_right_diagonal
alibi_slopes_shape = attention_params.alibi_slopes_shape
core_attention_bias_type = attention_params.core_attention_bias_type
core_attention_bias_shape = attention_params.core_attention_bias_shape
......@@ -474,7 +479,9 @@ def get_attention_backend(
logger.debug("Disabling FlashAttention 3 for FP8 training")
use_flash_attention_3 = False
if use_unfused_attention:
allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1"
allow_emulation = (
os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" or is_in_onnx_export_mode()
)
if not allow_emulation:
logger.debug("Disabling UnfusedDotProductAttention for FP8 attention")
use_unfused_attention = False
......@@ -730,22 +737,14 @@ def get_attention_backend(
)
use_unfused_attention = False
if qkv_format == "thd":
if cudnn_version < (9, 18, 0):
logger.debug(
"Disabling FusedAttention for softmax_type = %s and qkv_format = thd", softmax_type
)
use_fused_attention = False
logger.debug(
"Disabling UnfusedDotProductAttention for softmax_type = %s and qkv_format = thd",
"Disabling FusedAttention for softmax_type = %s, qkv_format = thd and cuDNN"
" version < 9.18",
softmax_type,
)
use_unfused_attention = False
use_fused_attention = False
if context_parallel:
logger.debug(
"Disabling UnfusedDotProductAttention for context parallelism with softmax_type"
" = %s",
softmax_type,
)
use_unfused_attention = False
if cp_comm_type != "a2a":
logger.debug(
"Disabling FusedAttention for context parallelism with softmax_type = %s and"
......@@ -881,23 +880,21 @@ def get_attention_backend(
# backend | window_size | diagonal alignment
# ---------------------------------------------------------------------------------
# FlashAttention | (-1, -1) or (>=0, >=0) | bottom right
# FusedAttention | (-1, 0) or (>=0, 0) | top left
# UnfusedDotProductAttention | (-1, -1) or (>=0, >=0) | both;
# FusedAttention | (-1, 0) or (>=0, >=0) | top left, bottom right
# UnfusedDotProductAttention | (-1, -1) or (>=0, >=0) | top left, bottom right
# | | converts window_size to an 'arbitrary' mask
if window_size is None:
window_size = check_set_window_size(attn_mask_type, window_size)
else:
if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]):
if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha):
logger.debug(
"Disabling FusedAttention as it does not support sliding window attention"
" for FP8"
"Disabling FusedAttention as it does not support sliding window attention for FP8"
)
use_fused_attention = False
elif window_size[1] != 0 or attention_dropout != 0.0:
elif attention_dropout != 0.0:
logger.debug(
"Disabling FusedAttention as it only supports sliding window attention "
"with (left, 0) and no dropout"
"without dropout"
)
use_fused_attention = False
elif max_seqlen_q > max_seqlen_kv:
......@@ -914,6 +911,12 @@ def get_attention_backend(
"Disabling FlashAttention as sliding window attention requires flash-attn 2.3+"
)
use_flash_attention_2 = False
elif not bottom_right_diagonal and max_seqlen_q != max_seqlen_kv:
logger.debug(
"Disabling FlashAttention as it only supports sliding window with bottom right"
" diagonal alignment for cross-attention"
)
use_flash_attention = False
# Filter: Attention bias
# backend | bias types | ALiBi diagonal alignment
......@@ -935,6 +938,12 @@ def get_attention_backend(
elif not FlashAttentionUtils.v2_4_plus:
logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+")
use_flash_attention_2 = False
elif not bottom_right_diagonal and max_seqlen_q != max_seqlen_kv:
logger.debug(
"Disabling FlashAttention as it only supports ALiBi with bottom right diagonal"
" alignment for cross-attention"
)
use_flash_attention = False
if (
core_attention_bias_type not in ["no_bias", "alibi"]
......@@ -952,13 +961,12 @@ def get_attention_backend(
if (
use_fused_attention
and core_attention_bias_type == "alibi"
and (alibi_slopes_shape is not None or max_seqlen_q != max_seqlen_kv)
and (alibi_slopes_shape is not None)
):
fu_core_attention_bias_type = "post_scale_bias"
fu_core_attention_bias_requires_grad = False
if alibi_slopes_shape is None:
fu_core_attention_bias_shape = "1hss"
elif len(alibi_slopes_shape) == 1 and alibi_slopes_shape[0] == num_heads:
if len(alibi_slopes_shape) == 1 and alibi_slopes_shape[0] == num_heads:
fu_core_attention_bias_shape = "1hss"
elif (
len(alibi_slopes_shape) == 2
......@@ -1008,6 +1016,7 @@ def get_attention_backend(
window_size[1],
return_max_logit,
cuda_graph,
deterministic,
)
if fused_attention_backend == FusedAttnBackend["No_Backend"]:
logger.debug("Disabling FusedAttention as no backend supports the provided input")
......@@ -1062,6 +1071,15 @@ def get_attention_backend(
)
use_flash_attention_2 = False
if use_fused_attention and deterministic:
if softmax_type != "vanilla":
logger.debug(
"Disabling FusedAttention for determinism reasons with softmax_type = %s. "
"Sink attention (off-by-one and learnable softmax) requires "
"NVTE_ALLOW_NONDETERMINISTIC_ALGO=1",
softmax_type,
)
use_fused_attention = False
fused_attention_backend = None
if fused_attention_backend == FusedAttnBackend["FP8"] and is_training:
logger.debug("Disabling FusedAttention for determinism reasons with FP8")
use_fused_attention = False
......@@ -1078,10 +1096,6 @@ def get_attention_backend(
logger.debug("Disabling FusedAttention for determinism reasons with post_scale_bias")
use_fused_attention = False
fused_attention_backend = None
if is_training and device_compute_capability >= (10, 0):
logger.debug("Disabling FusedAttention for determinism reasons on Blackwell")
use_fused_attention = False
fused_attention_backend = None
# use_flash_attention may have been set above
use_flash_attention_2 = use_flash_attention and use_flash_attention_2
......
......@@ -8,7 +8,6 @@ import collections
from typing import Callable, List, Optional, Tuple, Union
import torch
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.pytorch.quantization import FP8GlobalStateManager
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
......@@ -32,6 +31,7 @@ from transformer_engine.pytorch.distributed import (
from transformer_engine.pytorch.attention.dot_product_attention import DotProductAttention
from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.attention.rope import apply_rotary_pos_emb
from transformer_engine.pytorch.attention.dot_product_attention import utils as dpa_utils
from transformer_engine.pytorch.cpu_offload import start_offload, is_cpu_offload_enabled
......@@ -93,6 +93,11 @@ class MultiheadAttention(torch.nn.Module):
map to ``window_size = (-1, 0)`` and Transformer Engine distinguishes them based on
``attn_mask_type``. Similar to :attr:`attn_mask_type`, ``window_size`` can
be overridden by :attr:`window_size` in :meth:`forward` as well.
bottom_right_diagonal: Optional[bool], default = `None`
Align sliding window and ALiBi diagonal to the top left (`False`)
or bottom right (`True`) corner of the softmax matrix in the encoder.
If `None`, it will be set to `False` for `attn_mask_type` =
{`causal`, `padding_causal`} and `True` for other mask types.
num_gqa_groups : int, default = None
number of GQA groups in the transformer layer.
Grouped Query Attention is described in
......@@ -248,6 +253,7 @@ class MultiheadAttention(torch.nn.Module):
layer_number: Optional[int] = None,
attn_mask_type: str = "causal",
window_size: Optional[Tuple[int, int]] = None,
bottom_right_diagonal: Optional[bool] = None,
tp_group: Optional[dist_group_type] = None,
tp_size: int = 1,
num_gqa_groups: Optional[int] = None,
......@@ -286,6 +292,7 @@ class MultiheadAttention(torch.nn.Module):
self.qkv_format = qkv_format
self.attn_mask_type = attn_mask_type
self.window_size = window_size
self.bottom_right_diagonal = bottom_right_diagonal
self.layer_number = 1 if layer_number is None else layer_number
self.input_layernorm = input_layernorm
self.attention_type = attention_type
......@@ -335,6 +342,7 @@ class MultiheadAttention(torch.nn.Module):
self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups
self.name = name
TransformerEngineBaseModule._validate_name(self)
common_gemm_kwargs = {
"fuse_wgrad_accumulation": fuse_wgrad_accumulation,
......@@ -621,6 +629,7 @@ class MultiheadAttention(torch.nn.Module):
encoder_output: Optional[torch.Tensor] = None,
attn_mask_type: Optional[str] = None,
window_size: Optional[Tuple[int, int]] = None,
bottom_right_diagonal: Optional[bool] = None,
is_first_microbatch: Optional[bool] = None,
checkpoint_core_attention: bool = False,
inference_params: Optional[InferenceParams] = None,
......@@ -667,6 +676,11 @@ class MultiheadAttention(torch.nn.Module):
aligned to the bottom right corner.
window_size: Optional[Tuple[int, int]], default = None
sliding window size for local attention.
bottom_right_diagonal: Optional[bool], default = `None`
Align sliding window and ALiBi diagonal to the top left (`False`)
or bottom right (`True`) corner of the softmax matrix in the encoder.
If `None`, it will be set to `False` for `attn_mask_type` =
{`causal`, `padding_causal`} and `True` for other mask types.
encoder_output : Optional[torch.Tensor], default = None
Output of the encoder block to be fed into the decoder block if using
``layer_type="decoder"``.
......@@ -731,6 +745,17 @@ class MultiheadAttention(torch.nn.Module):
if window_size is None:
window_size = self.window_size
window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size)
if bottom_right_diagonal is None:
bottom_right_diagonal = self.bottom_right_diagonal
if attn_mask_type in {"causal", "padding_causal"}:
bottom_right_diagonal = False
if bottom_right_diagonal is None or attn_mask_type in {
"causal_bottom_right",
"padding_causal_bottom_right",
}:
bottom_right_diagonal = True
if "padding" in attn_mask_type and attention_mask is not None:
for mask in attention_mask:
assert mask.dtype == torch.bool, "Attention mask must be in boolean type!"
......@@ -739,9 +764,6 @@ class MultiheadAttention(torch.nn.Module):
core_attention_bias_type in AttnBiasTypes
), f"core_attention_bias_type {core_attention_bias_type} is not supported!"
if TEDebugState.debug_enabled:
TransformerEngineBaseModule._validate_name(self)
# =================================================
# Pre-allocate memory for key-value cache for inference
# =================================================
......@@ -1004,6 +1026,7 @@ class MultiheadAttention(torch.nn.Module):
attention_mask=attention_mask,
attn_mask_type=attn_mask_type,
window_size=window_size,
bottom_right_diagonal=bottom_right_diagonal,
checkpoint_core_attention=checkpoint_core_attention,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
......
......@@ -137,6 +137,7 @@ def fused_attn_fwd(
attn_mask_type: str = "padding",
softmax_type: str = "vanilla",
window_size: Tuple[int, int] = (-1, -1),
bottom_right_diagonal: bool = None,
rng_gen: torch.Generator = None,
softmax_offset: torch.Tensor = None,
return_max_logit: bool = False,
......@@ -212,6 +213,9 @@ def fused_attn_fwd(
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically.
bottom_right_diagonal: bool, default = None
whether to align sliding window and ALiBi diagonal to the top left (False) or
bottom right (True) corner of the softmax matrix.
rng_gen : torch.Generator, default = None
random number generator;
if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen
......@@ -255,6 +259,12 @@ def fused_attn_fwd(
max_logit : if return_max_logit = True, shape [h] and same data type as O; otherwise None
"""
if bottom_right_diagonal is None:
bottom_right_diagonal = attn_mask_type in {
"causal_bottom_right",
"padding_causal_bottom_right",
}
if attn_scale is None:
d = q.size(-1)
attn_scale = 1.0 / math.sqrt(d)
......@@ -306,6 +316,7 @@ def fused_attn_fwd(
AttnMaskType[attn_mask_type],
SoftmaxType[softmax_type],
window_size,
bottom_right_diagonal,
cu_seqlens_q,
cu_seqlens_kv,
q,
......@@ -370,6 +381,7 @@ def fused_attn_bwd(
attn_mask_type: str = "padding",
softmax_type: str = "vanilla",
window_size: Tuple[int, int] = (-1, -1),
bottom_right_diagonal: bool = None,
deterministic: bool = False,
cuda_graph: bool = False,
) -> Tuple[Union[torch.Tensor, None], ...]:
......@@ -442,6 +454,9 @@ def fused_attn_bwd(
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically.
bottom_right_diagonal: bool, default = None
whether to align sliding window and ALiBi diagonal to the top left (False) or
bottom right (True) corner of the softmax matrix.
deterministic : bool, default = False
whether to execute the backward pass with deterministic behaviours.
cuda_graph : bool, default = False
......@@ -462,6 +477,12 @@ def fused_attn_bwd(
gradient tensor of softmax offset of shape [1, h_q, 1, 1].
See softmax_type in DotProductAttention for details.
"""
if bottom_right_diagonal is None:
bottom_right_diagonal = attn_mask_type in {
"causal_bottom_right",
"padding_causal_bottom_right",
}
if attn_scale is None:
d = q.size(-1)
attn_scale = 1.0 / math.sqrt(d)
......@@ -500,6 +521,7 @@ def fused_attn_bwd(
AttnMaskType[attn_mask_type],
SoftmaxType[softmax_type],
window_size,
bottom_right_diagonal,
deterministic,
cu_seqlens_q,
cu_seqlens_kv,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment