Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
970620a5
Commit
970620a5
authored
Dec 27, 2025
by
wenjh
Browse files
merge nv_release_v2.10 to release_v2.10
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parents
c1a1c04e
769ed778
Changes
135
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
648 additions
and
240 deletions
+648
-240
tests/pytorch/distributed/test_comm_gemm_overlap.py
tests/pytorch/distributed/test_comm_gemm_overlap.py
+6
-0
tests/pytorch/distributed/test_sanity.py
tests/pytorch/distributed/test_sanity.py
+73
-13
tests/pytorch/test_fusible_ops.py
tests/pytorch/test_fusible_ops.py
+5
-5
tests/pytorch/test_numerics.py
tests/pytorch/test_numerics.py
+0
-7
tests/pytorch/test_sanity.py
tests/pytorch/test_sanity.py
+1
-3
tests/pytorch/utils.py
tests/pytorch/utils.py
+10
-0
transformer_engine/common/fused_attn/kv_cache.cu
transformer_engine/common/fused_attn/kv_cache.cu
+1
-1
transformer_engine/common/include/transformer_engine/fused_attn.h
...mer_engine/common/include/transformer_engine/fused_attn.h
+1
-1
transformer_engine/common/recipe/__init__.py
transformer_engine/common/recipe/__init__.py
+20
-18
transformer_engine/common/transformer_engine.cpp
transformer_engine/common/transformer_engine.cpp
+12
-7
transformer_engine/jax/attention.py
transformer_engine/jax/attention.py
+75
-18
transformer_engine/jax/cpp_extensions/activation.py
transformer_engine/jax/cpp_extensions/activation.py
+1
-1
transformer_engine/jax/cpp_extensions/attention.py
transformer_engine/jax/cpp_extensions/attention.py
+236
-56
transformer_engine/jax/cpp_extensions/gemm.py
transformer_engine/jax/cpp_extensions/gemm.py
+1
-1
transformer_engine/jax/cpp_extensions/misc.py
transformer_engine/jax/cpp_extensions/misc.py
+1
-1
transformer_engine/jax/cpp_extensions/normalization.py
transformer_engine/jax/cpp_extensions/normalization.py
+1
-1
transformer_engine/jax/cpp_extensions/quantization.py
transformer_engine/jax/cpp_extensions/quantization.py
+19
-17
transformer_engine/jax/cpp_extensions/softmax.py
transformer_engine/jax/cpp_extensions/softmax.py
+65
-9
transformer_engine/jax/csrc/extensions.h
transformer_engine/jax/csrc/extensions.h
+12
-12
transformer_engine/jax/csrc/extensions/attention.cpp
transformer_engine/jax/csrc/extensions/attention.cpp
+108
-69
No files found.
tests/pytorch/distributed/test_comm_gemm_overlap.py
View file @
970620a5
...
...
@@ -127,12 +127,18 @@ def _run_layer_with_overlap(
os
.
environ
[
"PYTORCH_JIT"
]
=
"0"
os
.
environ
[
"NVTE_TORCH_COMPILE"
]
=
"0"
os
.
environ
[
"NVTE_ALLOW_NONDETERMINISTIC_ALGO"
]
=
"0"
if
te
.
get_device_compute_capability
()
<=
(
8
,
0
):
# We've experienced numerical discrepancies in Flash Attention
# backward when running with Userbuffers on A100s. This does
# not show up in more recent GPUs.
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
result
=
subprocess
.
run
(
test_cmd
,
env
=
os
.
environ
,
capture_output
=
True
,
check
=
False
)
os
.
unsetenv
(
"PYTORCH_JIT"
)
os
.
unsetenv
(
"NVTE_TORCH_COMPILE"
)
os
.
unsetenv
(
"NVTE_ALLOW_NONDETERMINISTIC_ALGO"
)
os
.
unsetenv
(
"NVTE_FLASH_ATTN"
)
if
(
result
.
returncode
!=
0
...
...
tests/pytorch/distributed/test_sanity.py
View file @
970620a5
...
...
@@ -7,7 +7,7 @@ import sys
import
pytest
import
torch
import
transformer_engine
from
transformer_engine.pytorch
import
DotProductAttention
,
TransformerLayer
,
Linear
from
transformer_engine.pytorch
import
DotProductAttention
,
TransformerLayer
,
Linear
,
GroupedLinear
_current_file
=
pathlib
.
Path
(
__file__
).
resolve
()
sys
.
path
.
append
(
str
(
_current_file
.
parent
.
parent
))
...
...
@@ -19,7 +19,9 @@ model_configs = {
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"small"
])
@
pytest
.
mark
.
parametrize
(
"module"
,
[
"TransformerLayer"
,
"DotProductAttention"
,
"Linear"
])
@
pytest
.
mark
.
parametrize
(
"module"
,
[
"TransformerLayer"
,
"DotProductAttention"
,
"Linear"
,
"GroupedLinear"
]
)
def
test_current_device
(
model
,
module
):
"""Test cases where current device is different from tensor device"""
...
...
@@ -42,7 +44,29 @@ def test_current_device(model, module):
self_attn_mask_type
=
"padding"
,
device
=
f
"cuda:
{
tensor_device
}
"
,
)
num_tokens
=
torch
.
randint
(
0
,
config
.
max_seqlen_q
,
(
1
,)).
item
()
seqlens_q
=
torch
.
randint
(
1
,
config
.
max_seqlen_q
,
[
config
.
batch_size
],
dtype
=
torch
.
int32
,
device
=
f
"cuda:
{
tensor_device
}
"
,
)
cu_seqlens_q
=
torch
.
zeros
(
config
.
batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
f
"cuda:
{
tensor_device
}
"
)
cu_seqlens_q
[
1
:]
=
torch
.
cumsum
(
seqlens_q
,
dim
=
0
)
seqlens_kv
=
torch
.
randint
(
1
,
config
.
max_seqlen_kv
,
[
config
.
batch_size
],
dtype
=
torch
.
int32
,
device
=
f
"cuda:
{
tensor_device
}
"
,
)
cu_seqlens_kv
=
torch
.
zeros
(
config
.
batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
f
"cuda:
{
tensor_device
}
"
)
cu_seqlens_kv
[
1
:]
=
torch
.
cumsum
(
seqlens_kv
,
dim
=
0
)
num_tokens
=
cu_seqlens_q
[
-
1
]
args
=
[
torch
.
randn
(
(
num_tokens
,
config
.
hidden_size
),
...
...
@@ -51,37 +75,55 @@ def test_current_device(model, module):
requires_grad
=
True
,
)
]
cu_seqlens_q
,
cu_seqlens_kv
=
[
torch
.
Tensor
([
0
,
2
,
3
]).
to
(
dtype
=
torch
.
int32
,
device
=
tensor_device
)
for
_
in
range
(
2
)
]
kwargs
[
"cu_seqlens_q"
]
=
cu_seqlens_q
kwargs
[
"cu_seqlens_kv"
]
=
cu_seqlens_kv
kwargs
[
"max_seqlen_q"
]
=
config
.
max_seqlen_q
kwargs
[
"max_seqlen_kv"
]
=
config
.
max_seqlen_kv
if
module
==
"DotProductAttention"
:
el
if
module
==
"DotProductAttention"
:
model
=
DotProductAttention
(
config
.
num_heads
,
config
.
head_dim_qk
,
qkv_format
=
"thd"
,
attn_mask_type
=
"padding"
)
num_tokens
=
torch
.
randint
(
0
,
config
.
max_seqlen_q
,
(
1
,)).
item
()
seqlens_q
=
torch
.
randint
(
1
,
config
.
max_seqlen_q
,
[
config
.
batch_size
],
dtype
=
torch
.
int32
,
device
=
f
"cuda:
{
tensor_device
}
"
,
)
cu_seqlens_q
=
torch
.
zeros
(
config
.
batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
f
"cuda:
{
tensor_device
}
"
)
cu_seqlens_q
[
1
:]
=
torch
.
cumsum
(
seqlens_q
,
dim
=
0
)
seqlens_kv
=
torch
.
randint
(
1
,
config
.
max_seqlen_kv
,
[
config
.
batch_size
],
dtype
=
torch
.
int32
,
device
=
f
"cuda:
{
tensor_device
}
"
,
)
cu_seqlens_kv
=
torch
.
zeros
(
config
.
batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
f
"cuda:
{
tensor_device
}
"
)
cu_seqlens_kv
[
1
:]
=
torch
.
cumsum
(
seqlens_kv
,
dim
=
0
)
num_tokens
=
cu_seqlens_q
[
-
1
]
args
=
[
torch
.
randn
(
num_tokens
,
config
.
num_heads
,
config
.
head_dim_qk
,
dtype
=
dtype
,
device
=
tensor_device
,
device
=
f
"cuda:
{
tensor_device
}
"
,
requires_grad
=
True
,
)
for
_
in
range
(
3
)
]
cu_seqlens_q
,
cu_seqlens_kv
=
[
torch
.
Tensor
([
0
,
2
,
3
]).
to
(
dtype
=
torch
.
int32
,
device
=
tensor_device
)
for
_
in
range
(
2
)
]
kwargs
[
"cu_seqlens_q"
]
=
cu_seqlens_q
kwargs
[
"cu_seqlens_kv"
]
=
cu_seqlens_kv
kwargs
[
"max_seqlen_q"
]
=
config
.
max_seqlen_q
kwargs
[
"max_seqlen_kv"
]
=
config
.
max_seqlen_kv
bwd_args
=
[
torch
.
randn
(
num_tokens
,
config
.
hidden_size
,
dtype
=
dtype
,
device
=
tensor_device
)]
bwd_args
=
[
torch
.
randn
(
num_tokens
,
config
.
hidden_size
,
dtype
=
dtype
,
device
=
f
"cuda:
{
tensor_device
}
"
)
]
elif
module
==
"Linear"
:
model
=
Linear
(
config
.
hidden_size
,
...
...
@@ -97,6 +139,24 @@ def test_current_device(model, module):
requires_grad
=
True
,
)
]
elif
module
==
"GroupedLinear"
:
num_gemms
=
4
model
=
GroupedLinear
(
num_gemms
,
config
.
hidden_size
,
4
*
config
.
hidden_size
,
params_dtype
=
dtype
,
device
=
f
"cuda:
{
tensor_device
}
"
,
)
args
=
[
torch
.
randn
(
(
config
.
max_seqlen_q
*
config
.
batch_size
*
(
num_gemms
-
1
),
config
.
hidden_size
),
dtype
=
dtype
,
device
=
f
"cuda:
{
tensor_device
}
"
,
requires_grad
=
True
,
),
[
0
]
+
[
config
.
max_seqlen_q
*
config
.
batch_size
]
*
(
num_gemms
-
1
),
# Empty first split.
]
current_device_before
=
torch
.
cuda
.
current_device
()
out
=
model
(
*
args
,
**
kwargs
)
...
...
tests/pytorch/test_fusible_ops.py
View file @
970620a5
...
...
@@ -913,15 +913,15 @@ class TestBasicOps:
dtype
=
dtype
,
accumulate_into_main_grad
=
accumulate_into_main_grad
,
)
with
torch
.
no_grad
():
op
.
weight
.
copy_
(
w_test
)
del
w_test
op
.
weight
.
main_grad
=
torch
.
full_like
(
op
.
weight
,
0.5
,
dtype
=
torch
.
float32
)
forward
=
te_ops
.
Sequential
(
te_ops
.
Quantize
(
forward
=
quantized_input
,
backward
=
quantized_grad_input
),
op
,
te_ops
.
Quantize
(
forward
=
quantized_output
,
backward
=
quantized_grad_output
),
)
with
torch
.
no_grad
():
op
.
weight
.
copy_
(
w_test
)
del
w_test
op
.
weight
.
main_grad
=
torch
.
full_like
(
op
.
weight
,
0.5
,
dtype
=
torch
.
float32
)
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y_test
=
forward
(
x_test
)
y_test
.
backward
(
dy_test
)
...
...
tests/pytorch/test_numerics.py
View file @
970620a5
...
...
@@ -46,7 +46,6 @@ from transformer_engine.pytorch import (
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch
import
checkpoint
as
te_checkpoint
from
transformer_engine.pytorch.cpp_extensions
import
general_gemm
,
general_grouped_gemm
from
transformer_engine.pytorch.module.base
import
get_multi_stream_cublas_workspace
,
get_workspace
from
transformer_engine.common
import
recipe
import
transformer_engine_torch
as
tex
from
utils
import
ModelConfig
,
reset_rng_states
...
...
@@ -2757,7 +2756,6 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
general_gemm
(
A
[
i
],
B
[
i
],
get_workspace
(),
dtype
,
grad
=
grad
,
accumulate
=
accumulate
,
...
...
@@ -2772,7 +2770,6 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
B
,
out
,
dtype
,
get_multi_stream_cublas_workspace
(),
m_splits
=
m_splits
,
grad
=
grad
,
accumulate
=
accumulate
,
...
...
@@ -2832,7 +2829,6 @@ def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_qua
quantized_out
,
*
_
=
general_gemm
(
weight_fp8
,
inp_fp8
,
get_workspace
(),
outp_type
,
quantization_params
=
out_quantizer
,
bias
=
None
,
...
...
@@ -2842,7 +2838,6 @@ def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_qua
out
,
*
_
=
general_gemm
(
weight_fp8
,
inp_fp8
,
get_workspace
(),
outp_type
,
quantization_params
=
None
,
bias
=
None
,
...
...
@@ -2918,7 +2913,6 @@ def test_fp8_grouped_gemm(shape, accumulate):
general_gemm
(
A_fp8
[
i
],
B_fp8
[
i
],
get_workspace
(),
dtype
,
out
=
out_ref
[
i
],
accumulate
=
accumulate
,
...
...
@@ -2928,7 +2922,6 @@ def test_fp8_grouped_gemm(shape, accumulate):
B_fp8
,
out
,
dtype
,
get_multi_stream_cublas_workspace
(),
m_splits
=
m_splits
,
accumulate
=
accumulate
,
)
...
...
tests/pytorch/test_sanity.py
View file @
970620a5
...
...
@@ -37,7 +37,6 @@ from transformer_engine.pytorch import (
from
transformer_engine.common
import
recipe
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.cpp_extensions
import
general_gemm
from
transformer_engine.pytorch.module.base
import
get_workspace
from
transformer_engine.pytorch.tensor.utils
import
replace_raw_data
from
utils
import
ModelConfig
...
...
@@ -961,7 +960,7 @@ def test_sanity_gemm_with_unalignment(N, offset, datatype):
inp
=
torch
.
reshape
(
scratchpad
[
offset
:
-
offset
],
(
N
,
N
))
weight
=
torch
.
reshape
(
scratchpad
[
offset
*
2
:],
(
N
,
N
))
_
=
general_gemm
(
A
=
weight
,
B
=
inp
,
workspace
=
get_workspace
()
)
_
=
general_gemm
(
A
=
weight
,
B
=
inp
)
torch
.
cuda
.
synchronize
()
...
...
@@ -985,7 +984,6 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype):
general_gemm
(
weight_fp8
,
inp_fp8
,
get_workspace
(),
outp_type
,
bias
=
None
,
use_split_accumulator
=
False
,
...
...
tests/pytorch/utils.py
View file @
970620a5
...
...
@@ -8,6 +8,7 @@ import logging
import
os
from
contextlib
import
contextmanager
from
typing
import
Optional
,
Tuple
,
Dict
,
Any
,
List
from
packaging.version
import
Version
as
PkgVersion
import
torch
...
...
@@ -210,6 +211,7 @@ class ModelConfig:
max_ctx_len
:
int
=
None
,
num_layers
:
int
=
1
,
eps
:
float
=
1e-5
,
num_splits
=
1
,
):
self
.
batch_size
=
batch_size
self
.
max_seqlen_q
=
max_seqlen_q
...
...
@@ -239,6 +241,7 @@ class ModelConfig:
self
.
max_ctx_len
=
max_ctx_len
self
.
num_layers
=
num_layers
self
.
eps
=
eps
self
.
num_splits
=
num_splits
@
contextmanager
...
...
@@ -321,6 +324,9 @@ def get_available_attention_backends(
inference_params
=
inference_params
,
softmax_type
=
config
.
softmax_type
,
return_max_logit
=
config
.
return_max_logit
,
# allow all backends to pass so they can be used for testing;
# check for FA3 availability later
num_splits
=
1
,
)
(
use_flash_attention
,
...
...
@@ -330,6 +336,10 @@ def get_available_attention_backends(
use_unfused_attention
,
available_backends
,
)
=
get_attention_backend
(
attention_params
)
# Check if FA3 is an available backend when num_splits != 1
if
available_backends
[
0
]:
if
config
.
num_splits
!=
1
and
not
flash_attention_backend
>
PkgVersion
(
"3.0.0b"
):
available_backends
[
0
]
=
False
# Set attention.py _attention_backends var using return value
# from get_attention_backend()
_attention_backends
[
"use_flash_attention"
]
=
use_flash_attention
...
...
transformer_engine/common/fused_attn/kv_cache.cu
View file @
970620a5
...
...
@@ -278,7 +278,7 @@ void convert_bshd_to_thd(Tensor tensor, Tensor cu_seqlens, Tensor new_tensor, in
/***************************************************************************************************
* KV Cache: Copy new KV tokens to the KV cache
* 1. new_k and new_v are in qkv_format; k_cache and v_cache are in 'bshd' format
* 2. cu_new_lens and cu_cached_lens are
in
shape [b + 1]; cu_cached_lens include the added lens
* 2. cu_new_lens and cu_cached_lens are
of
shape [b + 1]; cu_cached_lens include the added lens
* in current step
* 3. Non-paged KV cache is a special case of paged KV cache, with page_table = [b, 1] and
* max_pages_per_seq = 1. We use the same underlying kernel for both non-paged and paged.
...
...
transformer_engine/common/include/transformer_engine/fused_attn.h
View file @
970620a5
...
...
@@ -131,7 +131,7 @@ enum NVTE_Mask_Type {
* NVTE_VANILLA_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
* NVTE_OFF_BY_ONE_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
* NVTE_LEARNABLE_SOFTMAX: S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
* where alpha is a learnable parameter
in
shape [H].
* where alpha is a learnable parameter
of
shape [H].
*/
enum
NVTE_Softmax_Type
{
/*! Vanilla softmax */
...
...
transformer_engine/common/recipe/__init__.py
View file @
970620a5
...
...
@@ -50,7 +50,7 @@ class MMParams:
Parameters
----------
use_split_accumulator : bool, default =
`
True
`
use_split_accumulator : bool, default = True
Use FP8 fast accumulation on Hopper or Ada. For more details,
see CUBLASLT_MATMUL_DESC_FAST_ACCUM option for cublasLtMatmul.
"""
...
...
@@ -159,7 +159,7 @@ class DelayedScaling(Recipe):
recipe: DelayedScaling) -> Tensor
where `Tensor` is a framework tensor type.
reduce_amax: bool, default =
`
True
`
reduce_amax: bool, default = True
By default, if `torch.distributed` is initialized, the `amax` value for FP8
tensors is reduced across the `amax_reduction_group` (specified in the `autocast`
call). This keeps the amaxes and scaling factors synced across the given
...
...
@@ -167,13 +167,13 @@ class DelayedScaling(Recipe):
GPU maintains local amaxes and scaling factors. To ensure results are
numerically identical across checkpointing boundaries in this case, all
ranks must checkpoint in order to store the local tensors.
fp8_dpa: bool, default =
`
False
`
fp8_dpa: bool, default = False
Whether to enable FP8 dot product attention (DPA). When the model is placed in an
`autocast(enabled=True)` region and `fp8_dpa` is set to `True`, DPA casts the
inputs from higher precision to FP8, performs attention in FP8, and casts tensors
back to higher precision as outputs. FP8 DPA currently is only supported in the
`FusedAttention` backend.
fp8_mha: bool, default =
`
False
`
fp8_mha: bool, default = False
Whether to enable FP8 multi-head attention (MHA). When `True`, it removes the casting
operations mentioned above at the DPA boundaries. Currently only standard MHA modules
i.e. `LayerNormLinear/Linear + DPA + Linear`, are supported for this feature. When
...
...
@@ -422,11 +422,11 @@ class NVFP4BlockScaling(Recipe):
----------
fp4_format : {Format.E2M1}, default = Format.E2M1
FP4 data type.
disable_rht : bool, default =
`
False
`
disable_rht : bool, default = False
If set to `True`, random Hadamard transforms are not applied to any tensor.
disable_stochastic_rounding : bool, default =
`
False
`
disable_stochastic_rounding : bool, default = False
If set to `True`, stochastic rounding is disabled during quantization for all tensors.
disable_2d_quantization : bool, default =
`
False
`
disable_2d_quantization : bool, default = False
If set to `True`, 1D block scaling with block size 16 is used for all tensors.
"""
...
...
@@ -494,13 +494,15 @@ class CustomRecipe(Recipe):
qfactory : Callable
Factory callable that returns a quantizer instance for a
given semantic tensor role.
The callable is typically invoked as:
The callable is typically invoked as::
qfactory(
role: str,
)
Where `role` is one of the following strings for e.g. te.Linear
(stable public contract):
- forward: "linear_input", "linear_weight", "linear_output"
- backward: "linear_grad_output", "linear_grad_input"
"""
...
...
transformer_engine/common/transformer_engine.cpp
View file @
970620a5
...
...
@@ -736,12 +736,17 @@ int nvte_is_non_tn_fp8_gemm_supported() {
#if USE_ROCM
return
true
;
#else
int
deviceComputeCapability
=
transformer_engine
::
cuda
::
sm_arch
(
transformer_engine
::
cuda
::
current_device
());
int
num_devices
=
transformer_engine
::
cuda
::
num_devices
();
static
std
::
vector
<
int
>
cache
(
num_devices
,
-
1
);
static
std
::
vector
<
std
::
once_flag
>
flags
(
num_devices
);
int
device_id
=
transformer_engine
::
cuda
::
current_device
();
std
::
call_once
(
flags
[
device_id
],
[
&
]()
{
int
deviceComputeCapability
=
transformer_engine
::
cuda
::
sm_arch
(
device_id
);
// Note: this is temporary restriction and should be lifted in the future.
// (remove the note once it's done.)
return
(
deviceComputeCapability
>=
100
&&
deviceComputeCapability
<
120
)
||
cache
[
device_id
]
=
(
deviceComputeCapability
>=
100
&&
deviceComputeCapability
<
120
)
||
deviceComputeCapability
>=
130
;
});
return
cache
[
device_id
];
#endif
}
transformer_engine/jax/attention.py
View file @
970620a5
...
...
@@ -18,6 +18,7 @@ from transformer_engine_jax import NVTE_Mask_Type
from
transformer_engine_jax
import
NVTE_QKV_Layout
from
transformer_engine_jax
import
NVTE_QKV_Format
from
transformer_engine_jax
import
nvte_get_qkv_format
from
transformer_engine_jax
import
NVTE_Softmax_Type
from
.
import
cpp_extensions
as
tex
...
...
@@ -74,6 +75,35 @@ class AttnMaskType(Enum):
]
class
AttnSoftmaxType
(
Enum
):
"""
VANILLA_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
OFF_BY_ONE_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)),
LEARNABLE_SOFTMAX: S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [H].
"""
VANILLA_SOFTMAX
=
NVTE_Softmax_Type
.
NVTE_VANILLA_SOFTMAX
OFF_BY_ONE_SOFTMAX
=
NVTE_Softmax_Type
.
NVTE_OFF_BY_ONE_SOFTMAX
LEARNABLE_SOFTMAX
=
NVTE_Softmax_Type
.
NVTE_LEARNABLE_SOFTMAX
@
classmethod
def
from_str
(
cls
,
softmax_type
:
str
)
->
"AttnSoftmaxType"
:
"""Convert string to AttnSoftmaxType: 'vanilla', 'off_by_one', or 'learnable'."""
softmax_type_map
=
{
"vanilla"
:
cls
.
VANILLA_SOFTMAX
,
"off_by_one"
:
cls
.
OFF_BY_ONE_SOFTMAX
,
"learnable"
:
cls
.
LEARNABLE_SOFTMAX
,
}
result
=
softmax_type_map
.
get
(
softmax_type
)
if
result
is
None
:
raise
ValueError
(
f
"Unknown softmax_type:
{
softmax_type
}
. "
"Valid options: 'vanilla', 'off_by_one', 'learnable'"
)
return
result
class
QKVFormat
(
Enum
):
"""
SBHD: q,k,v memory layout with [s, b, ..., h, d]
...
...
@@ -301,6 +331,7 @@ def is_fused_attn_kernel_available(
qkv_layout
,
attn_bias_type
,
attn_mask_type
,
softmax_type
,
dropout_probability
,
q_num_heads
,
kv_num_heads
,
...
...
@@ -313,6 +344,7 @@ def is_fused_attn_kernel_available(
"""
To check whether the fused attention kernel is supported
"""
window_size_tuple
=
(
-
1
,
-
1
)
if
window_size
is
None
else
window_size
def
make_helper
(
attn_mask_type
):
return
tex
.
FusedAttnHelper
(
...
...
@@ -322,6 +354,7 @@ def is_fused_attn_kernel_available(
qkv_layout
,
attn_bias_type
,
attn_mask_type
,
softmax_type
,
dropout_probability
,
q_num_heads
,
kv_num_heads
,
...
...
@@ -329,7 +362,7 @@ def is_fused_attn_kernel_available(
kv_max_seqlen
,
head_dim_qk
,
head_dim_v
,
(
-
1
,
-
1
)
if
window_size
is
None
else
window_size
,
window_siz
e_tupl
e
,
)
return
make_helper
(
attn_mask_type
).
is_fused_attn_kernel_available
()
...
...
@@ -497,6 +530,11 @@ def _segment_ids_pos_to_seqlens_offsets(
#
# This fast path avoids expanding the mask to Q * KV matrix and instead allows us to
# examine only O(Q+KV) elements.
# For seqlens and seqoffsets calculations, the intermediate(temp) attn_mask creation
# using the segment ids and pos along with mask type (causal or brcm) is sufficient.
# It does not need to involve SW for this mask's creation
# TODO(KshitijLakhani): Try exercising the fast path for BRCM as well
if
(
attn_mask_type
.
is_causal
()
and
window_size
is
None
)
or
(
window_size
==
(
-
1
,
-
1
)
and
not
attn_mask_type
.
is_bottom_right
()
...
...
@@ -558,21 +596,6 @@ def _segment_ids_pos_to_seqlens_offsets(
)
attn_mask
=
jnp
.
logical_and
(
segment_mask
,
causal_mask
)
# TODO(KshitijLakhani): Evaluate if swa_mask is needed to procure seqlen and offsets
swa_mask
=
(
make_swa_mask
(
segment_pos_q
,
segment_pos_kv
,
window_size
,
dtype
=
jnp
.
bool
,
segment_ids_q
=
segment_ids_q
,
segment_ids_kv
=
segment_ids_kv
,
)
if
attn_mask_type
.
is_bottom_right
()
else
make_swa_mask
(
segment_pos_q
,
segment_pos_kv
,
window_size
,
dtype
=
jnp
.
bool
)
)
attn_mask
=
jnp
.
logical_and
(
attn_mask
,
swa_mask
)
attn_mask_with_id
=
jnp
.
where
(
attn_mask
,
segment_mask_with_id
,
0
)
q_seqlen
,
q_offset
,
kv_seqlen
,
kv_offset
=
_mask_to_seqlens_offset
(
attn_mask_with_id
,
max_segments_per_seq
...
...
@@ -786,6 +809,7 @@ def _legacy_fused_attn(
attn_bias_type
:
AttnBiasType
,
attn_mask_type
:
AttnMaskType
,
qkv_layout
:
QKVLayout
,
softmax_type
:
AttnSoftmaxType
,
scaling_factor
:
float
,
dropout_probability
:
float
,
is_training
:
bool
,
...
...
@@ -793,6 +817,7 @@ def _legacy_fused_attn(
context_parallel_strategy
:
CPStrategy
=
CPStrategy
.
DEFAULT
,
context_parallel_causal_load_balanced
:
bool
=
False
,
context_parallel_axis
:
str
=
""
,
softmax_offset
:
Optional
[
jnp
.
ndarray
]
=
None
,
):
"""
Perform non-THD (non-packed) cuDNN fused attention.
...
...
@@ -815,6 +840,7 @@ def _legacy_fused_attn(
seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (AttnMaskType): Type of attention mask.
softmax_type (AttnSoftmaxType): Type of attention softmax.
qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
...
...
@@ -863,10 +889,12 @@ def _legacy_fused_attn(
output
=
_fused_attn
(
qkv
,
bias
,
softmax_offset
,
SequenceDescriptor
.
from_seqlens
((
q_seq_lens
,
kv_seq_lens
)),
seed
,
attn_bias_type
=
attn_bias_type
,
attn_mask_type
=
attn_mask_type
,
softmax_type
=
softmax_type
,
qkv_layout
=
qkv_layout
,
scaling_factor
=
scaling_factor
,
dropout_probability
=
dropout_probability
,
...
...
@@ -900,6 +928,7 @@ def fused_attn_thd(
context_parallel_strategy
:
CPStrategy
=
CPStrategy
.
DEFAULT
,
context_parallel_causal_load_balanced
:
bool
=
False
,
context_parallel_axis
:
str
=
""
,
softmax_offset
:
Optional
[
jnp
.
ndarray
]
=
None
,
):
"""
Deprecated THD fused attn, please use fusd_attn with SequenceDescriptor
...
...
@@ -937,6 +966,7 @@ def fused_attn_thd(
output
=
_fused_attn
(
qkv
,
bias
,
softmax_offset
,
SequenceDescriptor
.
from_seqlens_and_offsets
(
(
q_seq_lens
,
kv_seq_lens
),
(
q_seq_offsets
,
kv_seq_offsets
)
),
...
...
@@ -945,6 +975,7 @@ def fused_attn_thd(
attn_mask_type
=
attn_mask_type
,
qkv_layout
=
qkv_layout
,
scaling_factor
=
scaling_factor
,
softmax_type
=
AttnSoftmaxType
.
VANILLA_SOFTMAX
,
dropout_probability
=
dropout_probability
,
is_training
=
is_training
,
max_segments_per_seq
=
max_segments_per_seq
,
...
...
@@ -957,15 +988,17 @@ def fused_attn_thd(
return
output
@
partial
(
jax
.
custom_vjp
,
nondiff_argnums
=
(
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
))
@
partial
(
jax
.
custom_vjp
,
nondiff_argnums
=
(
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
))
def
_fused_attn
(
qkv
:
Tuple
[
jnp
.
ndarray
,
...],
bias
:
Optional
[
jnp
.
ndarray
],
softmax_offset
:
Optional
[
jnp
.
ndarray
],
sequence_descriptor
:
SequenceDescriptor
,
seed
:
Optional
[
jnp
.
ndarray
],
attn_bias_type
:
AttnBiasType
,
attn_mask_type
:
AttnMaskType
,
qkv_layout
:
QKVLayout
,
softmax_type
:
AttnSoftmaxType
,
scaling_factor
:
float
,
dropout_probability
:
float
,
is_training
:
bool
,
...
...
@@ -979,11 +1012,13 @@ def _fused_attn(
output
,
_
=
_fused_attn_fwd_rule
(
qkv
,
bias
,
softmax_offset
,
sequence_descriptor
,
seed
,
attn_bias_type
,
attn_mask_type
,
qkv_layout
,
softmax_type
,
scaling_factor
,
dropout_probability
,
is_training
,
...
...
@@ -1000,11 +1035,13 @@ def _fused_attn(
def
_fused_attn_fwd_rule
(
qkv
,
bias
,
softmax_offset
,
sequence_descriptor
,
seed
,
attn_bias_type
,
attn_mask_type
,
qkv_layout
,
softmax_type
,
scaling_factor
,
dropout_probability
,
is_training
,
...
...
@@ -1018,10 +1055,12 @@ def _fused_attn_fwd_rule(
output
,
softmax_aux
,
rng_state
=
tex
.
fused_attn_fwd
(
qkv
,
bias
,
softmax_offset
,
sequence_descriptor
,
seed
,
attn_bias_type
=
attn_bias_type
,
attn_mask_type
=
attn_mask_type
,
softmax_type
=
softmax_type
,
qkv_layout
=
qkv_layout
,
scaling_factor
=
scaling_factor
,
dropout_probability
=
dropout_probability
,
...
...
@@ -1041,6 +1080,7 @@ def _fused_attn_fwd_rule(
sequence_descriptor
,
softmax_aux
,
rng_state
,
softmax_offset
,
output
,
)
...
...
@@ -1049,6 +1089,7 @@ def _fused_attn_bwd_rule(
attn_bias_type
,
attn_mask_type
,
qkv_layout
,
softmax_type
,
scaling_factor
,
dropout_probability
,
is_training
,
...
...
@@ -1068,11 +1109,13 @@ def _fused_attn_bwd_rule(
sequence_descriptor
,
softmax_aux
,
rng_state
,
softmax_offset
,
output
,
)
=
ctx
grad_qkv
,
grad_bias
=
tex
.
fused_attn_bwd
(
grad_qkv
,
grad_bias
,
grad_softmax_offset
=
tex
.
fused_attn_bwd
(
qkv
,
bias
,
softmax_offset
,
softmax_aux
,
rng_state
,
output
,
...
...
@@ -1080,6 +1123,7 @@ def _fused_attn_bwd_rule(
sequence_descriptor
,
attn_bias_type
=
attn_bias_type
,
attn_mask_type
=
attn_mask_type
,
softmax_type
=
softmax_type
,
qkv_layout
=
qkv_layout
,
scaling_factor
=
scaling_factor
,
dropout_probability
=
dropout_probability
,
...
...
@@ -1092,9 +1136,12 @@ def _fused_attn_bwd_rule(
)
if
attn_bias_type
==
AttnBiasType
.
NO_BIAS
:
grad_bias
=
None
if
softmax_type
!=
AttnSoftmaxType
.
LEARNABLE_SOFTMAX
:
grad_softmax_offset
=
None
return
(
grad_qkv
,
grad_bias
,
grad_softmax_offset
,
None
,
None
,
)
...
...
@@ -1111,6 +1158,7 @@ def fused_attn(
attn_bias_type
:
AttnBiasType
,
attn_mask_type
:
AttnMaskType
,
qkv_layout
:
QKVLayout
,
softmax_type
:
AttnSoftmaxType
,
scaling_factor
:
float
,
dropout_probability
:
float
,
is_training
:
bool
,
...
...
@@ -1120,6 +1168,7 @@ def fused_attn(
context_parallel_causal_load_balanced
:
bool
=
False
,
context_parallel_axis
:
str
=
""
,
context_checkpoint_name
:
str
=
"context"
,
softmax_offset
:
Optional
[
jnp
.
ndarray
]
=
None
,
):
"""
Perform cuDNN fused attention.
...
...
@@ -1139,6 +1188,7 @@ def fused_attn(
seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (AttnMaskType): Type of attention mask.
softmax_type (AttnSoftmaxType): Type of attention softmax.
qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
...
...
@@ -1153,6 +1203,9 @@ def fused_attn(
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.
context_checkpoint_name (str): The name of the context checkpoint for the custom VJP forward pass.
softmax_offset (Optional[jnp.ndarray]): An optional learnable softmax offset tensor with shape
[1, num_heads, 1, 1]. Used when softmax_type is AttnSoftmaxType.LEARNABLE_SOFTMAX.
If provided, this parameter will receive gradients during backpropagation.
Returns:
(jnp.ndarray): The output tensor from the fused attention.
...
...
@@ -1200,6 +1253,7 @@ def fused_attn(
seed
,
attn_bias_type
=
attn_bias_type
,
attn_mask_type
=
attn_mask_type
,
softmax_type
=
softmax_type
,
qkv_layout
=
qkv_layout
,
scaling_factor
=
scaling_factor
,
dropout_probability
=
dropout_probability
,
...
...
@@ -1208,15 +1262,18 @@ def fused_attn(
context_parallel_strategy
=
context_parallel_strategy
,
context_parallel_causal_load_balanced
=
context_parallel_causal_load_balanced
,
context_parallel_axis
=
context_parallel_axis
,
softmax_offset
=
softmax_offset
,
)
output
=
_fused_attn
(
qkv
,
bias
,
softmax_offset
,
sequence_descriptor
,
seed
,
attn_bias_type
=
attn_bias_type
,
attn_mask_type
=
attn_mask_type
,
qkv_layout
=
qkv_layout
,
softmax_type
=
softmax_type
,
scaling_factor
=
scaling_factor
,
dropout_probability
=
dropout_probability
,
is_training
=
is_training
,
...
...
transformer_engine/jax/cpp_extensions/activation.py
View file @
970620a5
...
...
@@ -32,9 +32,9 @@ from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_a
from
..quantize
import
ScaledTensor
,
ScaledTensorFactory
,
NoScaleTensor
from
..quantize
import
(
Quantizer
,
QuantizeLayout
,
DelayedScaleQuantizer
,
ScalingMode
,
QuantizeLayout
,
)
...
...
transformer_engine/jax/cpp_extensions/attention.py
View file @
970620a5
This diff is collapsed.
Click to expand it.
transformer_engine/jax/cpp_extensions/gemm.py
View file @
970620a5
...
...
@@ -39,12 +39,12 @@ from ..quantize import (
Quantizer
,
GroupedQuantizer
,
QuantizerSet
,
QuantizeLayout
,
noop_quantizer_set
,
is_fp8_gemm_with_all_layouts_supported
,
apply_padding_to_scale_inv
,
get_quantize_config_with_recipe
,
get_global_quantize_recipe
,
QuantizeLayout
,
)
from
.misc
import
get_padded_spec
,
is_all_reduce_in_float32
from
..sharding
import
(
...
...
transformer_engine/jax/cpp_extensions/misc.py
View file @
970620a5
...
...
@@ -116,7 +116,7 @@ def multidim_transpose(shape, static_axis_boundary=-1, transpose_axis=-1):
transpose. Note, transpose_axis should be greater than static_axis_boundary
examples:
X
in
shape (dim0, dim1, dim2, dim3, dim4)
X
of
shape (dim0, dim1, dim2, dim3, dim4)
static_axis_boundary == -1, transpose_axis == 2
Xt = (dim2, dim3, dim4, dim0, dim1)
...
...
transformer_engine/jax/cpp_extensions/normalization.py
View file @
970620a5
...
...
@@ -35,9 +35,9 @@ from ..sharding import (
from
..quantize
import
ScaledTensor
,
ScaledTensorFactory
,
NoScaleTensor
from
..quantize
import
(
Quantizer
,
QuantizeLayout
,
DelayedScaleQuantizer
,
ScalingMode
,
QuantizeLayout
,
)
...
...
transformer_engine/jax/cpp_extensions/quantization.py
View file @
970620a5
...
...
@@ -40,11 +40,11 @@ from ..quantize import (
GroupedScaledTensor1x
,
Quantizer
,
GroupedQuantizer
,
QuantizeLayout
,
ScalingMode
,
compute_scale_from_amax
,
NoScaleTensor
,
get_rht_matrix
,
QuantizeLayout
,
)
...
...
@@ -497,6 +497,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
x_spec
=
get_padded_spec
(
arg_infos
[
0
])
amax_spec
=
get_padded_spec
(
arg_infos
[
2
])
sr_rng_state_spec
=
get_padded_spec
(
arg_infos
[
3
])
out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
),
...
...
@@ -551,9 +552,12 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
)
arg_shardings
=
list
(
arg_i
.
sharding
for
arg_i
in
arg_infos
)
if
len
(
sr_rng_state_spec
)
>
1
:
# sr_rng_state shape [n_devices, state_per_device]
sr_rng_state_spec
=
(
*
tuple
(
x
for
x
in
x_spec
if
x
is
not
None
),
None
)
arg_shardings
[
3
]
=
NamedSharding
(
mesh
,
PartitionSpec
(
tuple
(
x
for
x
in
x_spec
if
x
is
not
None
),
None
),
PartitionSpec
(
*
sr_rng_state_spec
),
desc
=
"BaseDBiasQuantizePrimitive.sr_rng_state"
,
)
arg_shardings
=
tuple
(
arg_shardings
)
...
...
@@ -654,9 +658,11 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
dbias
=
input_spec
[
flatten_axis
:]
if
is_dbias
else
(
prefix
+
"_dbias"
,)
amax
=
(
BATCHING
+
prefix
+
"_amax"
,)
scale
=
(
BATCHING
+
prefix
+
"_scale"
,)
sr_rng_state
=
(
BATCHING
+
prefix
+
"_sr_rng_state"
,)
if
value_types
[
3
].
shape
!=
[
0
]:
sr_rng_state
=
(
BATCHING
+
prefix
+
"_sr_rng_state_
partition_axi
s"
,
BATCHING
+
prefix
+
"sr_rng_state_data
_axis
"
,
BATCHING
+
prefix
+
"_sr_rng_state_
device
s"
,
prefix
+
"sr_rng_state_data"
,
)
post_rht_amax
=
(
BATCHING
+
prefix
+
"_post_rht_amax"
,)
...
...
@@ -849,7 +855,7 @@ def _quantize_dbias_impl(
if
force_1x_quantization
:
q_layout
=
QuantizeLayout
.
ROWWISE
sr_rng_state
=
None
sr_rng_state
=
jnp
.
empty
((
0
,),
jnp
.
uint32
)
if
quantizer
.
scaling_mode
.
is_nvfp4_scaling
:
# Only NVFP4 scaling modes support stochastic rounding
if
quantizer
.
stochastic_rounding_rng_state
is
not
None
:
...
...
@@ -866,11 +872,7 @@ def _quantize_dbias_impl(
x
.
data
,
scale
,
amax
,
(
sr_rng_state
if
sr_rng_state
is
not
None
else
jnp
.
empty
((
get_num_devices_in_mesh
(),
1
),
jnp
.
uint32
)
),
sr_rng_state
,
post_rht_amax
if
post_rht_amax
is
not
None
else
jnp
.
zeros
((
1
,),
jnp
.
float32
),
rht_matrix
,
out_dtype
=
quantizer
.
q_dtype
,
...
...
@@ -880,7 +882,7 @@ def _quantize_dbias_impl(
scale_dtype
=
quantizer
.
get_scale_dtype
(),
is_dbias
=
is_dbias
if
not
quantizer
.
scaling_mode
.
is_nvfp4_scaling
else
False
,
is_outer
=
True
,
stochastic_rounding
=
sr_rng_state
is
not
None
,
stochastic_rounding
=
sr_rng_state
.
size
!=
0
,
use_rht
=
use_rht
,
)
# For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
...
...
transformer_engine/jax/cpp_extensions/softmax.py
View file @
970620a5
...
...
@@ -11,10 +11,11 @@ import jax
import
jax.numpy
as
jnp
from
jax
import
dtypes
,
ffi
from
jax.sharding
import
PartitionSpec
,
NamedSharding
from
.attention
import
AttnSoftmaxType
from
.base
import
BasePrimitive
,
register_primitive
from
.misc
import
get_padded_spec
,
check_valid_batch_dims
from
..softmax
import
SoftmaxType
from
..softmax
import
Softmax
Fusion
Type
__all__
=
[
...
...
@@ -32,7 +33,8 @@ __all__ = [
def
is_softmax_kernel_available
(
softmax_type
:
SoftmaxType
,
softmax_fusion_type
:
SoftmaxFusionType
,
softmax_type
:
AttnSoftmaxType
,
batch
:
int
,
heads
:
int
,
q_seqlen
:
int
,
...
...
@@ -40,15 +42,18 @@ def is_softmax_kernel_available(
dtype
:
jnp
.
dtype
,
):
"""check softmax available"""
if
softmax_type
is
SoftmaxType
.
SCALED
:
if
softmax_type
!=
AttnSoftmaxType
.
VANILLA_SOFTMAX
:
return
False
if
softmax_fusion_type
is
SoftmaxFusionType
.
SCALED
:
return
ScaledSoftmaxFwdPrimitive
.
is_kernel_available
(
batch
,
heads
,
q_seqlen
,
k_seqlen
,
dtype
)
if
softmax_type
is
SoftmaxType
.
SCALED_MASKED
:
if
softmax_
fusion_
type
is
Softmax
Fusion
Type
.
SCALED_MASKED
:
return
ScaledMaskedSoftmaxFwdPrimitive
.
is_kernel_available
(
batch
,
heads
,
q_seqlen
,
k_seqlen
,
dtype
)
if
softmax_type
is
SoftmaxType
.
SCALED_UPPER_TRIANG_MASKED
:
if
softmax_
fusion_
type
is
Softmax
Fusion
Type
.
SCALED_UPPER_TRIANG_MASKED
:
return
ScaledUpperTriangMaskedSoftmaxFwdPrimitive
.
is_kernel_available
(
batch
,
heads
,
q_seqlen
,
k_seqlen
,
dtype
)
...
...
@@ -792,26 +797,77 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
register_primitive
(
ScaledUpperTriangMaskedSoftmaxBwdPrimitive
)
def
jax_scaled_softmax
(
logits
:
jnp
.
ndarray
,
scale_factor
:
float
):
def
jax_scaled_softmax
(
logits
:
jnp
.
ndarray
,
scale_factor
:
float
,
softmax_offset
:
jnp
.
ndarray
|
float
|
None
=
None
):
"""
JAX based implementation of scaled softmax
"""
if
softmax_offset
is
not
None
:
return
jax_general_softmax
(
scale_factor
*
logits
,
offset
=
softmax_offset
)
return
jax
.
nn
.
softmax
(
scale_factor
*
logits
)
def
jax_scaled_masked_softmax
(
logits
:
jnp
.
ndarray
,
mask
:
jnp
.
ndarray
,
scale_factor
:
float
):
def
jax_scaled_masked_softmax
(
logits
:
jnp
.
ndarray
,
mask
:
jnp
.
ndarray
,
scale_factor
:
float
,
softmax_offset
:
jnp
.
ndarray
|
float
|
None
=
None
,
):
"""
JAX based implementation of scaled and masked softmax
"""
if
softmax_offset
is
not
None
:
return
jax_general_softmax
(
logits
*
scale_factor
,
offset
=
softmax_offset
,
where
=
mask
!=
1
)
return
jax
.
nn
.
softmax
(
logits
*
scale_factor
,
where
=
mask
!=
1
)
def
jax_scaled_upper_triang_masked_softmax
(
logits
:
jnp
.
ndarray
,
scale_factor
:
float
):
def
jax_scaled_upper_triang_masked_softmax
(
logits
:
jnp
.
ndarray
,
scale_factor
:
float
,
softmax_offset
:
jnp
.
ndarray
|
float
|
None
=
None
):
"""
JAX based implementation of scaled and upper triangle masked softmax
"""
mask
=
1
-
jnp
.
tril
(
jnp
.
ones_like
(
logits
))
return
jax_scaled_masked_softmax
(
logits
,
mask
,
scale_factor
)
return
jax_scaled_masked_softmax
(
logits
,
mask
,
scale_factor
,
softmax_offset
)
def
jax_general_softmax
(
x
:
jnp
.
ndarray
,
axis
:
int
=
-
1
,
where
:
jnp
.
ndarray
|
None
=
None
,
initial
:
jnp
.
ndarray
=
-
jnp
.
inf
,
offset
:
jnp
.
ndarray
|
float
|
None
=
None
,
)
->
jnp
.
ndarray
:
"""
JAX based implementation of general softmax with optional masking and offset.
"""
# Compute max of x
x_max
=
jnp
.
max
(
x
,
axis
,
where
=
where
,
initial
=
initial
,
keepdims
=
True
)
if
offset
is
not
None
:
# Cast offset to x.dtype to prevent type promotion
if
isinstance
(
offset
,
(
int
,
float
)):
offset
=
jnp
.
array
(
offset
,
dtype
=
x
.
dtype
)
else
:
offset
=
offset
.
astype
(
x
.
dtype
)
# Include offset in max: x_max = max(x_max, offset)
# This is equivalent to computing max over [x..., offset]
x_max
=
jnp
.
maximum
(
x_max
,
offset
)
x_safe
=
x
if
where
is
None
else
jnp
.
where
(
where
,
x
,
initial
)
unnormalized
=
jnp
.
exp
(
x_safe
-
x_max
)
denominator
=
jnp
.
sum
(
unnormalized
,
axis
,
where
=
where
,
keepdims
=
True
)
if
offset
is
not
None
:
# Add exp(offset - x_max) to denominator
denominator
=
denominator
+
jnp
.
exp
(
offset
-
x_max
)
result
=
unnormalized
/
denominator
if
where
is
not
None
:
result
=
jnp
.
where
(
where
,
result
,
0
)
return
result
def
scaled_softmax_fwd
(
logits
:
jnp
.
ndarray
,
scale_factor
:
float
)
->
jnp
.
ndarray
:
...
...
transformer_engine/jax/csrc/extensions.h
View file @
970620a5
...
...
@@ -108,28 +108,28 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
FusedAttnBackwardHandler
);
NVTE_Fused_Attn_Backend
GetFusedAttnBackend
(
bool
is_training
,
DType
q_dtype
,
DType
kv_dtype
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
float
dropout_probability
,
size_t
q_num_heads
,
size_t
kv_num_heads
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
qk_head_dim
,
size_t
v_head_dim
,
int64_t
window_size_left
,
int64_t
window_size_right
);
NVTE_Fused_Attn_Backend
GetFusedAttnBackend
(
bool
is_training
,
DType
q_dtype
,
DType
kv_dtype
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
float
dropout_probability
,
size_t
q_attn_heads
,
size_t
kv_attn_heads
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
qk_head_dim
,
size_t
v_head_dim
,
int64_t
window_size_left
,
int64_t
window_size_right
);
pybind11
::
tuple
GetFusedAttnForwardWorkspaceSizes
(
size_t
input_batch
,
size_t
bias_batch
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
attn_heads
,
size_t
num_gqa_groups
,
size_t
bias_heads
,
size_t
qk_head_dim
,
size_t
v_head_dim
,
float
scaling_factor
,
float
dropout_probability
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_QKV_Layout
qkv_layout
,
DType
dtype
,
bool
is_training
,
size_t
max_segments_per_seq
,
int64_t
window_size_left
,
int64_t
window_size_right
);
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
NVTE_QKV_Layout
qkv_layout
,
DType
dtype
,
bool
is_training
,
size_t
max_segments_per_seq
,
int64_t
window_size_left
,
int64_t
window_size_right
);
pybind11
::
tuple
GetFusedAttnBackwardWorkspaceSizes
(
size_t
input_batch
,
size_t
bias_batch
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
attn_heads
,
size_t
num_gqa_groups
,
size_t
bias_heads
,
size_t
qk_head_dim
,
size_t
v_head_dim
,
float
scaling_factor
,
float
dropout_probability
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_
QKV_Layout
qkv_layout
,
DType
dtype
,
bool
is_training
,
bool
deterministic
,
size_t
max_segments_per_seq
,
int64_t
window_size_left
,
int64_t
window_size_right
);
NVTE_Mask_Type
mask_type
,
NVTE_
Softmax_Type
softmax_type
,
NVTE_QKV_Layout
qkv_layout
,
DType
dtype
,
bool
is_training
,
bool
deterministic
,
size_t
max_segments_per_seq
,
int64_t
window_size_left
,
int64_t
window_size_right
);
// GEMM
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
GemmHandler
);
...
...
transformer_engine/jax/csrc/extensions/attention.cpp
View file @
970620a5
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment