Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
970620a5
Commit
970620a5
authored
Dec 27, 2025
by
wenjh
Browse files
merge nv_release_v2.10 to release_v2.10
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parents
c1a1c04e
769ed778
Changes
135
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
648 additions
and
240 deletions
+648
-240
tests/pytorch/distributed/test_comm_gemm_overlap.py
tests/pytorch/distributed/test_comm_gemm_overlap.py
+6
-0
tests/pytorch/distributed/test_sanity.py
tests/pytorch/distributed/test_sanity.py
+73
-13
tests/pytorch/test_fusible_ops.py
tests/pytorch/test_fusible_ops.py
+5
-5
tests/pytorch/test_numerics.py
tests/pytorch/test_numerics.py
+0
-7
tests/pytorch/test_sanity.py
tests/pytorch/test_sanity.py
+1
-3
tests/pytorch/utils.py
tests/pytorch/utils.py
+10
-0
transformer_engine/common/fused_attn/kv_cache.cu
transformer_engine/common/fused_attn/kv_cache.cu
+1
-1
transformer_engine/common/include/transformer_engine/fused_attn.h
...mer_engine/common/include/transformer_engine/fused_attn.h
+1
-1
transformer_engine/common/recipe/__init__.py
transformer_engine/common/recipe/__init__.py
+20
-18
transformer_engine/common/transformer_engine.cpp
transformer_engine/common/transformer_engine.cpp
+12
-7
transformer_engine/jax/attention.py
transformer_engine/jax/attention.py
+75
-18
transformer_engine/jax/cpp_extensions/activation.py
transformer_engine/jax/cpp_extensions/activation.py
+1
-1
transformer_engine/jax/cpp_extensions/attention.py
transformer_engine/jax/cpp_extensions/attention.py
+236
-56
transformer_engine/jax/cpp_extensions/gemm.py
transformer_engine/jax/cpp_extensions/gemm.py
+1
-1
transformer_engine/jax/cpp_extensions/misc.py
transformer_engine/jax/cpp_extensions/misc.py
+1
-1
transformer_engine/jax/cpp_extensions/normalization.py
transformer_engine/jax/cpp_extensions/normalization.py
+1
-1
transformer_engine/jax/cpp_extensions/quantization.py
transformer_engine/jax/cpp_extensions/quantization.py
+19
-17
transformer_engine/jax/cpp_extensions/softmax.py
transformer_engine/jax/cpp_extensions/softmax.py
+65
-9
transformer_engine/jax/csrc/extensions.h
transformer_engine/jax/csrc/extensions.h
+12
-12
transformer_engine/jax/csrc/extensions/attention.cpp
transformer_engine/jax/csrc/extensions/attention.cpp
+108
-69
No files found.
tests/pytorch/distributed/test_comm_gemm_overlap.py
View file @
970620a5
...
...
@@ -127,12 +127,18 @@ def _run_layer_with_overlap(
os
.
environ
[
"PYTORCH_JIT"
]
=
"0"
os
.
environ
[
"NVTE_TORCH_COMPILE"
]
=
"0"
os
.
environ
[
"NVTE_ALLOW_NONDETERMINISTIC_ALGO"
]
=
"0"
if
te
.
get_device_compute_capability
()
<=
(
8
,
0
):
# We've experienced numerical discrepancies in Flash Attention
# backward when running with Userbuffers on A100s. This does
# not show up in more recent GPUs.
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
result
=
subprocess
.
run
(
test_cmd
,
env
=
os
.
environ
,
capture_output
=
True
,
check
=
False
)
os
.
unsetenv
(
"PYTORCH_JIT"
)
os
.
unsetenv
(
"NVTE_TORCH_COMPILE"
)
os
.
unsetenv
(
"NVTE_ALLOW_NONDETERMINISTIC_ALGO"
)
os
.
unsetenv
(
"NVTE_FLASH_ATTN"
)
if
(
result
.
returncode
!=
0
...
...
tests/pytorch/distributed/test_sanity.py
View file @
970620a5
...
...
@@ -7,7 +7,7 @@ import sys
import
pytest
import
torch
import
transformer_engine
from
transformer_engine.pytorch
import
DotProductAttention
,
TransformerLayer
,
Linear
from
transformer_engine.pytorch
import
DotProductAttention
,
TransformerLayer
,
Linear
,
GroupedLinear
_current_file
=
pathlib
.
Path
(
__file__
).
resolve
()
sys
.
path
.
append
(
str
(
_current_file
.
parent
.
parent
))
...
...
@@ -19,7 +19,9 @@ model_configs = {
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"small"
])
@
pytest
.
mark
.
parametrize
(
"module"
,
[
"TransformerLayer"
,
"DotProductAttention"
,
"Linear"
])
@
pytest
.
mark
.
parametrize
(
"module"
,
[
"TransformerLayer"
,
"DotProductAttention"
,
"Linear"
,
"GroupedLinear"
]
)
def
test_current_device
(
model
,
module
):
"""Test cases where current device is different from tensor device"""
...
...
@@ -42,7 +44,29 @@ def test_current_device(model, module):
self_attn_mask_type
=
"padding"
,
device
=
f
"cuda:
{
tensor_device
}
"
,
)
num_tokens
=
torch
.
randint
(
0
,
config
.
max_seqlen_q
,
(
1
,)).
item
()
seqlens_q
=
torch
.
randint
(
1
,
config
.
max_seqlen_q
,
[
config
.
batch_size
],
dtype
=
torch
.
int32
,
device
=
f
"cuda:
{
tensor_device
}
"
,
)
cu_seqlens_q
=
torch
.
zeros
(
config
.
batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
f
"cuda:
{
tensor_device
}
"
)
cu_seqlens_q
[
1
:]
=
torch
.
cumsum
(
seqlens_q
,
dim
=
0
)
seqlens_kv
=
torch
.
randint
(
1
,
config
.
max_seqlen_kv
,
[
config
.
batch_size
],
dtype
=
torch
.
int32
,
device
=
f
"cuda:
{
tensor_device
}
"
,
)
cu_seqlens_kv
=
torch
.
zeros
(
config
.
batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
f
"cuda:
{
tensor_device
}
"
)
cu_seqlens_kv
[
1
:]
=
torch
.
cumsum
(
seqlens_kv
,
dim
=
0
)
num_tokens
=
cu_seqlens_q
[
-
1
]
args
=
[
torch
.
randn
(
(
num_tokens
,
config
.
hidden_size
),
...
...
@@ -51,37 +75,55 @@ def test_current_device(model, module):
requires_grad
=
True
,
)
]
cu_seqlens_q
,
cu_seqlens_kv
=
[
torch
.
Tensor
([
0
,
2
,
3
]).
to
(
dtype
=
torch
.
int32
,
device
=
tensor_device
)
for
_
in
range
(
2
)
]
kwargs
[
"cu_seqlens_q"
]
=
cu_seqlens_q
kwargs
[
"cu_seqlens_kv"
]
=
cu_seqlens_kv
kwargs
[
"max_seqlen_q"
]
=
config
.
max_seqlen_q
kwargs
[
"max_seqlen_kv"
]
=
config
.
max_seqlen_kv
if
module
==
"DotProductAttention"
:
el
if
module
==
"DotProductAttention"
:
model
=
DotProductAttention
(
config
.
num_heads
,
config
.
head_dim_qk
,
qkv_format
=
"thd"
,
attn_mask_type
=
"padding"
)
num_tokens
=
torch
.
randint
(
0
,
config
.
max_seqlen_q
,
(
1
,)).
item
()
seqlens_q
=
torch
.
randint
(
1
,
config
.
max_seqlen_q
,
[
config
.
batch_size
],
dtype
=
torch
.
int32
,
device
=
f
"cuda:
{
tensor_device
}
"
,
)
cu_seqlens_q
=
torch
.
zeros
(
config
.
batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
f
"cuda:
{
tensor_device
}
"
)
cu_seqlens_q
[
1
:]
=
torch
.
cumsum
(
seqlens_q
,
dim
=
0
)
seqlens_kv
=
torch
.
randint
(
1
,
config
.
max_seqlen_kv
,
[
config
.
batch_size
],
dtype
=
torch
.
int32
,
device
=
f
"cuda:
{
tensor_device
}
"
,
)
cu_seqlens_kv
=
torch
.
zeros
(
config
.
batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
f
"cuda:
{
tensor_device
}
"
)
cu_seqlens_kv
[
1
:]
=
torch
.
cumsum
(
seqlens_kv
,
dim
=
0
)
num_tokens
=
cu_seqlens_q
[
-
1
]
args
=
[
torch
.
randn
(
num_tokens
,
config
.
num_heads
,
config
.
head_dim_qk
,
dtype
=
dtype
,
device
=
tensor_device
,
device
=
f
"cuda:
{
tensor_device
}
"
,
requires_grad
=
True
,
)
for
_
in
range
(
3
)
]
cu_seqlens_q
,
cu_seqlens_kv
=
[
torch
.
Tensor
([
0
,
2
,
3
]).
to
(
dtype
=
torch
.
int32
,
device
=
tensor_device
)
for
_
in
range
(
2
)
]
kwargs
[
"cu_seqlens_q"
]
=
cu_seqlens_q
kwargs
[
"cu_seqlens_kv"
]
=
cu_seqlens_kv
kwargs
[
"max_seqlen_q"
]
=
config
.
max_seqlen_q
kwargs
[
"max_seqlen_kv"
]
=
config
.
max_seqlen_kv
bwd_args
=
[
torch
.
randn
(
num_tokens
,
config
.
hidden_size
,
dtype
=
dtype
,
device
=
tensor_device
)]
bwd_args
=
[
torch
.
randn
(
num_tokens
,
config
.
hidden_size
,
dtype
=
dtype
,
device
=
f
"cuda:
{
tensor_device
}
"
)
]
elif
module
==
"Linear"
:
model
=
Linear
(
config
.
hidden_size
,
...
...
@@ -97,6 +139,24 @@ def test_current_device(model, module):
requires_grad
=
True
,
)
]
elif
module
==
"GroupedLinear"
:
num_gemms
=
4
model
=
GroupedLinear
(
num_gemms
,
config
.
hidden_size
,
4
*
config
.
hidden_size
,
params_dtype
=
dtype
,
device
=
f
"cuda:
{
tensor_device
}
"
,
)
args
=
[
torch
.
randn
(
(
config
.
max_seqlen_q
*
config
.
batch_size
*
(
num_gemms
-
1
),
config
.
hidden_size
),
dtype
=
dtype
,
device
=
f
"cuda:
{
tensor_device
}
"
,
requires_grad
=
True
,
),
[
0
]
+
[
config
.
max_seqlen_q
*
config
.
batch_size
]
*
(
num_gemms
-
1
),
# Empty first split.
]
current_device_before
=
torch
.
cuda
.
current_device
()
out
=
model
(
*
args
,
**
kwargs
)
...
...
tests/pytorch/test_fusible_ops.py
View file @
970620a5
...
...
@@ -913,15 +913,15 @@ class TestBasicOps:
dtype
=
dtype
,
accumulate_into_main_grad
=
accumulate_into_main_grad
,
)
with
torch
.
no_grad
():
op
.
weight
.
copy_
(
w_test
)
del
w_test
op
.
weight
.
main_grad
=
torch
.
full_like
(
op
.
weight
,
0.5
,
dtype
=
torch
.
float32
)
forward
=
te_ops
.
Sequential
(
te_ops
.
Quantize
(
forward
=
quantized_input
,
backward
=
quantized_grad_input
),
op
,
te_ops
.
Quantize
(
forward
=
quantized_output
,
backward
=
quantized_grad_output
),
)
with
torch
.
no_grad
():
op
.
weight
.
copy_
(
w_test
)
del
w_test
op
.
weight
.
main_grad
=
torch
.
full_like
(
op
.
weight
,
0.5
,
dtype
=
torch
.
float32
)
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y_test
=
forward
(
x_test
)
y_test
.
backward
(
dy_test
)
...
...
tests/pytorch/test_numerics.py
View file @
970620a5
...
...
@@ -46,7 +46,6 @@ from transformer_engine.pytorch import (
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch
import
checkpoint
as
te_checkpoint
from
transformer_engine.pytorch.cpp_extensions
import
general_gemm
,
general_grouped_gemm
from
transformer_engine.pytorch.module.base
import
get_multi_stream_cublas_workspace
,
get_workspace
from
transformer_engine.common
import
recipe
import
transformer_engine_torch
as
tex
from
utils
import
ModelConfig
,
reset_rng_states
...
...
@@ -2757,7 +2756,6 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
general_gemm
(
A
[
i
],
B
[
i
],
get_workspace
(),
dtype
,
grad
=
grad
,
accumulate
=
accumulate
,
...
...
@@ -2772,7 +2770,6 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
B
,
out
,
dtype
,
get_multi_stream_cublas_workspace
(),
m_splits
=
m_splits
,
grad
=
grad
,
accumulate
=
accumulate
,
...
...
@@ -2832,7 +2829,6 @@ def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_qua
quantized_out
,
*
_
=
general_gemm
(
weight_fp8
,
inp_fp8
,
get_workspace
(),
outp_type
,
quantization_params
=
out_quantizer
,
bias
=
None
,
...
...
@@ -2842,7 +2838,6 @@ def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_qua
out
,
*
_
=
general_gemm
(
weight_fp8
,
inp_fp8
,
get_workspace
(),
outp_type
,
quantization_params
=
None
,
bias
=
None
,
...
...
@@ -2918,7 +2913,6 @@ def test_fp8_grouped_gemm(shape, accumulate):
general_gemm
(
A_fp8
[
i
],
B_fp8
[
i
],
get_workspace
(),
dtype
,
out
=
out_ref
[
i
],
accumulate
=
accumulate
,
...
...
@@ -2928,7 +2922,6 @@ def test_fp8_grouped_gemm(shape, accumulate):
B_fp8
,
out
,
dtype
,
get_multi_stream_cublas_workspace
(),
m_splits
=
m_splits
,
accumulate
=
accumulate
,
)
...
...
tests/pytorch/test_sanity.py
View file @
970620a5
...
...
@@ -37,7 +37,6 @@ from transformer_engine.pytorch import (
from
transformer_engine.common
import
recipe
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.cpp_extensions
import
general_gemm
from
transformer_engine.pytorch.module.base
import
get_workspace
from
transformer_engine.pytorch.tensor.utils
import
replace_raw_data
from
utils
import
ModelConfig
...
...
@@ -961,7 +960,7 @@ def test_sanity_gemm_with_unalignment(N, offset, datatype):
inp
=
torch
.
reshape
(
scratchpad
[
offset
:
-
offset
],
(
N
,
N
))
weight
=
torch
.
reshape
(
scratchpad
[
offset
*
2
:],
(
N
,
N
))
_
=
general_gemm
(
A
=
weight
,
B
=
inp
,
workspace
=
get_workspace
()
)
_
=
general_gemm
(
A
=
weight
,
B
=
inp
)
torch
.
cuda
.
synchronize
()
...
...
@@ -985,7 +984,6 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype):
general_gemm
(
weight_fp8
,
inp_fp8
,
get_workspace
(),
outp_type
,
bias
=
None
,
use_split_accumulator
=
False
,
...
...
tests/pytorch/utils.py
View file @
970620a5
...
...
@@ -8,6 +8,7 @@ import logging
import
os
from
contextlib
import
contextmanager
from
typing
import
Optional
,
Tuple
,
Dict
,
Any
,
List
from
packaging.version
import
Version
as
PkgVersion
import
torch
...
...
@@ -210,6 +211,7 @@ class ModelConfig:
max_ctx_len
:
int
=
None
,
num_layers
:
int
=
1
,
eps
:
float
=
1e-5
,
num_splits
=
1
,
):
self
.
batch_size
=
batch_size
self
.
max_seqlen_q
=
max_seqlen_q
...
...
@@ -239,6 +241,7 @@ class ModelConfig:
self
.
max_ctx_len
=
max_ctx_len
self
.
num_layers
=
num_layers
self
.
eps
=
eps
self
.
num_splits
=
num_splits
@
contextmanager
...
...
@@ -321,6 +324,9 @@ def get_available_attention_backends(
inference_params
=
inference_params
,
softmax_type
=
config
.
softmax_type
,
return_max_logit
=
config
.
return_max_logit
,
# allow all backends to pass so they can be used for testing;
# check for FA3 availability later
num_splits
=
1
,
)
(
use_flash_attention
,
...
...
@@ -330,6 +336,10 @@ def get_available_attention_backends(
use_unfused_attention
,
available_backends
,
)
=
get_attention_backend
(
attention_params
)
# Check if FA3 is an available backend when num_splits != 1
if
available_backends
[
0
]:
if
config
.
num_splits
!=
1
and
not
flash_attention_backend
>
PkgVersion
(
"3.0.0b"
):
available_backends
[
0
]
=
False
# Set attention.py _attention_backends var using return value
# from get_attention_backend()
_attention_backends
[
"use_flash_attention"
]
=
use_flash_attention
...
...
transformer_engine/common/fused_attn/kv_cache.cu
View file @
970620a5
...
...
@@ -278,7 +278,7 @@ void convert_bshd_to_thd(Tensor tensor, Tensor cu_seqlens, Tensor new_tensor, in
/***************************************************************************************************
* KV Cache: Copy new KV tokens to the KV cache
* 1. new_k and new_v are in qkv_format; k_cache and v_cache are in 'bshd' format
* 2. cu_new_lens and cu_cached_lens are
in
shape [b + 1]; cu_cached_lens include the added lens
* 2. cu_new_lens and cu_cached_lens are
of
shape [b + 1]; cu_cached_lens include the added lens
* in current step
* 3. Non-paged KV cache is a special case of paged KV cache, with page_table = [b, 1] and
* max_pages_per_seq = 1. We use the same underlying kernel for both non-paged and paged.
...
...
transformer_engine/common/include/transformer_engine/fused_attn.h
View file @
970620a5
...
...
@@ -131,7 +131,7 @@ enum NVTE_Mask_Type {
* NVTE_VANILLA_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
* NVTE_OFF_BY_ONE_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
* NVTE_LEARNABLE_SOFTMAX: S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
* where alpha is a learnable parameter
in
shape [H].
* where alpha is a learnable parameter
of
shape [H].
*/
enum
NVTE_Softmax_Type
{
/*! Vanilla softmax */
...
...
transformer_engine/common/recipe/__init__.py
View file @
970620a5
...
...
@@ -50,7 +50,7 @@ class MMParams:
Parameters
----------
use_split_accumulator : bool, default =
`
True
`
use_split_accumulator : bool, default = True
Use FP8 fast accumulation on Hopper or Ada. For more details,
see CUBLASLT_MATMUL_DESC_FAST_ACCUM option for cublasLtMatmul.
"""
...
...
@@ -159,7 +159,7 @@ class DelayedScaling(Recipe):
recipe: DelayedScaling) -> Tensor
where `Tensor` is a framework tensor type.
reduce_amax: bool, default =
`
True
`
reduce_amax: bool, default = True
By default, if `torch.distributed` is initialized, the `amax` value for FP8
tensors is reduced across the `amax_reduction_group` (specified in the `autocast`
call). This keeps the amaxes and scaling factors synced across the given
...
...
@@ -167,13 +167,13 @@ class DelayedScaling(Recipe):
GPU maintains local amaxes and scaling factors. To ensure results are
numerically identical across checkpointing boundaries in this case, all
ranks must checkpoint in order to store the local tensors.
fp8_dpa: bool, default =
`
False
`
fp8_dpa: bool, default = False
Whether to enable FP8 dot product attention (DPA). When the model is placed in an
`autocast(enabled=True)` region and `fp8_dpa` is set to `True`, DPA casts the
inputs from higher precision to FP8, performs attention in FP8, and casts tensors
back to higher precision as outputs. FP8 DPA currently is only supported in the
`FusedAttention` backend.
fp8_mha: bool, default =
`
False
`
fp8_mha: bool, default = False
Whether to enable FP8 multi-head attention (MHA). When `True`, it removes the casting
operations mentioned above at the DPA boundaries. Currently only standard MHA modules
i.e. `LayerNormLinear/Linear + DPA + Linear`, are supported for this feature. When
...
...
@@ -422,11 +422,11 @@ class NVFP4BlockScaling(Recipe):
----------
fp4_format : {Format.E2M1}, default = Format.E2M1
FP4 data type.
disable_rht : bool, default =
`
False
`
disable_rht : bool, default = False
If set to `True`, random Hadamard transforms are not applied to any tensor.
disable_stochastic_rounding : bool, default =
`
False
`
disable_stochastic_rounding : bool, default = False
If set to `True`, stochastic rounding is disabled during quantization for all tensors.
disable_2d_quantization : bool, default =
`
False
`
disable_2d_quantization : bool, default = False
If set to `True`, 1D block scaling with block size 16 is used for all tensors.
"""
...
...
@@ -494,13 +494,15 @@ class CustomRecipe(Recipe):
qfactory : Callable
Factory callable that returns a quantizer instance for a
given semantic tensor role.
The callable is typically invoked as:
The callable is typically invoked as::
qfactory(
role: str,
)
Where `role` is one of the following strings for e.g. te.Linear
(stable public contract):
- forward: "linear_input", "linear_weight", "linear_output"
- backward: "linear_grad_output", "linear_grad_input"
"""
...
...
transformer_engine/common/transformer_engine.cpp
View file @
970620a5
...
...
@@ -736,12 +736,17 @@ int nvte_is_non_tn_fp8_gemm_supported() {
#if USE_ROCM
return
true
;
#else
int
deviceComputeCapability
=
transformer_engine
::
cuda
::
sm_arch
(
transformer_engine
::
cuda
::
current_device
());
int
num_devices
=
transformer_engine
::
cuda
::
num_devices
();
static
std
::
vector
<
int
>
cache
(
num_devices
,
-
1
);
static
std
::
vector
<
std
::
once_flag
>
flags
(
num_devices
);
int
device_id
=
transformer_engine
::
cuda
::
current_device
();
std
::
call_once
(
flags
[
device_id
],
[
&
]()
{
int
deviceComputeCapability
=
transformer_engine
::
cuda
::
sm_arch
(
device_id
);
// Note: this is temporary restriction and should be lifted in the future.
// (remove the note once it's done.)
return
(
deviceComputeCapability
>=
100
&&
deviceComputeCapability
<
120
)
||
cache
[
device_id
]
=
(
deviceComputeCapability
>=
100
&&
deviceComputeCapability
<
120
)
||
deviceComputeCapability
>=
130
;
});
return
cache
[
device_id
];
#endif
}
transformer_engine/jax/attention.py
View file @
970620a5
...
...
@@ -18,6 +18,7 @@ from transformer_engine_jax import NVTE_Mask_Type
from
transformer_engine_jax
import
NVTE_QKV_Layout
from
transformer_engine_jax
import
NVTE_QKV_Format
from
transformer_engine_jax
import
nvte_get_qkv_format
from
transformer_engine_jax
import
NVTE_Softmax_Type
from
.
import
cpp_extensions
as
tex
...
...
@@ -74,6 +75,35 @@ class AttnMaskType(Enum):
]
class
AttnSoftmaxType
(
Enum
):
"""
VANILLA_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
OFF_BY_ONE_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)),
LEARNABLE_SOFTMAX: S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [H].
"""
VANILLA_SOFTMAX
=
NVTE_Softmax_Type
.
NVTE_VANILLA_SOFTMAX
OFF_BY_ONE_SOFTMAX
=
NVTE_Softmax_Type
.
NVTE_OFF_BY_ONE_SOFTMAX
LEARNABLE_SOFTMAX
=
NVTE_Softmax_Type
.
NVTE_LEARNABLE_SOFTMAX
@
classmethod
def
from_str
(
cls
,
softmax_type
:
str
)
->
"AttnSoftmaxType"
:
"""Convert string to AttnSoftmaxType: 'vanilla', 'off_by_one', or 'learnable'."""
softmax_type_map
=
{
"vanilla"
:
cls
.
VANILLA_SOFTMAX
,
"off_by_one"
:
cls
.
OFF_BY_ONE_SOFTMAX
,
"learnable"
:
cls
.
LEARNABLE_SOFTMAX
,
}
result
=
softmax_type_map
.
get
(
softmax_type
)
if
result
is
None
:
raise
ValueError
(
f
"Unknown softmax_type:
{
softmax_type
}
. "
"Valid options: 'vanilla', 'off_by_one', 'learnable'"
)
return
result
class
QKVFormat
(
Enum
):
"""
SBHD: q,k,v memory layout with [s, b, ..., h, d]
...
...
@@ -301,6 +331,7 @@ def is_fused_attn_kernel_available(
qkv_layout
,
attn_bias_type
,
attn_mask_type
,
softmax_type
,
dropout_probability
,
q_num_heads
,
kv_num_heads
,
...
...
@@ -313,6 +344,7 @@ def is_fused_attn_kernel_available(
"""
To check whether the fused attention kernel is supported
"""
window_size_tuple
=
(
-
1
,
-
1
)
if
window_size
is
None
else
window_size
def
make_helper
(
attn_mask_type
):
return
tex
.
FusedAttnHelper
(
...
...
@@ -322,6 +354,7 @@ def is_fused_attn_kernel_available(
qkv_layout
,
attn_bias_type
,
attn_mask_type
,
softmax_type
,
dropout_probability
,
q_num_heads
,
kv_num_heads
,
...
...
@@ -329,7 +362,7 @@ def is_fused_attn_kernel_available(
kv_max_seqlen
,
head_dim_qk
,
head_dim_v
,
(
-
1
,
-
1
)
if
window_size
is
None
else
window_size
,
window_siz
e_tupl
e
,
)
return
make_helper
(
attn_mask_type
).
is_fused_attn_kernel_available
()
...
...
@@ -497,6 +530,11 @@ def _segment_ids_pos_to_seqlens_offsets(
#
# This fast path avoids expanding the mask to Q * KV matrix and instead allows us to
# examine only O(Q+KV) elements.
# For seqlens and seqoffsets calculations, the intermediate(temp) attn_mask creation
# using the segment ids and pos along with mask type (causal or brcm) is sufficient.
# It does not need to involve SW for this mask's creation
# TODO(KshitijLakhani): Try exercising the fast path for BRCM as well
if
(
attn_mask_type
.
is_causal
()
and
window_size
is
None
)
or
(
window_size
==
(
-
1
,
-
1
)
and
not
attn_mask_type
.
is_bottom_right
()
...
...
@@ -558,21 +596,6 @@ def _segment_ids_pos_to_seqlens_offsets(
)
attn_mask
=
jnp
.
logical_and
(
segment_mask
,
causal_mask
)
# TODO(KshitijLakhani): Evaluate if swa_mask is needed to procure seqlen and offsets
swa_mask
=
(
make_swa_mask
(
segment_pos_q
,
segment_pos_kv
,
window_size
,
dtype
=
jnp
.
bool
,
segment_ids_q
=
segment_ids_q
,
segment_ids_kv
=
segment_ids_kv
,
)
if
attn_mask_type
.
is_bottom_right
()
else
make_swa_mask
(
segment_pos_q
,
segment_pos_kv
,
window_size
,
dtype
=
jnp
.
bool
)
)
attn_mask
=
jnp
.
logical_and
(
attn_mask
,
swa_mask
)
attn_mask_with_id
=
jnp
.
where
(
attn_mask
,
segment_mask_with_id
,
0
)
q_seqlen
,
q_offset
,
kv_seqlen
,
kv_offset
=
_mask_to_seqlens_offset
(
attn_mask_with_id
,
max_segments_per_seq
...
...
@@ -786,6 +809,7 @@ def _legacy_fused_attn(
attn_bias_type
:
AttnBiasType
,
attn_mask_type
:
AttnMaskType
,
qkv_layout
:
QKVLayout
,
softmax_type
:
AttnSoftmaxType
,
scaling_factor
:
float
,
dropout_probability
:
float
,
is_training
:
bool
,
...
...
@@ -793,6 +817,7 @@ def _legacy_fused_attn(
context_parallel_strategy
:
CPStrategy
=
CPStrategy
.
DEFAULT
,
context_parallel_causal_load_balanced
:
bool
=
False
,
context_parallel_axis
:
str
=
""
,
softmax_offset
:
Optional
[
jnp
.
ndarray
]
=
None
,
):
"""
Perform non-THD (non-packed) cuDNN fused attention.
...
...
@@ -815,6 +840,7 @@ def _legacy_fused_attn(
seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (AttnMaskType): Type of attention mask.
softmax_type (AttnSoftmaxType): Type of attention softmax.
qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
...
...
@@ -863,10 +889,12 @@ def _legacy_fused_attn(
output
=
_fused_attn
(
qkv
,
bias
,
softmax_offset
,
SequenceDescriptor
.
from_seqlens
((
q_seq_lens
,
kv_seq_lens
)),
seed
,
attn_bias_type
=
attn_bias_type
,
attn_mask_type
=
attn_mask_type
,
softmax_type
=
softmax_type
,
qkv_layout
=
qkv_layout
,
scaling_factor
=
scaling_factor
,
dropout_probability
=
dropout_probability
,
...
...
@@ -900,6 +928,7 @@ def fused_attn_thd(
context_parallel_strategy
:
CPStrategy
=
CPStrategy
.
DEFAULT
,
context_parallel_causal_load_balanced
:
bool
=
False
,
context_parallel_axis
:
str
=
""
,
softmax_offset
:
Optional
[
jnp
.
ndarray
]
=
None
,
):
"""
Deprecated THD fused attn, please use fusd_attn with SequenceDescriptor
...
...
@@ -937,6 +966,7 @@ def fused_attn_thd(
output
=
_fused_attn
(
qkv
,
bias
,
softmax_offset
,
SequenceDescriptor
.
from_seqlens_and_offsets
(
(
q_seq_lens
,
kv_seq_lens
),
(
q_seq_offsets
,
kv_seq_offsets
)
),
...
...
@@ -945,6 +975,7 @@ def fused_attn_thd(
attn_mask_type
=
attn_mask_type
,
qkv_layout
=
qkv_layout
,
scaling_factor
=
scaling_factor
,
softmax_type
=
AttnSoftmaxType
.
VANILLA_SOFTMAX
,
dropout_probability
=
dropout_probability
,
is_training
=
is_training
,
max_segments_per_seq
=
max_segments_per_seq
,
...
...
@@ -957,15 +988,17 @@ def fused_attn_thd(
return
output
@
partial
(
jax
.
custom_vjp
,
nondiff_argnums
=
(
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
))
@
partial
(
jax
.
custom_vjp
,
nondiff_argnums
=
(
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
))
def
_fused_attn
(
qkv
:
Tuple
[
jnp
.
ndarray
,
...],
bias
:
Optional
[
jnp
.
ndarray
],
softmax_offset
:
Optional
[
jnp
.
ndarray
],
sequence_descriptor
:
SequenceDescriptor
,
seed
:
Optional
[
jnp
.
ndarray
],
attn_bias_type
:
AttnBiasType
,
attn_mask_type
:
AttnMaskType
,
qkv_layout
:
QKVLayout
,
softmax_type
:
AttnSoftmaxType
,
scaling_factor
:
float
,
dropout_probability
:
float
,
is_training
:
bool
,
...
...
@@ -979,11 +1012,13 @@ def _fused_attn(
output
,
_
=
_fused_attn_fwd_rule
(
qkv
,
bias
,
softmax_offset
,
sequence_descriptor
,
seed
,
attn_bias_type
,
attn_mask_type
,
qkv_layout
,
softmax_type
,
scaling_factor
,
dropout_probability
,
is_training
,
...
...
@@ -1000,11 +1035,13 @@ def _fused_attn(
def
_fused_attn_fwd_rule
(
qkv
,
bias
,
softmax_offset
,
sequence_descriptor
,
seed
,
attn_bias_type
,
attn_mask_type
,
qkv_layout
,
softmax_type
,
scaling_factor
,
dropout_probability
,
is_training
,
...
...
@@ -1018,10 +1055,12 @@ def _fused_attn_fwd_rule(
output
,
softmax_aux
,
rng_state
=
tex
.
fused_attn_fwd
(
qkv
,
bias
,
softmax_offset
,
sequence_descriptor
,
seed
,
attn_bias_type
=
attn_bias_type
,
attn_mask_type
=
attn_mask_type
,
softmax_type
=
softmax_type
,
qkv_layout
=
qkv_layout
,
scaling_factor
=
scaling_factor
,
dropout_probability
=
dropout_probability
,
...
...
@@ -1041,6 +1080,7 @@ def _fused_attn_fwd_rule(
sequence_descriptor
,
softmax_aux
,
rng_state
,
softmax_offset
,
output
,
)
...
...
@@ -1049,6 +1089,7 @@ def _fused_attn_bwd_rule(
attn_bias_type
,
attn_mask_type
,
qkv_layout
,
softmax_type
,
scaling_factor
,
dropout_probability
,
is_training
,
...
...
@@ -1068,11 +1109,13 @@ def _fused_attn_bwd_rule(
sequence_descriptor
,
softmax_aux
,
rng_state
,
softmax_offset
,
output
,
)
=
ctx
grad_qkv
,
grad_bias
=
tex
.
fused_attn_bwd
(
grad_qkv
,
grad_bias
,
grad_softmax_offset
=
tex
.
fused_attn_bwd
(
qkv
,
bias
,
softmax_offset
,
softmax_aux
,
rng_state
,
output
,
...
...
@@ -1080,6 +1123,7 @@ def _fused_attn_bwd_rule(
sequence_descriptor
,
attn_bias_type
=
attn_bias_type
,
attn_mask_type
=
attn_mask_type
,
softmax_type
=
softmax_type
,
qkv_layout
=
qkv_layout
,
scaling_factor
=
scaling_factor
,
dropout_probability
=
dropout_probability
,
...
...
@@ -1092,9 +1136,12 @@ def _fused_attn_bwd_rule(
)
if
attn_bias_type
==
AttnBiasType
.
NO_BIAS
:
grad_bias
=
None
if
softmax_type
!=
AttnSoftmaxType
.
LEARNABLE_SOFTMAX
:
grad_softmax_offset
=
None
return
(
grad_qkv
,
grad_bias
,
grad_softmax_offset
,
None
,
None
,
)
...
...
@@ -1111,6 +1158,7 @@ def fused_attn(
attn_bias_type
:
AttnBiasType
,
attn_mask_type
:
AttnMaskType
,
qkv_layout
:
QKVLayout
,
softmax_type
:
AttnSoftmaxType
,
scaling_factor
:
float
,
dropout_probability
:
float
,
is_training
:
bool
,
...
...
@@ -1120,6 +1168,7 @@ def fused_attn(
context_parallel_causal_load_balanced
:
bool
=
False
,
context_parallel_axis
:
str
=
""
,
context_checkpoint_name
:
str
=
"context"
,
softmax_offset
:
Optional
[
jnp
.
ndarray
]
=
None
,
):
"""
Perform cuDNN fused attention.
...
...
@@ -1139,6 +1188,7 @@ def fused_attn(
seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (AttnMaskType): Type of attention mask.
softmax_type (AttnSoftmaxType): Type of attention softmax.
qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
...
...
@@ -1153,6 +1203,9 @@ def fused_attn(
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.
context_checkpoint_name (str): The name of the context checkpoint for the custom VJP forward pass.
softmax_offset (Optional[jnp.ndarray]): An optional learnable softmax offset tensor with shape
[1, num_heads, 1, 1]. Used when softmax_type is AttnSoftmaxType.LEARNABLE_SOFTMAX.
If provided, this parameter will receive gradients during backpropagation.
Returns:
(jnp.ndarray): The output tensor from the fused attention.
...
...
@@ -1200,6 +1253,7 @@ def fused_attn(
seed
,
attn_bias_type
=
attn_bias_type
,
attn_mask_type
=
attn_mask_type
,
softmax_type
=
softmax_type
,
qkv_layout
=
qkv_layout
,
scaling_factor
=
scaling_factor
,
dropout_probability
=
dropout_probability
,
...
...
@@ -1208,15 +1262,18 @@ def fused_attn(
context_parallel_strategy
=
context_parallel_strategy
,
context_parallel_causal_load_balanced
=
context_parallel_causal_load_balanced
,
context_parallel_axis
=
context_parallel_axis
,
softmax_offset
=
softmax_offset
,
)
output
=
_fused_attn
(
qkv
,
bias
,
softmax_offset
,
sequence_descriptor
,
seed
,
attn_bias_type
=
attn_bias_type
,
attn_mask_type
=
attn_mask_type
,
qkv_layout
=
qkv_layout
,
softmax_type
=
softmax_type
,
scaling_factor
=
scaling_factor
,
dropout_probability
=
dropout_probability
,
is_training
=
is_training
,
...
...
transformer_engine/jax/cpp_extensions/activation.py
View file @
970620a5
...
...
@@ -32,9 +32,9 @@ from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_a
from
..quantize
import
ScaledTensor
,
ScaledTensorFactory
,
NoScaleTensor
from
..quantize
import
(
Quantizer
,
QuantizeLayout
,
DelayedScaleQuantizer
,
ScalingMode
,
QuantizeLayout
,
)
...
...
transformer_engine/jax/cpp_extensions/attention.py
View file @
970620a5
...
...
@@ -20,11 +20,13 @@ from transformer_engine_jax import NVTE_Fused_Attn_Backend
from
transformer_engine.jax.attention
import
(
AttnBiasType
,
AttnMaskType
,
AttnSoftmaxType
,
QKVLayout
,
QKVFormat
,
CPStrategy
,
SequenceDescriptor
,
)
from
..sharding
import
with_sharding_constraint_by_logical_axes
,
HEAD_AXES
from
.base
import
BasePrimitive
,
register_primitive
from
.misc
import
(
...
...
@@ -61,6 +63,7 @@ __all__ = [
meta_fields
=
[
"attn_bias_type"
,
"attn_mask_type"
,
"softmax_type"
,
"qkv_layout"
,
"scaling_factor"
,
"dropout_probability"
,
...
...
@@ -80,6 +83,7 @@ class _FusedAttnConfig:
attn_bias_type
:
AttnBiasType
attn_mask_type
:
AttnMaskType
softmax_type
:
AttnSoftmaxType
qkv_layout
:
QKVLayout
scaling_factor
:
float
dropout_probability
:
float
...
...
@@ -103,6 +107,7 @@ class FusedAttnHelper:
qkv_layout
:
QKVLayout
attn_bias_type
:
AttnBiasType
attn_mask_type
:
AttnMaskType
softmax_type
:
AttnSoftmaxType
dropout_probability
:
float
q_num_heads
:
int
kv_num_heads
:
int
...
...
@@ -125,6 +130,7 @@ class FusedAttnHelper:
self
.
qkv_layout
.
value
,
self
.
attn_bias_type
.
value
,
self
.
attn_mask_type
.
value
,
self
.
softmax_type
.
value
,
self
.
dropout_probability
,
self
.
q_num_heads
,
self
.
kv_num_heads
,
...
...
@@ -254,7 +260,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
name
=
"te_fused_attn_forward_ffi"
multiple_results
=
True
impl_static_args
=
(
1
3
,)
impl_static_args
=
(
1
4
,)
inner_primitive
=
None
outer_primitive
=
None
...
...
@@ -264,6 +270,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
k_aval
,
v_aval
,
bias_aval
,
softmax_offset_aval
,
seed_aval
,
q_seqlen_or_cu_seqlen_aval
,
kv_seqlen_or_cu_seqlen_aval
,
...
...
@@ -312,6 +319,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
config
.
qkv_layout
,
config
.
attn_bias_type
,
config
.
attn_mask_type
,
config
.
softmax_type
,
config
.
dropout_probability
,
attn_heads
,
num_gqa_groups
,
...
...
@@ -375,6 +383,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
config
.
dropout_probability
,
config
.
attn_bias_type
.
value
,
config
.
attn_mask_type
.
value
,
config
.
softmax_type
.
value
,
config
.
qkv_layout
.
value
,
jax_dtype_to_te_dtype
(
q_aval
.
dtype
),
config
.
is_training
,
...
...
@@ -386,6 +395,12 @@ class FusedAttnFwdPrimitive(BasePrimitive):
shape
=
wkspace_info
[
0
],
dtype
=
te_dtype_to_jax_dtype
(
wkspace_info
[
1
])
)
assert
softmax_offset_aval
.
dtype
==
jnp
.
float32
if
config
.
softmax_type
!=
AttnSoftmaxType
.
VANILLA_SOFTMAX
:
assert
softmax_offset_aval
.
shape
==
(
1
,
attn_heads
,
1
,
1
)
else
:
assert
softmax_offset_aval
.
shape
==
(
0
,)
return
out_aval
,
softmax_aux_aval
,
rng_state_aval
,
wkspace_aval
@
staticmethod
...
...
@@ -405,6 +420,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
k
,
v
,
bias
,
softmax_offset
,
seed
,
q_cu_seqlen
,
kv_cu_seqlen
,
...
...
@@ -453,6 +469,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
k
,
v
,
bias
,
softmax_offset
,
seed
,
q_cu_seqlen
,
kv_cu_seqlen
,
...
...
@@ -481,6 +498,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
deterministic
=
not
FusedAttnHelper
.
is_non_deterministic_allowed
(),
window_size_left
=
window_size_left
,
window_size_right
=
window_size_right
,
softmax_type
=
int
(
config
.
softmax_type
.
value
),
)
@
staticmethod
...
...
@@ -489,6 +507,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
k
,
v
,
bias
,
softmax_offset
,
seed
,
q_seqlen
,
kv_seqlen
,
...
...
@@ -579,6 +598,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
k
,
v
,
bias
,
softmax_offset
,
seed
,
q_cu_seqlen
,
kv_cu_seqlen
,
...
...
@@ -596,7 +616,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
def
batcher
(
batched_args
,
batch_dims
,
*
,
config
):
check_valid_batch_dims
(
batch_dims
)
assert
FusedAttnFwdPrimitive
.
outer_primitive
is
not
None
q_bdim
,
_
,
_
,
_
,
seed_bdim
,
*
_
=
batch_dims
q_bdim
,
_
,
_
,
_
,
_
,
seed_bdim
,
*
_
=
batch_dims
out_bdims
=
q_bdim
,
q_bdim
,
seed_bdim
return
(
...
...
@@ -662,7 +682,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
mesh
,
PartitionSpec
(
get_all_mesh_axes
(),
None
)
)
arg_shardings
=
[
arg_i
.
sharding
for
arg_i
in
arg_infos
]
arg_shardings
[
4
]
=
seed_sharding
arg_shardings
[
5
]
=
seed_sharding
arg_shardings
[
-
1
]
=
arg_shardings
[
-
3
]
arg_shardings
[
-
2
]
=
arg_shardings
[
-
4
]
arg_shardings
=
tuple
(
arg_shardings
)
...
...
@@ -710,7 +730,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
name
=
"te_fused_attn_backward_ffi"
multiple_results
=
True
impl_static_args
=
(
1
6
,)
impl_static_args
=
(
1
7
,)
inner_primitive
=
None
outer_primitive
=
None
...
...
@@ -720,6 +740,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
k_aval
,
v_aval
,
bias_aval
,
softmax_offset_aval
,
softmax_aux_aval
,
rng_state_aval
,
output_aval
,
...
...
@@ -781,6 +802,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
config
.
dropout_probability
,
config
.
attn_bias_type
.
value
,
config
.
attn_mask_type
.
value
,
config
.
softmax_type
.
value
,
config
.
qkv_layout
.
value
,
jax_dtype_to_te_dtype
(
q_aval
.
dtype
),
config
.
is_training
,
...
...
@@ -798,15 +820,39 @@ class FusedAttnBwdPrimitive(BasePrimitive):
shape
=
wkspace_shape
,
dtype
=
te_dtype_to_jax_dtype
(
wkspace_dtype
)
)
return
dq_aval
,
dk_aval
,
dv_aval
,
dbias_aval
,
wkspace_aval
# Validate incoming softmax_offset shape and dtype
assert
(
softmax_offset_aval
.
dtype
==
jnp
.
float32
),
f
"Incorrect softmax_offset dtype:
{
softmax_offset_aval
.
dtype
}
, expected:
{
jnp
.
float32
}
"
if
config
.
softmax_type
!=
AttnSoftmaxType
.
VANILLA_SOFTMAX
:
assert
softmax_offset_aval
.
shape
==
(
1
,
attn_heads
,
1
,
1
),
(
f
"Incorrect softmax_offset shape for
{
config
.
softmax_type
}
:"
f
"
{
softmax_offset_aval
.
shape
}
, expected: (1,
{
attn_heads
}
, 1, 1)"
)
else
:
assert
softmax_offset_aval
.
shape
==
(
0
,),
(
f
"Incorrect softmax_offset shape for
{
config
.
softmax_type
}
:"
f
"
{
softmax_offset_aval
.
shape
}
, expected: (0,)"
)
if
config
.
softmax_type
==
AttnSoftmaxType
.
VANILLA_SOFTMAX
:
dsoftmax_offset_aval
=
q_aval
.
update
(
shape
=
softmax_offset_aval
.
shape
,
dtype
=
softmax_offset_aval
.
dtype
)
else
:
dsoftmax_offset_aval
=
q_aval
.
update
(
shape
=
(
1
,
attn_heads
,
1
,
1
),
dtype
=
jnp
.
float32
)
return
dq_aval
,
dk_aval
,
dv_aval
,
dbias_aval
,
dsoftmax_offset_aval
,
wkspace_aval
@
staticmethod
def
outer_abstract
(
*
args
,
**
kwargs
):
"""
Fused attention fwd outer primitive abstract
"""
dq_aval
,
dk_aval
,
dv_aval
,
dbias_aval
,
_
=
FusedAttnBwdPrimitive
.
abstract
(
*
args
,
**
kwargs
)
return
dq_aval
,
dk_aval
,
dv_aval
,
dbias_aval
dq_aval
,
dk_aval
,
dv_aval
,
dbias_aval
,
dsoftmax_offset_aval
,
_
=
(
FusedAttnBwdPrimitive
.
abstract
(
*
args
,
**
kwargs
)
)
return
dq_aval
,
dk_aval
,
dv_aval
,
dbias_aval
,
dsoftmax_offset_aval
@
staticmethod
def
lowering
(
...
...
@@ -815,6 +861,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
k
,
v
,
bias
,
softmax_offset
,
softmax_aux
,
rng_state
,
output
,
...
...
@@ -866,6 +913,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
k
,
v
,
bias
,
softmax_offset
,
softmax_aux
,
rng_state
,
output
,
...
...
@@ -897,6 +945,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
deterministic
=
not
FusedAttnHelper
.
is_non_deterministic_allowed
(),
window_size_left
=
window_size_left
,
window_size_right
=
window_size_right
,
softmax_type
=
int
(
config
.
softmax_type
.
value
),
)
@
staticmethod
...
...
@@ -905,6 +954,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
k
,
v
,
bias
,
softmax_offset
,
softmax_aux
,
rng_state
,
output
,
...
...
@@ -993,11 +1043,12 @@ class FusedAttnBwdPrimitive(BasePrimitive):
q_cu_seqlen
=
generate_cu_seqlen
(
q_seqlen
.
flatten
())
kv_cu_seqlen
=
generate_cu_seqlen
(
kv_seqlen
.
flatten
())
dq
,
dk
,
dv
,
dbias
,
_
=
FusedAttnBwdPrimitive
.
inner_primitive
.
bind
(
dq
,
dk
,
dv
,
dbias
,
dsoftmax_offset
,
_
=
FusedAttnBwdPrimitive
.
inner_primitive
.
bind
(
q
,
k
,
v
,
bias
,
softmax_offset
,
softmax_aux
,
rng_state
,
output
,
...
...
@@ -1012,15 +1063,15 @@ class FusedAttnBwdPrimitive(BasePrimitive):
_kv_segment_pos
,
config
=
config
,
)
return
dq
,
dk
,
dv
,
dbias
return
dq
,
dk
,
dv
,
dbias
,
dsoftmax_offset
@
staticmethod
def
batcher
(
batched_args
,
batch_dims
,
*
,
config
):
check_valid_batch_dims
(
batch_dims
)
assert
FusedAttnBwdPrimitive
.
outer_primitive
is
not
None
q_bdim
,
k_bdim
,
v_bdim
,
*
_
=
batch_dims
q_bdim
,
k_bdim
,
v_bdim
,
bias_bdim
,
softmax_offset_bdim
,
*
_
=
batch_dims
out_bdims
=
q_bdim
,
k_bdim
,
v_bdim
,
q
_bdim
out_bdims
=
q_bdim
,
k_bdim
,
v_bdim
,
bias_bdim
,
softmax_offset
_bdim
return
(
FusedAttnBwdPrimitive
.
outer_primitive
.
bind
(
*
batched_args
,
config
=
config
),
out_bdims
,
...
...
@@ -1033,11 +1084,13 @@ class FusedAttnBwdPrimitive(BasePrimitive):
k_spec
=
get_padded_spec
(
arg_infos
[
1
])
v_spec
=
get_padded_spec
(
arg_infos
[
2
])
bias_spec
=
get_padded_spec
(
arg_infos
[
3
])
softmax_offset_spec
=
get_padded_spec
(
arg_infos
[
4
])
dq_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
q_spec
))
dk_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
k_spec
))
dv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
v_spec
))
dbias_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
bias_spec
))
return
(
dq_sharding
,
dk_sharding
,
dv_sharding
,
dbias_sharding
)
dsoftmax_offset_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
softmax_offset_spec
))
return
(
dq_sharding
,
dk_sharding
,
dv_sharding
,
dbias_sharding
,
dsoftmax_offset_sharding
)
@
staticmethod
def
partition
(
config
,
mesh
,
arg_infos
,
result_infos
):
...
...
@@ -1046,21 +1099,30 @@ class FusedAttnBwdPrimitive(BasePrimitive):
k_spec
=
get_padded_spec
(
arg_infos
[
1
])
v_spec
=
get_padded_spec
(
arg_infos
[
2
])
bias_spec
=
get_padded_spec
(
arg_infos
[
3
])
softmax_offset_spec
=
get_padded_spec
(
arg_infos
[
4
])
dq_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
q_spec
))
dk_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
k_spec
))
dv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
v_spec
))
dbias_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
bias_spec
))
dsoftmax_offset_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
softmax_offset_spec
))
arg_shardings
=
[
arg_i
.
sharding
for
arg_i
in
arg_infos
]
arg_shardings
[
-
1
]
=
arg_shardings
[
-
3
]
arg_shardings
[
-
2
]
=
arg_shardings
[
-
4
]
arg_shardings
=
tuple
(
arg_shardings
)
out_shardings
=
(
dq_sharding
,
dk_sharding
,
dv_sharding
,
dbias_sharding
)
out_shardings
=
(
dq_sharding
,
dk_sharding
,
dv_sharding
,
dbias_sharding
,
dsoftmax_offset_sharding
,
)
def
sharded_impl
(
q
,
k
,
v
,
bias
,
softmax_offset
,
softmax_aux
,
rng_state
,
output
,
...
...
@@ -1074,11 +1136,13 @@ class FusedAttnBwdPrimitive(BasePrimitive):
_q_segment_pos
,
_kv_segment_pos
,
):
local_dq
,
local_dk
,
local_dv
,
local_dbias
=
FusedAttnBwdPrimitive
.
impl
(
local_dq
,
local_dk
,
local_dv
,
local_dbias
,
local_dsoftmax_offset
=
(
FusedAttnBwdPrimitive
.
impl
(
q
,
k
,
v
,
bias
,
softmax_offset
,
softmax_aux
,
rng_state
,
output
,
...
...
@@ -1093,17 +1157,22 @@ class FusedAttnBwdPrimitive(BasePrimitive):
_kv_segment_pos
,
config
=
config
,
)
)
global_dbias
=
local_dbias
if
config
.
attn_bias_type
is
not
AttnBiasType
.
NO_BIAS
:
global_dbias
=
all_reduce_sum_along_dp_fsdp
(
local_dbias
,
mesh
)
return
local_dq
,
local_dk
,
local_dv
,
global_dbias
global_dsoftmax_offset
=
local_dsoftmax_offset
if
config
.
softmax_type
==
AttnSoftmaxType
.
LEARNABLE_SOFTMAX
:
global_dsoftmax_offset
=
all_reduce_sum_along_dp_fsdp
(
local_dsoftmax_offset
,
mesh
)
return
local_dq
,
local_dk
,
local_dv
,
global_dbias
,
global_dsoftmax_offset
return
mesh
,
sharded_impl
,
out_shardings
,
arg_shardings
@
staticmethod
def
shardy_sharding_rule
(
config
,
mesh
,
value_types
,
result_types
):
del
config
,
mesh
# We only care about the four first arguments.
# Keep in sync with `infer_sharding_from_operands`.
input_spec
=
tuple
((
f
"…
{
x
}
"
,)
for
x
in
range
(
len
(
value_types
)))
output_spec
=
tuple
((
f
"…
{
x
}
"
,)
for
x
in
range
(
len
(
result_types
)))
...
...
@@ -1229,6 +1298,11 @@ class _FusedAttnCPWithAllGatherHelper:
if
self
.
config
.
dropout_probability
!=
0.0
:
raise
ValueError
(
f
"
{
header
}
does not support dropout"
)
if
self
.
config
.
softmax_type
!=
AttnSoftmaxType
.
VANILLA_SOFTMAX
:
raise
ValueError
(
f
"
{
header
}
only supports VANILLA_SOFTMAX, got:
{
self
.
config
.
softmax_type
}
"
)
def
get_adjusted_mask
(
self
):
"""Converts the mask for context parallelism."""
if
self
.
config
.
attn_mask_type
==
AttnMaskType
.
CAUSAL_MASK
:
...
...
@@ -1240,6 +1314,7 @@ class _FusedAttnCPWithAllGatherHelper:
return
_FusedAttnConfig
(
attn_bias_type
=
self
.
config
.
attn_bias_type
,
attn_mask_type
=
self
.
get_adjusted_mask
(),
softmax_type
=
self
.
config
.
softmax_type
,
qkv_layout
=
self
.
config
.
qkv_layout
,
scaling_factor
=
self
.
config
.
scaling_factor
,
dropout_probability
=
self
.
config
.
dropout_probability
,
...
...
@@ -1376,7 +1451,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
mesh
,
PartitionSpec
(
get_all_mesh_axes
(),
None
)
)
arg_shardings
=
[
arg_i
.
sharding
for
arg_i
in
arg_infos
]
arg_shardings
[
4
]
=
seed_sharding
arg_shardings
[
5
]
=
seed_sharding
arg_shardings
=
tuple
(
arg_shardings
)
out_shardings
=
(
out_sharding
,
softmax_aux_sharding
,
rng_state_sharding
)
...
...
@@ -1385,6 +1460,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
k
,
v
,
bias
,
softmax_offset
,
seed
,
q_seqlen
,
kv_seqlen
,
...
...
@@ -1404,7 +1480,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
# meeting the expectation of the SPMD model.
# TODO(mgoldfarb-nvidia): When cuDNN supports we should be able to make use of a padding
# mask/sequence length tensor to avoid this unrolled loop.
def
_cross_attn
(
idx
,
q
,
k
,
v
,
bias
,
q_seqlen
,
kv_seqlen
,
seed
):
def
_cross_attn
(
idx
,
q
,
k
,
v
,
bias
,
softmax_offset
,
q_seqlen
,
kv_seqlen
,
seed
):
kv_max_seqlen
=
k
.
shape
[
1
]
kv_seqlen_per_subrank
=
kv_max_seqlen
//
(
cp_size
*
2
)
assert
kv_max_seqlen
%
cp_size
==
0
,
"sequence length must evenly divide cp size"
...
...
@@ -1431,6 +1507,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
k_unmasked
,
v_unmasked
,
bias
,
softmax_offset
,
seed
,
q_seqlen_for_step
,
kv_seqlen_for_step
,
...
...
@@ -1453,7 +1530,9 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
k_ag
,
v_ag
=
helper
.
all_gather_kv
(
k
,
v
)
functions
=
[
partial
(
_cross_attn
,
idx
,
q
,
k_ag
,
v_ag
,
bias
,
q_seqlen
,
kv_seqlen
,
seed
)
partial
(
_cross_attn
,
idx
,
q
,
k_ag
,
v_ag
,
bias
,
softmax_offset
,
q_seqlen
,
kv_seqlen
,
seed
)
for
idx
in
range
(
cp_size
)
]
...
...
@@ -1492,18 +1571,27 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
k_spec
=
get_padded_spec
(
arg_infos
[
1
])
v_spec
=
get_padded_spec
(
arg_infos
[
2
])
bias_spec
=
get_padded_spec
(
arg_infos
[
3
])
softmax_offset_spec
=
get_padded_spec
(
arg_infos
[
4
])
dq_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
q_spec
))
dk_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
k_spec
))
dv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
v_spec
))
dbias_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
bias_spec
))
dsoftmax_offset_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
softmax_offset_spec
))
arg_shardings
=
tuple
(
arg_i
.
sharding
for
arg_i
in
arg_infos
)
out_shardings
=
(
dq_sharding
,
dk_sharding
,
dv_sharding
,
dbias_sharding
)
out_shardings
=
(
dq_sharding
,
dk_sharding
,
dv_sharding
,
dbias_sharding
,
dsoftmax_offset_sharding
,
)
def
impl
(
q
,
k
,
v
,
bias
,
softmax_offset
,
softmax_aux
,
rng_state
,
output
,
...
...
@@ -1527,6 +1615,7 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
k
,
v
,
bias
,
softmax_offset
,
softmax_aux
,
rng_state
,
output
,
...
...
@@ -1562,11 +1651,12 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
num_kv_chunks
=
kv_max_seqlen
//
kv_seqlens_for_rank
[
sub_idx
]
kv_seqlen_for_step
=
(
kv_seqlen
//
(
cp_size
*
2
))
*
num_kv_chunks
dq_local
,
dk_local
,
dv_local
,
dbias_local
=
FusedAttnBwdPrimitive
.
impl
(
dq_local
,
dk_local
,
dv_local
,
dbias_local
,
_
=
FusedAttnBwdPrimitive
.
impl
(
q_split
[
sub_idx
],
k_unmasked
,
v_unmasked
,
bias
,
softmax_offset
,
softmax_aux_split
[
sub_idx
],
rng_state
,
output_split
[
sub_idx
],
...
...
@@ -1604,6 +1694,7 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
k_ag
,
v_ag
,
bias
,
softmax_offset
,
softmax_aux
,
rng_state
,
output
,
...
...
@@ -1621,7 +1712,9 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
dq
,
dk_local
,
dv_local
,
dbias
=
lax
.
switch
(
cp_rank
,
functions
)
dk
,
dv
=
helper
.
reduce_scatter_dkv
(
dk_local
,
dv_local
)
return
dq
,
dk
,
dv
,
dbias
# Return dummy dsoftmax_offset for arity matching (all-gather CP doesn't use it)
dummy_dsoftmax_offset
=
jnp
.
empty_like
(
softmax_offset
)
return
dq
,
dk
,
dv
,
dbias
,
dummy_dsoftmax_offset
return
mesh
,
impl
,
out_shardings
,
arg_shardings
...
...
@@ -1679,6 +1772,11 @@ class _FusedAttnCPWithP2PHelper:
if
self
.
config
.
dropout_probability
!=
0.0
:
raise
ValueError
(
f
"
{
header
}
does not support dropout"
)
if
self
.
config
.
softmax_type
!=
AttnSoftmaxType
.
VANILLA_SOFTMAX
:
raise
ValueError
(
f
"
{
header
}
only supports VANILLA_SOFTMAX, got:
{
self
.
config
.
softmax_type
}
"
)
# We want to encourage use of scan loop to minimize unrolling and ensure more
# predictable scheduling from XLA. The unrolled flavor will be supported but
# not the prefered implementation.
...
...
@@ -1703,6 +1801,7 @@ class _FusedAttnCPWithP2PHelper:
return
_FusedAttnConfig
(
attn_bias_type
=
self
.
config
.
attn_bias_type
,
attn_mask_type
=
attn_mask_type
,
softmax_type
=
self
.
config
.
softmax_type
,
qkv_layout
=
QKVLayout
.
BSHD_BS2HD
,
scaling_factor
=
self
.
config
.
scaling_factor
,
dropout_probability
=
self
.
config
.
dropout_probability
,
...
...
@@ -1783,7 +1882,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
mesh
,
PartitionSpec
(
get_all_mesh_axes
(),
None
)
)
arg_shardings
=
[
arg_i
.
sharding
for
arg_i
in
arg_infos
]
arg_shardings
[
4
]
=
seed_sharding
arg_shardings
[
5
]
=
seed_sharding
# Ensure segment_pos gets same sharding as ID.
arg_shardings
[
-
1
]
=
arg_shardings
[
-
3
]
arg_shardings
[
-
2
]
=
arg_shardings
[
-
4
]
...
...
@@ -1795,6 +1894,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
k
,
v
,
bias
,
_softmax_offset
,
seed
,
q_seqlen
,
kv_seqlen
,
...
...
@@ -1840,6 +1940,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
kv
,
_not_used
,
bias
,
_softmax_offset
,
seed
,
q_seqlen_per_step
,
kv_seqlen_per_step
,
...
...
@@ -1865,6 +1966,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
kv_part
,
_not_used
,
bias
,
_softmax_offset
,
seed
,
q_seqlen_per_step
,
kv_seqlen_per_step
,
...
...
@@ -1887,6 +1989,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
kv
,
_not_used
,
bias
,
_softmax_offset
,
seed
,
q_seqlen_per_step
,
kv_seqlen_per_step
,
...
...
@@ -1990,18 +2093,24 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
k_spec
=
get_padded_spec
(
arg_infos
[
1
])
v_spec
=
get_padded_spec
(
arg_infos
[
2
])
bias_spec
=
get_padded_spec
(
arg_infos
[
3
])
softmax_offset_spec
=
get_padded_spec
(
arg_infos
[
4
])
dq_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
q_spec
))
dk_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
k_spec
))
dv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
v_spec
))
dbias_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
bias_spec
))
# Ring attention doesn't use dsoftmax_offset, but we need to return it for arity matching
dsoftmax_offset_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
softmax_offset_spec
))
arg_shardings
=
[
arg_i
.
sharding
for
arg_i
in
arg_infos
]
# Ensure segment_pos gets same sharding as ID.
arg_shardings
[
-
1
]
=
arg_shardings
[
-
3
]
arg_shardings
[
-
2
]
=
arg_shardings
[
-
4
]
arg_shardings
=
tuple
(
arg_shardings
)
out_shardings
=
(
dq_sharding
,
dk_sharding
,
dv_sharding
,
dbias_sharding
)
out_shardings
=
(
dq_sharding
,
dk_sharding
,
dv_sharding
,
dbias_sharding
,
dsoftmax_offset_sharding
,
)
helper
=
_FusedAttnCPWithP2PHelper
(
mesh
,
config
)
helper
.
check_supported
()
...
...
@@ -2011,6 +2120,7 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
k
,
v
,
bias
,
_softmax_offset
,
softmax_aux
,
rng_state
,
output
,
...
...
@@ -2054,11 +2164,12 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
def
mask_compute
(
attn_mask_type
):
q_seqlen_per_step
=
helper
.
adjust_seqlen
(
q_seqlen
,
q_max_seqlen
,
idx
)
kv_seqlen_per_step
=
helper
.
adjust_seqlen
(
kv_seqlen
,
kv_max_seqlen
,
idx
)
dq_per_step
,
dk_dv_per_step
,
_
,
dbias_per_step
=
FusedAttnBwdPrimitive
.
impl
(
dq_per_step
,
dk_dv_per_step
,
_
,
dbias_per_step
,
_
=
FusedAttnBwdPrimitive
.
impl
(
q
,
kv
,
_not_used
,
bias
,
_softmax_offset
,
softmax_aux
,
rng_state
,
output
,
...
...
@@ -2082,11 +2193,12 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
q_seqlen_per_step
=
helper
.
adjust_seqlen
(
q_seqlen
,
q_max_seqlen
,
idx
)
kv_seqlen_per_step
=
helper
.
adjust_seqlen
(
kv_seqlen
,
kv_max_seqlen
,
idx
)
//
2
kv_part
=
lax
.
slice_in_dim
(
kv
,
0
,
kv_max_seqlen
//
2
,
axis
=
1
)
dq_per_step
,
dk_dv_per_step
,
_
,
dbias_per_step
=
FusedAttnBwdPrimitive
.
impl
(
dq_per_step
,
dk_dv_per_step
,
_
,
dbias_per_step
,
_
=
FusedAttnBwdPrimitive
.
impl
(
q
,
kv_part
,
_not_used
,
bias
,
_softmax_offset
,
softmax_aux
,
rng_state
,
output
,
...
...
@@ -2120,11 +2232,12 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
softmax_aux
,
q_max_seqlen
//
2
,
q_max_seqlen
,
axis
=
2
)
dq_per_step
,
dk_dv_per_step
,
_
,
dbias_per_step
=
FusedAttnBwdPrimitive
.
impl
(
dq_per_step
,
dk_dv_per_step
,
_
,
dbias_per_step
,
_
=
FusedAttnBwdPrimitive
.
impl
(
q_part
,
kv
,
_not_used
,
bias
,
_softmax_offset
,
softmax_aux_part
,
rng_state
,
output_part
,
...
...
@@ -2184,7 +2297,9 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
global_dbias
=
all_reduce_sum_along_dp_fsdp
(
dbias
,
mesh
)
dk
,
dv
=
helper
.
unstack_kv
(
dk_dv
)
return
dq
,
dk
,
dv
,
global_dbias
# Return dummy dsoftmax_offset for arity matching (ring attention doesn't use it)
dummy_dsoftmax_offset
=
jnp
.
empty_like
(
_softmax_offset
)
return
dq
,
dk
,
dv
,
global_dbias
,
dummy_dsoftmax_offset
return
mesh
,
ring_attn_bwd_impl
,
out_shardings
,
arg_shardings
...
...
@@ -2273,7 +2388,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
mesh
,
PartitionSpec
(
get_all_mesh_axes
(),
None
)
)
arg_shardings
=
[
arg_i
.
sharding
for
arg_i
in
arg_infos
]
arg_shardings
[
4
]
=
seed_sharding
arg_shardings
[
5
]
=
seed_sharding
# Ensure segment_pos gets same sharding as ID.
arg_shardings
[
-
1
]
=
arg_shardings
[
-
3
]
arg_shardings
[
-
2
]
=
arg_shardings
[
-
4
]
...
...
@@ -2285,6 +2400,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
k
,
v
,
bias
,
_softmax_offset
,
seed
,
q_seqlen
,
kv_seqlen
,
...
...
@@ -2336,6 +2452,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
kv
,
_not_used
,
bias
,
_softmax_offset
,
seed
,
q_seqlen
,
kv_seqlen
,
...
...
@@ -2345,7 +2462,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
kv_segment_ids
,
q_segment_pos
,
kv_segment_pos
,
config
,
config
=
config
,
)
if
config
.
window_size
!=
(
-
1
,
-
1
):
...
...
@@ -2420,8 +2537,8 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
arg_shardings
[
-
1
]
=
arg_shardings
[
-
3
]
arg_shardings
[
-
2
]
=
arg_shardings
[
-
4
]
arg_shardings
=
tuple
(
arg_shardings
)
# dq, dk, dv, dbias sharding = q, k, v, bias sharding
out_shardings
=
tuple
(
arg
.
sharding
for
arg
in
arg_infos
[:
4
])
# dq, dk, dv, dbias
, dsoftmax_offset
sharding = q, k, v, bias
, softmax_offset
sharding
out_shardings
=
tuple
(
arg
.
sharding
for
arg
in
arg_infos
[:
5
])
helper
=
_FusedAttnCPWithP2PHelper
(
mesh
,
config
)
helper
.
check_supported
()
...
...
@@ -2431,6 +2548,7 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
k
,
v
,
bias
,
_softmax_offset
,
softmax_aux
,
rng_state
,
output
,
...
...
@@ -2478,11 +2596,12 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
kv_segment_pos_next
=
helper
.
permute_kv
(
kv_segment_pos
,
cp_perm
)
def
compute
(
config
):
dq_per_step
,
dkv_per_step
,
_
,
dbias_per_step
=
FusedAttnBwdPrimitive
.
impl
(
dq_per_step
,
dkv_per_step
,
_
,
dbias_per_step
,
_
=
FusedAttnBwdPrimitive
.
impl
(
q
,
kv
,
_not_used
,
bias
,
_softmax_offset
,
softmax_aux
,
rng_state
,
output
,
...
...
@@ -2536,7 +2655,9 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
global_dbias
=
all_reduce_sum_along_dp_fsdp
(
dbias
,
mesh
)
dk
,
dv
=
helper
.
unstack_kv
(
dkv
)
return
dq
,
dk
,
dv
,
global_dbias
# Return dummy dsoftmax_offset for arity matching (ring attention doesn't use it)
dummy_dsoftmax_offset
=
jnp
.
empty_like
(
_softmax_offset
)
return
dq
,
dk
,
dv
,
global_dbias
,
dummy_dsoftmax_offset
return
mesh
,
bwd_impl
,
out_shardings
,
arg_shardings
...
...
@@ -2557,10 +2678,12 @@ def _maybe_context_parallel_axis(cp_axis: str):
def
fused_attn_fwd
(
qkv
:
Tuple
[
jnp
.
ndarray
,
...],
bias
:
Optional
[
jnp
.
ndarray
],
softmax_offset
:
Optional
[
jnp
.
ndarray
],
sequence_descriptor
:
SequenceDescriptor
,
seed
:
Optional
[
jnp
.
ndarray
],
attn_bias_type
:
AttnBiasType
,
attn_mask_type
:
AttnMaskType
,
softmax_type
:
AttnSoftmaxType
,
qkv_layout
:
QKVLayout
,
scaling_factor
:
float
,
dropout_probability
:
float
,
...
...
@@ -2585,6 +2708,7 @@ def fused_attn_fwd(
query has a different shape (e.g., cross-attention).
- `(query, key, value)`: For separate query, key, and value tensors.
bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
softmax_offset (Optional[jnp.ndarray]): An optional softmax offset tensor.
q_seqlen (jnp.ndarray): Sequence lengths for the query, with shape [batch,].
kv_seqlen (jnp.ndarray): Sequence lengths for the key and value, with shape [batch,].
q_seq_offsets (jnp.ndarray):
...
...
@@ -2594,6 +2718,7 @@ def fused_attn_fwd(
seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (AttnMaskType): Type of attention mask.
softmax_type (AttnSoftmaxType): Type of softmax.
qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
...
...
@@ -2633,10 +2758,36 @@ def fused_attn_fwd(
assert
bias
is
None
bias
=
jnp
.
zeros
(
0
,
dtype
=
qkv
[
0
].
dtype
)
if
softmax_offset
is
None
:
assert
(
softmax_type
!=
AttnSoftmaxType
.
LEARNABLE_SOFTMAX
),
f
"Softmax type
{
softmax_type
}
is not supported when softmax_offset is None"
if
softmax_type
==
AttnSoftmaxType
.
OFF_BY_ONE_SOFTMAX
:
num_heads
=
qkv
[
0
].
shape
[
-
2
]
# Create tensor [1, h, 1, 1] filled with zeros (logit value = 0)
# This adds exp(0 - x_max) = exp(-x_max) to the denominator,
# which contributes exactly 1 after normalization, giving: exp(x_i) / (sum(exp(x_j)) + 1)
softmax_offset
=
jnp
.
zeros
((
1
,
num_heads
,
1
,
1
),
dtype
=
jnp
.
float32
)
# Shard by heads dimension
softmax_offset
=
with_sharding_constraint_by_logical_axes
(
softmax_offset
,
(
None
,
HEAD_AXES
,
None
,
None
)
)
else
:
assert
softmax_type
==
AttnSoftmaxType
.
VANILLA_SOFTMAX
softmax_offset
=
jnp
.
zeros
(
0
,
dtype
=
jnp
.
float32
)
else
:
assert
softmax_offset
.
dtype
==
jnp
.
float32
# Shard by heads dimension if not VANILLA_SOFTMAX
if
softmax_type
!=
AttnSoftmaxType
.
VANILLA_SOFTMAX
:
softmax_offset
=
with_sharding_constraint_by_logical_axes
(
softmax_offset
,
(
None
,
HEAD_AXES
,
None
,
None
)
)
fused_config
=
_FusedAttnConfig
(
attn_bias_type
=
attn_bias_type
,
attn_mask_type
=
attn_mask_type
,
qkv_layout
=
qkv_layout
,
softmax_type
=
softmax_type
,
scaling_factor
=
scaling_factor
,
dropout_probability
=
dropout_probability
,
is_training
=
is_training
,
...
...
@@ -2662,6 +2813,7 @@ def fused_attn_fwd(
output
,
softmax_aux
,
rng_state
=
primitive
.
bind
(
*
qkv_for_primitive
,
bias
,
softmax_offset
,
seed
,
*
seq_desc_flatten
,
config
=
fused_config
,
...
...
@@ -2673,6 +2825,7 @@ def fused_attn_fwd(
def
fused_attn_bwd
(
qkv
:
Tuple
[
jnp
.
ndarray
,
...],
bias
:
Optional
[
jnp
.
ndarray
],
softmax_offset
:
Optional
[
jnp
.
ndarray
],
softmax_aux
:
jnp
.
ndarray
,
rng_state
:
jnp
.
ndarray
,
output
:
jnp
.
ndarray
,
...
...
@@ -2681,6 +2834,7 @@ def fused_attn_bwd(
attn_bias_type
:
AttnBiasType
,
attn_mask_type
:
AttnMaskType
,
qkv_layout
:
QKVLayout
,
softmax_type
:
AttnSoftmaxType
,
scaling_factor
:
float
,
dropout_probability
:
float
,
is_training
:
bool
,
...
...
@@ -2702,6 +2856,7 @@ def fused_attn_bwd(
query has a different shape (e.g., cross-attention).
- `(query, key, value)`: For separate query, key, and value tensors.
bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
softmax_offset (Optional[jnp.ndarray]): An optional softmax offset tensor.
softmax_aux (jnp.ndarray): Auxiliary tensors from the softmax step used in the forward pass.
rng_state (jnp.ndarray): Auxiliary tensors to save the random state in the forward pass.
output (jnp.ndarray): The output tensor from the forward pass.
...
...
@@ -2714,6 +2869,7 @@ def fused_attn_bwd(
The offsets in the sequence dim for the query, with shape [batch + 1,].
attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (AttnMaskType): Type of attention mask.
softmax_type (AttnSoftmaxType): Type of softmax.
qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
...
...
@@ -2755,6 +2911,28 @@ def fused_attn_bwd(
assert
bias
is
None
bias
=
jnp
.
zeros
(
0
,
dtype
=
qkv
[
0
].
dtype
)
if
softmax_offset
is
None
:
assert
softmax_type
!=
AttnSoftmaxType
.
LEARNABLE_SOFTMAX
,
f
"Unknown
{
softmax_type
=
}
"
if
softmax_type
==
AttnSoftmaxType
.
OFF_BY_ONE_SOFTMAX
:
num_heads
=
qkv
[
0
].
shape
[
-
2
]
# Create tensor [1, h, 1, 1] filled with zeros
softmax_offset
=
jnp
.
zeros
((
1
,
num_heads
,
1
,
1
),
dtype
=
jnp
.
float32
)
# Shard by heads dimension
softmax_offset
=
with_sharding_constraint_by_logical_axes
(
softmax_offset
,
(
None
,
HEAD_AXES
,
None
,
None
)
)
elif
softmax_type
==
AttnSoftmaxType
.
VANILLA_SOFTMAX
:
softmax_offset
=
jnp
.
zeros
(
0
,
dtype
=
jnp
.
float32
)
else
:
raise
NotImplementedError
(
f
"Unknown
{
softmax_type
=
}
"
)
else
:
softmax_offset
=
softmax_offset
.
astype
(
jnp
.
float32
)
# Shard by heads dimension if not VANILLA_SOFTMAX
if
softmax_type
!=
AttnSoftmaxType
.
VANILLA_SOFTMAX
:
softmax_offset
=
with_sharding_constraint_by_logical_axes
(
softmax_offset
,
(
None
,
HEAD_AXES
,
None
,
None
)
)
# TODO(KshitijLakhani): Add a check for cuDNN version when determinism does get supported on
# sm100+
compute_capabilities
=
get_all_device_compute_capability
()
...
...
@@ -2767,6 +2945,7 @@ def fused_attn_bwd(
attn_bias_type
=
attn_bias_type
,
attn_mask_type
=
attn_mask_type
,
qkv_layout
=
qkv_layout
,
softmax_type
=
softmax_type
,
scaling_factor
=
scaling_factor
,
dropout_probability
=
dropout_probability
,
is_training
=
is_training
,
...
...
@@ -2788,9 +2967,10 @@ def fused_attn_bwd(
primitive
=
FusedRingAttnBwdPrimitive
.
outer_primitive
seq_desc_flatten
,
_
=
jax
.
tree
.
flatten
(
sequence_descriptor
)
*
qkv_grads
,
bias_grad
=
primitive
.
bind
(
*
qkv_grads
,
bias_grad
,
softmax_offset_grad
=
primitive
.
bind
(
*
qkv_for_primitive
,
bias
,
softmax_offset
,
softmax_aux
,
rng_state
,
output
,
...
...
@@ -2798,4 +2978,4 @@ def fused_attn_bwd(
*
seq_desc_flatten
,
config
=
fused_config
,
)
return
tuple
(
qkv_grads
[:
len
(
qkv
)]),
bias_grad
return
tuple
(
qkv_grads
[:
len
(
qkv
)]),
bias_grad
,
softmax_offset_grad
transformer_engine/jax/cpp_extensions/gemm.py
View file @
970620a5
...
...
@@ -39,12 +39,12 @@ from ..quantize import (
Quantizer
,
GroupedQuantizer
,
QuantizerSet
,
QuantizeLayout
,
noop_quantizer_set
,
is_fp8_gemm_with_all_layouts_supported
,
apply_padding_to_scale_inv
,
get_quantize_config_with_recipe
,
get_global_quantize_recipe
,
QuantizeLayout
,
)
from
.misc
import
get_padded_spec
,
is_all_reduce_in_float32
from
..sharding
import
(
...
...
transformer_engine/jax/cpp_extensions/misc.py
View file @
970620a5
...
...
@@ -116,7 +116,7 @@ def multidim_transpose(shape, static_axis_boundary=-1, transpose_axis=-1):
transpose. Note, transpose_axis should be greater than static_axis_boundary
examples:
X
in
shape (dim0, dim1, dim2, dim3, dim4)
X
of
shape (dim0, dim1, dim2, dim3, dim4)
static_axis_boundary == -1, transpose_axis == 2
Xt = (dim2, dim3, dim4, dim0, dim1)
...
...
transformer_engine/jax/cpp_extensions/normalization.py
View file @
970620a5
...
...
@@ -35,9 +35,9 @@ from ..sharding import (
from
..quantize
import
ScaledTensor
,
ScaledTensorFactory
,
NoScaleTensor
from
..quantize
import
(
Quantizer
,
QuantizeLayout
,
DelayedScaleQuantizer
,
ScalingMode
,
QuantizeLayout
,
)
...
...
transformer_engine/jax/cpp_extensions/quantization.py
View file @
970620a5
...
...
@@ -40,11 +40,11 @@ from ..quantize import (
GroupedScaledTensor1x
,
Quantizer
,
GroupedQuantizer
,
QuantizeLayout
,
ScalingMode
,
compute_scale_from_amax
,
NoScaleTensor
,
get_rht_matrix
,
QuantizeLayout
,
)
...
...
@@ -497,6 +497,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
x_spec
=
get_padded_spec
(
arg_infos
[
0
])
amax_spec
=
get_padded_spec
(
arg_infos
[
2
])
sr_rng_state_spec
=
get_padded_spec
(
arg_infos
[
3
])
out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
),
...
...
@@ -551,9 +552,12 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
)
arg_shardings
=
list
(
arg_i
.
sharding
for
arg_i
in
arg_infos
)
if
len
(
sr_rng_state_spec
)
>
1
:
# sr_rng_state shape [n_devices, state_per_device]
sr_rng_state_spec
=
(
*
tuple
(
x
for
x
in
x_spec
if
x
is
not
None
),
None
)
arg_shardings
[
3
]
=
NamedSharding
(
mesh
,
PartitionSpec
(
tuple
(
x
for
x
in
x_spec
if
x
is
not
None
),
None
),
PartitionSpec
(
*
sr_rng_state_spec
),
desc
=
"BaseDBiasQuantizePrimitive.sr_rng_state"
,
)
arg_shardings
=
tuple
(
arg_shardings
)
...
...
@@ -654,9 +658,11 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
dbias
=
input_spec
[
flatten_axis
:]
if
is_dbias
else
(
prefix
+
"_dbias"
,)
amax
=
(
BATCHING
+
prefix
+
"_amax"
,)
scale
=
(
BATCHING
+
prefix
+
"_scale"
,)
sr_rng_state
=
(
BATCHING
+
prefix
+
"_sr_rng_state"
,)
if
value_types
[
3
].
shape
!=
[
0
]:
sr_rng_state
=
(
BATCHING
+
prefix
+
"_sr_rng_state_
partition_axi
s"
,
BATCHING
+
prefix
+
"sr_rng_state_data
_axis
"
,
BATCHING
+
prefix
+
"_sr_rng_state_
device
s"
,
prefix
+
"sr_rng_state_data"
,
)
post_rht_amax
=
(
BATCHING
+
prefix
+
"_post_rht_amax"
,)
...
...
@@ -849,7 +855,7 @@ def _quantize_dbias_impl(
if
force_1x_quantization
:
q_layout
=
QuantizeLayout
.
ROWWISE
sr_rng_state
=
None
sr_rng_state
=
jnp
.
empty
((
0
,),
jnp
.
uint32
)
if
quantizer
.
scaling_mode
.
is_nvfp4_scaling
:
# Only NVFP4 scaling modes support stochastic rounding
if
quantizer
.
stochastic_rounding_rng_state
is
not
None
:
...
...
@@ -866,11 +872,7 @@ def _quantize_dbias_impl(
x
.
data
,
scale
,
amax
,
(
sr_rng_state
if
sr_rng_state
is
not
None
else
jnp
.
empty
((
get_num_devices_in_mesh
(),
1
),
jnp
.
uint32
)
),
sr_rng_state
,
post_rht_amax
if
post_rht_amax
is
not
None
else
jnp
.
zeros
((
1
,),
jnp
.
float32
),
rht_matrix
,
out_dtype
=
quantizer
.
q_dtype
,
...
...
@@ -880,7 +882,7 @@ def _quantize_dbias_impl(
scale_dtype
=
quantizer
.
get_scale_dtype
(),
is_dbias
=
is_dbias
if
not
quantizer
.
scaling_mode
.
is_nvfp4_scaling
else
False
,
is_outer
=
True
,
stochastic_rounding
=
sr_rng_state
is
not
None
,
stochastic_rounding
=
sr_rng_state
.
size
!=
0
,
use_rht
=
use_rht
,
)
# For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
...
...
transformer_engine/jax/cpp_extensions/softmax.py
View file @
970620a5
...
...
@@ -11,10 +11,11 @@ import jax
import
jax.numpy
as
jnp
from
jax
import
dtypes
,
ffi
from
jax.sharding
import
PartitionSpec
,
NamedSharding
from
.attention
import
AttnSoftmaxType
from
.base
import
BasePrimitive
,
register_primitive
from
.misc
import
get_padded_spec
,
check_valid_batch_dims
from
..softmax
import
SoftmaxType
from
..softmax
import
Softmax
Fusion
Type
__all__
=
[
...
...
@@ -32,7 +33,8 @@ __all__ = [
def
is_softmax_kernel_available
(
softmax_type
:
SoftmaxType
,
softmax_fusion_type
:
SoftmaxFusionType
,
softmax_type
:
AttnSoftmaxType
,
batch
:
int
,
heads
:
int
,
q_seqlen
:
int
,
...
...
@@ -40,15 +42,18 @@ def is_softmax_kernel_available(
dtype
:
jnp
.
dtype
,
):
"""check softmax available"""
if
softmax_type
is
SoftmaxType
.
SCALED
:
if
softmax_type
!=
AttnSoftmaxType
.
VANILLA_SOFTMAX
:
return
False
if
softmax_fusion_type
is
SoftmaxFusionType
.
SCALED
:
return
ScaledSoftmaxFwdPrimitive
.
is_kernel_available
(
batch
,
heads
,
q_seqlen
,
k_seqlen
,
dtype
)
if
softmax_type
is
SoftmaxType
.
SCALED_MASKED
:
if
softmax_
fusion_
type
is
Softmax
Fusion
Type
.
SCALED_MASKED
:
return
ScaledMaskedSoftmaxFwdPrimitive
.
is_kernel_available
(
batch
,
heads
,
q_seqlen
,
k_seqlen
,
dtype
)
if
softmax_type
is
SoftmaxType
.
SCALED_UPPER_TRIANG_MASKED
:
if
softmax_
fusion_
type
is
Softmax
Fusion
Type
.
SCALED_UPPER_TRIANG_MASKED
:
return
ScaledUpperTriangMaskedSoftmaxFwdPrimitive
.
is_kernel_available
(
batch
,
heads
,
q_seqlen
,
k_seqlen
,
dtype
)
...
...
@@ -792,26 +797,77 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
register_primitive
(
ScaledUpperTriangMaskedSoftmaxBwdPrimitive
)
def
jax_scaled_softmax
(
logits
:
jnp
.
ndarray
,
scale_factor
:
float
):
def
jax_scaled_softmax
(
logits
:
jnp
.
ndarray
,
scale_factor
:
float
,
softmax_offset
:
jnp
.
ndarray
|
float
|
None
=
None
):
"""
JAX based implementation of scaled softmax
"""
if
softmax_offset
is
not
None
:
return
jax_general_softmax
(
scale_factor
*
logits
,
offset
=
softmax_offset
)
return
jax
.
nn
.
softmax
(
scale_factor
*
logits
)
def
jax_scaled_masked_softmax
(
logits
:
jnp
.
ndarray
,
mask
:
jnp
.
ndarray
,
scale_factor
:
float
):
def
jax_scaled_masked_softmax
(
logits
:
jnp
.
ndarray
,
mask
:
jnp
.
ndarray
,
scale_factor
:
float
,
softmax_offset
:
jnp
.
ndarray
|
float
|
None
=
None
,
):
"""
JAX based implementation of scaled and masked softmax
"""
if
softmax_offset
is
not
None
:
return
jax_general_softmax
(
logits
*
scale_factor
,
offset
=
softmax_offset
,
where
=
mask
!=
1
)
return
jax
.
nn
.
softmax
(
logits
*
scale_factor
,
where
=
mask
!=
1
)
def
jax_scaled_upper_triang_masked_softmax
(
logits
:
jnp
.
ndarray
,
scale_factor
:
float
):
def
jax_scaled_upper_triang_masked_softmax
(
logits
:
jnp
.
ndarray
,
scale_factor
:
float
,
softmax_offset
:
jnp
.
ndarray
|
float
|
None
=
None
):
"""
JAX based implementation of scaled and upper triangle masked softmax
"""
mask
=
1
-
jnp
.
tril
(
jnp
.
ones_like
(
logits
))
return
jax_scaled_masked_softmax
(
logits
,
mask
,
scale_factor
)
return
jax_scaled_masked_softmax
(
logits
,
mask
,
scale_factor
,
softmax_offset
)
def
jax_general_softmax
(
x
:
jnp
.
ndarray
,
axis
:
int
=
-
1
,
where
:
jnp
.
ndarray
|
None
=
None
,
initial
:
jnp
.
ndarray
=
-
jnp
.
inf
,
offset
:
jnp
.
ndarray
|
float
|
None
=
None
,
)
->
jnp
.
ndarray
:
"""
JAX based implementation of general softmax with optional masking and offset.
"""
# Compute max of x
x_max
=
jnp
.
max
(
x
,
axis
,
where
=
where
,
initial
=
initial
,
keepdims
=
True
)
if
offset
is
not
None
:
# Cast offset to x.dtype to prevent type promotion
if
isinstance
(
offset
,
(
int
,
float
)):
offset
=
jnp
.
array
(
offset
,
dtype
=
x
.
dtype
)
else
:
offset
=
offset
.
astype
(
x
.
dtype
)
# Include offset in max: x_max = max(x_max, offset)
# This is equivalent to computing max over [x..., offset]
x_max
=
jnp
.
maximum
(
x_max
,
offset
)
x_safe
=
x
if
where
is
None
else
jnp
.
where
(
where
,
x
,
initial
)
unnormalized
=
jnp
.
exp
(
x_safe
-
x_max
)
denominator
=
jnp
.
sum
(
unnormalized
,
axis
,
where
=
where
,
keepdims
=
True
)
if
offset
is
not
None
:
# Add exp(offset - x_max) to denominator
denominator
=
denominator
+
jnp
.
exp
(
offset
-
x_max
)
result
=
unnormalized
/
denominator
if
where
is
not
None
:
result
=
jnp
.
where
(
where
,
result
,
0
)
return
result
def
scaled_softmax_fwd
(
logits
:
jnp
.
ndarray
,
scale_factor
:
float
)
->
jnp
.
ndarray
:
...
...
transformer_engine/jax/csrc/extensions.h
View file @
970620a5
...
...
@@ -108,28 +108,28 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
FusedAttnBackwardHandler
);
NVTE_Fused_Attn_Backend
GetFusedAttnBackend
(
bool
is_training
,
DType
q_dtype
,
DType
kv_dtype
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
float
dropout_probability
,
size_t
q_num_heads
,
size_t
kv_num_heads
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
qk_head_dim
,
size_t
v_head_dim
,
int64_t
window_size_left
,
int64_t
window_size_right
);
NVTE_Fused_Attn_Backend
GetFusedAttnBackend
(
bool
is_training
,
DType
q_dtype
,
DType
kv_dtype
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
float
dropout_probability
,
size_t
q_attn_heads
,
size_t
kv_attn_heads
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
qk_head_dim
,
size_t
v_head_dim
,
int64_t
window_size_left
,
int64_t
window_size_right
);
pybind11
::
tuple
GetFusedAttnForwardWorkspaceSizes
(
size_t
input_batch
,
size_t
bias_batch
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
attn_heads
,
size_t
num_gqa_groups
,
size_t
bias_heads
,
size_t
qk_head_dim
,
size_t
v_head_dim
,
float
scaling_factor
,
float
dropout_probability
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_QKV_Layout
qkv_layout
,
DType
dtype
,
bool
is_training
,
size_t
max_segments_per_seq
,
int64_t
window_size_left
,
int64_t
window_size_right
);
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
NVTE_QKV_Layout
qkv_layout
,
DType
dtype
,
bool
is_training
,
size_t
max_segments_per_seq
,
int64_t
window_size_left
,
int64_t
window_size_right
);
pybind11
::
tuple
GetFusedAttnBackwardWorkspaceSizes
(
size_t
input_batch
,
size_t
bias_batch
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
attn_heads
,
size_t
num_gqa_groups
,
size_t
bias_heads
,
size_t
qk_head_dim
,
size_t
v_head_dim
,
float
scaling_factor
,
float
dropout_probability
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_
QKV_Layout
qkv_layout
,
DType
dtype
,
bool
is_training
,
bool
deterministic
,
size_t
max_segments_per_seq
,
int64_t
window_size_left
,
int64_t
window_size_right
);
NVTE_Mask_Type
mask_type
,
NVTE_
Softmax_Type
softmax_type
,
NVTE_QKV_Layout
qkv_layout
,
DType
dtype
,
bool
is_training
,
bool
deterministic
,
size_t
max_segments_per_seq
,
int64_t
window_size_left
,
int64_t
window_size_right
);
// GEMM
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
GemmHandler
);
...
...
transformer_engine/jax/csrc/extensions/attention.cpp
View file @
970620a5
...
...
@@ -11,14 +11,12 @@
namespace
transformer_engine
{
namespace
jax
{
NVTE_Fused_Attn_Backend
GetFusedAttnBackend
(
bool
is_training
,
DType
q_dtype
,
DType
kv_dtype
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
float
dropout_probability
,
size_t
q_attn_heads
,
size_t
kv_attn_heads
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
qk_head_dim
,
size_t
v_head_dim
,
int64_t
window_size_left
,
int64_t
window_size_right
)
{
NVTE_Softmax_Type
softmax_type
=
NVTE_Softmax_Type
::
NVTE_VANILLA_SOFTMAX
;
NVTE_Fused_Attn_Backend
GetFusedAttnBackend
(
bool
is_training
,
DType
q_dtype
,
DType
kv_dtype
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
float
dropout_probability
,
size_t
q_attn_heads
,
size_t
kv_attn_heads
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
qk_head_dim
,
size_t
v_head_dim
,
int64_t
window_size_left
,
int64_t
window_size_right
)
{
auto
backend
=
nvte_get_fused_attn_backend
(
is_training
,
static_cast
<
NVTEDType
>
(
q_dtype
),
static_cast
<
NVTEDType
>
(
kv_dtype
),
qkv_layout
,
bias_type
,
mask_type
,
softmax_type
,
dropout_probability
,
q_attn_heads
,
kv_attn_heads
,
...
...
@@ -39,7 +37,8 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t
const
size_t
kv_max_seqlen
,
DType
dtype
,
NVTE_Bias_Type
bias_type
,
NVTE_Fused_Attn_Backend
backend
,
void
*
softmax_buf
,
void
*
rng_state_buf
=
nullptr
,
void
*
bias_buf
=
nullptr
)
{
void
*
bias_buf
=
nullptr
,
void
*
softmax_offset_buf
=
nullptr
)
{
// all backends need softmax but expect different shapes/dtypes
// start with the max512 sequence length softmax shape/dtype and correct later
tensor_pack
->
size
=
1
;
...
...
@@ -67,10 +66,12 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t
softmax_aux_data
.
shape
.
data
[
3
]
=
1
;
// {B,H,Qs,Ks} -> {B,H,Qs,1}
softmax_aux_data
.
dtype
=
static_cast
<
NVTEDType
>
(
DType
::
kFloat32
);
int
size
=
2
;
// Start at 2 (we have softmax and rng_state at indices 0, 1)
// include bias if enabled
if
(
bias_type
!=
NVTE_Bias_Type
::
NVTE_NO_BIAS
&&
bias_type
!=
NVTE_Bias_Type
::
NVTE_ALIBI
)
{
tensor_pack
->
size
=
3
;
NVTETensor
&
bias_aux
=
tensor_pack
->
tensors
[
2
]
;
NVTETensor
&
bias_aux
=
tensor_pack
->
tensors
[
size
]
;
size
++
;
NVTEBasicTensor
bias_aux_data
;
bias_aux_data
.
data_ptr
=
bias_buf
;
bias_aux_data
.
shape
.
ndim
=
4
;
...
...
@@ -81,6 +82,24 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t
bias_aux_data
.
dtype
=
static_cast
<
NVTEDType
>
(
dtype
);
nvte_set_tensor_param
(
&
bias_aux
,
kNVTERowwiseData
,
&
bias_aux_data
);
}
// include softmax_offset if provided
if
(
softmax_offset_buf
!=
nullptr
)
{
NVTETensor
&
softmax_offset_aux
=
tensor_pack
->
tensors
[
size
];
size
++
;
NVTEBasicTensor
softmax_offset_aux_data
;
softmax_offset_aux_data
.
data_ptr
=
softmax_offset_buf
;
softmax_offset_aux_data
.
shape
.
ndim
=
4
;
softmax_offset_aux_data
.
shape
.
data
[
0
]
=
1
;
softmax_offset_aux_data
.
shape
.
data
[
1
]
=
attn_heads
;
softmax_offset_aux_data
.
shape
.
data
[
2
]
=
1
;
softmax_offset_aux_data
.
shape
.
data
[
3
]
=
1
;
softmax_offset_aux_data
.
dtype
=
static_cast
<
NVTEDType
>
(
DType
::
kFloat32
);
nvte_set_tensor_param
(
&
softmax_offset_aux
,
kNVTERowwiseData
,
&
softmax_offset_aux_data
);
}
// Set final size
tensor_pack
->
size
=
size
;
}
nvte_set_tensor_param
(
&
softmax_aux
,
kNVTERowwiseData
,
&
softmax_aux_data
);
}
...
...
@@ -98,14 +117,16 @@ void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack, const size_
const
size_t
bias_heads
,
const
size_t
q_max_seqlen
,
const
size_t
kv_max_seqlen
,
DType
dtype
,
NVTE_Fused_Attn_Backend
backend
,
void
*
softmax_buf
,
void
*
rng_state_buf
,
void
*
bias_buf
)
{
void
*
rng_state_buf
,
void
*
bias_buf
,
void
*
softmax_offset_buf
=
nullptr
)
{
// Backward calls put everything into the tensor pack for every backend
// so we set dummy bias_type and backend choices here to follow the correct code path
auto
dummy_bias_type
=
NVTE_Bias_Type
::
NVTE_POST_SCALE_BIAS
;
auto
dummy_backend
=
NVTE_Fused_Attn_Backend
::
NVTE_F16_arbitrary_seqlen
;
PrepareFusedAttnForwardAuxTensors
(
tensor_pack
,
input_batch
,
bias_batch
,
attn_heads
,
bias_heads
,
q_max_seqlen
,
kv_max_seqlen
,
dtype
,
dummy_bias_type
,
dummy_backend
,
softmax_buf
,
rng_state_buf
,
bias_buf
);
dummy_backend
,
softmax_buf
,
rng_state_buf
,
bias_buf
,
softmax_offset_buf
);
// correct softmax shape for max512 sequence length kernel
if
(
backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
...
...
@@ -121,8 +142,9 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t
input_batch
,
size_t
bias_batch
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
attn_heads
,
size_t
num_gqa_groups
,
size_t
bias_heads
,
size_t
qk_head_dim
,
size_t
v_head_dim
,
float
scaling_factor
,
float
dropout_probability
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_QKV_Layout
qkv_layout
,
DType
dtype
,
bool
is_training
,
size_t
max_segments_per_seq
,
int64_t
window_size_left
,
int64_t
window_size_right
)
{
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
NVTE_QKV_Layout
qkv_layout
,
DType
dtype
,
bool
is_training
,
size_t
max_segments_per_seq
,
int64_t
window_size_left
,
int64_t
window_size_right
)
{
auto
q_shape
=
std
::
vector
<
size_t
>
{
input_batch
*
q_max_seqlen
,
attn_heads
,
qk_head_dim
};
auto
q_tensor
=
TensorWrapper
(
nullptr
,
q_shape
,
dtype
);
auto
k_shape
=
std
::
vector
<
size_t
>
{
input_batch
*
kv_max_seqlen
,
num_gqa_groups
,
qk_head_dim
};
...
...
@@ -141,7 +163,6 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
auto
dummy_page_table_tensor
=
TensorWrapper
(
nullptr
,
std
::
vector
<
size_t
>
{
1
},
DType
::
kInt32
);
auto
dummy_softmax_offset_tensor
=
TensorWrapper
(
nullptr
,
std
::
vector
<
size_t
>
{
1
},
DType
::
kFloat32
);
NVTE_Softmax_Type
softmax_type
=
NVTE_Softmax_Type
::
NVTE_VANILLA_SOFTMAX
;
NVTETensorPack
aux_output_tensors
;
nvte_tensor_pack_create
(
&
aux_output_tensors
);
...
...
@@ -208,18 +229,21 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
static
void
FusedAttnForwardImpl
(
cudaStream_t
stream
,
void
*
q
,
void
*
k
,
void
*
v
,
void
*
bias
,
void
*
seed
,
void
*
q_cu_seqlens
,
void
*
kv_cu_seqlens
,
void
*
q_seq_offsets
,
void
*
k_seq_offsets
,
void
*
output
,
void
*
softmax_aux
,
void
*
rng_state
,
void
*
workspace
,
size_t
input_batch
,
size_t
bias_batch
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
attn_heads
,
size_t
num_gqa_groups
,
size_t
bias_heads
,
size_t
qk_head_dim
,
size_t
v_head_dim
,
size_t
max_segments_per_seq
,
size_t
wkspace_size
,
float
scaling_factor
,
float
dropout_probability
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_QKV_Layout
qkv_layout
,
DType
dtype
,
DType
wkspace_dtype
,
bool
is_training
,
bool
deterministic
,
int64_t
window_size_left
,
int64_t
window_size_right
)
{
cudaStream_t
stream
,
void
*
q
,
void
*
k
,
void
*
v
,
void
*
bias
,
void
*
softmax_offset
,
void
*
seed
,
void
*
q_cu_seqlens
,
void
*
kv_cu_seqlens
,
void
*
q_seq_offsets
,
void
*
k_seq_offsets
,
void
*
output
,
void
*
softmax_aux
,
void
*
rng_state
,
void
*
workspace
,
size_t
input_batch
,
size_t
bias_batch
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
attn_heads
,
size_t
num_gqa_groups
,
size_t
bias_heads
,
size_t
qk_head_dim
,
size_t
v_head_dim
,
size_t
max_segments_per_seq
,
size_t
wkspace_size
,
float
scaling_factor
,
float
dropout_probability
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
NVTE_QKV_Layout
qkv_layout
,
DType
dtype
,
DType
wkspace_dtype
,
bool
is_training
,
bool
deterministic
,
int64_t
window_size_left
,
int64_t
window_size_right
)
{
FUSED_ATTN_IMPL_COMMON_BLOCK
;
/* Input tensors */
auto
bias_tensor
=
TensorWrapper
(
bias
,
bias_shape
,
dtype
);
auto
softmax_offset_tensor
=
TensorWrapper
(
softmax_offset
,
std
::
vector
<
size_t
>
{
1
,
attn_heads
,
1
,
1
},
DType
::
kFloat32
);
if
(
is_ragged
)
{
auto
output_size
=
input_batch
*
q_max_seqlen
*
attn_heads
*
v_head_dim
;
...
...
@@ -238,10 +262,6 @@ static void FusedAttnForwardImpl(
/* Prepare RNG state */
auto
rng_state_tensor
=
TensorWrapper
(
rng_state
,
std
::
vector
<
size_t
>
{
2
},
DType
::
kInt64
);
auto
dummy_softmax_offset_tensor
=
TensorWrapper
(
nullptr
,
std
::
vector
<
size_t
>
{
1
},
DType
::
kFloat32
);
NVTE_Softmax_Type
softmax_type
=
NVTE_Softmax_Type
::
NVTE_VANILLA_SOFTMAX
;
auto
backend
=
nvte_get_fused_attn_backend
(
is_training
,
static_cast
<
NVTEDType
>
(
dtype
),
static_cast
<
NVTEDType
>
(
dtype
),
qkv_layout
,
bias_type
,
mask_type
,
softmax_type
,
dropout_probability
,
attn_heads
,
num_gqa_groups
,
...
...
@@ -254,7 +274,7 @@ static void FusedAttnForwardImpl(
nvte_tensor_pack_create
(
&
aux_output_tensors
);
PrepareFusedAttnForwardAuxTensors
(
&
aux_output_tensors
,
input_batch
,
bias_batch
,
attn_heads
,
bias_heads
,
q_max_seqlen
,
kv_max_seqlen
,
dtype
,
bias_type
,
backend
,
softmax_aux
);
backend
,
softmax_aux
,
softmax_offset
);
/* Call the underlying NVTE API */
auto
dummy_page_table_tensor
=
TensorWrapper
(
nullptr
,
std
::
vector
<
size_t
>
{
1
},
DType
::
kInt32
);
...
...
@@ -303,7 +323,7 @@ static void FusedAttnForwardImpl(
nvte_fused_attn_fwd
(
q_tensor
.
data
(),
k_tensor
.
data
(),
v_tensor
.
data
(),
bias_tensor
.
data
(),
dummy_
softmax_offset_tensor
.
data
(),
s_tensor
.
data
(),
o_tensor
.
data
(),
&
aux_output_tensors
,
softmax_offset_tensor
.
data
(),
s_tensor
.
data
(),
o_tensor
.
data
(),
&
aux_output_tensors
,
q_cu_seqlens_tensor
.
data
(),
kv_cu_seqlens_tensor
.
data
(),
q_seq_offsets_tensor
.
data
(),
k_seq_offsets_tensor
.
data
(),
dummy_page_table_tensor
.
data
(),
dummy_page_table_tensor
.
data
(),
rng_state_tensor
.
data
(),
q_max_seqlen
,
kv_max_seqlen
,
is_training
,
false
,
false
,
...
...
@@ -332,6 +352,8 @@ static void FusedAttnForwardImpl(
static_cast<NVTE_Bias_Type>(get_attr_value<int64_t>(attrs, "bias_type")); \
NVTE_Mask_Type mask_type = \
static_cast<NVTE_Mask_Type>(get_attr_value<int64_t>(attrs, "mask_type")); \
NVTE_Softmax_Type softmax_type = \
static_cast<NVTE_Softmax_Type>(get_attr_value<int64_t>(attrs, "softmax_type")); \
NVTE_QKV_Layout qkv_layout = \
static_cast<NVTE_QKV_Layout>(get_attr_value<int64_t>(attrs, "qkv_layout")); \
bool is_training = get_attr_value<bool>(attrs, "is_training"); \
...
...
@@ -342,7 +364,8 @@ static void FusedAttnForwardImpl(
DType wkspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
Error_Type
FusedAttnForwardFFI
(
cudaStream_t
stream
,
Buffer_Type
q_buf
,
Buffer_Type
k_buf
,
Buffer_Type
v_buf
,
Buffer_Type
bias_buf
,
Buffer_Type
seed_buf
,
Buffer_Type
v_buf
,
Buffer_Type
bias_buf
,
Buffer_Type
softmax_offset_buf
,
Buffer_Type
seed_buf
,
Buffer_Type
q_cu_seqlens_buf
,
Buffer_Type
kv_cu_seqlens_buf
,
Buffer_Type
q_seq_offsets_buf
,
Buffer_Type
k_seq_offsets_buf
,
Variadic_Buffer_Type
_unused_args
,
Result_Type
output_buf
,
...
...
@@ -352,15 +375,15 @@ Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Ty
FusedAttnForwardImpl
(
stream
,
q_buf
.
untyped_data
(),
k_buf
.
untyped_data
(),
v_buf
.
untyped_data
(),
bias_buf
.
untyped_data
(),
seed_buf
.
untyped_data
(),
q_cu_seqlens_buf
.
untyped_data
(),
kv_cu_seqlens_buf
.
untyped_data
(),
is_ragged
?
q_seq_offsets_buf
.
untyped_data
()
:
nullptr
,
bias_buf
.
untyped_data
(),
softmax_offset_buf
.
untyped_data
(),
seed_buf
.
untyped_data
(),
q_cu_seqlens_buf
.
untyped_data
(),
kv_cu_seqlens_buf
.
untyped_data
(),
is_ragged
?
q_seq_offsets_buf
.
untyped_data
()
:
nullptr
,
is_ragged
?
k_seq_offsets_buf
.
untyped_data
()
:
nullptr
,
output_buf
->
untyped_data
(),
softmax_aux_buf
->
untyped_data
(),
rng_state_buf
->
untyped_data
(),
workspace_buf
->
untyped_data
(),
input_batch
,
bias_batch
,
q_max_seqlen
,
kv_max_seqlen
,
attn_heads
,
num_gqa_groups
,
bias_heads
,
qk_head_dim
,
v_head_dim
,
max_segments_per_seq
,
wkspace_size
,
scaling_factor
,
dropout_probability
,
bias_type
,
mask_type
,
qkv_layout
,
dtype
,
wkspace_dtype
,
is_training
,
deterministic
,
window_size_left
,
window_size_right
);
dropout_probability
,
bias_type
,
mask_type
,
softmax_type
,
qkv_layout
,
dtype
,
wkspace_dtype
,
is_training
,
deterministic
,
window_size_left
,
window_size_right
);
return
ffi_with_cuda_error_check
();
}
...
...
@@ -371,6 +394,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnForwardHandler, FusedAttnForwardFFI,
.
Arg
<
Buffer_Type
>
()
// k
.
Arg
<
Buffer_Type
>
()
// v
.
Arg
<
Buffer_Type
>
()
// bias
.
Arg
<
Buffer_Type
>
()
// softmax_offset
.
Arg
<
Buffer_Type
>
()
// seed_buf
.
Arg
<
Buffer_Type
>
()
// q_cu_seqlens
.
Arg
<
Buffer_Type
>
()
// kv_cu_seqlens
...
...
@@ -388,9 +412,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t
input_batch
,
size_t
bias_batch
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
attn_heads
,
size_t
num_gqa_groups
,
size_t
bias_heads
,
size_t
qk_head_dim
,
size_t
v_head_dim
,
float
scaling_factor
,
float
dropout_probability
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_
QKV_Layout
qkv_layout
,
DType
dtype
,
bool
is_training
,
bool
deterministic
,
size_t
max_segments_per_seq
,
int64_t
window_size_left
,
int64_t
window_size_right
)
{
NVTE_Mask_Type
mask_type
,
NVTE_
Softmax_Type
softmax_type
,
NVTE_QKV_Layout
qkv_layout
,
DType
dtype
,
bool
is_training
,
bool
deterministic
,
size_t
max_segments_per_seq
,
int64_t
window_size_left
,
int64_t
window_size_right
)
{
auto
q_shape
=
std
::
vector
<
size_t
>
{
input_batch
*
q_max_seqlen
,
attn_heads
,
qk_head_dim
};
auto
q_tensor
=
TensorWrapper
(
nullptr
,
q_shape
,
dtype
);
auto
dq_tensor
=
TensorWrapper
(
nullptr
,
q_shape
,
dtype
);
...
...
@@ -425,9 +449,14 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
// For cuDNN < 9.3.0, it requires to run all possible seqlens to address act_seqlen = 0
min_num_segments
=
input_batch
*
max_segments_per_seq
;
}
auto
dummy_d_softmax_offset_tensor
=
TensorWrapper
(
nullptr
,
std
::
vector
<
size_t
>
{
1
},
DType
::
kFloat32
);
NVTE_Softmax_Type
softmax_type
=
NVTE_Softmax_Type
::
NVTE_VANILLA_SOFTMAX
;
TensorWrapper
dummy_d_softmax_offset_tensor
;
if
(
softmax_type
==
NVTE_Softmax_Type
::
NVTE_OFF_BY_ONE_SOFTMAX
||
softmax_type
==
NVTE_Softmax_Type
::
NVTE_LEARNABLE_SOFTMAX
)
{
dummy_d_softmax_offset_tensor
=
TensorWrapper
(
nullptr
,
std
::
vector
<
size_t
>
{
1
,
attn_heads
,
1
,
1
},
DType
::
kFloat32
);
}
for
(
auto
num_segments
=
min_num_segments
;
num_segments
<=
max_num_segments
;
++
num_segments
)
{
// the last one is the largest which will be the returned workspace size
auto
q_cu_seqlens_tensor
=
...
...
@@ -457,15 +486,16 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
}
static
void
FusedAttnBackwardImpl
(
cudaStream_t
stream
,
void
*
q
,
void
*
k
,
void
*
v
,
void
*
bias
,
void
*
softmax_aux
,
void
*
rng_state
,
void
*
output
,
void
*
doutput
,
void
*
q_cu_seqlens
,
void
*
kv_cu_seqlens
,
void
*
q_seq_offsets
,
void
*
k_seq_offsets
,
void
*
dq
,
void
*
dk
,
void
*
dv
,
void
*
dbias
,
void
*
workspace
,
size_t
input_batch
,
size_t
bias_batch
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
attn_heads
,
size_t
num_gqa_groups
,
size_t
bias_heads
,
size_t
qk_head_dim
,
size_t
v_head_dim
,
size_t
max_segments_per_seq
,
size_t
wkspace_size
,
float
scaling_factor
,
float
dropout_probability
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_QKV_Layout
qkv_layout
,
DType
dtype
,
DType
wkspace_dtype
,
bool
is_training
,
bool
deterministic
,
int64_t
window_size_left
,
int64_t
window_size_right
)
{
cudaStream_t
stream
,
void
*
q
,
void
*
k
,
void
*
v
,
void
*
bias
,
void
*
softmax_offset
,
void
*
softmax_aux
,
void
*
rng_state
,
void
*
output
,
void
*
doutput
,
void
*
q_cu_seqlens
,
void
*
kv_cu_seqlens
,
void
*
q_seq_offsets
,
void
*
k_seq_offsets
,
void
*
dq
,
void
*
dk
,
void
*
dv
,
void
*
dbias
,
void
*
dsoftmax_offset
,
void
*
workspace
,
size_t
input_batch
,
size_t
bias_batch
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
attn_heads
,
size_t
num_gqa_groups
,
size_t
bias_heads
,
size_t
qk_head_dim
,
size_t
v_head_dim
,
size_t
max_segments_per_seq
,
size_t
wkspace_size
,
float
scaling_factor
,
float
dropout_probability
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
NVTE_QKV_Layout
qkv_layout
,
DType
dtype
,
DType
wkspace_dtype
,
bool
is_training
,
bool
deterministic
,
int64_t
window_size_left
,
int64_t
window_size_right
)
{
FUSED_ATTN_IMPL_COMMON_BLOCK
;
/* Input tensors */
...
...
@@ -476,9 +506,13 @@ static void FusedAttnBackwardImpl(
/* Output tensors */
auto
s_tensor
=
TensorWrapper
(
nullptr
,
std
::
vector
<
size_t
>
{
1
},
dtype
);
// not used in F16
auto
dbias_tensor
=
TensorWrapper
(
dbias
,
bias_shape
,
dtype
);
auto
dummy_d_softmax_offset_tensor
=
TensorWrapper
(
nullptr
,
std
::
vector
<
size_t
>
{
1
},
DType
::
kFloat32
);
NVTE_Softmax_Type
softmax_type
=
NVTE_Softmax_Type
::
NVTE_VANILLA_SOFTMAX
;
TensorWrapper
dsoftmax_offset_tensor
;
if
(
softmax_type
==
NVTE_Softmax_Type
::
NVTE_OFF_BY_ONE_SOFTMAX
||
softmax_type
==
NVTE_Softmax_Type
::
NVTE_LEARNABLE_SOFTMAX
)
{
dsoftmax_offset_tensor
=
TensorWrapper
(
dsoftmax_offset
,
std
::
vector
<
size_t
>
{
1
,
attn_heads
,
1
,
1
},
DType
::
kFloat32
);
}
/* Auxiliary tensors (propagated from the forward pass) */
NVTETensorPack
aux_input_tensors
;
...
...
@@ -490,7 +524,7 @@ static void FusedAttnBackwardImpl(
false
,
false
);
PrepareFusedAttnBackwardAuxTensors
(
&
aux_input_tensors
,
input_batch
,
bias_batch
,
attn_heads
,
bias_heads
,
q_max_seqlen
,
kv_max_seqlen
,
dtype
,
backend
,
softmax_aux
,
rng_state
,
bias
);
softmax_aux
,
rng_state
,
bias
,
softmax_offset
);
/* Call the underly NVTE API */
// Prepare Q, K, V pointers and shapes based on layout
...
...
@@ -564,7 +598,7 @@ static void FusedAttnBackwardImpl(
s_tensor
.
data
(),
// not used for F16
s_tensor
.
data
(),
// not used for F16
&
aux_input_tensors
,
dq_tensor
.
data
(),
dk_tensor
.
data
(),
dv_tensor
.
data
(),
dbias_tensor
.
data
(),
d
ummy_d_
softmax_offset_tensor
.
data
(),
q_cu_seqlens_tensor
.
data
(),
kv_cu_seqlens_tensor
.
data
(),
dsoftmax_offset_tensor
.
data
(),
q_cu_seqlens_tensor
.
data
(),
kv_cu_seqlens_tensor
.
data
(),
q_seq_offsets_tensor
.
data
(),
k_seq_offsets_tensor
.
data
(),
q_max_seqlen
,
kv_max_seqlen
,
scaling_factor
,
dropout_probability
,
qkv_layout
,
bias_type
,
mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
deterministic
,
false
,
workspace_tensor
.
data
(),
stream
);
...
...
@@ -574,26 +608,29 @@ static void FusedAttnBackwardImpl(
Error_Type
FusedAttnBackwardFFI
(
cudaStream_t
stream
,
Buffer_Type
q_buf
,
Buffer_Type
k_buf
,
Buffer_Type
v_buf
,
Buffer_Type
bias_buf
,
Buffer_Type
softmax_aux_buf
,
Buffer_Type
rng_state_buf
,
Buffer_Type
output_buf
,
Buffer_Type
doutput_buf
,
Buffer_Type
q_cu_seqlens_buf
,
Buffer_Type
kv_cu_seqlens_buf
,
Buffer_Type
q_seq_offsets_buf
,
Buffer_Type
k_seq_offsets_buf
,
Variadic_Buffer_Type
_unused_args
,
Result_Type
dq_buf
,
Result_Type
dk_buf
,
Result_Type
dv_buf
,
Result_Type
dbias_buf
,
Buffer_Type
softmax_offset_buf
,
Buffer_Type
softmax_aux_buf
,
Buffer_Type
rng_state_buf
,
Buffer_Type
output_buf
,
Buffer_Type
doutput_buf
,
Buffer_Type
q_cu_seqlens_buf
,
Buffer_Type
kv_cu_seqlens_buf
,
Buffer_Type
q_seq_offsets_buf
,
Buffer_Type
k_seq_offsets_buf
,
Variadic_Buffer_Type
_unused_args
,
Result_Type
dq_buf
,
Result_Type
dk_buf
,
Result_Type
dv_buf
,
Result_Type
dbias_buf
,
Result_Type
dsoftmax_offset_buf
,
Result_Type
workspace_buf
,
Dictionary
attrs
)
{
FUSED_ATTN_FFI_GET_ATTRS
;
FusedAttnBackwardImpl
(
stream
,
q_buf
.
untyped_data
(),
k_buf
.
untyped_data
(),
v_buf
.
untyped_data
(),
bias_buf
.
untyped_data
(),
softmax_aux_buf
.
untyped_data
(),
rng_state_buf
.
untyped_data
(),
output_buf
.
untyped_data
(),
doutput_buf
.
untyped_data
(),
q_cu_seqlens_buf
.
untyped_data
(),
kv_cu_seqlens_buf
.
untyped_data
(),
is_ragged
?
q_seq_offsets_buf
.
untyped_data
()
:
nullptr
,
bias_buf
.
untyped_data
(),
softmax_offset_buf
.
untyped_data
(),
softmax_aux_buf
.
untyped_data
(),
rng_state_buf
.
untyped_data
(),
output_buf
.
untyped_data
(),
doutput_buf
.
untyped_data
(),
q_cu_seqlens_buf
.
untyped_data
(),
kv_cu_seqlens_buf
.
untyped_data
(),
is_ragged
?
q_seq_offsets_buf
.
untyped_data
()
:
nullptr
,
is_ragged
?
k_seq_offsets_buf
.
untyped_data
()
:
nullptr
,
dq_buf
->
untyped_data
(),
dk_buf
->
untyped_data
(),
dv_buf
->
untyped_data
(),
dbias_buf
->
untyped_data
(),
workspace_buf
->
untyped_data
(),
input_batch
,
bias_batch
,
q_max_seqlen
,
kv_max_seqlen
,
attn_heads
,
num_gqa_groups
,
bias_heads
,
qk_head_dim
,
v_head_dim
,
max_segments_per_seq
,
wkspace_size
,
scaling_factor
,
dropout_probability
,
bias_type
,
mask_type
,
qkv_layout
,
dtype
,
wkspace_dtype
,
is_training
,
deterministic
,
window_size_left
,
window_size_right
);
dsoftmax_offset_buf
->
untyped_data
(),
workspace_buf
->
untyped_data
(),
input_batch
,
bias_batch
,
q_max_seqlen
,
kv_max_seqlen
,
attn_heads
,
num_gqa_groups
,
bias_heads
,
qk_head_dim
,
v_head_dim
,
max_segments_per_seq
,
wkspace_size
,
scaling_factor
,
dropout_probability
,
bias_type
,
mask_type
,
softmax_type
,
qkv_layout
,
dtype
,
wkspace_dtype
,
is_training
,
deterministic
,
window_size_left
,
window_size_right
);
return
ffi_with_cuda_error_check
();
}
...
...
@@ -605,6 +642,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnBackwardHandler, FusedAttnBackwardFFI,
.
Arg
<
Buffer_Type
>
()
// k
.
Arg
<
Buffer_Type
>
()
// v
.
Arg
<
Buffer_Type
>
()
// bias
.
Arg
<
Buffer_Type
>
()
// softmax_offset
.
Arg
<
Buffer_Type
>
()
// softmax_aux
.
Arg
<
Buffer_Type
>
()
// rng_state
.
Arg
<
Buffer_Type
>
()
// output
...
...
@@ -618,6 +656,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnBackwardHandler, FusedAttnBackwardFFI,
.
Ret
<
Buffer_Type
>
()
// dk
.
Ret
<
Buffer_Type
>
()
// dv
.
Ret
<
Buffer_Type
>
()
// dbias
.
Ret
<
Buffer_Type
>
()
// dsoftmax_offset
.
Ret
<
Buffer_Type
>
()
// workspace
.
Attrs
(),
FFI_CudaGraph_Traits
);
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment