Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
9df0c4a3
Commit
9df0c4a3
authored
Mar 03, 2026
by
wenjh
Browse files
Merge branch 'nv_main'
parents
0d874a4e
f122b07d
Changes
221
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
692 additions
and
167 deletions
+692
-167
transformer_engine/jax/cpp_extensions/activation.py
transformer_engine/jax/cpp_extensions/activation.py
+1
-0
transformer_engine/jax/cpp_extensions/attention.py
transformer_engine/jax/cpp_extensions/attention.py
+35
-8
transformer_engine/jax/csrc/extensions.h
transformer_engine/jax/csrc/extensions.h
+6
-3
transformer_engine/jax/csrc/extensions/activation.cpp
transformer_engine/jax/csrc/extensions/activation.cpp
+6
-0
transformer_engine/jax/csrc/extensions/amax.cpp
transformer_engine/jax/csrc/extensions/amax.cpp
+0
-2
transformer_engine/jax/csrc/extensions/attention.cpp
transformer_engine/jax/csrc/extensions/attention.cpp
+37
-33
transformer_engine/jax/csrc/extensions/inspect.cpp
transformer_engine/jax/csrc/extensions/inspect.cpp
+99
-0
transformer_engine/jax/csrc/extensions/pybind.cpp
transformer_engine/jax/csrc/extensions/pybind.cpp
+4
-0
transformer_engine/jax/debug/__init__.py
transformer_engine/jax/debug/__init__.py
+11
-0
transformer_engine/jax/debug/experimental/__init__.py
transformer_engine/jax/debug/experimental/__init__.py
+14
-0
transformer_engine/jax/debug/experimental/inspect.py
transformer_engine/jax/debug/experimental/inspect.py
+174
-0
transformer_engine/jax/permutation.py
transformer_engine/jax/permutation.py
+37
-23
transformer_engine/jax/triton_extensions/permutation.py
transformer_engine/jax/triton_extensions/permutation.py
+18
-11
transformer_engine/jax/triton_extensions/utils.py
transformer_engine/jax/triton_extensions/utils.py
+6
-3
transformer_engine/pytorch/attention/dot_product_attention/backends.py
...ngine/pytorch/attention/dot_product_attention/backends.py
+66
-1
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py
...torch/attention/dot_product_attention/context_parallel.py
+10
-8
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py
.../attention/dot_product_attention/dot_product_attention.py
+51
-17
transformer_engine/pytorch/attention/dot_product_attention/utils.py
...r_engine/pytorch/attention/dot_product_attention/utils.py
+68
-54
transformer_engine/pytorch/attention/multi_head_attention.py
transformer_engine/pytorch/attention/multi_head_attention.py
+27
-4
transformer_engine/pytorch/cpp_extensions/fused_attn.py
transformer_engine/pytorch/cpp_extensions/fused_attn.py
+22
-0
No files found.
transformer_engine/jax/cpp_extensions/activation.py
View file @
9df0c4a3
...
...
@@ -44,6 +44,7 @@ __all__ = ["act_lu", "dact_lu", "quantize_dact_dbias"]
ActivationEnum
=
{
(
"gelu"
,):
NVTE_Activation_Type
.
GELU
,
(
"gelu"
,
"linear"
):
NVTE_Activation_Type
.
GEGLU
,
(
"sigmoid"
,
"linear"
):
NVTE_Activation_Type
.
GLU
,
(
"silu"
,):
NVTE_Activation_Type
.
SILU
,
(
"silu"
,
"linear"
):
NVTE_Activation_Type
.
SWIGLU
,
(
"relu"
,):
NVTE_Activation_Type
.
RELU
,
...
...
transformer_engine/jax/cpp_extensions/attention.py
View file @
9df0c4a3
...
...
@@ -70,6 +70,7 @@ __all__ = [
"is_training"
,
"max_segments_per_seq"
,
"window_size"
,
"bottom_right_diagonal"
,
"context_parallel_load_balanced"
,
"cp_axis"
,
"cp_striped_window_size"
,
...
...
@@ -91,6 +92,7 @@ class _FusedAttnConfig:
is_training
:
bool
max_segments_per_seq
:
int
window_size
:
Tuple
[
int
,
int
]
bottom_right_diagonal
:
bool
context_parallel_load_balanced
:
bool
cp_axis
:
str
cp_striped_window_size
:
Tuple
[
int
,
int
]
# Only for CP + Ring P2P + THD + SWA
...
...
@@ -144,6 +146,7 @@ class FusedAttnHelper:
self
.
head_dim_v
,
self
.
window_size
[
0
],
self
.
window_size
[
1
],
not
self
.
is_non_deterministic_allowed
(),
)
@
staticmethod
...
...
@@ -370,6 +373,11 @@ class FusedAttnFwdPrimitive(BasePrimitive):
*
bias_batch_shape
,
bias_heads
,
_
,
_
=
bias_aval
.
shape
bias_batch
=
reduce
(
operator
.
mul
,
bias_batch_shape
)
bottom_right_diagonal
=
config
.
attn_mask_type
in
[
AttnMaskType
.
CAUSAL_BOTTOM_RIGHT_MASK
,
AttnMaskType
.
PADDING_CAUSAL_BOTTOM_RIGHT_MASK
,
]
# do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to
# prepare for the active fused-attn backend
input_batch
=
reduce
(
operator
.
mul
,
batch_shape
)
...
...
@@ -394,6 +402,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
config
.
max_segments_per_seq
,
config
.
window_size
[
0
],
config
.
window_size
[
1
],
bottom_right_diagonal
,
)
wkspace_aval
=
q_aval
.
update
(
shape
=
wkspace_info
[
0
],
dtype
=
te_dtype_to_jax_dtype
(
wkspace_info
[
1
])
...
...
@@ -502,6 +511,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
deterministic
=
not
FusedAttnHelper
.
is_non_deterministic_allowed
(),
window_size_left
=
window_size_left
,
window_size_right
=
window_size_right
,
bottom_right_diagonal
=
config
.
bottom_right_diagonal
,
softmax_type
=
int
(
config
.
softmax_type
.
value
),
)
...
...
@@ -812,6 +822,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
config
.
max_segments_per_seq
,
config
.
window_size
[
0
],
config
.
window_size
[
1
],
config
.
bottom_right_diagonal
,
)
dq_aval
=
q_aval
.
update
(
shape
=
q_aval
.
shape
,
dtype
=
q_dtype
)
...
...
@@ -947,6 +958,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
deterministic
=
not
FusedAttnHelper
.
is_non_deterministic_allowed
(),
window_size_left
=
window_size_left
,
window_size_right
=
window_size_right
,
bottom_right_diagonal
=
config
.
bottom_right_diagonal
,
softmax_type
=
int
(
config
.
softmax_type
.
value
),
)
...
...
@@ -1356,9 +1368,10 @@ class _FusedAttnCPWithAllGatherHelper:
def
get_step_config
(
self
)
->
_FusedAttnConfig
:
"""Returns a _FusedAttnConfig for single CP step call to fused attention."""
adjusted_mask
=
self
.
get_adjusted_mask
()
return
_FusedAttnConfig
(
attn_bias_type
=
self
.
config
.
attn_bias_type
,
attn_mask_type
=
self
.
get_
adjusted_mask
()
,
attn_mask_type
=
adjusted_mask
,
softmax_type
=
self
.
config
.
softmax_type
,
qkv_layout
=
self
.
config
.
qkv_layout
,
scaling_factor
=
self
.
config
.
scaling_factor
,
...
...
@@ -1366,6 +1379,7 @@ class _FusedAttnCPWithAllGatherHelper:
is_training
=
self
.
config
.
is_training
,
max_segments_per_seq
=
self
.
config
.
max_segments_per_seq
,
window_size
=
self
.
config
.
window_size
,
bottom_right_diagonal
=
adjusted_mask
.
is_bottom_right
(),
context_parallel_load_balanced
=
self
.
config
.
context_parallel_load_balanced
,
cp_axis
=
self
.
config
.
cp_axis
,
cp_striped_window_size
=
None
,
...
...
@@ -1374,9 +1388,10 @@ class _FusedAttnCPWithAllGatherHelper:
def
get_step_config_for_striped
(
self
,
max_seqlen
,
cp_size
)
->
_FusedAttnConfig
:
"""Returns a _FusedAttnConfig for single CP step call (made via a striped AG primitive) to fused attention."""
adjusted_mask
=
self
.
get_adjusted_mask
()
return
_FusedAttnConfig
(
attn_bias_type
=
self
.
config
.
attn_bias_type
,
attn_mask_type
=
self
.
get_
adjusted_mask
()
,
attn_mask_type
=
adjusted_mask
,
softmax_type
=
self
.
config
.
softmax_type
,
qkv_layout
=
self
.
config
.
qkv_layout
,
scaling_factor
=
self
.
config
.
scaling_factor
,
...
...
@@ -1384,6 +1399,7 @@ class _FusedAttnCPWithAllGatherHelper:
is_training
=
self
.
config
.
is_training
,
max_segments_per_seq
=
self
.
get_adjusted_max_segments_per_seq
(
max_seqlen
,
cp_size
),
window_size
=
self
.
config
.
window_size
,
bottom_right_diagonal
=
adjusted_mask
.
is_bottom_right
(),
context_parallel_load_balanced
=
self
.
config
.
context_parallel_load_balanced
,
cp_axis
=
self
.
config
.
cp_axis
,
cp_striped_window_size
=
None
,
...
...
@@ -2429,6 +2445,7 @@ class _FusedAttnCPWithP2PHelper:
is_training
=
self
.
config
.
is_training
,
max_segments_per_seq
=
self
.
config
.
max_segments_per_seq
,
window_size
=
self
.
config
.
window_size
,
bottom_right_diagonal
=
attn_mask_type
.
is_bottom_right
(),
context_parallel_load_balanced
=
self
.
config
.
context_parallel_load_balanced
,
cp_axis
=
self
.
config
.
cp_axis
,
cp_striped_window_size
=
None
,
...
...
@@ -3417,6 +3434,7 @@ def fused_attn_fwd(
is_training
=
is_training
,
max_segments_per_seq
=
max_segments_per_seq
,
window_size
=
(
-
1
,
-
1
)
if
window_size
is
None
else
window_size
,
bottom_right_diagonal
=
attn_mask_type
.
is_bottom_right
(),
context_parallel_load_balanced
=
context_parallel_causal_load_balanced
,
cp_axis
=
_maybe_context_parallel_axis
(
context_parallel_axis
),
cp_striped_window_size
=
None
,
...
...
@@ -3563,13 +3581,21 @@ def fused_attn_bwd(
softmax_offset
,
(
None
,
HEAD_AXES
,
None
,
None
)
)
# TODO(KshitijLakhani): Add a check for cuDNN version when determinism does get supported on
# sm100+
compute_capabilities
=
get_all_device_compute_capability
()
if
any
(
x
>=
100
for
x
in
compute_capabilities
):
assert
not
(
attn_bias_type
!=
AttnBiasType
.
NO_BIAS
and
dropout_probability
!=
0
),
"For sm100+, bprop kernel support for dropout + determinism (bias) is not supported"
if
any
(
x
>=
100
for
x
in
compute_capabilities
)
and
is_training
:
assert
(
FusedAttnHelper
.
is_non_deterministic_allowed
()
and
get_cudnn_version
()
>=
(
9
,
7
,
0
)
and
(
attn_bias_type
==
AttnBiasType
.
NO_BIAS
or
dropout_probability
==
0.0
)
)
or
(
not
FusedAttnHelper
.
is_non_deterministic_allowed
()
and
get_cudnn_version
()
>=
(
9
,
18
,
1
)
and
attn_bias_type
==
AttnBiasType
.
NO_BIAS
and
dropout_probability
==
0.0
),
(
"For sm100+, non-deterministic bprop (cuDNN 9.7+) does not support bias with dropout,"
" and deterministic bprop (cuDNN 9.18.1+) does not support bias or dropout"
)
fused_config
=
_FusedAttnConfig
(
attn_bias_type
=
attn_bias_type
,
...
...
@@ -3581,6 +3607,7 @@ def fused_attn_bwd(
is_training
=
is_training
,
max_segments_per_seq
=
max_segments_per_seq
,
window_size
=
(
-
1
,
-
1
)
if
window_size
is
None
else
window_size
,
bottom_right_diagonal
=
attn_mask_type
.
is_bottom_right
(),
context_parallel_load_balanced
=
context_parallel_causal_load_balanced
,
cp_axis
=
_maybe_context_parallel_axis
(
context_parallel_axis
),
cp_striped_window_size
=
None
,
...
...
transformer_engine/jax/csrc/extensions.h
View file @
9df0c4a3
...
...
@@ -113,7 +113,7 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
float
dropout_probability
,
size_t
q_attn_heads
,
size_t
kv_attn_heads
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
qk_head_dim
,
size_t
v_head_dim
,
int64_t
window_size_left
,
int64_t
window_size_right
);
int64_t
window_size_right
,
bool
deterministic
);
pybind11
::
tuple
GetFusedAttnForwardWorkspaceSizes
(
size_t
input_batch
,
size_t
bias_batch
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
...
...
@@ -121,7 +121,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t
v_head_dim
,
float
scaling_factor
,
float
dropout_probability
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
NVTE_QKV_Layout
qkv_layout
,
DType
dtype
,
bool
is_training
,
size_t
max_segments_per_seq
,
int64_t
window_size_left
,
int64_t
window_size_right
);
int64_t
window_size_right
,
bool
bottom_right_diagonal
);
pybind11
::
tuple
GetFusedAttnBackwardWorkspaceSizes
(
size_t
input_batch
,
size_t
bias_batch
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
...
...
@@ -129,7 +129,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t
v_head_dim
,
float
scaling_factor
,
float
dropout_probability
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
NVTE_QKV_Layout
qkv_layout
,
DType
dtype
,
bool
is_training
,
bool
deterministic
,
size_t
max_segments_per_seq
,
int64_t
window_size_left
,
int64_t
window_size_right
);
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
bottom_right_diagonal
);
// GEMM
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
GemmHandler
);
...
...
@@ -143,6 +143,9 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
RHTAmaxCalculationInitializeHandler
);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
RHTAmaxCalculationHandler
);
// Inspect
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
InspectHandler
);
// Cudnn helpers
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
CudnnHandleInitHandler
);
...
...
transformer_engine/jax/csrc/extensions/activation.cpp
View file @
9df0c4a3
...
...
@@ -109,6 +109,9 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
case
NVTE_Activation_Type
::
GEGLU
:
nvte_geglu
(
input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
GLU
:
nvte_glu
(
input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
SILU
:
nvte_silu
(
input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
...
...
@@ -427,6 +430,9 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
case
NVTE_Activation_Type
::
GEGLU
:
nvte_dgeglu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
GLU
:
nvte_dglu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
SWIGLU
:
nvte_dswiglu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
...
...
transformer_engine/jax/csrc/extensions/amax.cpp
View file @
9df0c4a3
...
...
@@ -5,8 +5,6 @@
************************************************************************/
#include <cuda_runtime.h>
#include <iostream>
#include "../extensions.h"
#include "transformer_engine/cast.h"
#include "transformer_engine/hadamard_transform.h"
...
...
transformer_engine/jax/csrc/extensions/attention.cpp
View file @
9df0c4a3
...
...
@@ -16,12 +16,12 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
float
dropout_probability
,
size_t
q_attn_heads
,
size_t
kv_attn_heads
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
qk_head_dim
,
size_t
v_head_dim
,
int64_t
window_size_left
,
int64_t
window_size_right
)
{
int64_t
window_size_right
,
bool
deterministic
)
{
auto
backend
=
nvte_get_fused_attn_backend
(
is_training
,
static_cast
<
NVTEDType
>
(
q_dtype
),
static_cast
<
NVTEDType
>
(
kv_dtype
),
qkv_layout
,
bias_type
,
mask_type
,
softmax_type
,
dropout_probability
,
q_attn_heads
,
kv_attn_heads
,
q_max_seqlen
,
kv_max_seqlen
,
qk_head_dim
,
v_head_dim
,
window_size_left
,
window_size_right
,
false
,
false
);
false
,
false
,
deterministic
);
return
backend
;
}
...
...
@@ -144,7 +144,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t
v_head_dim
,
float
scaling_factor
,
float
dropout_probability
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
NVTE_QKV_Layout
qkv_layout
,
DType
dtype
,
bool
is_training
,
size_t
max_segments_per_seq
,
int64_t
window_size_left
,
int64_t
window_size_right
)
{
int64_t
window_size_right
,
bool
bottom_right_diagonal
)
{
auto
q_shape
=
std
::
vector
<
size_t
>
{
input_batch
*
q_max_seqlen
,
attn_heads
,
qk_head_dim
};
auto
q_tensor
=
TensorWrapper
(
nullptr
,
q_shape
,
dtype
);
auto
k_shape
=
std
::
vector
<
size_t
>
{
input_batch
*
kv_max_seqlen
,
num_gqa_groups
,
qk_head_dim
};
...
...
@@ -192,7 +192,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
ragged_offset_tensor
.
data
(),
dummy_page_table_tensor
.
data
(),
dummy_page_table_tensor
.
data
(),
dummy_rng_state_tensor
.
data
(),
q_max_seqlen
,
kv_max_seqlen
,
is_training
,
false
,
false
,
scaling_factor
,
dropout_probability
,
qkv_layout
,
bias_type
,
mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
query_workspace_tensor
.
data
(),
nullptr
);
window_size_left
,
window_size_right
,
bottom_right_diagonal
,
query_workspace_tensor
.
data
(),
nullptr
);
}
nvte_tensor_pack_destroy
(
&
aux_output_tensors
);
...
...
@@ -237,7 +238,7 @@ static void FusedAttnForwardImpl(
size_t
wkspace_size
,
float
scaling_factor
,
float
dropout_probability
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
NVTE_QKV_Layout
qkv_layout
,
DType
dtype
,
DType
wkspace_dtype
,
bool
is_training
,
bool
deterministic
,
int64_t
window_size_left
,
int64_t
window_size_right
)
{
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
bottom_right_diagonal
)
{
FUSED_ATTN_IMPL_COMMON_BLOCK
;
/* Input tensors */
...
...
@@ -266,7 +267,7 @@ static void FusedAttnForwardImpl(
is_training
,
static_cast
<
NVTEDType
>
(
dtype
),
static_cast
<
NVTEDType
>
(
dtype
),
qkv_layout
,
bias_type
,
mask_type
,
softmax_type
,
dropout_probability
,
attn_heads
,
num_gqa_groups
,
q_max_seqlen
,
kv_max_seqlen
,
qk_head_dim
,
v_head_dim
,
window_size_left
,
window_size_right
,
false
,
false
);
false
,
false
,
deterministic
);
nvte_populate_rng_state_async
(
rng_state
,
seed
,
q_max_seqlen
,
kv_max_seqlen
,
backend
,
stream
);
/* Auxiliary tensors (to be propagated to the backward pass later) */
...
...
@@ -328,7 +329,7 @@ static void FusedAttnForwardImpl(
k_seq_offsets_tensor
.
data
(),
dummy_page_table_tensor
.
data
(),
dummy_page_table_tensor
.
data
(),
rng_state_tensor
.
data
(),
q_max_seqlen
,
kv_max_seqlen
,
is_training
,
false
,
false
,
scaling_factor
,
dropout_probability
,
qkv_layout
,
bias_type
,
mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
workspace_tensor
.
data
(),
stream
);
window_size_left
,
window_size_right
,
bottom_right_diagonal
,
workspace_tensor
.
data
(),
stream
);
nvte_tensor_pack_destroy
(
&
aux_output_tensors
);
}
...
...
@@ -346,6 +347,7 @@ static void FusedAttnForwardImpl(
size_t max_segments_per_seq = get_attr_value<int64_t>(attrs, "max_segments_per_seq"); \
auto window_size_left = get_attr_value<int64_t>(attrs, "window_size_left"); \
auto window_size_right = get_attr_value<int64_t>(attrs, "window_size_right"); \
bool bottom_right_diagonal = get_attr_value<bool>(attrs, "bottom_right_diagonal"); \
float scaling_factor = get_attr_value<double>(attrs, "scaling_factor"); \
float dropout_probability = get_attr_value<double>(attrs, "dropout_probability"); \
NVTE_Bias_Type bias_type = \
...
...
@@ -384,7 +386,7 @@ Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Ty
input_batch
,
bias_batch
,
q_max_seqlen
,
kv_max_seqlen
,
attn_heads
,
num_gqa_groups
,
bias_heads
,
qk_head_dim
,
v_head_dim
,
max_segments_per_seq
,
wkspace_size
,
scaling_factor
,
dropout_probability
,
bias_type
,
mask_type
,
softmax_type
,
qkv_layout
,
dtype
,
wkspace_dtype
,
is_training
,
deterministic
,
window_size_left
,
window_size_right
);
is_training
,
deterministic
,
window_size_left
,
window_size_right
,
bottom_right_diagonal
);
return
ffi_with_cuda_error_check
();
}
...
...
@@ -415,7 +417,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t
v_head_dim
,
float
scaling_factor
,
float
dropout_probability
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
NVTE_QKV_Layout
qkv_layout
,
DType
dtype
,
bool
is_training
,
bool
deterministic
,
size_t
max_segments_per_seq
,
int64_t
window_size_left
,
int64_t
window_size_right
)
{
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
bottom_right_diagonal
)
{
auto
q_shape
=
std
::
vector
<
size_t
>
{
input_batch
*
q_max_seqlen
,
attn_heads
,
qk_head_dim
};
auto
q_tensor
=
TensorWrapper
(
nullptr
,
q_shape
,
dtype
);
auto
dq_tensor
=
TensorWrapper
(
nullptr
,
q_shape
,
dtype
);
...
...
@@ -467,17 +469,18 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
auto
dummy_ragged_offset_tensor
=
TensorWrapper
(
nullptr
,
std
::
vector
<
size_t
>
{
num_segments
+
1
},
DType
::
kInt32
);
nvte_fused_attn_bwd
(
q_tensor
.
data
(),
k_tensor
.
data
(),
v_tensor
.
data
(),
output_tensor
.
data
(),
nvte_fused_attn_bwd
(
q_tensor
.
data
(),
k_tensor
.
data
(),
v_tensor
.
data
(),
output_tensor
.
data
(),
doutput_tensor
.
data
(),
s_tensor
.
data
(),
// not used for F16
s_tensor
.
data
(),
// not used for F16
&
aux_input_tensors
,
dq_tensor
.
data
(),
dk_tensor
.
data
(),
dv_tensor
.
data
(),
dbias_tensor
.
data
(),
dummy_d_softmax_offset_tensor
.
data
(),
q_cu_seqlens_tensor
.
data
(),
kv_cu_seqlens_tensor
.
data
(),
dummy_ragged_offset_tensor
.
data
(),
dummy_ragged_offset_tensor
.
data
(),
q_max_seqlen
,
kv_max_seqlen
,
scaling_factor
,
dropout_probability
,
qkv_layout
,
bias_type
,
mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
deterministic
,
false
,
query_workspace_tensor
.
data
(),
nullptr
);
dbias_tensor
.
data
(),
dummy_d_softmax_offset_tensor
.
data
(),
q_cu_seqlens_tensor
.
data
(),
kv_cu_seqlens_tensor
.
data
(),
dummy_ragged_offset_tensor
.
data
(),
dummy_ragged_offset_tensor
.
data
(),
q_max_seqlen
,
kv_max_seqlen
,
scaling_factor
,
dropout_probability
,
qkv_layout
,
bias_type
,
mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
bottom_right_diagonal
,
deterministic
,
false
,
query_workspace_tensor
.
data
(),
nullptr
);
}
nvte_tensor_pack_destroy
(
&
aux_input_tensors
);
...
...
@@ -496,7 +499,7 @@ static void FusedAttnBackwardImpl(
size_t
wkspace_size
,
float
scaling_factor
,
float
dropout_probability
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
NVTE_QKV_Layout
qkv_layout
,
DType
dtype
,
DType
wkspace_dtype
,
bool
is_training
,
bool
deterministic
,
int64_t
window_size_left
,
int64_t
window_size_right
)
{
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
bottom_right_diagonal
)
{
FUSED_ATTN_IMPL_COMMON_BLOCK
;
/* Input tensors */
...
...
@@ -522,7 +525,7 @@ static void FusedAttnBackwardImpl(
is_training
,
static_cast
<
NVTEDType
>
(
dtype
),
static_cast
<
NVTEDType
>
(
dtype
),
qkv_layout
,
bias_type
,
mask_type
,
softmax_type
,
dropout_probability
,
attn_heads
,
num_gqa_groups
,
q_max_seqlen
,
kv_max_seqlen
,
qk_head_dim
,
v_head_dim
,
window_size_left
,
window_size_right
,
false
,
false
);
false
,
false
,
deterministic
);
PrepareFusedAttnBackwardAuxTensors
(
&
aux_input_tensors
,
input_batch
,
bias_batch
,
attn_heads
,
bias_heads
,
q_max_seqlen
,
kv_max_seqlen
,
dtype
,
backend
,
softmax_aux
,
rng_state
,
bias
,
softmax_offset
);
...
...
@@ -593,16 +596,17 @@ static void FusedAttnBackwardImpl(
}
}
nvte_fused_attn_bwd
(
q_tensor
.
data
(),
k_tensor
.
data
(),
v_tensor
.
data
(),
output_tensor
.
data
(),
nvte_fused_attn_bwd
(
q_tensor
.
data
(),
k_tensor
.
data
(),
v_tensor
.
data
(),
output_tensor
.
data
(),
doutput_tensor
.
data
(),
s_tensor
.
data
(),
// not used for F16
s_tensor
.
data
(),
// not used for F16
&
aux_input_tensors
,
dq_tensor
.
data
(),
dk_tensor
.
data
(),
dv_tensor
.
data
(),
dbias_tensor
.
data
(),
dsoftmax_offset_tensor
.
data
(),
q_cu_seqlens_tensor
.
data
(),
kv_cu_seqlens_tensor
.
data
(),
q_seq_offsets_tensor
.
data
(),
k_seq_offsets_tensor
.
data
(),
q_max_seqlen
,
kv_max_seqlen
,
scaling_factor
,
dropout_probability
,
qkv_layout
,
bias_type
,
mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
deterministic
,
false
,
workspace_tensor
.
data
(),
stream
);
&
aux_input_tensors
,
dq_tensor
.
data
(),
dk_tensor
.
data
(),
dv_tensor
.
data
(),
dbias_tensor
.
data
(),
dsoftmax_offset_tensor
.
data
(),
q_cu_seqlens_tensor
.
data
(),
kv_cu_seqlens_tensor
.
data
(),
q_seq_offsets_tensor
.
data
(),
k_seq_offsets_tensor
.
data
(),
q_max_seqlen
,
kv_max_seqlen
,
scaling_factor
,
dropout_probability
,
qkv_layout
,
bias_type
,
mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
bottom_right_diagonal
,
deterministic
,
false
,
workspace_tensor
.
data
(),
stream
);
nvte_tensor_pack_destroy
(
&
aux_input_tensors
);
}
...
...
@@ -631,7 +635,7 @@ Error_Type FusedAttnBackwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_T
q_max_seqlen
,
kv_max_seqlen
,
attn_heads
,
num_gqa_groups
,
bias_heads
,
qk_head_dim
,
v_head_dim
,
max_segments_per_seq
,
wkspace_size
,
scaling_factor
,
dropout_probability
,
bias_type
,
mask_type
,
softmax_type
,
qkv_layout
,
dtype
,
wkspace_dtype
,
is_training
,
deterministic
,
window_size_left
,
window_size_right
);
window_size_right
,
bottom_right_diagonal
);
return
ffi_with_cuda_error_check
();
}
...
...
transformer_engine/jax/csrc/extensions/inspect.cpp
0 → 100644
View file @
9df0c4a3
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_runtime.h>
#include <fstream>
#include <iostream>
#include "../extensions.h"
#include "xla/ffi/api/c_api.h"
namespace
transformer_engine
{
namespace
jax
{
Error_Type
InspectFFI
(
cudaStream_t
stream
,
Buffer_Type
input_buf
,
Buffer_Type
min_buf
,
Buffer_Type
max_buf
,
Buffer_Type
mean_buf
,
Buffer_Type
std_buf
,
Result_Type
output_buf
)
{
NVTE_CHECK
(
input_buf
.
untyped_data
()
!=
nullptr
,
"Input must be provided for inspect operation"
);
NVTE_CHECK
(
output_buf
->
untyped_data
()
!=
nullptr
,
"Output must be provided for inspect operation"
);
NVTE_CHECK
(
input_buf
.
untyped_data
()
==
output_buf
->
untyped_data
(),
"Input and output must point to the same buffer for inspect operation"
);
std
::
vector
<
uint8_t
>
input_data
(
input_buf
.
size_bytes
());
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
input_data
.
data
(),
input_buf
.
untyped_data
(),
input_buf
.
size_bytes
(),
cudaMemcpyDeviceToHost
,
stream
));
float
min_val
{},
max_val
{},
mean_val
{},
std_val
{};
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
&
min_val
,
min_buf
.
untyped_data
(),
sizeof
(
float
),
cudaMemcpyDeviceToHost
,
stream
));
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
&
max_val
,
max_buf
.
untyped_data
(),
sizeof
(
float
),
cudaMemcpyDeviceToHost
,
stream
));
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
&
mean_val
,
mean_buf
.
untyped_data
(),
sizeof
(
float
),
cudaMemcpyDeviceToHost
,
stream
));
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
&
std_val
,
std_buf
.
untyped_data
(),
sizeof
(
float
),
cudaMemcpyDeviceToHost
,
stream
));
NVTE_CHECK_CUDA
(
cudaStreamSynchronize
(
stream
));
int
device
;
NVTE_CHECK_CUDA
(
cudaGetDevice
(
&
device
));
// Write the tensor data to a file as a binary blob
std
::
string
filename
=
"my_tensor_gpu"
+
std
::
to_string
(
device
)
+
".bin"
;
std
::
ofstream
file
(
filename
,
std
::
ios
::
binary
);
NVTE_CHECK
(
file
.
is_open
(),
"Failed to create file: "
,
filename
);
file
.
write
(
reinterpret_cast
<
const
char
*>
(
input_data
.
data
()),
input_data
.
size
());
file
.
close
();
// Write out a metadata file
std
::
string
meta_filename
=
"my_tensor_gpu"
+
std
::
to_string
(
device
)
+
"_meta.json"
;
std
::
ofstream
meta_file
(
meta_filename
);
NVTE_CHECK
(
meta_file
.
is_open
(),
"Failed to create file: "
,
meta_filename
);
meta_file
<<
"{"
;
meta_file
<<
"
\"
shape
\"
: ["
;
for
(
size_t
i
=
0
;
i
<
input_buf
.
dimensions
().
size
();
++
i
)
{
meta_file
<<
input_buf
.
dimensions
()[
i
];
if
(
i
<
input_buf
.
dimensions
().
size
()
-
1
)
{
meta_file
<<
", "
;
}
}
meta_file
<<
"], "
;
meta_file
<<
"
\"
dtype
\"
: "
<<
static_cast
<
int
>
(
input_buf
.
element_type
());
meta_file
<<
",
\"
min
\"
: "
<<
min_val
;
meta_file
<<
",
\"
max
\"
: "
<<
max_val
;
meta_file
<<
",
\"
mean
\"
: "
<<
mean_val
;
meta_file
<<
",
\"
std
\"
: "
<<
std_val
;
meta_file
<<
"}"
;
meta_file
.
close
();
// Log the tensor metadata to the console
printf
(
"[gpu%d]: Tensor data written to %s (shape: ["
,
device
,
filename
.
c_str
());
for
(
size_t
i
=
0
;
i
<
input_buf
.
dimensions
().
size
();
++
i
)
{
printf
(
"%zu"
,
static_cast
<
size_t
>
(
input_buf
.
dimensions
()[
i
]));
if
(
i
<
input_buf
.
dimensions
().
size
()
-
1
)
{
printf
(
", "
);
}
}
printf
(
"], dtype: %d"
,
static_cast
<
int
>
(
input_buf
.
element_type
()));
printf
(
", min: %f, max: %f, mean: %f, std: %f)
\n
"
,
min_val
,
max_val
,
mean_val
,
std_val
);
return
ffi_with_cuda_error_check
();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL
(
InspectHandler
,
InspectFFI
,
FFI
::
Bind
()
.
Ctx
<
FFI_Stream_Type
>
()
// stream
.
Arg
<
Buffer_Type
>
()
// input
.
Arg
<
Buffer_Type
>
()
// min
.
Arg
<
Buffer_Type
>
()
// max
.
Arg
<
Buffer_Type
>
()
// mean
.
Arg
<
Buffer_Type
>
()
// std
.
Ret
<
Buffer_Type
>
()
// output
);
}
// namespace jax
}
// namespace transformer_engine
transformer_engine/jax/csrc/extensions/pybind.cpp
View file @
9df0c4a3
...
...
@@ -81,6 +81,9 @@ pybind11::dict Registrations() {
pybind11
::
arg
(
"initialize"
)
=
EncapsulateFFI
(
RHTAmaxCalculationInitializeHandler
),
pybind11
::
arg
(
"execute"
)
=
EncapsulateFFI
(
RHTAmaxCalculationHandler
));
dict
[
"te_inspect_ffi"
]
=
pybind11
::
dict
(
pybind11
::
arg
(
"execute"
)
=
EncapsulateFFI
(
InspectHandler
));
return
dict
;
}
...
...
@@ -150,6 +153,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
pybind11
::
enum_
<
NVTE_Activation_Type
>
(
m
,
"NVTE_Activation_Type"
,
pybind11
::
module_local
())
.
value
(
"GELU"
,
NVTE_Activation_Type
::
GELU
)
.
value
(
"GEGLU"
,
NVTE_Activation_Type
::
GEGLU
)
.
value
(
"GLU"
,
NVTE_Activation_Type
::
GLU
)
.
value
(
"SILU"
,
NVTE_Activation_Type
::
SILU
)
.
value
(
"SWIGLU"
,
NVTE_Activation_Type
::
SWIGLU
)
.
value
(
"RELU"
,
NVTE_Activation_Type
::
RELU
)
...
...
transformer_engine/jax/debug/__init__.py
0 → 100644
View file @
9df0c4a3
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""EXPERIMENTAL debugging utilities for Transformer Engine JAX.
This API is experimental and may change or be removed without deprecation in future releases.
"""
__all__
=
[
"experimental"
,
]
transformer_engine/jax/debug/experimental/__init__.py
0 → 100644
View file @
9df0c4a3
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""EXPERIMENTAL debugging utilities for Transformer Engine JAX.
This API is experimental and may change or be removed without deprecation in future releases.
"""
from
.inspect
import
inspect_array
,
load_array_dump
__all__
=
[
"inspect_array"
,
"load_array_dump"
,
]
transformer_engine/jax/debug/experimental/inspect.py
0 → 100644
View file @
9df0c4a3
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Experimental JAX array inspection utilities."""
from
functools
import
partial
import
jax
import
jax.numpy
as
jnp
from
jax
import
ffi
from
transformer_engine.jax.cpp_extensions.base
import
BasePrimitive
,
register_primitive
__all__
=
[
"inspect_array"
,
"load_array_dump"
]
class
InspectPrimitive
(
BasePrimitive
):
"""
No-op used for inspect array values.
"""
name
=
"te_inspect_ffi"
multiple_results
=
False
impl_static_args
=
()
inner_primitive
=
None
outer_primitive
=
None
@
staticmethod
def
abstract
(
x_aval
,
x_min_aval
,
x_max_aval
,
x_mean_aval
,
x_std_aval
,
):
"""
inspect abstract
"""
assert
(
x_min_aval
.
shape
==
()
and
x_min_aval
.
dtype
==
jnp
.
float32
),
"x_min must be a scalar with dtype float32"
assert
(
x_max_aval
.
shape
==
()
and
x_max_aval
.
dtype
==
jnp
.
float32
),
"x_max must be a scalar with dtype float32"
assert
(
x_mean_aval
.
shape
==
()
and
x_mean_aval
.
dtype
==
jnp
.
float32
),
"x_mean must be a scalar with dtype float32"
assert
(
x_std_aval
.
shape
==
()
and
x_std_aval
.
dtype
==
jnp
.
float32
),
"x_std must be a scalar with dtype float32"
return
x_aval
@
staticmethod
def
lowering
(
ctx
,
x
,
x_min
,
x_max
,
x_mean
,
x_std
,
):
"""
inspect lowering rules
"""
return
ffi
.
ffi_lowering
(
InspectPrimitive
.
name
,
operand_output_aliases
=
{
0
:
0
},
# donate input buffer to output buffer
)(
ctx
,
x
,
x_min
,
x_max
,
x_mean
,
x_std
,
)
@
staticmethod
def
impl
(
x
,
x_min
,
x_max
,
x_mean
,
x_std
,
):
"""
inspect implementation
"""
assert
InspectPrimitive
.
inner_primitive
is
not
None
(
x
)
=
InspectPrimitive
.
inner_primitive
.
bind
(
x
,
x_min
,
x_max
,
x_mean
,
x_std
,
)
return
x
register_primitive
(
InspectPrimitive
)
def
_inspect_array_inner
(
x
:
jnp
.
ndarray
)
->
jnp
.
ndarray
:
assert
InspectPrimitive
.
outer_primitive
is
not
None
,
(
"InspectPrimitive FFI is not registered. Please ensure the C++ extension is properly built"
" and registered."
)
return
InspectPrimitive
.
outer_primitive
.
bind
(
x
,
jnp
.
min
(
x
).
astype
(
jnp
.
float32
),
jnp
.
max
(
x
).
astype
(
jnp
.
float32
),
jnp
.
mean
(
x
.
astype
(
jnp
.
float32
)),
jnp
.
std
(
x
.
astype
(
jnp
.
float32
)),
)
@
partial
(
jax
.
custom_vjp
,
nondiff_argnums
=
())
def
_inspect
(
x
,
):
""" """
output
,
_
=
_inspect_fwd_rule
(
x
,
)
return
output
def
_inspect_fwd_rule
(
x
,
):
""""""
ctx
=
()
x
=
_inspect_array_inner
(
x
)
return
x
,
ctx
def
_inspect_bwd_rule
(
ctx
,
grad
,
):
""""""
del
ctx
return
(
grad
,)
_inspect
.
defvjp
(
_inspect_fwd_rule
,
_inspect_bwd_rule
)
def
inspect_array
(
x
:
jnp
.
ndarray
,
name
:
str
)
->
jnp
.
ndarray
:
"""Utility function to inspect JAX arrays by printing their name, shape, dtype, and statistics.
Args:
x (jnp.ndarray): The JAX array to inspect.
name (str): The name of the array for identification in the output.
"""
del
name
# Name is currently unused, but can be included in the future for more informative output
return
_inspect
(
x
)
def
load_array_dump
(
filename
:
str
,
shape
:
tuple
,
dtype
:
jnp
.
dtype
)
->
jnp
.
ndarray
:
"""Utility function to load a JAX array from a dumped binary file.
Args:
filename (str): The path to the binary file containing the array data.
shape (tuple): The shape of the array to be loaded.
dtype (jnp.dtype): The data type of the array to be loaded.
Returns:
jnp.ndarray: The loaded JAX array.
"""
with
open
(
filename
,
"rb"
)
as
f
:
data
=
f
.
read
()
array
=
jnp
.
frombuffer
(
data
,
dtype
=
dtype
).
reshape
(
shape
)
return
array
transformer_engine/jax/permutation.py
View file @
9df0c4a3
...
...
@@ -52,7 +52,7 @@ def token_dispatch(
Optional
[
jnp
.
ndarray
],
jnp
.
ndarray
,
Optional
[
jnp
.
ndarray
],
Optional
[
jnp
.
ndarray
]
,
jnp
.
ndarray
,
]:
"""
Dispatch tokens to experts based on routing map.
...
...
@@ -101,9 +101,11 @@ def token_dispatch(
pad_offsets : Optional[jnp.ndarray]
Per-expert cumulative padding offsets of shape [num_experts] when using padding,
None otherwise. Pass this to token_combine when unpadding is needed.
target_tokens_per_expert : Optional[jnp.ndarray]
Aligned token counts per expert of shape [num_experts] when using padding,
None otherwise.
tokens_per_expert : jnp.ndarray
Token counts per expert of shape [num_experts]:
- Without padding: actual token counts (sum of routing_map columns)
- With padding: aligned token counts (ceil(actual / align_size) * align_size)
This gives the effective number of tokens per expert in the output buffer.
Note
----
...
...
@@ -151,10 +153,10 @@ def _token_dispatch(
Optional
[
jnp
.
ndarray
],
jnp
.
ndarray
,
Optional
[
jnp
.
ndarray
],
Optional
[
jnp
.
ndarray
]
,
jnp
.
ndarray
,
]:
"""Internal token_dispatch with custom VJP."""
(
output
,
permuted_probs
,
row_id_map
,
pad_offsets
,
target_
tokens_per_expert
),
_
=
(
(
output
,
permuted_probs
,
row_id_map
,
pad_offsets
,
tokens_per_expert
),
_
=
(
_token_dispatch_fwd_rule
(
inp
,
routing_map
,
...
...
@@ -165,7 +167,7 @@ def _token_dispatch(
use_padding
,
)
)
return
output
,
permuted_probs
,
row_id_map
,
pad_offsets
,
target_
tokens_per_expert
return
output
,
permuted_probs
,
row_id_map
,
pad_offsets
,
tokens_per_expert
def
_token_dispatch_fwd_rule
(
...
...
@@ -182,7 +184,7 @@ def _token_dispatch_fwd_rule(
Optional
[
jnp
.
ndarray
],
jnp
.
ndarray
,
Optional
[
jnp
.
ndarray
],
Optional
[
jnp
.
ndarray
]
,
jnp
.
ndarray
,
],
Tuple
[
jnp
.
ndarray
,
Optional
[
jnp
.
ndarray
],
int
,
int
,
int
,
bool
],
]:
...
...
@@ -212,11 +214,11 @@ def _token_dispatch_fwd_rule(
with_probs
=
probs
is
not
None
if
use_padding
:
# Compute tokens_per_expert internally from routing_map
# This can be a traced value since output shape uses worst_case_out_tokens
# Compute tokens_per_expert from routing_map (actual counts)
# This is well-optimized by XLA as a simple column-wise reduction
tokens_per_expert
=
jnp
.
sum
(
routing_map
,
axis
=
0
).
astype
(
jnp
.
int32
)
if
use_padding
:
# Calculate aligned token counts per expert
target_tokens_per_expert
=
(
jnp
.
ceil
(
tokens_per_expert
/
align_size
)
*
align_size
).
astype
(
jnp
.
int32
...
...
@@ -242,10 +244,12 @@ def _token_dispatch_fwd_rule(
hidden_size
,
align_size
=
align_size
,
)
# Return aligned counts when using padding
out_tokens_per_expert
=
target_tokens_per_expert
else
:
# No padding
pad_offsets
=
None
target_tokens_per_expert
=
None
output
,
permuted_probs
=
permute_with_mask_map
(
inp
,
...
...
@@ -257,14 +261,20 @@ def _token_dispatch_fwd_rule(
hidden_size
,
)
# Return actual counts when not using padding
out_tokens_per_expert
=
tokens_per_expert
# Return (primals, residuals)
# out_tokens_per_expert is:
# - target_tokens_per_expert (aligned) when using padding
# - tokens_per_expert (actual) when not using padding
residuals
=
(
row_id_map
,
pad_offsets
,
num_tokens
,
num_experts
,
hidden_size
,
with_probs
)
return
(
output
,
permuted_probs
,
row_id_map
,
pad_offsets
,
targe
t_tokens_per_expert
,
ou
t_tokens_per_expert
,
),
residuals
...
...
@@ -571,7 +581,7 @@ def sort_chunks_by_index(
return
_sort_chunks_by_index
(
inp
,
split_sizes
,
sorted_indices
)
@
partial
(
jax
.
custom_vjp
,
nondiff_argnums
=
(
1
,
2
))
@
jax
.
custom_vjp
def
_sort_chunks_by_index
(
inp
:
jnp
.
ndarray
,
split_sizes
:
jnp
.
ndarray
,
...
...
@@ -586,7 +596,7 @@ def _sort_chunks_by_index_fwd_rule(
inp
:
jnp
.
ndarray
,
split_sizes
:
jnp
.
ndarray
,
sorted_indices
:
jnp
.
ndarray
,
)
->
Tuple
[
Tuple
[
jnp
.
ndarray
,
jnp
.
ndarray
],
Tuple
[
jnp
.
ndarray
,
int
,
int
]]:
)
->
Tuple
[
Tuple
[
jnp
.
ndarray
,
jnp
.
ndarray
],
Tuple
[
jnp
.
ndarray
,
jnp
.
ndarray
,
jnp
.
ndarray
,
int
,
int
]]:
"""Forward pass rule for sort_chunks_by_index."""
# Validate input dimensions
assert
inp
.
ndim
in
[
2
,
3
],
f
"inp must be 2D or 3D, got
{
inp
.
ndim
}
D"
...
...
@@ -608,18 +618,17 @@ def _sort_chunks_by_index_fwd_rule(
)
# Return (primals, residuals)
residuals
=
(
row_id_map
,
num_tokens
,
hidden_size
)
# Include split_sizes and sorted_indices in residuals since we removed nondiff_argnums
residuals
=
(
row_id_map
,
split_sizes
,
sorted_indices
,
num_tokens
,
hidden_size
)
return
(
output
,
row_id_map
),
residuals
def
_sort_chunks_by_index_bwd_rule
(
_split_sizes
:
jnp
.
ndarray
,
_sorted_indices
:
jnp
.
ndarray
,
residuals
:
Tuple
[
jnp
.
ndarray
,
int
,
int
],
residuals
:
Tuple
[
jnp
.
ndarray
,
jnp
.
ndarray
,
jnp
.
ndarray
,
int
,
int
],
g
:
Tuple
[
jnp
.
ndarray
,
jnp
.
ndarray
],
)
->
Tuple
[
jnp
.
ndarray
]:
)
->
Tuple
[
jnp
.
ndarray
,
jnp
.
ndarray
,
jnp
.
ndarray
]:
"""Backward pass rule for sort_chunks_by_index."""
row_id_map
,
num_tokens
,
hidden_size
=
residuals
row_id_map
,
split_sizes
,
sorted_indices
,
num_tokens
,
hidden_size
=
residuals
output_grad
,
_
=
g
# Backward: reverse the sort
...
...
@@ -632,7 +641,12 @@ def _sort_chunks_by_index_bwd_rule(
is_forward
=
False
,
)
return
(
inp_grad
,)
# Return gradients for all inputs: (inp, split_sizes, sorted_indices)
# split_sizes and sorted_indices are integer arrays, so their gradients are zeros
split_sizes_grad
=
jnp
.
zeros_like
(
split_sizes
,
dtype
=
split_sizes
.
dtype
)
sorted_indices_grad
=
jnp
.
zeros_like
(
sorted_indices
,
dtype
=
sorted_indices
.
dtype
)
return
(
inp_grad
,
split_sizes_grad
,
sorted_indices_grad
)
_sort_chunks_by_index
.
defvjp
(
_sort_chunks_by_index_fwd_rule
,
_sort_chunks_by_index_bwd_rule
)
transformer_engine/jax/triton_extensions/permutation.py
View file @
9df0c4a3
...
...
@@ -65,8 +65,6 @@ class RowIdMapPass1Primitive(BasePrimitive):
@
staticmethod
def
abstract
(
routing_map_aval
,
*
,
num_tokens
,
num_experts
,
block_size
):
"""Shape/dtype inference for pass 1."""
del
block_size
# Only affects grid, not output shape
assert
routing_map_aval
.
shape
==
(
num_tokens
,
num_experts
,
...
...
@@ -75,7 +73,7 @@ class RowIdMapPass1Primitive(BasePrimitive):
row_id_map_shape
=
(
num_tokens
,
num_experts
*
2
+
1
)
workspace_shape
=
(
num_experts
,
triton
.
cdiv
(
num_tokens
,
DEFAULT_BLOCK_SIZE
),
triton
.
cdiv
(
num_tokens
,
block_size
),
)
return
(
...
...
@@ -134,9 +132,10 @@ class RowIdMapPass1Primitive(BasePrimitive):
desc
=
"RowIdMapPass1.row_id_map_sharding"
,
)
# Workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE))
# Second dim depends on num_tokens, so it must be sharded on the same axis as tokens
workspace_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
,
None
),
PartitionSpec
(
None
,
routing_map_spec
[
0
]
),
desc
=
"RowIdMapPass1.workspace_sharding"
,
)
return
[
row_id_map_sharding
,
workspace_sharding
]
...
...
@@ -156,9 +155,11 @@ class RowIdMapPass1Primitive(BasePrimitive):
PartitionSpec
(
routing_map_spec
[
0
],
None
),
desc
=
"RowIdMapPass1.row_id_map_sharding"
,
)
# Workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE))
# Second dim depends on num_tokens, so it must be sharded on the same axis as tokens
workspace_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
,
None
),
PartitionSpec
(
None
,
routing_map_spec
[
0
]
),
desc
=
"RowIdMapPass1.workspace_sharding"
,
)
out_shardings
=
[
row_id_map_sharding
,
workspace_sharding
]
...
...
@@ -186,7 +187,8 @@ class RowIdMapPass1Primitive(BasePrimitive):
# Note: row_id_cols != experts since it's num_experts * 2 + 1
row_id_map_spec
=
(
f
"
{
prefix
}
_tokens"
,
f
"
{
prefix
}
_row_id_cols"
)
# workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE))
workspace_spec
=
(
f
"
{
prefix
}
_experts"
,
f
"
{
prefix
}
_ws_blocks"
)
# Second dim depends on num_tokens, so use same factor to ensure same sharding
workspace_spec
=
(
f
"
{
prefix
}
_experts"
,
f
"
{
prefix
}
_tokens"
)
return
SdyShardingRule
((
input_spec
,),
(
row_id_map_spec
,
workspace_spec
))
...
...
@@ -208,10 +210,9 @@ class RowIdMapPass2Primitive(BasePrimitive):
def
abstract
(
row_id_map_aval
,
workspace_aval
,
*
,
num_tokens
,
num_experts
,
block_size
):
"""Shape/dtype inference for pass 2 (in-place operation)."""
del
row_id_map_aval
,
workspace_aval
del
block_size
row_id_map_shape
=
(
num_tokens
,
num_experts
*
2
+
1
)
workspace_shape
=
(
num_experts
,
triton
.
cdiv
(
num_tokens
,
DEFAULT_BLOCK_SIZE
))
workspace_shape
=
(
num_experts
,
triton
.
cdiv
(
num_tokens
,
block_size
))
return
(
jax
.
core
.
ShapedArray
(
row_id_map_shape
,
jnp
.
int32
),
...
...
@@ -270,9 +271,11 @@ class RowIdMapPass2Primitive(BasePrimitive):
PartitionSpec
(
*
row_id_map_spec
),
desc
=
"RowIdMapPass2.row_id_map_sharding"
,
)
# Workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE))
# Second dim depends on num_tokens, so it must be sharded on the same axis as tokens
workspace_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
,
None
),
PartitionSpec
(
None
,
row_id_map_spec
[
0
]
),
desc
=
"RowIdMapPass2.workspace_sharding"
,
)
return
[
row_id_map_sharding
,
workspace_sharding
]
...
...
@@ -292,9 +295,11 @@ class RowIdMapPass2Primitive(BasePrimitive):
PartitionSpec
(
*
row_id_map_spec
),
desc
=
"RowIdMapPass2.row_id_map_sharding"
,
)
# Workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE))
# Second dim depends on num_tokens, so it must be sharded on the same axis as tokens
workspace_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
,
None
),
PartitionSpec
(
None
,
row_id_map_spec
[
0
]
),
desc
=
"RowIdMapPass2.workspace_sharding"
,
)
out_shardings
=
[
row_id_map_sharding
,
workspace_sharding
]
...
...
@@ -317,7 +322,9 @@ class RowIdMapPass2Primitive(BasePrimitive):
del
num_tokens
,
num_experts
,
block_size
,
mesh
,
value_types
,
result_types
prefix
=
"RowIdMapPass2"
row_id_map_spec
=
(
f
"
{
prefix
}
_tokens"
,
f
"
{
prefix
}
_cols"
)
workspace_spec
=
(
f
"
{
prefix
}
_ws_experts"
,
f
"
{
prefix
}
_ws_blocks"
)
# workspace shape: (num_experts, cdiv(num_tokens, BLOCK_SIZE))
# Second dim depends on num_tokens, so use same factor to ensure same sharding
workspace_spec
=
(
f
"
{
prefix
}
_ws_experts"
,
f
"
{
prefix
}
_tokens"
)
return
SdyShardingRule
((
row_id_map_spec
,
workspace_spec
),
(
row_id_map_spec
,
workspace_spec
))
...
...
transformer_engine/jax/triton_extensions/utils.py
View file @
9df0c4a3
...
...
@@ -36,6 +36,8 @@ import warnings
from
typing
import
Any
,
Callable
,
Mapping
import
zlib
from
packaging
import
version
from
jax
import
core
import
jax
import
jax.numpy
as
jnp
...
...
@@ -274,13 +276,16 @@ def compile_triton(
return
_TRITON_KERNEL_CACHE
[
cache_key
]
# Compile kernel
cuda_option_kwargs
=
{}
if
version
.
parse
(
_TRITON_VERSION
)
<
version
.
parse
(
"3.6.0"
):
cuda_option_kwargs
[
"cluster_dims"
]
=
(
1
,
1
,
1
)
options
=
cb
.
CUDAOptions
(
num_warps
=
num_warps
,
num_stages
=
num_stages
,
num_ctas
=
num_ctas
,
cluster_dims
=
(
1
,
1
,
1
),
debug
=
False
,
enable_fp_fusion
=
enable_fp_fusion
,
**
cuda_option_kwargs
,
)
# Mark constants as constexpr in signature
...
...
@@ -303,8 +308,6 @@ def compile_triton(
# Create kernel object for JAX
# From jax/jaxlib/gpu/triton_kernels.cc:
from
packaging
import
version
if
version
.
parse
(
jax
.
__version__
)
>=
version
.
parse
(
"0.8.2"
):
kernel
=
gpu_triton
.
TritonKernel
(
compiled
.
name
,
# arg0: kernel_name (str)
...
...
transformer_engine/pytorch/attention/dot_product_attention/backends.py
View file @
9df0c4a3
...
...
@@ -166,6 +166,11 @@ class FP8EmulationFunc(torch.autograd.Function):
@
staticmethod
def
forward
(
ctx
,
tensor1
,
tensor2
,
tensor3
,
quantizer
,
quantizer_name
,
qkv_layout
):
# pylint: disable=missing-function-docstring
if
is_in_onnx_export_mode
():
return
FP8EmulationFunc
.
onnx_forward
(
tensor1
,
tensor2
,
tensor3
,
quantizer
,
quantizer_name
,
qkv_layout
)
if
quantizer_name
==
"QKV_quantizer"
:
query_layer
,
key_layer
,
value_layer
=
[
x
.
contiguous
()
for
x
in
[
tensor1
,
tensor2
,
tensor3
]
...
...
@@ -204,6 +209,47 @@ class FP8EmulationFunc(torch.autograd.Function):
tensors
=
grad1
,
grad2
,
grad3
return
tensors
[
0
],
tensors
[
1
],
tensors
[
2
],
None
,
None
,
None
@
staticmethod
def
onnx_forward
(
tensor1
,
tensor2
,
tensor3
,
quantizer
,
quantizer_name
,
qkv_layout
=
None
):
"""
ONNX-compatible forward for FP8 emulation using operations with defined ONNX translations.
"""
# pylint: disable=unused-argument
is_qkv_quantizer
=
quantizer_name
==
"QKV_quantizer"
assert
isinstance
(
quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
)
),
"ONNX FP8 emulation path supports only Float8 quantizers."
if
is_qkv_quantizer
:
# Flatten + concatenate + quantize + split. Equivalent to combine_and_quantize Case 3.
orig_dtype
=
tensor1
.
dtype
shapes
=
[
tensor1
.
shape
,
tensor2
.
shape
,
tensor3
.
shape
]
numels
=
[
tensor1
.
numel
(),
tensor2
.
numel
(),
tensor3
.
numel
()]
# Flatten and concatenate
combined
=
torch
.
cat
(
[
tensor1
.
reshape
(
-
1
),
tensor2
.
reshape
(
-
1
),
tensor3
.
reshape
(
-
1
)],
dim
=
0
)
# Quantize + dequantize combined tensor using quantizer's ONNX methods
combined_fp8
=
quantizer
.
onnx_quantize
(
combined
)
out
=
quantizer
.
onnx_dequantize
(
combined_fp8
).
to
(
orig_dtype
)
# Split back
out1
=
out
[:
numels
[
0
]].
reshape
(
shapes
[
0
])
out2
=
out
[
numels
[
0
]
:
numels
[
0
]
+
numels
[
1
]].
reshape
(
shapes
[
1
])
out3
=
out
[
numels
[
0
]
+
numels
[
1
]
:].
reshape
(
shapes
[
2
])
return
out1
,
out2
,
out3
if
quantizer_name
in
[
"S_quantizer"
,
"O_quantizer"
]:
# Emulate FP8 on single tensor using quantizer's ONNX methods
orig_dtype
=
tensor1
.
dtype
t_fp8
=
quantizer
.
onnx_quantize
(
tensor1
)
out
=
quantizer
.
onnx_dequantize
(
t_fp8
).
to
(
orig_dtype
)
return
out
,
tensor2
,
tensor3
# Pass-through
return
tensor1
,
tensor2
,
tensor3
class
UnfusedDotProductAttention
(
torch
.
nn
.
Module
):
"""Parallel attention w/o QKV and Proj Gemms
...
...
@@ -263,6 +309,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
attn_mask_type
:
str
=
"causal"
,
attention_mask
:
Optional
[
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]]
=
None
,
window_size
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
bottom_right_diagonal
:
Optional
[
bool
]
=
None
,
core_attention_bias_type
:
str
=
"no_bias"
,
core_attention_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -348,6 +395,11 @@ class UnfusedDotProductAttention(torch.nn.Module):
attention_mask
=
attention_mask
,
window_size
=
window_size
,
attention_type
=
self
.
attention_type
,
bottom_right_alignment
=
(
attn_mask_type
not
in
[
"causal"
,
"padding_causal"
]
if
bottom_right_diagonal
is
None
else
bottom_right_diagonal
),
)
)
...
...
@@ -451,7 +503,11 @@ class UnfusedDotProductAttention(torch.nn.Module):
actual_seqlens_q
=
actual_seqlens_q
if
"padding"
in
attn_mask_type
else
None
,
actual_seqlens_kv
=
actual_seqlens_kv
if
"padding"
in
attn_mask_type
else
None
,
alibi_slopes
=
alibi_slopes
,
bottom_right_alignment
=
attn_mask_type
not
in
[
"causal"
,
"padding_causal"
],
bottom_right_alignment
=
(
attn_mask_type
not
in
[
"causal"
,
"padding_causal"
]
if
bottom_right_diagonal
is
None
else
bottom_right_diagonal
),
)
matmul_result
=
torch
.
baddbmm
(
matmul_result
,
...
...
@@ -1112,6 +1168,7 @@ class FusedAttnFunc(torch.autograd.Function):
attn_mask_type
,
softmax_type
,
window_size
,
bottom_right_diagonal
,
rng_gen
,
fused_attention_backend
,
use_FAv2_bwd
,
...
...
@@ -1215,6 +1272,7 @@ class FusedAttnFunc(torch.autograd.Function):
attn_mask_type
,
softmax_type
,
window_size
,
bottom_right_diagonal
,
rng_gen
,
softmax_offset
,
cuda_graph
=
is_graph_capturing
(),
...
...
@@ -1292,6 +1350,7 @@ class FusedAttnFunc(torch.autograd.Function):
attn_mask_type
,
softmax_type
,
window_size
,
bottom_right_diagonal
,
rng_gen
,
softmax_offset
,
return_max_logit
,
...
...
@@ -1379,6 +1438,7 @@ class FusedAttnFunc(torch.autograd.Function):
ctx
.
attn_mask_type
=
attn_mask_type
ctx
.
softmax_type
=
softmax_type
ctx
.
window_size
=
window_size
ctx
.
bottom_right_diagonal
=
bottom_right_diagonal
ctx
.
fused_attention_backend
=
(
fused_attention_backend
if
ctx
.
fp8
else
FusedAttnBackend
[
"F16_arbitrary_seqlen"
]
)
...
...
@@ -1529,6 +1589,7 @@ class FusedAttnFunc(torch.autograd.Function):
ctx
.
attn_mask_type
,
ctx
.
softmax_type
,
ctx
.
window_size
,
ctx
.
bottom_right_diagonal
,
ctx
.
deterministic
,
is_graph_capturing
(),
)
...
...
@@ -1594,6 +1655,7 @@ class FusedAttnFunc(torch.autograd.Function):
ctx
.
attn_mask_type
,
ctx
.
softmax_type
,
ctx
.
window_size
,
ctx
.
bottom_right_diagonal
,
ctx
.
deterministic
,
is_graph_capturing
(),
)
...
...
@@ -1633,6 +1695,7 @@ class FusedAttnFunc(torch.autograd.Function):
None
,
None
,
None
,
None
,
d_softmax_offset
,
None
,
None
,
...
...
@@ -1730,6 +1793,7 @@ class FusedAttention(torch.nn.Module):
attn_mask_type
:
str
=
"causal"
,
attention_mask
:
Optional
[
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]]
=
None
,
window_size
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
bottom_right_diagonal
:
Optional
[
bool
]
=
None
,
fused_attention_backend
:
tex
.
NVTE_Fused_Attn_Backend
=
tex
.
NVTE_Fused_Attn_Backend
.
NVTE_No_Backend
,
core_attention_bias_type
:
str
=
"no_bias"
,
core_attention_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -1937,6 +2001,7 @@ class FusedAttention(torch.nn.Module):
attn_mask_type
,
self
.
softmax_type
,
window_size
,
bottom_right_diagonal
,
None
,
# rng_gen
fused_attention_backend
,
use_FAv2_bwd
,
...
...
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py
View file @
9df0c4a3
...
...
@@ -4026,28 +4026,30 @@ def attn_forward_func_with_cp(
assert
not
sliding_window_attn
or
cp_comm_type
in
[
"a2a"
,
"all_gather"
,
],
"Context parallelism does not support sliding window attention with {cp_comm_type=}!"
],
f
"Context parallelism does not support sliding window attention with
{
cp_comm_type
=
}
!"
enable_mla
=
k
.
shape
[
-
1
]
!=
v
.
shape
[
-
1
]
assert
not
enable_mla
or
cp_comm_type
in
[
"p2p"
,
"a2a+p2p"
,
],
"Context parallelism does not support MLA with {cp_comm_type=}!"
],
f
"Context parallelism does not support MLA with
{
cp_comm_type
=
}
!"
if
fp8
and
fp8_meta
is
not
None
:
if
fp8_meta
[
"recipe"
].
fp8_dpa
:
assert
(
softmax_type
==
"vanilla"
),
"Context parallelism does not support {softmax_type=} with FP8 attention!"
),
f
"Context parallelism does not support
{
softmax_type
=
}
with FP8 attention!"
assert
(
softmax_type
==
"vanilla"
or
use_fused_attention
),
"Context parallelism only supports {softmax_type=} with FusedAttention backend!"
),
f
"Context parallelism only supports
{
softmax_type
=
}
with FusedAttention backend!"
assert
(
softmax_type
==
"vanilla"
or
cp_comm_type
==
"a2a"
),
"Context parallelism only supports {softmax_type=} with cp_comm_type = 'a2a'!"
assert
(
softmax_type
==
"vanilla"
or
qkv_format
!=
"thd"
),
"Context parallelism does not support {softmax_type=} with qkv_format = 'thd'!"
),
f
"Context parallelism only supports
{
softmax_type
=
}
with cp_comm_type = 'a2a'!"
if
get_cudnn_version
()
<
(
9
,
18
,
0
):
assert
softmax_type
==
"vanilla"
or
qkv_format
!=
"thd"
,
(
f
"Before cuDNN 9.18.0, context parallelism does not support
{
softmax_type
=
}
with"
" qkv_format = 'thd'!"
)
args
=
[
is_training
,
...
...
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py
View file @
9df0c4a3
...
...
@@ -228,6 +228,11 @@ class DotProductAttention(TransformerEngineBaseModule):
map to ``window_size = (-1, 0)`` and Transformer Engine distinguishes them based on
``attn_mask_type``. Similar to :attr:`attn_mask_type`, ``window_size`` can
be overridden by :attr:`window_size` in ``forward`` as well.
bottom_right_diagonal: Optional[bool], default = `None`
Align sliding window and ALiBi diagonal to the top left (`False`)
or bottom right (`True`) corner of the softmax matrix in the encoder.
If `None`, it will be set to `False` for `attn_mask_type` =
{'causal', 'padding_causal'} and `True` for other mask types.
attention_type : str, default = "self"
type of attention, either ``"self"`` and ``"cross"``.
layer_number : int, default = None
...
...
@@ -324,6 +329,7 @@ class DotProductAttention(TransformerEngineBaseModule):
qkv_format
:
str
=
"sbhd"
,
attn_mask_type
:
str
=
"causal"
,
window_size
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
bottom_right_diagonal
:
Optional
[
bool
]
=
None
,
sequence_parallel
:
bool
=
False
,
tp_size
:
int
=
1
,
get_rng_state_tracker
:
Optional
[
Callable
]
=
None
,
...
...
@@ -350,6 +356,7 @@ class DotProductAttention(TransformerEngineBaseModule):
attn_mask_type
=
"padding_causal"
self
.
attn_mask_type
=
attn_mask_type
self
.
window_size
=
dpa_utils
.
check_set_window_size
(
attn_mask_type
,
window_size
)
self
.
bottom_right_diagonal
=
bottom_right_diagonal
if
tp_group
is
None
:
self
.
tp_size
=
tp_size
if
tp_size
==
1
:
...
...
@@ -676,9 +683,9 @@ class DotProductAttention(TransformerEngineBaseModule):
# assume attention uses the same fp8_group as GEMMs
fp8_group
=
FP8GlobalStateManager
.
get_fp8_group
()
self
.
fp8_parameters
=
FP8GlobalStateManager
.
with_fp8_parameters
()
self
.
f
p8
=
FP8GlobalStateManager
.
is_fp8_enabled
()
self
.
fp8_calibration
=
FP8GlobalStateManager
.
is_fp8_calibration
()
self
.
fast_setattr
(
"
fp8_parameters
"
,
FP8GlobalStateManager
.
with_fp8_parameters
()
)
self
.
f
ast_setattr
(
"fp8"
,
FP8GlobalStateManager
.
is_fp8_enabled
()
)
self
.
fast_setattr
(
"
fp8_calibration
"
,
FP8GlobalStateManager
.
is_fp8_calibration
()
)
fp8_enabled
=
self
.
fp8
or
self
.
fp8_calibration
self
.
fp8_meta
[
"fp8_checkpoint"
]
=
self
.
fp8
or
self
.
fp8_calibration
if
self
.
fp8_parameters
or
fp8_enabled
:
...
...
@@ -703,7 +710,7 @@ class DotProductAttention(TransformerEngineBaseModule):
)
else
:
# If fp8 isn't enabled, turn off and return.
self
.
fp8_initialized
=
False
self
.
fast_setattr
(
"
fp8_initialized
"
,
False
)
return
if
self
.
fp8_parameters
and
not
self
.
fp8_initialized
:
...
...
@@ -721,7 +728,7 @@ class DotProductAttention(TransformerEngineBaseModule):
# Allocate scales and amaxes
self
.
init_fp8_meta_tensors
(
fp8_recipes
)
self
.
fp8_initialized
=
True
self
.
fast_setattr
(
"
fp8_initialized
"
,
True
)
self
.
fp8_meta
[
"recipe"
]
=
fp8_recipe_dpa
if
fp8_recipe
!=
fp8_recipe_dpa
:
...
...
@@ -811,6 +818,7 @@ class DotProductAttention(TransformerEngineBaseModule):
max_seqlen_kv
:
int
=
None
,
attn_mask_type
:
Optional
[
str
]
=
None
,
window_size
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
bottom_right_diagonal
:
Optional
[
bool
]
=
None
,
checkpoint_core_attention
:
bool
=
False
,
core_attention_bias_type
:
str
=
"no_bias"
,
core_attention_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -963,6 +971,16 @@ class DotProductAttention(TransformerEngineBaseModule):
causal masks are aligned to the bottom right corner.
window_size: Optional[Tuple[int, int]], default = None
Sliding window size for local attention.
bottom_right_diagonal: Optional[bool], default = None
Align sliding window and ALiBi diagonal to the top left (`False`)
or bottom right (`True`) corner of the softmax matrix in the encoder.
If `None`, it will be set to `False` for `attn_mask_type` =
{'causal', 'padding_causal'} and `True` for other mask types.
Note: This parameter will be automatically overridden based on the
`attn_mask_type` - it will be forced to `False` for 'causal' and
'padding_causal' mask types, and forced to `True` for mask types
containing 'bottom_right' (e.g., 'causal_bottom_right',
'padding_causal_bottom_right'), regardless of the explicitly passed value.
checkpoint_core_attention : bool, default = False
If true, forward activations for attention are recomputed
during the backward pass in order to save memory that would
...
...
@@ -1000,7 +1018,7 @@ class DotProductAttention(TransformerEngineBaseModule):
cases. It is ignored for other backends and when context parallelism is enabled.
"""
with
self
.
prepare_forward
(
with
self
.
prepare_forward
_ctx
(
query_layer
,
num_gemms
=
3
,
allow_non_contiguous
=
True
,
...
...
@@ -1081,6 +1099,15 @@ class DotProductAttention(TransformerEngineBaseModule):
if
window_size
is
None
:
window_size
=
self
.
window_size
window_size
=
dpa_utils
.
check_set_window_size
(
attn_mask_type
,
window_size
)
if
bottom_right_diagonal
is
None
:
bottom_right_diagonal
=
self
.
bottom_right_diagonal
if
attn_mask_type
in
{
"causal"
,
"padding_causal"
}:
bottom_right_diagonal
=
False
if
bottom_right_diagonal
is
None
or
attn_mask_type
in
{
"causal_bottom_right"
,
"padding_causal_bottom_right"
,
}:
bottom_right_diagonal
=
True
# checks for qkv_format
if
qkv_format
is
None
:
...
...
@@ -1144,8 +1171,11 @@ class DotProductAttention(TransformerEngineBaseModule):
assert
"padding"
in
attn_mask_type
,
"KV caching requires padding mask!"
if
attn_mask_type
==
"padding_causal"
:
attn_mask_type
=
attn_mask_type
+
"_bottom_right"
# since attention mask is changed, set `bottom_right_diagonal` to True
bottom_right_diagonal
=
True
self
.
attention_type
=
"cross"
if
self
.
attention_type
!=
"cross"
:
self
.
fast_setattr
(
"attention_type"
,
"cross"
)
self
.
flash_attention
.
attention_type
=
self
.
attention_type
self
.
fused_attention
.
attention_type
=
self
.
attention_type
self
.
unfused_attention
.
attention_type
=
self
.
attention_type
...
...
@@ -1256,7 +1286,6 @@ class DotProductAttention(TransformerEngineBaseModule):
if
self
.
layer_number
==
1
:
_alibi_cache
[
"_alibi_slopes_require_update"
]
=
True
_alibi_cache
[
"_alibi_bias_require_update"
]
=
True
bottom_right_alignment
=
(
attn_mask_type
not
in
[
"causal"
,
"padding_causal"
],)
if
core_attention_bias_type
==
"alibi"
:
assert
(
core_attention_bias
is
None
...
...
@@ -1265,7 +1294,7 @@ class DotProductAttention(TransformerEngineBaseModule):
_alibi_cache
[
"_num_heads"
]
!=
query_layer
.
shape
[
-
2
]
or
_alibi_cache
[
"_max_seqlen_q"
]
!=
max_seqlen_q
or
_alibi_cache
[
"_max_seqlen_kv"
]
!=
max_seqlen_kv
or
_alibi_cache
[
"_bottom_right_alignment"
]
!=
bottom_right_
alignment
or
_alibi_cache
[
"_bottom_right_alignment"
]
!=
bottom_right_
diagonal
or
_alibi_cache
[
"_alibi_slopes"
]
is
None
):
_alibi_cache
[
"_alibi_slopes_require_update"
]
=
True
...
...
@@ -1322,6 +1351,7 @@ class DotProductAttention(TransformerEngineBaseModule):
head_dim_v
=
head_dim_v
,
attn_mask_type
=
attn_mask_type
,
window_size
=
window_size
,
bottom_right_diagonal
=
bottom_right_diagonal
,
alibi_slopes_shape
=
alibi_slopes
.
shape
if
alibi_slopes
is
not
None
else
None
,
core_attention_bias_type
=
core_attention_bias_type
,
core_attention_bias_shape
=
core_attention_bias_shape
,
...
...
@@ -1445,9 +1475,7 @@ class DotProductAttention(TransformerEngineBaseModule):
if
use_fused_attention
:
fu_core_attention_bias_type
=
core_attention_bias_type
fu_core_attention_bias
=
core_attention_bias
if
core_attention_bias_type
==
"alibi"
and
(
alibi_slopes
is
not
None
or
max_seqlen_q
!=
max_seqlen_kv
):
if
core_attention_bias_type
==
"alibi"
and
(
alibi_slopes
is
not
None
):
fu_core_attention_bias_type
=
"post_scale_bias"
_
,
fu_core_attention_bias
=
dpa_utils
.
get_alibi
(
_alibi_cache
,
...
...
@@ -1456,7 +1484,7 @@ class DotProductAttention(TransformerEngineBaseModule):
max_seqlen_kv
,
alibi_slopes
=
alibi_slopes
,
bias_dtype
=
query_layer
.
dtype
,
bottom_right_alignment
=
attn_mask_type
not
in
[
"causal"
,
"padding_causal"
]
,
bottom_right_alignment
=
bottom_right_diagonal
,
)
if
checkpoint_core_attention
:
return
self
.
_checkpointed_attention_forward
(
...
...
@@ -1474,6 +1502,7 @@ class DotProductAttention(TransformerEngineBaseModule):
attn_mask_type
=
attn_mask_type
,
attention_mask
=
attention_mask
,
window_size
=
window_size
,
bottom_right_diagonal
=
bottom_right_diagonal
,
fused_attention_backend
=
fused_attention_backend
,
core_attention_bias_type
=
fu_core_attention_bias_type
,
core_attention_bias
=
fu_core_attention_bias
,
...
...
@@ -1504,6 +1533,7 @@ class DotProductAttention(TransformerEngineBaseModule):
attn_mask_type
=
attn_mask_type
,
attention_mask
=
attention_mask
,
window_size
=
window_size
,
bottom_right_diagonal
=
bottom_right_diagonal
,
fused_attention_backend
=
fused_attention_backend
,
core_attention_bias_type
=
fu_core_attention_bias_type
,
core_attention_bias
=
fu_core_attention_bias
,
...
...
@@ -1522,7 +1552,9 @@ class DotProductAttention(TransformerEngineBaseModule):
)
if
use_unfused_attention
:
allow_emulation
=
os
.
getenv
(
"NVTE_UnfusedDPA_Emulate_FP8"
,
"0"
)
==
"1"
allow_emulation
=
(
os
.
getenv
(
"NVTE_UnfusedDPA_Emulate_FP8"
,
"0"
)
==
"1"
or
is_in_onnx_export_mode
()
)
if
checkpoint_core_attention
:
return
self
.
_checkpointed_attention_forward
(
self
.
unfused_attention
,
...
...
@@ -1538,6 +1570,7 @@ class DotProductAttention(TransformerEngineBaseModule):
attn_mask_type
=
attn_mask_type
,
attention_mask
=
attention_mask
,
window_size
=
window_size
,
bottom_right_diagonal
=
bottom_right_diagonal
,
core_attention_bias_type
=
core_attention_bias_type
,
core_attention_bias
=
core_attention_bias
,
alibi_slopes
=
alibi_slopes
,
...
...
@@ -1561,6 +1594,7 @@ class DotProductAttention(TransformerEngineBaseModule):
attn_mask_type
=
attn_mask_type
,
attention_mask
=
attention_mask
,
window_size
=
window_size
,
bottom_right_diagonal
=
bottom_right_diagonal
,
core_attention_bias_type
=
core_attention_bias_type
,
core_attention_bias
=
core_attention_bias
,
alibi_slopes
=
alibi_slopes
,
...
...
transformer_engine/pytorch/attention/dot_product_attention/utils.py
View file @
9df0c4a3
...
...
@@ -200,6 +200,9 @@ class AttentionParams:
`causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`}
window_size : Tuple[int, int], default = None
Sliding window attention size.
bottom_right_diagonal: bool, default = `None`
Whether to align sliding window and ALiBi diagonal to the bottom right corner
of the softmax matrix.
alibi_slopes_shape : Optional[Union[torch.Size, List]], default = None
Tensor shape of :attr:`alibi_slopes` in `DotProductAttention`.
core_attention_bias_type : str, default = no_bias
...
...
@@ -249,6 +252,7 @@ class AttentionParams:
head_dim_v
:
int
=
64
attn_mask_type
:
str
=
"no_mask"
window_size
:
Union
[
Tuple
[
int
,
int
],
None
]
=
None
bottom_right_diagonal
:
bool
=
True
alibi_slopes_shape
:
Union
[
torch
.
Size
,
List
,
None
]
=
None
core_attention_bias_type
:
str
=
"no_bias"
core_attention_bias_shape
:
str
=
"1hss"
...
...
@@ -325,6 +329,7 @@ def get_attention_backend(
head_dim_v
=
attention_params
.
head_dim_v
attn_mask_type
=
attention_params
.
attn_mask_type
window_size
=
attention_params
.
window_size
bottom_right_diagonal
=
attention_params
.
bottom_right_diagonal
alibi_slopes_shape
=
attention_params
.
alibi_slopes_shape
core_attention_bias_type
=
attention_params
.
core_attention_bias_type
core_attention_bias_shape
=
attention_params
.
core_attention_bias_shape
...
...
@@ -474,7 +479,9 @@ def get_attention_backend(
logger
.
debug
(
"Disabling FlashAttention 3 for FP8 training"
)
use_flash_attention_3
=
False
if
use_unfused_attention
:
allow_emulation
=
os
.
getenv
(
"NVTE_UnfusedDPA_Emulate_FP8"
,
"0"
)
==
"1"
allow_emulation
=
(
os
.
getenv
(
"NVTE_UnfusedDPA_Emulate_FP8"
,
"0"
)
==
"1"
or
is_in_onnx_export_mode
()
)
if
not
allow_emulation
:
logger
.
debug
(
"Disabling UnfusedDotProductAttention for FP8 attention"
)
use_unfused_attention
=
False
...
...
@@ -730,22 +737,14 @@ def get_attention_backend(
)
use_unfused_attention
=
False
if
qkv_format
==
"thd"
:
if
cudnn_version
<
(
9
,
18
,
0
):
logger
.
debug
(
"Disabling FusedAttention for softmax_type = %s and qkv_format = thd"
,
softmax_type
)
use_fused_attention
=
False
logger
.
debug
(
"Disabling UnfusedDotProductAttention for softmax_type = %s and qkv_format = thd"
,
"Disabling FusedAttention for softmax_type = %s, qkv_format = thd and cuDNN"
" version < 9.18"
,
softmax_type
,
)
use_
un
fused_attention
=
False
use_fused_attention
=
False
if
context_parallel
:
logger
.
debug
(
"Disabling UnfusedDotProductAttention for context parallelism with softmax_type"
" = %s"
,
softmax_type
,
)
use_unfused_attention
=
False
if
cp_comm_type
!=
"a2a"
:
logger
.
debug
(
"Disabling FusedAttention for context parallelism with softmax_type = %s and"
...
...
@@ -881,23 +880,21 @@ def get_attention_backend(
# backend | window_size | diagonal alignment
# ---------------------------------------------------------------------------------
# FlashAttention | (-1, -1) or (>=0, >=0) | bottom right
# FusedAttention | (-1, 0) or (>=0, 0)
| top left
# UnfusedDotProductAttention | (-1, -1) or (>=0, >=0) |
both;
# FusedAttention | (-1, 0) or (>=0,
>=
0) | top left
, bottom right
# UnfusedDotProductAttention | (-1, -1) or (>=0, >=0) |
top left, bottom right
# | | converts window_size to an 'arbitrary' mask
if
window_size
is
None
:
window_size
=
check_set_window_size
(
attn_mask_type
,
window_size
)
else
:
if
use_fused_attention
and
(
window_size
[
0
]
!=
-
1
or
window_size
[
1
]
not
in
[
-
1
,
0
]):
if
fp8
and
(
fp8_meta
[
"recipe"
].
fp8_dpa
or
fp8_meta
[
"recipe"
].
fp8_mha
):
logger
.
debug
(
"Disabling FusedAttention as it does not support sliding window attention"
" for FP8"
"Disabling FusedAttention as it does not support sliding window attention for FP8"
)
use_fused_attention
=
False
elif
window_size
[
1
]
!=
0
or
attention_dropout
!=
0.0
:
elif
attention_dropout
!=
0.0
:
logger
.
debug
(
"Disabling FusedAttention as it only supports sliding window attention "
"with
(left, 0) and no
dropout"
"with
out
dropout"
)
use_fused_attention
=
False
elif
max_seqlen_q
>
max_seqlen_kv
:
...
...
@@ -914,6 +911,12 @@ def get_attention_backend(
"Disabling FlashAttention as sliding window attention requires flash-attn 2.3+"
)
use_flash_attention_2
=
False
elif
not
bottom_right_diagonal
and
max_seqlen_q
!=
max_seqlen_kv
:
logger
.
debug
(
"Disabling FlashAttention as it only supports sliding window with bottom right"
" diagonal alignment for cross-attention"
)
use_flash_attention
=
False
# Filter: Attention bias
# backend | bias types | ALiBi diagonal alignment
...
...
@@ -935,6 +938,12 @@ def get_attention_backend(
elif
not
FlashAttentionUtils
.
v2_4_plus
:
logger
.
debug
(
"Disabling FlashAttention as ALiBi requires flash-attn 2.4+"
)
use_flash_attention_2
=
False
elif
not
bottom_right_diagonal
and
max_seqlen_q
!=
max_seqlen_kv
:
logger
.
debug
(
"Disabling FlashAttention as it only supports ALiBi with bottom right diagonal"
" alignment for cross-attention"
)
use_flash_attention
=
False
if
(
core_attention_bias_type
not
in
[
"no_bias"
,
"alibi"
]
...
...
@@ -952,13 +961,12 @@ def get_attention_backend(
if
(
use_fused_attention
and
core_attention_bias_type
==
"alibi"
and
(
alibi_slopes_shape
is
not
None
or
max_seqlen_q
!=
max_seqlen_kv
)
and
(
alibi_slopes_shape
is
not
None
)
):
fu_core_attention_bias_type
=
"post_scale_bias"
fu_core_attention_bias_requires_grad
=
False
if
alibi_slopes_shape
is
None
:
fu_core_attention_bias_shape
=
"1hss"
elif
len
(
alibi_slopes_shape
)
==
1
and
alibi_slopes_shape
[
0
]
==
num_heads
:
if
len
(
alibi_slopes_shape
)
==
1
and
alibi_slopes_shape
[
0
]
==
num_heads
:
fu_core_attention_bias_shape
=
"1hss"
elif
(
len
(
alibi_slopes_shape
)
==
2
...
...
@@ -1008,6 +1016,7 @@ def get_attention_backend(
window_size
[
1
],
return_max_logit
,
cuda_graph
,
deterministic
,
)
if
fused_attention_backend
==
FusedAttnBackend
[
"No_Backend"
]:
logger
.
debug
(
"Disabling FusedAttention as no backend supports the provided input"
)
...
...
@@ -1062,6 +1071,15 @@ def get_attention_backend(
)
use_flash_attention_2
=
False
if
use_fused_attention
and
deterministic
:
if
softmax_type
!=
"vanilla"
:
logger
.
debug
(
"Disabling FusedAttention for determinism reasons with softmax_type = %s. "
"Sink attention (off-by-one and learnable softmax) requires "
"NVTE_ALLOW_NONDETERMINISTIC_ALGO=1"
,
softmax_type
,
)
use_fused_attention
=
False
fused_attention_backend
=
None
if
fused_attention_backend
==
FusedAttnBackend
[
"FP8"
]
and
is_training
:
logger
.
debug
(
"Disabling FusedAttention for determinism reasons with FP8"
)
use_fused_attention
=
False
...
...
@@ -1078,10 +1096,6 @@ def get_attention_backend(
logger
.
debug
(
"Disabling FusedAttention for determinism reasons with post_scale_bias"
)
use_fused_attention
=
False
fused_attention_backend
=
None
if
is_training
and
device_compute_capability
>=
(
10
,
0
):
logger
.
debug
(
"Disabling FusedAttention for determinism reasons on Blackwell"
)
use_fused_attention
=
False
fused_attention_backend
=
None
# use_flash_attention may have been set above
use_flash_attention_2
=
use_flash_attention
and
use_flash_attention_2
...
...
transformer_engine/pytorch/attention/multi_head_attention.py
View file @
9df0c4a3
...
...
@@ -8,7 +8,6 @@ import collections
from
typing
import
Callable
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
transformer_engine.debug.pytorch.debug_state
import
TEDebugState
from
transformer_engine.pytorch.quantization
import
FP8GlobalStateManager
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Tensor
from
transformer_engine.pytorch.module.base
import
TransformerEngineBaseModule
...
...
@@ -32,6 +31,7 @@ from transformer_engine.pytorch.distributed import (
from
transformer_engine.pytorch.attention.dot_product_attention
import
DotProductAttention
from
transformer_engine.pytorch.attention.inference
import
InferenceParams
from
transformer_engine.pytorch.attention.rope
import
apply_rotary_pos_emb
from
transformer_engine.pytorch.attention.dot_product_attention
import
utils
as
dpa_utils
from
transformer_engine.pytorch.cpu_offload
import
start_offload
,
is_cpu_offload_enabled
...
...
@@ -93,6 +93,11 @@ class MultiheadAttention(torch.nn.Module):
map to ``window_size = (-1, 0)`` and Transformer Engine distinguishes them based on
``attn_mask_type``. Similar to :attr:`attn_mask_type`, ``window_size`` can
be overridden by :attr:`window_size` in :meth:`forward` as well.
bottom_right_diagonal: Optional[bool], default = `None`
Align sliding window and ALiBi diagonal to the top left (`False`)
or bottom right (`True`) corner of the softmax matrix in the encoder.
If `None`, it will be set to `False` for `attn_mask_type` =
{`causal`, `padding_causal`} and `True` for other mask types.
num_gqa_groups : int, default = None
number of GQA groups in the transformer layer.
Grouped Query Attention is described in
...
...
@@ -248,6 +253,7 @@ class MultiheadAttention(torch.nn.Module):
layer_number
:
Optional
[
int
]
=
None
,
attn_mask_type
:
str
=
"causal"
,
window_size
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
bottom_right_diagonal
:
Optional
[
bool
]
=
None
,
tp_group
:
Optional
[
dist_group_type
]
=
None
,
tp_size
:
int
=
1
,
num_gqa_groups
:
Optional
[
int
]
=
None
,
...
...
@@ -286,6 +292,7 @@ class MultiheadAttention(torch.nn.Module):
self
.
qkv_format
=
qkv_format
self
.
attn_mask_type
=
attn_mask_type
self
.
window_size
=
window_size
self
.
bottom_right_diagonal
=
bottom_right_diagonal
self
.
layer_number
=
1
if
layer_number
is
None
else
layer_number
self
.
input_layernorm
=
input_layernorm
self
.
attention_type
=
attention_type
...
...
@@ -335,6 +342,7 @@ class MultiheadAttention(torch.nn.Module):
self
.
hidden_size_kv
=
self
.
hidden_size_per_attention_head
*
self
.
num_gqa_groups
self
.
name
=
name
TransformerEngineBaseModule
.
_validate_name
(
self
)
common_gemm_kwargs
=
{
"fuse_wgrad_accumulation"
:
fuse_wgrad_accumulation
,
...
...
@@ -621,6 +629,7 @@ class MultiheadAttention(torch.nn.Module):
encoder_output
:
Optional
[
torch
.
Tensor
]
=
None
,
attn_mask_type
:
Optional
[
str
]
=
None
,
window_size
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
bottom_right_diagonal
:
Optional
[
bool
]
=
None
,
is_first_microbatch
:
Optional
[
bool
]
=
None
,
checkpoint_core_attention
:
bool
=
False
,
inference_params
:
Optional
[
InferenceParams
]
=
None
,
...
...
@@ -667,6 +676,11 @@ class MultiheadAttention(torch.nn.Module):
aligned to the bottom right corner.
window_size: Optional[Tuple[int, int]], default = None
sliding window size for local attention.
bottom_right_diagonal: Optional[bool], default = `None`
Align sliding window and ALiBi diagonal to the top left (`False`)
or bottom right (`True`) corner of the softmax matrix in the encoder.
If `None`, it will be set to `False` for `attn_mask_type` =
{`causal`, `padding_causal`} and `True` for other mask types.
encoder_output : Optional[torch.Tensor], default = None
Output of the encoder block to be fed into the decoder block if using
``layer_type="decoder"``.
...
...
@@ -731,6 +745,17 @@ class MultiheadAttention(torch.nn.Module):
if
window_size
is
None
:
window_size
=
self
.
window_size
window_size
=
dpa_utils
.
check_set_window_size
(
attn_mask_type
,
window_size
)
if
bottom_right_diagonal
is
None
:
bottom_right_diagonal
=
self
.
bottom_right_diagonal
if
attn_mask_type
in
{
"causal"
,
"padding_causal"
}:
bottom_right_diagonal
=
False
if
bottom_right_diagonal
is
None
or
attn_mask_type
in
{
"causal_bottom_right"
,
"padding_causal_bottom_right"
,
}:
bottom_right_diagonal
=
True
if
"padding"
in
attn_mask_type
and
attention_mask
is
not
None
:
for
mask
in
attention_mask
:
assert
mask
.
dtype
==
torch
.
bool
,
"Attention mask must be in boolean type!"
...
...
@@ -739,9 +764,6 @@ class MultiheadAttention(torch.nn.Module):
core_attention_bias_type
in
AttnBiasTypes
),
f
"core_attention_bias_type
{
core_attention_bias_type
}
is not supported!"
if
TEDebugState
.
debug_enabled
:
TransformerEngineBaseModule
.
_validate_name
(
self
)
# =================================================
# Pre-allocate memory for key-value cache for inference
# =================================================
...
...
@@ -1004,6 +1026,7 @@ class MultiheadAttention(torch.nn.Module):
attention_mask
=
attention_mask
,
attn_mask_type
=
attn_mask_type
,
window_size
=
window_size
,
bottom_right_diagonal
=
bottom_right_diagonal
,
checkpoint_core_attention
=
checkpoint_core_attention
,
core_attention_bias_type
=
core_attention_bias_type
,
core_attention_bias
=
core_attention_bias
,
...
...
transformer_engine/pytorch/cpp_extensions/fused_attn.py
View file @
9df0c4a3
...
...
@@ -137,6 +137,7 @@ def fused_attn_fwd(
attn_mask_type
:
str
=
"padding"
,
softmax_type
:
str
=
"vanilla"
,
window_size
:
Tuple
[
int
,
int
]
=
(
-
1
,
-
1
),
bottom_right_diagonal
:
bool
=
None
,
rng_gen
:
torch
.
Generator
=
None
,
softmax_offset
:
torch
.
Tensor
=
None
,
return_max_logit
:
bool
=
False
,
...
...
@@ -212,6 +213,9 @@ def fused_attn_fwd(
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically.
bottom_right_diagonal: bool, default = None
whether to align sliding window and ALiBi diagonal to the top left (False) or
bottom right (True) corner of the softmax matrix.
rng_gen : torch.Generator, default = None
random number generator;
if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen
...
...
@@ -255,6 +259,12 @@ def fused_attn_fwd(
max_logit : if return_max_logit = True, shape [h] and same data type as O; otherwise None
"""
if
bottom_right_diagonal
is
None
:
bottom_right_diagonal
=
attn_mask_type
in
{
"causal_bottom_right"
,
"padding_causal_bottom_right"
,
}
if
attn_scale
is
None
:
d
=
q
.
size
(
-
1
)
attn_scale
=
1.0
/
math
.
sqrt
(
d
)
...
...
@@ -306,6 +316,7 @@ def fused_attn_fwd(
AttnMaskType
[
attn_mask_type
],
SoftmaxType
[
softmax_type
],
window_size
,
bottom_right_diagonal
,
cu_seqlens_q
,
cu_seqlens_kv
,
q
,
...
...
@@ -370,6 +381,7 @@ def fused_attn_bwd(
attn_mask_type
:
str
=
"padding"
,
softmax_type
:
str
=
"vanilla"
,
window_size
:
Tuple
[
int
,
int
]
=
(
-
1
,
-
1
),
bottom_right_diagonal
:
bool
=
None
,
deterministic
:
bool
=
False
,
cuda_graph
:
bool
=
False
,
)
->
Tuple
[
Union
[
torch
.
Tensor
,
None
],
...]:
...
...
@@ -442,6 +454,9 @@ def fused_attn_bwd(
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically.
bottom_right_diagonal: bool, default = None
whether to align sliding window and ALiBi diagonal to the top left (False) or
bottom right (True) corner of the softmax matrix.
deterministic : bool, default = False
whether to execute the backward pass with deterministic behaviours.
cuda_graph : bool, default = False
...
...
@@ -462,6 +477,12 @@ def fused_attn_bwd(
gradient tensor of softmax offset of shape [1, h_q, 1, 1].
See softmax_type in DotProductAttention for details.
"""
if
bottom_right_diagonal
is
None
:
bottom_right_diagonal
=
attn_mask_type
in
{
"causal_bottom_right"
,
"padding_causal_bottom_right"
,
}
if
attn_scale
is
None
:
d
=
q
.
size
(
-
1
)
attn_scale
=
1.0
/
math
.
sqrt
(
d
)
...
...
@@ -500,6 +521,7 @@ def fused_attn_bwd(
AttnMaskType
[
attn_mask_type
],
SoftmaxType
[
softmax_type
],
window_size
,
bottom_right_diagonal
,
deterministic
,
cu_seqlens_q
,
cu_seqlens_kv
,
...
...
Prev
1
…
5
6
7
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment