Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
970620a5
Commit
970620a5
authored
Dec 27, 2025
by
wenjh
Browse files
merge nv_release_v2.10 to release_v2.10
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parents
c1a1c04e
769ed778
Changes
135
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
648 additions
and
240 deletions
+648
-240
tests/pytorch/distributed/test_comm_gemm_overlap.py
tests/pytorch/distributed/test_comm_gemm_overlap.py
+6
-0
tests/pytorch/distributed/test_sanity.py
tests/pytorch/distributed/test_sanity.py
+73
-13
tests/pytorch/test_fusible_ops.py
tests/pytorch/test_fusible_ops.py
+5
-5
tests/pytorch/test_numerics.py
tests/pytorch/test_numerics.py
+0
-7
tests/pytorch/test_sanity.py
tests/pytorch/test_sanity.py
+1
-3
tests/pytorch/utils.py
tests/pytorch/utils.py
+10
-0
transformer_engine/common/fused_attn/kv_cache.cu
transformer_engine/common/fused_attn/kv_cache.cu
+1
-1
transformer_engine/common/include/transformer_engine/fused_attn.h
...mer_engine/common/include/transformer_engine/fused_attn.h
+1
-1
transformer_engine/common/recipe/__init__.py
transformer_engine/common/recipe/__init__.py
+20
-18
transformer_engine/common/transformer_engine.cpp
transformer_engine/common/transformer_engine.cpp
+12
-7
transformer_engine/jax/attention.py
transformer_engine/jax/attention.py
+75
-18
transformer_engine/jax/cpp_extensions/activation.py
transformer_engine/jax/cpp_extensions/activation.py
+1
-1
transformer_engine/jax/cpp_extensions/attention.py
transformer_engine/jax/cpp_extensions/attention.py
+236
-56
transformer_engine/jax/cpp_extensions/gemm.py
transformer_engine/jax/cpp_extensions/gemm.py
+1
-1
transformer_engine/jax/cpp_extensions/misc.py
transformer_engine/jax/cpp_extensions/misc.py
+1
-1
transformer_engine/jax/cpp_extensions/normalization.py
transformer_engine/jax/cpp_extensions/normalization.py
+1
-1
transformer_engine/jax/cpp_extensions/quantization.py
transformer_engine/jax/cpp_extensions/quantization.py
+19
-17
transformer_engine/jax/cpp_extensions/softmax.py
transformer_engine/jax/cpp_extensions/softmax.py
+65
-9
transformer_engine/jax/csrc/extensions.h
transformer_engine/jax/csrc/extensions.h
+12
-12
transformer_engine/jax/csrc/extensions/attention.cpp
transformer_engine/jax/csrc/extensions/attention.cpp
+108
-69
No files found.
tests/pytorch/distributed/test_comm_gemm_overlap.py
View file @
970620a5
...
@@ -127,12 +127,18 @@ def _run_layer_with_overlap(
...
@@ -127,12 +127,18 @@ def _run_layer_with_overlap(
os
.
environ
[
"PYTORCH_JIT"
]
=
"0"
os
.
environ
[
"PYTORCH_JIT"
]
=
"0"
os
.
environ
[
"NVTE_TORCH_COMPILE"
]
=
"0"
os
.
environ
[
"NVTE_TORCH_COMPILE"
]
=
"0"
os
.
environ
[
"NVTE_ALLOW_NONDETERMINISTIC_ALGO"
]
=
"0"
os
.
environ
[
"NVTE_ALLOW_NONDETERMINISTIC_ALGO"
]
=
"0"
if
te
.
get_device_compute_capability
()
<=
(
8
,
0
):
# We've experienced numerical discrepancies in Flash Attention
# backward when running with Userbuffers on A100s. This does
# not show up in more recent GPUs.
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
result
=
subprocess
.
run
(
test_cmd
,
env
=
os
.
environ
,
capture_output
=
True
,
check
=
False
)
result
=
subprocess
.
run
(
test_cmd
,
env
=
os
.
environ
,
capture_output
=
True
,
check
=
False
)
os
.
unsetenv
(
"PYTORCH_JIT"
)
os
.
unsetenv
(
"PYTORCH_JIT"
)
os
.
unsetenv
(
"NVTE_TORCH_COMPILE"
)
os
.
unsetenv
(
"NVTE_TORCH_COMPILE"
)
os
.
unsetenv
(
"NVTE_ALLOW_NONDETERMINISTIC_ALGO"
)
os
.
unsetenv
(
"NVTE_ALLOW_NONDETERMINISTIC_ALGO"
)
os
.
unsetenv
(
"NVTE_FLASH_ATTN"
)
if
(
if
(
result
.
returncode
!=
0
result
.
returncode
!=
0
...
...
tests/pytorch/distributed/test_sanity.py
View file @
970620a5
...
@@ -7,7 +7,7 @@ import sys
...
@@ -7,7 +7,7 @@ import sys
import
pytest
import
pytest
import
torch
import
torch
import
transformer_engine
import
transformer_engine
from
transformer_engine.pytorch
import
DotProductAttention
,
TransformerLayer
,
Linear
from
transformer_engine.pytorch
import
DotProductAttention
,
TransformerLayer
,
Linear
,
GroupedLinear
_current_file
=
pathlib
.
Path
(
__file__
).
resolve
()
_current_file
=
pathlib
.
Path
(
__file__
).
resolve
()
sys
.
path
.
append
(
str
(
_current_file
.
parent
.
parent
))
sys
.
path
.
append
(
str
(
_current_file
.
parent
.
parent
))
...
@@ -19,7 +19,9 @@ model_configs = {
...
@@ -19,7 +19,9 @@ model_configs = {
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"small"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"small"
])
@
pytest
.
mark
.
parametrize
(
"module"
,
[
"TransformerLayer"
,
"DotProductAttention"
,
"Linear"
])
@
pytest
.
mark
.
parametrize
(
"module"
,
[
"TransformerLayer"
,
"DotProductAttention"
,
"Linear"
,
"GroupedLinear"
]
)
def
test_current_device
(
model
,
module
):
def
test_current_device
(
model
,
module
):
"""Test cases where current device is different from tensor device"""
"""Test cases where current device is different from tensor device"""
...
@@ -42,7 +44,29 @@ def test_current_device(model, module):
...
@@ -42,7 +44,29 @@ def test_current_device(model, module):
self_attn_mask_type
=
"padding"
,
self_attn_mask_type
=
"padding"
,
device
=
f
"cuda:
{
tensor_device
}
"
,
device
=
f
"cuda:
{
tensor_device
}
"
,
)
)
num_tokens
=
torch
.
randint
(
0
,
config
.
max_seqlen_q
,
(
1
,)).
item
()
seqlens_q
=
torch
.
randint
(
1
,
config
.
max_seqlen_q
,
[
config
.
batch_size
],
dtype
=
torch
.
int32
,
device
=
f
"cuda:
{
tensor_device
}
"
,
)
cu_seqlens_q
=
torch
.
zeros
(
config
.
batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
f
"cuda:
{
tensor_device
}
"
)
cu_seqlens_q
[
1
:]
=
torch
.
cumsum
(
seqlens_q
,
dim
=
0
)
seqlens_kv
=
torch
.
randint
(
1
,
config
.
max_seqlen_kv
,
[
config
.
batch_size
],
dtype
=
torch
.
int32
,
device
=
f
"cuda:
{
tensor_device
}
"
,
)
cu_seqlens_kv
=
torch
.
zeros
(
config
.
batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
f
"cuda:
{
tensor_device
}
"
)
cu_seqlens_kv
[
1
:]
=
torch
.
cumsum
(
seqlens_kv
,
dim
=
0
)
num_tokens
=
cu_seqlens_q
[
-
1
]
args
=
[
args
=
[
torch
.
randn
(
torch
.
randn
(
(
num_tokens
,
config
.
hidden_size
),
(
num_tokens
,
config
.
hidden_size
),
...
@@ -51,37 +75,55 @@ def test_current_device(model, module):
...
@@ -51,37 +75,55 @@ def test_current_device(model, module):
requires_grad
=
True
,
requires_grad
=
True
,
)
)
]
]
cu_seqlens_q
,
cu_seqlens_kv
=
[
torch
.
Tensor
([
0
,
2
,
3
]).
to
(
dtype
=
torch
.
int32
,
device
=
tensor_device
)
for
_
in
range
(
2
)
]
kwargs
[
"cu_seqlens_q"
]
=
cu_seqlens_q
kwargs
[
"cu_seqlens_q"
]
=
cu_seqlens_q
kwargs
[
"cu_seqlens_kv"
]
=
cu_seqlens_kv
kwargs
[
"cu_seqlens_kv"
]
=
cu_seqlens_kv
kwargs
[
"max_seqlen_q"
]
=
config
.
max_seqlen_q
kwargs
[
"max_seqlen_q"
]
=
config
.
max_seqlen_q
kwargs
[
"max_seqlen_kv"
]
=
config
.
max_seqlen_kv
kwargs
[
"max_seqlen_kv"
]
=
config
.
max_seqlen_kv
if
module
==
"DotProductAttention"
:
el
if
module
==
"DotProductAttention"
:
model
=
DotProductAttention
(
model
=
DotProductAttention
(
config
.
num_heads
,
config
.
head_dim_qk
,
qkv_format
=
"thd"
,
attn_mask_type
=
"padding"
config
.
num_heads
,
config
.
head_dim_qk
,
qkv_format
=
"thd"
,
attn_mask_type
=
"padding"
)
)
num_tokens
=
torch
.
randint
(
0
,
config
.
max_seqlen_q
,
(
1
,)).
item
()
seqlens_q
=
torch
.
randint
(
1
,
config
.
max_seqlen_q
,
[
config
.
batch_size
],
dtype
=
torch
.
int32
,
device
=
f
"cuda:
{
tensor_device
}
"
,
)
cu_seqlens_q
=
torch
.
zeros
(
config
.
batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
f
"cuda:
{
tensor_device
}
"
)
cu_seqlens_q
[
1
:]
=
torch
.
cumsum
(
seqlens_q
,
dim
=
0
)
seqlens_kv
=
torch
.
randint
(
1
,
config
.
max_seqlen_kv
,
[
config
.
batch_size
],
dtype
=
torch
.
int32
,
device
=
f
"cuda:
{
tensor_device
}
"
,
)
cu_seqlens_kv
=
torch
.
zeros
(
config
.
batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
f
"cuda:
{
tensor_device
}
"
)
cu_seqlens_kv
[
1
:]
=
torch
.
cumsum
(
seqlens_kv
,
dim
=
0
)
num_tokens
=
cu_seqlens_q
[
-
1
]
args
=
[
args
=
[
torch
.
randn
(
torch
.
randn
(
num_tokens
,
num_tokens
,
config
.
num_heads
,
config
.
num_heads
,
config
.
head_dim_qk
,
config
.
head_dim_qk
,
dtype
=
dtype
,
dtype
=
dtype
,
device
=
tensor_device
,
device
=
f
"cuda:
{
tensor_device
}
"
,
requires_grad
=
True
,
requires_grad
=
True
,
)
)
for
_
in
range
(
3
)
for
_
in
range
(
3
)
]
]
cu_seqlens_q
,
cu_seqlens_kv
=
[
torch
.
Tensor
([
0
,
2
,
3
]).
to
(
dtype
=
torch
.
int32
,
device
=
tensor_device
)
for
_
in
range
(
2
)
]
kwargs
[
"cu_seqlens_q"
]
=
cu_seqlens_q
kwargs
[
"cu_seqlens_q"
]
=
cu_seqlens_q
kwargs
[
"cu_seqlens_kv"
]
=
cu_seqlens_kv
kwargs
[
"cu_seqlens_kv"
]
=
cu_seqlens_kv
kwargs
[
"max_seqlen_q"
]
=
config
.
max_seqlen_q
kwargs
[
"max_seqlen_q"
]
=
config
.
max_seqlen_q
kwargs
[
"max_seqlen_kv"
]
=
config
.
max_seqlen_kv
kwargs
[
"max_seqlen_kv"
]
=
config
.
max_seqlen_kv
bwd_args
=
[
torch
.
randn
(
num_tokens
,
config
.
hidden_size
,
dtype
=
dtype
,
device
=
tensor_device
)]
bwd_args
=
[
torch
.
randn
(
num_tokens
,
config
.
hidden_size
,
dtype
=
dtype
,
device
=
f
"cuda:
{
tensor_device
}
"
)
]
elif
module
==
"Linear"
:
elif
module
==
"Linear"
:
model
=
Linear
(
model
=
Linear
(
config
.
hidden_size
,
config
.
hidden_size
,
...
@@ -97,6 +139,24 @@ def test_current_device(model, module):
...
@@ -97,6 +139,24 @@ def test_current_device(model, module):
requires_grad
=
True
,
requires_grad
=
True
,
)
)
]
]
elif
module
==
"GroupedLinear"
:
num_gemms
=
4
model
=
GroupedLinear
(
num_gemms
,
config
.
hidden_size
,
4
*
config
.
hidden_size
,
params_dtype
=
dtype
,
device
=
f
"cuda:
{
tensor_device
}
"
,
)
args
=
[
torch
.
randn
(
(
config
.
max_seqlen_q
*
config
.
batch_size
*
(
num_gemms
-
1
),
config
.
hidden_size
),
dtype
=
dtype
,
device
=
f
"cuda:
{
tensor_device
}
"
,
requires_grad
=
True
,
),
[
0
]
+
[
config
.
max_seqlen_q
*
config
.
batch_size
]
*
(
num_gemms
-
1
),
# Empty first split.
]
current_device_before
=
torch
.
cuda
.
current_device
()
current_device_before
=
torch
.
cuda
.
current_device
()
out
=
model
(
*
args
,
**
kwargs
)
out
=
model
(
*
args
,
**
kwargs
)
...
...
tests/pytorch/test_fusible_ops.py
View file @
970620a5
...
@@ -913,15 +913,15 @@ class TestBasicOps:
...
@@ -913,15 +913,15 @@ class TestBasicOps:
dtype
=
dtype
,
dtype
=
dtype
,
accumulate_into_main_grad
=
accumulate_into_main_grad
,
accumulate_into_main_grad
=
accumulate_into_main_grad
,
)
)
forward
=
te_ops
.
Sequential
(
te_ops
.
Quantize
(
forward
=
quantized_input
,
backward
=
quantized_grad_input
),
op
,
te_ops
.
Quantize
(
forward
=
quantized_output
,
backward
=
quantized_grad_output
),
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
op
.
weight
.
copy_
(
w_test
)
op
.
weight
.
copy_
(
w_test
)
del
w_test
del
w_test
op
.
weight
.
main_grad
=
torch
.
full_like
(
op
.
weight
,
0.5
,
dtype
=
torch
.
float32
)
op
.
weight
.
main_grad
=
torch
.
full_like
(
op
.
weight
,
0.5
,
dtype
=
torch
.
float32
)
forward
=
te_ops
.
Sequential
(
te_ops
.
Quantize
(
forward
=
quantized_input
,
backward
=
quantized_grad_input
),
op
,
te_ops
.
Quantize
(
forward
=
quantized_output
,
backward
=
quantized_grad_output
),
)
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y_test
=
forward
(
x_test
)
y_test
=
forward
(
x_test
)
y_test
.
backward
(
dy_test
)
y_test
.
backward
(
dy_test
)
...
...
tests/pytorch/test_numerics.py
View file @
970620a5
...
@@ -46,7 +46,6 @@ from transformer_engine.pytorch import (
...
@@ -46,7 +46,6 @@ from transformer_engine.pytorch import (
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch
import
checkpoint
as
te_checkpoint
from
transformer_engine.pytorch
import
checkpoint
as
te_checkpoint
from
transformer_engine.pytorch.cpp_extensions
import
general_gemm
,
general_grouped_gemm
from
transformer_engine.pytorch.cpp_extensions
import
general_gemm
,
general_grouped_gemm
from
transformer_engine.pytorch.module.base
import
get_multi_stream_cublas_workspace
,
get_workspace
from
transformer_engine.common
import
recipe
from
transformer_engine.common
import
recipe
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
utils
import
ModelConfig
,
reset_rng_states
from
utils
import
ModelConfig
,
reset_rng_states
...
@@ -2757,7 +2756,6 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
...
@@ -2757,7 +2756,6 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
general_gemm
(
general_gemm
(
A
[
i
],
A
[
i
],
B
[
i
],
B
[
i
],
get_workspace
(),
dtype
,
dtype
,
grad
=
grad
,
grad
=
grad
,
accumulate
=
accumulate
,
accumulate
=
accumulate
,
...
@@ -2772,7 +2770,6 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
...
@@ -2772,7 +2770,6 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
B
,
B
,
out
,
out
,
dtype
,
dtype
,
get_multi_stream_cublas_workspace
(),
m_splits
=
m_splits
,
m_splits
=
m_splits
,
grad
=
grad
,
grad
=
grad
,
accumulate
=
accumulate
,
accumulate
=
accumulate
,
...
@@ -2832,7 +2829,6 @@ def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_qua
...
@@ -2832,7 +2829,6 @@ def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_qua
quantized_out
,
*
_
=
general_gemm
(
quantized_out
,
*
_
=
general_gemm
(
weight_fp8
,
weight_fp8
,
inp_fp8
,
inp_fp8
,
get_workspace
(),
outp_type
,
outp_type
,
quantization_params
=
out_quantizer
,
quantization_params
=
out_quantizer
,
bias
=
None
,
bias
=
None
,
...
@@ -2842,7 +2838,6 @@ def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_qua
...
@@ -2842,7 +2838,6 @@ def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_qua
out
,
*
_
=
general_gemm
(
out
,
*
_
=
general_gemm
(
weight_fp8
,
weight_fp8
,
inp_fp8
,
inp_fp8
,
get_workspace
(),
outp_type
,
outp_type
,
quantization_params
=
None
,
quantization_params
=
None
,
bias
=
None
,
bias
=
None
,
...
@@ -2918,7 +2913,6 @@ def test_fp8_grouped_gemm(shape, accumulate):
...
@@ -2918,7 +2913,6 @@ def test_fp8_grouped_gemm(shape, accumulate):
general_gemm
(
general_gemm
(
A_fp8
[
i
],
A_fp8
[
i
],
B_fp8
[
i
],
B_fp8
[
i
],
get_workspace
(),
dtype
,
dtype
,
out
=
out_ref
[
i
],
out
=
out_ref
[
i
],
accumulate
=
accumulate
,
accumulate
=
accumulate
,
...
@@ -2928,7 +2922,6 @@ def test_fp8_grouped_gemm(shape, accumulate):
...
@@ -2928,7 +2922,6 @@ def test_fp8_grouped_gemm(shape, accumulate):
B_fp8
,
B_fp8
,
out
,
out
,
dtype
,
dtype
,
get_multi_stream_cublas_workspace
(),
m_splits
=
m_splits
,
m_splits
=
m_splits
,
accumulate
=
accumulate
,
accumulate
=
accumulate
,
)
)
...
...
tests/pytorch/test_sanity.py
View file @
970620a5
...
@@ -37,7 +37,6 @@ from transformer_engine.pytorch import (
...
@@ -37,7 +37,6 @@ from transformer_engine.pytorch import (
from
transformer_engine.common
import
recipe
from
transformer_engine.common
import
recipe
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.cpp_extensions
import
general_gemm
from
transformer_engine.pytorch.cpp_extensions
import
general_gemm
from
transformer_engine.pytorch.module.base
import
get_workspace
from
transformer_engine.pytorch.tensor.utils
import
replace_raw_data
from
transformer_engine.pytorch.tensor.utils
import
replace_raw_data
from
utils
import
ModelConfig
from
utils
import
ModelConfig
...
@@ -961,7 +960,7 @@ def test_sanity_gemm_with_unalignment(N, offset, datatype):
...
@@ -961,7 +960,7 @@ def test_sanity_gemm_with_unalignment(N, offset, datatype):
inp
=
torch
.
reshape
(
scratchpad
[
offset
:
-
offset
],
(
N
,
N
))
inp
=
torch
.
reshape
(
scratchpad
[
offset
:
-
offset
],
(
N
,
N
))
weight
=
torch
.
reshape
(
scratchpad
[
offset
*
2
:],
(
N
,
N
))
weight
=
torch
.
reshape
(
scratchpad
[
offset
*
2
:],
(
N
,
N
))
_
=
general_gemm
(
A
=
weight
,
B
=
inp
,
workspace
=
get_workspace
()
)
_
=
general_gemm
(
A
=
weight
,
B
=
inp
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
...
@@ -985,7 +984,6 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype):
...
@@ -985,7 +984,6 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype):
general_gemm
(
general_gemm
(
weight_fp8
,
weight_fp8
,
inp_fp8
,
inp_fp8
,
get_workspace
(),
outp_type
,
outp_type
,
bias
=
None
,
bias
=
None
,
use_split_accumulator
=
False
,
use_split_accumulator
=
False
,
...
...
tests/pytorch/utils.py
View file @
970620a5
...
@@ -8,6 +8,7 @@ import logging
...
@@ -8,6 +8,7 @@ import logging
import
os
import
os
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
Optional
,
Tuple
,
Dict
,
Any
,
List
from
typing
import
Optional
,
Tuple
,
Dict
,
Any
,
List
from
packaging.version
import
Version
as
PkgVersion
import
torch
import
torch
...
@@ -210,6 +211,7 @@ class ModelConfig:
...
@@ -210,6 +211,7 @@ class ModelConfig:
max_ctx_len
:
int
=
None
,
max_ctx_len
:
int
=
None
,
num_layers
:
int
=
1
,
num_layers
:
int
=
1
,
eps
:
float
=
1e-5
,
eps
:
float
=
1e-5
,
num_splits
=
1
,
):
):
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
self
.
max_seqlen_q
=
max_seqlen_q
self
.
max_seqlen_q
=
max_seqlen_q
...
@@ -239,6 +241,7 @@ class ModelConfig:
...
@@ -239,6 +241,7 @@ class ModelConfig:
self
.
max_ctx_len
=
max_ctx_len
self
.
max_ctx_len
=
max_ctx_len
self
.
num_layers
=
num_layers
self
.
num_layers
=
num_layers
self
.
eps
=
eps
self
.
eps
=
eps
self
.
num_splits
=
num_splits
@
contextmanager
@
contextmanager
...
@@ -321,6 +324,9 @@ def get_available_attention_backends(
...
@@ -321,6 +324,9 @@ def get_available_attention_backends(
inference_params
=
inference_params
,
inference_params
=
inference_params
,
softmax_type
=
config
.
softmax_type
,
softmax_type
=
config
.
softmax_type
,
return_max_logit
=
config
.
return_max_logit
,
return_max_logit
=
config
.
return_max_logit
,
# allow all backends to pass so they can be used for testing;
# check for FA3 availability later
num_splits
=
1
,
)
)
(
(
use_flash_attention
,
use_flash_attention
,
...
@@ -330,6 +336,10 @@ def get_available_attention_backends(
...
@@ -330,6 +336,10 @@ def get_available_attention_backends(
use_unfused_attention
,
use_unfused_attention
,
available_backends
,
available_backends
,
)
=
get_attention_backend
(
attention_params
)
)
=
get_attention_backend
(
attention_params
)
# Check if FA3 is an available backend when num_splits != 1
if
available_backends
[
0
]:
if
config
.
num_splits
!=
1
and
not
flash_attention_backend
>
PkgVersion
(
"3.0.0b"
):
available_backends
[
0
]
=
False
# Set attention.py _attention_backends var using return value
# Set attention.py _attention_backends var using return value
# from get_attention_backend()
# from get_attention_backend()
_attention_backends
[
"use_flash_attention"
]
=
use_flash_attention
_attention_backends
[
"use_flash_attention"
]
=
use_flash_attention
...
...
transformer_engine/common/fused_attn/kv_cache.cu
View file @
970620a5
...
@@ -278,7 +278,7 @@ void convert_bshd_to_thd(Tensor tensor, Tensor cu_seqlens, Tensor new_tensor, in
...
@@ -278,7 +278,7 @@ void convert_bshd_to_thd(Tensor tensor, Tensor cu_seqlens, Tensor new_tensor, in
/***************************************************************************************************
/***************************************************************************************************
* KV Cache: Copy new KV tokens to the KV cache
* KV Cache: Copy new KV tokens to the KV cache
* 1. new_k and new_v are in qkv_format; k_cache and v_cache are in 'bshd' format
* 1. new_k and new_v are in qkv_format; k_cache and v_cache are in 'bshd' format
* 2. cu_new_lens and cu_cached_lens are
in
shape [b + 1]; cu_cached_lens include the added lens
* 2. cu_new_lens and cu_cached_lens are
of
shape [b + 1]; cu_cached_lens include the added lens
* in current step
* in current step
* 3. Non-paged KV cache is a special case of paged KV cache, with page_table = [b, 1] and
* 3. Non-paged KV cache is a special case of paged KV cache, with page_table = [b, 1] and
* max_pages_per_seq = 1. We use the same underlying kernel for both non-paged and paged.
* max_pages_per_seq = 1. We use the same underlying kernel for both non-paged and paged.
...
...
transformer_engine/common/include/transformer_engine/fused_attn.h
View file @
970620a5
...
@@ -131,7 +131,7 @@ enum NVTE_Mask_Type {
...
@@ -131,7 +131,7 @@ enum NVTE_Mask_Type {
* NVTE_VANILLA_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
* NVTE_VANILLA_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
* NVTE_OFF_BY_ONE_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
* NVTE_OFF_BY_ONE_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
* NVTE_LEARNABLE_SOFTMAX: S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
* NVTE_LEARNABLE_SOFTMAX: S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
* where alpha is a learnable parameter
in
shape [H].
* where alpha is a learnable parameter
of
shape [H].
*/
*/
enum
NVTE_Softmax_Type
{
enum
NVTE_Softmax_Type
{
/*! Vanilla softmax */
/*! Vanilla softmax */
...
...
transformer_engine/common/recipe/__init__.py
View file @
970620a5
...
@@ -50,7 +50,7 @@ class MMParams:
...
@@ -50,7 +50,7 @@ class MMParams:
Parameters
Parameters
----------
----------
use_split_accumulator : bool, default =
`
True
`
use_split_accumulator : bool, default = True
Use FP8 fast accumulation on Hopper or Ada. For more details,
Use FP8 fast accumulation on Hopper or Ada. For more details,
see CUBLASLT_MATMUL_DESC_FAST_ACCUM option for cublasLtMatmul.
see CUBLASLT_MATMUL_DESC_FAST_ACCUM option for cublasLtMatmul.
"""
"""
...
@@ -159,7 +159,7 @@ class DelayedScaling(Recipe):
...
@@ -159,7 +159,7 @@ class DelayedScaling(Recipe):
recipe: DelayedScaling) -> Tensor
recipe: DelayedScaling) -> Tensor
where `Tensor` is a framework tensor type.
where `Tensor` is a framework tensor type.
reduce_amax: bool, default =
`
True
`
reduce_amax: bool, default = True
By default, if `torch.distributed` is initialized, the `amax` value for FP8
By default, if `torch.distributed` is initialized, the `amax` value for FP8
tensors is reduced across the `amax_reduction_group` (specified in the `autocast`
tensors is reduced across the `amax_reduction_group` (specified in the `autocast`
call). This keeps the amaxes and scaling factors synced across the given
call). This keeps the amaxes and scaling factors synced across the given
...
@@ -167,13 +167,13 @@ class DelayedScaling(Recipe):
...
@@ -167,13 +167,13 @@ class DelayedScaling(Recipe):
GPU maintains local amaxes and scaling factors. To ensure results are
GPU maintains local amaxes and scaling factors. To ensure results are
numerically identical across checkpointing boundaries in this case, all
numerically identical across checkpointing boundaries in this case, all
ranks must checkpoint in order to store the local tensors.
ranks must checkpoint in order to store the local tensors.
fp8_dpa: bool, default =
`
False
`
fp8_dpa: bool, default = False
Whether to enable FP8 dot product attention (DPA). When the model is placed in an
Whether to enable FP8 dot product attention (DPA). When the model is placed in an
`autocast(enabled=True)` region and `fp8_dpa` is set to `True`, DPA casts the
`autocast(enabled=True)` region and `fp8_dpa` is set to `True`, DPA casts the
inputs from higher precision to FP8, performs attention in FP8, and casts tensors
inputs from higher precision to FP8, performs attention in FP8, and casts tensors
back to higher precision as outputs. FP8 DPA currently is only supported in the
back to higher precision as outputs. FP8 DPA currently is only supported in the
`FusedAttention` backend.
`FusedAttention` backend.
fp8_mha: bool, default =
`
False
`
fp8_mha: bool, default = False
Whether to enable FP8 multi-head attention (MHA). When `True`, it removes the casting
Whether to enable FP8 multi-head attention (MHA). When `True`, it removes the casting
operations mentioned above at the DPA boundaries. Currently only standard MHA modules
operations mentioned above at the DPA boundaries. Currently only standard MHA modules
i.e. `LayerNormLinear/Linear + DPA + Linear`, are supported for this feature. When
i.e. `LayerNormLinear/Linear + DPA + Linear`, are supported for this feature. When
...
@@ -422,11 +422,11 @@ class NVFP4BlockScaling(Recipe):
...
@@ -422,11 +422,11 @@ class NVFP4BlockScaling(Recipe):
----------
----------
fp4_format : {Format.E2M1}, default = Format.E2M1
fp4_format : {Format.E2M1}, default = Format.E2M1
FP4 data type.
FP4 data type.
disable_rht : bool, default =
`
False
`
disable_rht : bool, default = False
If set to `True`, random Hadamard transforms are not applied to any tensor.
If set to `True`, random Hadamard transforms are not applied to any tensor.
disable_stochastic_rounding : bool, default =
`
False
`
disable_stochastic_rounding : bool, default = False
If set to `True`, stochastic rounding is disabled during quantization for all tensors.
If set to `True`, stochastic rounding is disabled during quantization for all tensors.
disable_2d_quantization : bool, default =
`
False
`
disable_2d_quantization : bool, default = False
If set to `True`, 1D block scaling with block size 16 is used for all tensors.
If set to `True`, 1D block scaling with block size 16 is used for all tensors.
"""
"""
...
@@ -492,17 +492,19 @@ class CustomRecipe(Recipe):
...
@@ -492,17 +492,19 @@ class CustomRecipe(Recipe):
Parameters
Parameters
----------
----------
qfactory : Callable
qfactory : Callable
Factory callable that returns a quantizer instance for a
Factory callable that returns a quantizer instance for a
given semantic tensor role.
given semantic tensor role.
The callable is typically invoked as:
The callable is typically invoked as::
qfactory(
role: str,
qfactory(
)
role: str,
)
Where `role` is one of the following strings for e.g. te.Linear
(stable public contract):
Where `role` is one of the following strings for e.g. te.Linear
- forward: "linear_input", "linear_weight", "linear_output"
(stable public contract):
- backward: "linear_grad_output", "linear_grad_input"
- forward: "linear_input", "linear_weight", "linear_output"
- backward: "linear_grad_output", "linear_grad_input"
"""
"""
qfactory
:
Callable
[...,
Any
]
qfactory
:
Callable
[...,
Any
]
...
...
transformer_engine/common/transformer_engine.cpp
View file @
970620a5
...
@@ -736,12 +736,17 @@ int nvte_is_non_tn_fp8_gemm_supported() {
...
@@ -736,12 +736,17 @@ int nvte_is_non_tn_fp8_gemm_supported() {
#if USE_ROCM
#if USE_ROCM
return
true
;
return
true
;
#else
#else
int
deviceComputeCapability
=
int
num_devices
=
transformer_engine
::
cuda
::
num_devices
();
transformer_engine
::
cuda
::
sm_arch
(
transformer_engine
::
cuda
::
current_device
());
static
std
::
vector
<
int
>
cache
(
num_devices
,
-
1
);
static
std
::
vector
<
std
::
once_flag
>
flags
(
num_devices
);
// Note: this is temporary restriction and should be lifted in the future.
int
device_id
=
transformer_engine
::
cuda
::
current_device
();
// (remove the note once it's done.)
std
::
call_once
(
flags
[
device_id
],
[
&
]()
{
return
(
deviceComputeCapability
>=
100
&&
deviceComputeCapability
<
120
)
||
int
deviceComputeCapability
=
transformer_engine
::
cuda
::
sm_arch
(
device_id
);
deviceComputeCapability
>=
130
;
// Note: this is temporary restriction and should be lifted in the future.
// (remove the note once it's done.)
cache
[
device_id
]
=
(
deviceComputeCapability
>=
100
&&
deviceComputeCapability
<
120
)
||
deviceComputeCapability
>=
130
;
});
return
cache
[
device_id
];
#endif
#endif
}
}
transformer_engine/jax/attention.py
View file @
970620a5
...
@@ -18,6 +18,7 @@ from transformer_engine_jax import NVTE_Mask_Type
...
@@ -18,6 +18,7 @@ from transformer_engine_jax import NVTE_Mask_Type
from
transformer_engine_jax
import
NVTE_QKV_Layout
from
transformer_engine_jax
import
NVTE_QKV_Layout
from
transformer_engine_jax
import
NVTE_QKV_Format
from
transformer_engine_jax
import
NVTE_QKV_Format
from
transformer_engine_jax
import
nvte_get_qkv_format
from
transformer_engine_jax
import
nvte_get_qkv_format
from
transformer_engine_jax
import
NVTE_Softmax_Type
from
.
import
cpp_extensions
as
tex
from
.
import
cpp_extensions
as
tex
...
@@ -74,6 +75,35 @@ class AttnMaskType(Enum):
...
@@ -74,6 +75,35 @@ class AttnMaskType(Enum):
]
]
class
AttnSoftmaxType
(
Enum
):
"""
VANILLA_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
OFF_BY_ONE_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)),
LEARNABLE_SOFTMAX: S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [H].
"""
VANILLA_SOFTMAX
=
NVTE_Softmax_Type
.
NVTE_VANILLA_SOFTMAX
OFF_BY_ONE_SOFTMAX
=
NVTE_Softmax_Type
.
NVTE_OFF_BY_ONE_SOFTMAX
LEARNABLE_SOFTMAX
=
NVTE_Softmax_Type
.
NVTE_LEARNABLE_SOFTMAX
@
classmethod
def
from_str
(
cls
,
softmax_type
:
str
)
->
"AttnSoftmaxType"
:
"""Convert string to AttnSoftmaxType: 'vanilla', 'off_by_one', or 'learnable'."""
softmax_type_map
=
{
"vanilla"
:
cls
.
VANILLA_SOFTMAX
,
"off_by_one"
:
cls
.
OFF_BY_ONE_SOFTMAX
,
"learnable"
:
cls
.
LEARNABLE_SOFTMAX
,
}
result
=
softmax_type_map
.
get
(
softmax_type
)
if
result
is
None
:
raise
ValueError
(
f
"Unknown softmax_type:
{
softmax_type
}
. "
"Valid options: 'vanilla', 'off_by_one', 'learnable'"
)
return
result
class
QKVFormat
(
Enum
):
class
QKVFormat
(
Enum
):
"""
"""
SBHD: q,k,v memory layout with [s, b, ..., h, d]
SBHD: q,k,v memory layout with [s, b, ..., h, d]
...
@@ -301,6 +331,7 @@ def is_fused_attn_kernel_available(
...
@@ -301,6 +331,7 @@ def is_fused_attn_kernel_available(
qkv_layout
,
qkv_layout
,
attn_bias_type
,
attn_bias_type
,
attn_mask_type
,
attn_mask_type
,
softmax_type
,
dropout_probability
,
dropout_probability
,
q_num_heads
,
q_num_heads
,
kv_num_heads
,
kv_num_heads
,
...
@@ -313,6 +344,7 @@ def is_fused_attn_kernel_available(
...
@@ -313,6 +344,7 @@ def is_fused_attn_kernel_available(
"""
"""
To check whether the fused attention kernel is supported
To check whether the fused attention kernel is supported
"""
"""
window_size_tuple
=
(
-
1
,
-
1
)
if
window_size
is
None
else
window_size
def
make_helper
(
attn_mask_type
):
def
make_helper
(
attn_mask_type
):
return
tex
.
FusedAttnHelper
(
return
tex
.
FusedAttnHelper
(
...
@@ -322,6 +354,7 @@ def is_fused_attn_kernel_available(
...
@@ -322,6 +354,7 @@ def is_fused_attn_kernel_available(
qkv_layout
,
qkv_layout
,
attn_bias_type
,
attn_bias_type
,
attn_mask_type
,
attn_mask_type
,
softmax_type
,
dropout_probability
,
dropout_probability
,
q_num_heads
,
q_num_heads
,
kv_num_heads
,
kv_num_heads
,
...
@@ -329,7 +362,7 @@ def is_fused_attn_kernel_available(
...
@@ -329,7 +362,7 @@ def is_fused_attn_kernel_available(
kv_max_seqlen
,
kv_max_seqlen
,
head_dim_qk
,
head_dim_qk
,
head_dim_v
,
head_dim_v
,
(
-
1
,
-
1
)
if
window_size
is
None
else
window_size
,
window_siz
e_tupl
e
,
)
)
return
make_helper
(
attn_mask_type
).
is_fused_attn_kernel_available
()
return
make_helper
(
attn_mask_type
).
is_fused_attn_kernel_available
()
...
@@ -497,6 +530,11 @@ def _segment_ids_pos_to_seqlens_offsets(
...
@@ -497,6 +530,11 @@ def _segment_ids_pos_to_seqlens_offsets(
#
#
# This fast path avoids expanding the mask to Q * KV matrix and instead allows us to
# This fast path avoids expanding the mask to Q * KV matrix and instead allows us to
# examine only O(Q+KV) elements.
# examine only O(Q+KV) elements.
# For seqlens and seqoffsets calculations, the intermediate(temp) attn_mask creation
# using the segment ids and pos along with mask type (causal or brcm) is sufficient.
# It does not need to involve SW for this mask's creation
# TODO(KshitijLakhani): Try exercising the fast path for BRCM as well
# TODO(KshitijLakhani): Try exercising the fast path for BRCM as well
if
(
attn_mask_type
.
is_causal
()
and
window_size
is
None
)
or
(
if
(
attn_mask_type
.
is_causal
()
and
window_size
is
None
)
or
(
window_size
==
(
-
1
,
-
1
)
and
not
attn_mask_type
.
is_bottom_right
()
window_size
==
(
-
1
,
-
1
)
and
not
attn_mask_type
.
is_bottom_right
()
...
@@ -558,21 +596,6 @@ def _segment_ids_pos_to_seqlens_offsets(
...
@@ -558,21 +596,6 @@ def _segment_ids_pos_to_seqlens_offsets(
)
)
attn_mask
=
jnp
.
logical_and
(
segment_mask
,
causal_mask
)
attn_mask
=
jnp
.
logical_and
(
segment_mask
,
causal_mask
)
# TODO(KshitijLakhani): Evaluate if swa_mask is needed to procure seqlen and offsets
swa_mask
=
(
make_swa_mask
(
segment_pos_q
,
segment_pos_kv
,
window_size
,
dtype
=
jnp
.
bool
,
segment_ids_q
=
segment_ids_q
,
segment_ids_kv
=
segment_ids_kv
,
)
if
attn_mask_type
.
is_bottom_right
()
else
make_swa_mask
(
segment_pos_q
,
segment_pos_kv
,
window_size
,
dtype
=
jnp
.
bool
)
)
attn_mask
=
jnp
.
logical_and
(
attn_mask
,
swa_mask
)
attn_mask_with_id
=
jnp
.
where
(
attn_mask
,
segment_mask_with_id
,
0
)
attn_mask_with_id
=
jnp
.
where
(
attn_mask
,
segment_mask_with_id
,
0
)
q_seqlen
,
q_offset
,
kv_seqlen
,
kv_offset
=
_mask_to_seqlens_offset
(
q_seqlen
,
q_offset
,
kv_seqlen
,
kv_offset
=
_mask_to_seqlens_offset
(
attn_mask_with_id
,
max_segments_per_seq
attn_mask_with_id
,
max_segments_per_seq
...
@@ -786,6 +809,7 @@ def _legacy_fused_attn(
...
@@ -786,6 +809,7 @@ def _legacy_fused_attn(
attn_bias_type
:
AttnBiasType
,
attn_bias_type
:
AttnBiasType
,
attn_mask_type
:
AttnMaskType
,
attn_mask_type
:
AttnMaskType
,
qkv_layout
:
QKVLayout
,
qkv_layout
:
QKVLayout
,
softmax_type
:
AttnSoftmaxType
,
scaling_factor
:
float
,
scaling_factor
:
float
,
dropout_probability
:
float
,
dropout_probability
:
float
,
is_training
:
bool
,
is_training
:
bool
,
...
@@ -793,6 +817,7 @@ def _legacy_fused_attn(
...
@@ -793,6 +817,7 @@ def _legacy_fused_attn(
context_parallel_strategy
:
CPStrategy
=
CPStrategy
.
DEFAULT
,
context_parallel_strategy
:
CPStrategy
=
CPStrategy
.
DEFAULT
,
context_parallel_causal_load_balanced
:
bool
=
False
,
context_parallel_causal_load_balanced
:
bool
=
False
,
context_parallel_axis
:
str
=
""
,
context_parallel_axis
:
str
=
""
,
softmax_offset
:
Optional
[
jnp
.
ndarray
]
=
None
,
):
):
"""
"""
Perform non-THD (non-packed) cuDNN fused attention.
Perform non-THD (non-packed) cuDNN fused attention.
...
@@ -815,6 +840,7 @@ def _legacy_fused_attn(
...
@@ -815,6 +840,7 @@ def _legacy_fused_attn(
seed (Optional[jnp.ndarray]): Optional random seed for dropout.
seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (AttnBiasType): Type of attention bias.
attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (AttnMaskType): Type of attention mask.
attn_mask_type (AttnMaskType): Type of attention mask.
softmax_type (AttnSoftmaxType): Type of attention softmax.
qkv_layout (QKVLayout): Layout of the QKV tensors.
qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores.
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
dropout_probability (float): Dropout probability to apply during attention.
...
@@ -863,10 +889,12 @@ def _legacy_fused_attn(
...
@@ -863,10 +889,12 @@ def _legacy_fused_attn(
output
=
_fused_attn
(
output
=
_fused_attn
(
qkv
,
qkv
,
bias
,
bias
,
softmax_offset
,
SequenceDescriptor
.
from_seqlens
((
q_seq_lens
,
kv_seq_lens
)),
SequenceDescriptor
.
from_seqlens
((
q_seq_lens
,
kv_seq_lens
)),
seed
,
seed
,
attn_bias_type
=
attn_bias_type
,
attn_bias_type
=
attn_bias_type
,
attn_mask_type
=
attn_mask_type
,
attn_mask_type
=
attn_mask_type
,
softmax_type
=
softmax_type
,
qkv_layout
=
qkv_layout
,
qkv_layout
=
qkv_layout
,
scaling_factor
=
scaling_factor
,
scaling_factor
=
scaling_factor
,
dropout_probability
=
dropout_probability
,
dropout_probability
=
dropout_probability
,
...
@@ -900,6 +928,7 @@ def fused_attn_thd(
...
@@ -900,6 +928,7 @@ def fused_attn_thd(
context_parallel_strategy
:
CPStrategy
=
CPStrategy
.
DEFAULT
,
context_parallel_strategy
:
CPStrategy
=
CPStrategy
.
DEFAULT
,
context_parallel_causal_load_balanced
:
bool
=
False
,
context_parallel_causal_load_balanced
:
bool
=
False
,
context_parallel_axis
:
str
=
""
,
context_parallel_axis
:
str
=
""
,
softmax_offset
:
Optional
[
jnp
.
ndarray
]
=
None
,
):
):
"""
"""
Deprecated THD fused attn, please use fusd_attn with SequenceDescriptor
Deprecated THD fused attn, please use fusd_attn with SequenceDescriptor
...
@@ -937,6 +966,7 @@ def fused_attn_thd(
...
@@ -937,6 +966,7 @@ def fused_attn_thd(
output
=
_fused_attn
(
output
=
_fused_attn
(
qkv
,
qkv
,
bias
,
bias
,
softmax_offset
,
SequenceDescriptor
.
from_seqlens_and_offsets
(
SequenceDescriptor
.
from_seqlens_and_offsets
(
(
q_seq_lens
,
kv_seq_lens
),
(
q_seq_offsets
,
kv_seq_offsets
)
(
q_seq_lens
,
kv_seq_lens
),
(
q_seq_offsets
,
kv_seq_offsets
)
),
),
...
@@ -945,6 +975,7 @@ def fused_attn_thd(
...
@@ -945,6 +975,7 @@ def fused_attn_thd(
attn_mask_type
=
attn_mask_type
,
attn_mask_type
=
attn_mask_type
,
qkv_layout
=
qkv_layout
,
qkv_layout
=
qkv_layout
,
scaling_factor
=
scaling_factor
,
scaling_factor
=
scaling_factor
,
softmax_type
=
AttnSoftmaxType
.
VANILLA_SOFTMAX
,
dropout_probability
=
dropout_probability
,
dropout_probability
=
dropout_probability
,
is_training
=
is_training
,
is_training
=
is_training
,
max_segments_per_seq
=
max_segments_per_seq
,
max_segments_per_seq
=
max_segments_per_seq
,
...
@@ -957,15 +988,17 @@ def fused_attn_thd(
...
@@ -957,15 +988,17 @@ def fused_attn_thd(
return
output
return
output
@
partial
(
jax
.
custom_vjp
,
nondiff_argnums
=
(
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
))
@
partial
(
jax
.
custom_vjp
,
nondiff_argnums
=
(
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
))
def
_fused_attn
(
def
_fused_attn
(
qkv
:
Tuple
[
jnp
.
ndarray
,
...],
qkv
:
Tuple
[
jnp
.
ndarray
,
...],
bias
:
Optional
[
jnp
.
ndarray
],
bias
:
Optional
[
jnp
.
ndarray
],
softmax_offset
:
Optional
[
jnp
.
ndarray
],
sequence_descriptor
:
SequenceDescriptor
,
sequence_descriptor
:
SequenceDescriptor
,
seed
:
Optional
[
jnp
.
ndarray
],
seed
:
Optional
[
jnp
.
ndarray
],
attn_bias_type
:
AttnBiasType
,
attn_bias_type
:
AttnBiasType
,
attn_mask_type
:
AttnMaskType
,
attn_mask_type
:
AttnMaskType
,
qkv_layout
:
QKVLayout
,
qkv_layout
:
QKVLayout
,
softmax_type
:
AttnSoftmaxType
,
scaling_factor
:
float
,
scaling_factor
:
float
,
dropout_probability
:
float
,
dropout_probability
:
float
,
is_training
:
bool
,
is_training
:
bool
,
...
@@ -979,11 +1012,13 @@ def _fused_attn(
...
@@ -979,11 +1012,13 @@ def _fused_attn(
output
,
_
=
_fused_attn_fwd_rule
(
output
,
_
=
_fused_attn_fwd_rule
(
qkv
,
qkv
,
bias
,
bias
,
softmax_offset
,
sequence_descriptor
,
sequence_descriptor
,
seed
,
seed
,
attn_bias_type
,
attn_bias_type
,
attn_mask_type
,
attn_mask_type
,
qkv_layout
,
qkv_layout
,
softmax_type
,
scaling_factor
,
scaling_factor
,
dropout_probability
,
dropout_probability
,
is_training
,
is_training
,
...
@@ -1000,11 +1035,13 @@ def _fused_attn(
...
@@ -1000,11 +1035,13 @@ def _fused_attn(
def
_fused_attn_fwd_rule
(
def
_fused_attn_fwd_rule
(
qkv
,
qkv
,
bias
,
bias
,
softmax_offset
,
sequence_descriptor
,
sequence_descriptor
,
seed
,
seed
,
attn_bias_type
,
attn_bias_type
,
attn_mask_type
,
attn_mask_type
,
qkv_layout
,
qkv_layout
,
softmax_type
,
scaling_factor
,
scaling_factor
,
dropout_probability
,
dropout_probability
,
is_training
,
is_training
,
...
@@ -1018,10 +1055,12 @@ def _fused_attn_fwd_rule(
...
@@ -1018,10 +1055,12 @@ def _fused_attn_fwd_rule(
output
,
softmax_aux
,
rng_state
=
tex
.
fused_attn_fwd
(
output
,
softmax_aux
,
rng_state
=
tex
.
fused_attn_fwd
(
qkv
,
qkv
,
bias
,
bias
,
softmax_offset
,
sequence_descriptor
,
sequence_descriptor
,
seed
,
seed
,
attn_bias_type
=
attn_bias_type
,
attn_bias_type
=
attn_bias_type
,
attn_mask_type
=
attn_mask_type
,
attn_mask_type
=
attn_mask_type
,
softmax_type
=
softmax_type
,
qkv_layout
=
qkv_layout
,
qkv_layout
=
qkv_layout
,
scaling_factor
=
scaling_factor
,
scaling_factor
=
scaling_factor
,
dropout_probability
=
dropout_probability
,
dropout_probability
=
dropout_probability
,
...
@@ -1041,6 +1080,7 @@ def _fused_attn_fwd_rule(
...
@@ -1041,6 +1080,7 @@ def _fused_attn_fwd_rule(
sequence_descriptor
,
sequence_descriptor
,
softmax_aux
,
softmax_aux
,
rng_state
,
rng_state
,
softmax_offset
,
output
,
output
,
)
)
...
@@ -1049,6 +1089,7 @@ def _fused_attn_bwd_rule(
...
@@ -1049,6 +1089,7 @@ def _fused_attn_bwd_rule(
attn_bias_type
,
attn_bias_type
,
attn_mask_type
,
attn_mask_type
,
qkv_layout
,
qkv_layout
,
softmax_type
,
scaling_factor
,
scaling_factor
,
dropout_probability
,
dropout_probability
,
is_training
,
is_training
,
...
@@ -1068,11 +1109,13 @@ def _fused_attn_bwd_rule(
...
@@ -1068,11 +1109,13 @@ def _fused_attn_bwd_rule(
sequence_descriptor
,
sequence_descriptor
,
softmax_aux
,
softmax_aux
,
rng_state
,
rng_state
,
softmax_offset
,
output
,
output
,
)
=
ctx
)
=
ctx
grad_qkv
,
grad_bias
=
tex
.
fused_attn_bwd
(
grad_qkv
,
grad_bias
,
grad_softmax_offset
=
tex
.
fused_attn_bwd
(
qkv
,
qkv
,
bias
,
bias
,
softmax_offset
,
softmax_aux
,
softmax_aux
,
rng_state
,
rng_state
,
output
,
output
,
...
@@ -1080,6 +1123,7 @@ def _fused_attn_bwd_rule(
...
@@ -1080,6 +1123,7 @@ def _fused_attn_bwd_rule(
sequence_descriptor
,
sequence_descriptor
,
attn_bias_type
=
attn_bias_type
,
attn_bias_type
=
attn_bias_type
,
attn_mask_type
=
attn_mask_type
,
attn_mask_type
=
attn_mask_type
,
softmax_type
=
softmax_type
,
qkv_layout
=
qkv_layout
,
qkv_layout
=
qkv_layout
,
scaling_factor
=
scaling_factor
,
scaling_factor
=
scaling_factor
,
dropout_probability
=
dropout_probability
,
dropout_probability
=
dropout_probability
,
...
@@ -1092,9 +1136,12 @@ def _fused_attn_bwd_rule(
...
@@ -1092,9 +1136,12 @@ def _fused_attn_bwd_rule(
)
)
if
attn_bias_type
==
AttnBiasType
.
NO_BIAS
:
if
attn_bias_type
==
AttnBiasType
.
NO_BIAS
:
grad_bias
=
None
grad_bias
=
None
if
softmax_type
!=
AttnSoftmaxType
.
LEARNABLE_SOFTMAX
:
grad_softmax_offset
=
None
return
(
return
(
grad_qkv
,
grad_qkv
,
grad_bias
,
grad_bias
,
grad_softmax_offset
,
None
,
None
,
None
,
None
,
)
)
...
@@ -1111,6 +1158,7 @@ def fused_attn(
...
@@ -1111,6 +1158,7 @@ def fused_attn(
attn_bias_type
:
AttnBiasType
,
attn_bias_type
:
AttnBiasType
,
attn_mask_type
:
AttnMaskType
,
attn_mask_type
:
AttnMaskType
,
qkv_layout
:
QKVLayout
,
qkv_layout
:
QKVLayout
,
softmax_type
:
AttnSoftmaxType
,
scaling_factor
:
float
,
scaling_factor
:
float
,
dropout_probability
:
float
,
dropout_probability
:
float
,
is_training
:
bool
,
is_training
:
bool
,
...
@@ -1120,6 +1168,7 @@ def fused_attn(
...
@@ -1120,6 +1168,7 @@ def fused_attn(
context_parallel_causal_load_balanced
:
bool
=
False
,
context_parallel_causal_load_balanced
:
bool
=
False
,
context_parallel_axis
:
str
=
""
,
context_parallel_axis
:
str
=
""
,
context_checkpoint_name
:
str
=
"context"
,
context_checkpoint_name
:
str
=
"context"
,
softmax_offset
:
Optional
[
jnp
.
ndarray
]
=
None
,
):
):
"""
"""
Perform cuDNN fused attention.
Perform cuDNN fused attention.
...
@@ -1139,6 +1188,7 @@ def fused_attn(
...
@@ -1139,6 +1188,7 @@ def fused_attn(
seed (Optional[jnp.ndarray]): Optional random seed for dropout.
seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (AttnBiasType): Type of attention bias.
attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (AttnMaskType): Type of attention mask.
attn_mask_type (AttnMaskType): Type of attention mask.
softmax_type (AttnSoftmaxType): Type of attention softmax.
qkv_layout (QKVLayout): Layout of the QKV tensors.
qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores.
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
dropout_probability (float): Dropout probability to apply during attention.
...
@@ -1153,6 +1203,9 @@ def fused_attn(
...
@@ -1153,6 +1203,9 @@ def fused_attn(
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.
context_parallel_axis (str): The name of the context parallel axis.
context_checkpoint_name (str): The name of the context checkpoint for the custom VJP forward pass.
context_checkpoint_name (str): The name of the context checkpoint for the custom VJP forward pass.
softmax_offset (Optional[jnp.ndarray]): An optional learnable softmax offset tensor with shape
[1, num_heads, 1, 1]. Used when softmax_type is AttnSoftmaxType.LEARNABLE_SOFTMAX.
If provided, this parameter will receive gradients during backpropagation.
Returns:
Returns:
(jnp.ndarray): The output tensor from the fused attention.
(jnp.ndarray): The output tensor from the fused attention.
...
@@ -1200,6 +1253,7 @@ def fused_attn(
...
@@ -1200,6 +1253,7 @@ def fused_attn(
seed
,
seed
,
attn_bias_type
=
attn_bias_type
,
attn_bias_type
=
attn_bias_type
,
attn_mask_type
=
attn_mask_type
,
attn_mask_type
=
attn_mask_type
,
softmax_type
=
softmax_type
,
qkv_layout
=
qkv_layout
,
qkv_layout
=
qkv_layout
,
scaling_factor
=
scaling_factor
,
scaling_factor
=
scaling_factor
,
dropout_probability
=
dropout_probability
,
dropout_probability
=
dropout_probability
,
...
@@ -1208,15 +1262,18 @@ def fused_attn(
...
@@ -1208,15 +1262,18 @@ def fused_attn(
context_parallel_strategy
=
context_parallel_strategy
,
context_parallel_strategy
=
context_parallel_strategy
,
context_parallel_causal_load_balanced
=
context_parallel_causal_load_balanced
,
context_parallel_causal_load_balanced
=
context_parallel_causal_load_balanced
,
context_parallel_axis
=
context_parallel_axis
,
context_parallel_axis
=
context_parallel_axis
,
softmax_offset
=
softmax_offset
,
)
)
output
=
_fused_attn
(
output
=
_fused_attn
(
qkv
,
qkv
,
bias
,
bias
,
softmax_offset
,
sequence_descriptor
,
sequence_descriptor
,
seed
,
seed
,
attn_bias_type
=
attn_bias_type
,
attn_bias_type
=
attn_bias_type
,
attn_mask_type
=
attn_mask_type
,
attn_mask_type
=
attn_mask_type
,
qkv_layout
=
qkv_layout
,
qkv_layout
=
qkv_layout
,
softmax_type
=
softmax_type
,
scaling_factor
=
scaling_factor
,
scaling_factor
=
scaling_factor
,
dropout_probability
=
dropout_probability
,
dropout_probability
=
dropout_probability
,
is_training
=
is_training
,
is_training
=
is_training
,
...
...
transformer_engine/jax/cpp_extensions/activation.py
View file @
970620a5
...
@@ -32,9 +32,9 @@ from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_a
...
@@ -32,9 +32,9 @@ from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_a
from
..quantize
import
ScaledTensor
,
ScaledTensorFactory
,
NoScaleTensor
from
..quantize
import
ScaledTensor
,
ScaledTensorFactory
,
NoScaleTensor
from
..quantize
import
(
from
..quantize
import
(
Quantizer
,
Quantizer
,
QuantizeLayout
,
DelayedScaleQuantizer
,
DelayedScaleQuantizer
,
ScalingMode
,
ScalingMode
,
QuantizeLayout
,
)
)
...
...
transformer_engine/jax/cpp_extensions/attention.py
View file @
970620a5
...
@@ -20,11 +20,13 @@ from transformer_engine_jax import NVTE_Fused_Attn_Backend
...
@@ -20,11 +20,13 @@ from transformer_engine_jax import NVTE_Fused_Attn_Backend
from
transformer_engine.jax.attention
import
(
from
transformer_engine.jax.attention
import
(
AttnBiasType
,
AttnBiasType
,
AttnMaskType
,
AttnMaskType
,
AttnSoftmaxType
,
QKVLayout
,
QKVLayout
,
QKVFormat
,
QKVFormat
,
CPStrategy
,
CPStrategy
,
SequenceDescriptor
,
SequenceDescriptor
,
)
)
from
..sharding
import
with_sharding_constraint_by_logical_axes
,
HEAD_AXES
from
.base
import
BasePrimitive
,
register_primitive
from
.base
import
BasePrimitive
,
register_primitive
from
.misc
import
(
from
.misc
import
(
...
@@ -61,6 +63,7 @@ __all__ = [
...
@@ -61,6 +63,7 @@ __all__ = [
meta_fields
=
[
meta_fields
=
[
"attn_bias_type"
,
"attn_bias_type"
,
"attn_mask_type"
,
"attn_mask_type"
,
"softmax_type"
,
"qkv_layout"
,
"qkv_layout"
,
"scaling_factor"
,
"scaling_factor"
,
"dropout_probability"
,
"dropout_probability"
,
...
@@ -80,6 +83,7 @@ class _FusedAttnConfig:
...
@@ -80,6 +83,7 @@ class _FusedAttnConfig:
attn_bias_type
:
AttnBiasType
attn_bias_type
:
AttnBiasType
attn_mask_type
:
AttnMaskType
attn_mask_type
:
AttnMaskType
softmax_type
:
AttnSoftmaxType
qkv_layout
:
QKVLayout
qkv_layout
:
QKVLayout
scaling_factor
:
float
scaling_factor
:
float
dropout_probability
:
float
dropout_probability
:
float
...
@@ -103,6 +107,7 @@ class FusedAttnHelper:
...
@@ -103,6 +107,7 @@ class FusedAttnHelper:
qkv_layout
:
QKVLayout
qkv_layout
:
QKVLayout
attn_bias_type
:
AttnBiasType
attn_bias_type
:
AttnBiasType
attn_mask_type
:
AttnMaskType
attn_mask_type
:
AttnMaskType
softmax_type
:
AttnSoftmaxType
dropout_probability
:
float
dropout_probability
:
float
q_num_heads
:
int
q_num_heads
:
int
kv_num_heads
:
int
kv_num_heads
:
int
...
@@ -125,6 +130,7 @@ class FusedAttnHelper:
...
@@ -125,6 +130,7 @@ class FusedAttnHelper:
self
.
qkv_layout
.
value
,
self
.
qkv_layout
.
value
,
self
.
attn_bias_type
.
value
,
self
.
attn_bias_type
.
value
,
self
.
attn_mask_type
.
value
,
self
.
attn_mask_type
.
value
,
self
.
softmax_type
.
value
,
self
.
dropout_probability
,
self
.
dropout_probability
,
self
.
q_num_heads
,
self
.
q_num_heads
,
self
.
kv_num_heads
,
self
.
kv_num_heads
,
...
@@ -254,7 +260,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
...
@@ -254,7 +260,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
name
=
"te_fused_attn_forward_ffi"
name
=
"te_fused_attn_forward_ffi"
multiple_results
=
True
multiple_results
=
True
impl_static_args
=
(
1
3
,)
impl_static_args
=
(
1
4
,)
inner_primitive
=
None
inner_primitive
=
None
outer_primitive
=
None
outer_primitive
=
None
...
@@ -264,6 +270,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
...
@@ -264,6 +270,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
k_aval
,
k_aval
,
v_aval
,
v_aval
,
bias_aval
,
bias_aval
,
softmax_offset_aval
,
seed_aval
,
seed_aval
,
q_seqlen_or_cu_seqlen_aval
,
q_seqlen_or_cu_seqlen_aval
,
kv_seqlen_or_cu_seqlen_aval
,
kv_seqlen_or_cu_seqlen_aval
,
...
@@ -312,6 +319,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
...
@@ -312,6 +319,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
config
.
qkv_layout
,
config
.
qkv_layout
,
config
.
attn_bias_type
,
config
.
attn_bias_type
,
config
.
attn_mask_type
,
config
.
attn_mask_type
,
config
.
softmax_type
,
config
.
dropout_probability
,
config
.
dropout_probability
,
attn_heads
,
attn_heads
,
num_gqa_groups
,
num_gqa_groups
,
...
@@ -375,6 +383,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
...
@@ -375,6 +383,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
config
.
dropout_probability
,
config
.
dropout_probability
,
config
.
attn_bias_type
.
value
,
config
.
attn_bias_type
.
value
,
config
.
attn_mask_type
.
value
,
config
.
attn_mask_type
.
value
,
config
.
softmax_type
.
value
,
config
.
qkv_layout
.
value
,
config
.
qkv_layout
.
value
,
jax_dtype_to_te_dtype
(
q_aval
.
dtype
),
jax_dtype_to_te_dtype
(
q_aval
.
dtype
),
config
.
is_training
,
config
.
is_training
,
...
@@ -386,6 +395,12 @@ class FusedAttnFwdPrimitive(BasePrimitive):
...
@@ -386,6 +395,12 @@ class FusedAttnFwdPrimitive(BasePrimitive):
shape
=
wkspace_info
[
0
],
dtype
=
te_dtype_to_jax_dtype
(
wkspace_info
[
1
])
shape
=
wkspace_info
[
0
],
dtype
=
te_dtype_to_jax_dtype
(
wkspace_info
[
1
])
)
)
assert
softmax_offset_aval
.
dtype
==
jnp
.
float32
if
config
.
softmax_type
!=
AttnSoftmaxType
.
VANILLA_SOFTMAX
:
assert
softmax_offset_aval
.
shape
==
(
1
,
attn_heads
,
1
,
1
)
else
:
assert
softmax_offset_aval
.
shape
==
(
0
,)
return
out_aval
,
softmax_aux_aval
,
rng_state_aval
,
wkspace_aval
return
out_aval
,
softmax_aux_aval
,
rng_state_aval
,
wkspace_aval
@
staticmethod
@
staticmethod
...
@@ -405,6 +420,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
...
@@ -405,6 +420,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
k
,
k
,
v
,
v
,
bias
,
bias
,
softmax_offset
,
seed
,
seed
,
q_cu_seqlen
,
q_cu_seqlen
,
kv_cu_seqlen
,
kv_cu_seqlen
,
...
@@ -453,6 +469,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
...
@@ -453,6 +469,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
k
,
k
,
v
,
v
,
bias
,
bias
,
softmax_offset
,
seed
,
seed
,
q_cu_seqlen
,
q_cu_seqlen
,
kv_cu_seqlen
,
kv_cu_seqlen
,
...
@@ -481,6 +498,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
...
@@ -481,6 +498,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
deterministic
=
not
FusedAttnHelper
.
is_non_deterministic_allowed
(),
deterministic
=
not
FusedAttnHelper
.
is_non_deterministic_allowed
(),
window_size_left
=
window_size_left
,
window_size_left
=
window_size_left
,
window_size_right
=
window_size_right
,
window_size_right
=
window_size_right
,
softmax_type
=
int
(
config
.
softmax_type
.
value
),
)
)
@
staticmethod
@
staticmethod
...
@@ -489,6 +507,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
...
@@ -489,6 +507,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
k
,
k
,
v
,
v
,
bias
,
bias
,
softmax_offset
,
seed
,
seed
,
q_seqlen
,
q_seqlen
,
kv_seqlen
,
kv_seqlen
,
...
@@ -579,6 +598,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
...
@@ -579,6 +598,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
k
,
k
,
v
,
v
,
bias
,
bias
,
softmax_offset
,
seed
,
seed
,
q_cu_seqlen
,
q_cu_seqlen
,
kv_cu_seqlen
,
kv_cu_seqlen
,
...
@@ -596,7 +616,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
...
@@ -596,7 +616,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
def
batcher
(
batched_args
,
batch_dims
,
*
,
config
):
def
batcher
(
batched_args
,
batch_dims
,
*
,
config
):
check_valid_batch_dims
(
batch_dims
)
check_valid_batch_dims
(
batch_dims
)
assert
FusedAttnFwdPrimitive
.
outer_primitive
is
not
None
assert
FusedAttnFwdPrimitive
.
outer_primitive
is
not
None
q_bdim
,
_
,
_
,
_
,
seed_bdim
,
*
_
=
batch_dims
q_bdim
,
_
,
_
,
_
,
_
,
seed_bdim
,
*
_
=
batch_dims
out_bdims
=
q_bdim
,
q_bdim
,
seed_bdim
out_bdims
=
q_bdim
,
q_bdim
,
seed_bdim
return
(
return
(
...
@@ -662,7 +682,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
...
@@ -662,7 +682,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
mesh
,
PartitionSpec
(
get_all_mesh_axes
(),
None
)
mesh
,
PartitionSpec
(
get_all_mesh_axes
(),
None
)
)
)
arg_shardings
=
[
arg_i
.
sharding
for
arg_i
in
arg_infos
]
arg_shardings
=
[
arg_i
.
sharding
for
arg_i
in
arg_infos
]
arg_shardings
[
4
]
=
seed_sharding
arg_shardings
[
5
]
=
seed_sharding
arg_shardings
[
-
1
]
=
arg_shardings
[
-
3
]
arg_shardings
[
-
1
]
=
arg_shardings
[
-
3
]
arg_shardings
[
-
2
]
=
arg_shardings
[
-
4
]
arg_shardings
[
-
2
]
=
arg_shardings
[
-
4
]
arg_shardings
=
tuple
(
arg_shardings
)
arg_shardings
=
tuple
(
arg_shardings
)
...
@@ -710,7 +730,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
...
@@ -710,7 +730,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
name
=
"te_fused_attn_backward_ffi"
name
=
"te_fused_attn_backward_ffi"
multiple_results
=
True
multiple_results
=
True
impl_static_args
=
(
1
6
,)
impl_static_args
=
(
1
7
,)
inner_primitive
=
None
inner_primitive
=
None
outer_primitive
=
None
outer_primitive
=
None
...
@@ -720,6 +740,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
...
@@ -720,6 +740,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
k_aval
,
k_aval
,
v_aval
,
v_aval
,
bias_aval
,
bias_aval
,
softmax_offset_aval
,
softmax_aux_aval
,
softmax_aux_aval
,
rng_state_aval
,
rng_state_aval
,
output_aval
,
output_aval
,
...
@@ -781,6 +802,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
...
@@ -781,6 +802,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
config
.
dropout_probability
,
config
.
dropout_probability
,
config
.
attn_bias_type
.
value
,
config
.
attn_bias_type
.
value
,
config
.
attn_mask_type
.
value
,
config
.
attn_mask_type
.
value
,
config
.
softmax_type
.
value
,
config
.
qkv_layout
.
value
,
config
.
qkv_layout
.
value
,
jax_dtype_to_te_dtype
(
q_aval
.
dtype
),
jax_dtype_to_te_dtype
(
q_aval
.
dtype
),
config
.
is_training
,
config
.
is_training
,
...
@@ -798,15 +820,39 @@ class FusedAttnBwdPrimitive(BasePrimitive):
...
@@ -798,15 +820,39 @@ class FusedAttnBwdPrimitive(BasePrimitive):
shape
=
wkspace_shape
,
dtype
=
te_dtype_to_jax_dtype
(
wkspace_dtype
)
shape
=
wkspace_shape
,
dtype
=
te_dtype_to_jax_dtype
(
wkspace_dtype
)
)
)
return
dq_aval
,
dk_aval
,
dv_aval
,
dbias_aval
,
wkspace_aval
# Validate incoming softmax_offset shape and dtype
assert
(
softmax_offset_aval
.
dtype
==
jnp
.
float32
),
f
"Incorrect softmax_offset dtype:
{
softmax_offset_aval
.
dtype
}
, expected:
{
jnp
.
float32
}
"
if
config
.
softmax_type
!=
AttnSoftmaxType
.
VANILLA_SOFTMAX
:
assert
softmax_offset_aval
.
shape
==
(
1
,
attn_heads
,
1
,
1
),
(
f
"Incorrect softmax_offset shape for
{
config
.
softmax_type
}
:"
f
"
{
softmax_offset_aval
.
shape
}
, expected: (1,
{
attn_heads
}
, 1, 1)"
)
else
:
assert
softmax_offset_aval
.
shape
==
(
0
,),
(
f
"Incorrect softmax_offset shape for
{
config
.
softmax_type
}
:"
f
"
{
softmax_offset_aval
.
shape
}
, expected: (0,)"
)
if
config
.
softmax_type
==
AttnSoftmaxType
.
VANILLA_SOFTMAX
:
dsoftmax_offset_aval
=
q_aval
.
update
(
shape
=
softmax_offset_aval
.
shape
,
dtype
=
softmax_offset_aval
.
dtype
)
else
:
dsoftmax_offset_aval
=
q_aval
.
update
(
shape
=
(
1
,
attn_heads
,
1
,
1
),
dtype
=
jnp
.
float32
)
return
dq_aval
,
dk_aval
,
dv_aval
,
dbias_aval
,
dsoftmax_offset_aval
,
wkspace_aval
@
staticmethod
@
staticmethod
def
outer_abstract
(
*
args
,
**
kwargs
):
def
outer_abstract
(
*
args
,
**
kwargs
):
"""
"""
Fused attention fwd outer primitive abstract
Fused attention fwd outer primitive abstract
"""
"""
dq_aval
,
dk_aval
,
dv_aval
,
dbias_aval
,
_
=
FusedAttnBwdPrimitive
.
abstract
(
*
args
,
**
kwargs
)
dq_aval
,
dk_aval
,
dv_aval
,
dbias_aval
,
dsoftmax_offset_aval
,
_
=
(
return
dq_aval
,
dk_aval
,
dv_aval
,
dbias_aval
FusedAttnBwdPrimitive
.
abstract
(
*
args
,
**
kwargs
)
)
return
dq_aval
,
dk_aval
,
dv_aval
,
dbias_aval
,
dsoftmax_offset_aval
@
staticmethod
@
staticmethod
def
lowering
(
def
lowering
(
...
@@ -815,6 +861,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
...
@@ -815,6 +861,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
k
,
k
,
v
,
v
,
bias
,
bias
,
softmax_offset
,
softmax_aux
,
softmax_aux
,
rng_state
,
rng_state
,
output
,
output
,
...
@@ -866,6 +913,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
...
@@ -866,6 +913,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
k
,
k
,
v
,
v
,
bias
,
bias
,
softmax_offset
,
softmax_aux
,
softmax_aux
,
rng_state
,
rng_state
,
output
,
output
,
...
@@ -897,6 +945,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
...
@@ -897,6 +945,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
deterministic
=
not
FusedAttnHelper
.
is_non_deterministic_allowed
(),
deterministic
=
not
FusedAttnHelper
.
is_non_deterministic_allowed
(),
window_size_left
=
window_size_left
,
window_size_left
=
window_size_left
,
window_size_right
=
window_size_right
,
window_size_right
=
window_size_right
,
softmax_type
=
int
(
config
.
softmax_type
.
value
),
)
)
@
staticmethod
@
staticmethod
...
@@ -905,6 +954,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
...
@@ -905,6 +954,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
k
,
k
,
v
,
v
,
bias
,
bias
,
softmax_offset
,
softmax_aux
,
softmax_aux
,
rng_state
,
rng_state
,
output
,
output
,
...
@@ -993,11 +1043,12 @@ class FusedAttnBwdPrimitive(BasePrimitive):
...
@@ -993,11 +1043,12 @@ class FusedAttnBwdPrimitive(BasePrimitive):
q_cu_seqlen
=
generate_cu_seqlen
(
q_seqlen
.
flatten
())
q_cu_seqlen
=
generate_cu_seqlen
(
q_seqlen
.
flatten
())
kv_cu_seqlen
=
generate_cu_seqlen
(
kv_seqlen
.
flatten
())
kv_cu_seqlen
=
generate_cu_seqlen
(
kv_seqlen
.
flatten
())
dq
,
dk
,
dv
,
dbias
,
_
=
FusedAttnBwdPrimitive
.
inner_primitive
.
bind
(
dq
,
dk
,
dv
,
dbias
,
dsoftmax_offset
,
_
=
FusedAttnBwdPrimitive
.
inner_primitive
.
bind
(
q
,
q
,
k
,
k
,
v
,
v
,
bias
,
bias
,
softmax_offset
,
softmax_aux
,
softmax_aux
,
rng_state
,
rng_state
,
output
,
output
,
...
@@ -1012,15 +1063,15 @@ class FusedAttnBwdPrimitive(BasePrimitive):
...
@@ -1012,15 +1063,15 @@ class FusedAttnBwdPrimitive(BasePrimitive):
_kv_segment_pos
,
_kv_segment_pos
,
config
=
config
,
config
=
config
,
)
)
return
dq
,
dk
,
dv
,
dbias
return
dq
,
dk
,
dv
,
dbias
,
dsoftmax_offset
@
staticmethod
@
staticmethod
def
batcher
(
batched_args
,
batch_dims
,
*
,
config
):
def
batcher
(
batched_args
,
batch_dims
,
*
,
config
):
check_valid_batch_dims
(
batch_dims
)
check_valid_batch_dims
(
batch_dims
)
assert
FusedAttnBwdPrimitive
.
outer_primitive
is
not
None
assert
FusedAttnBwdPrimitive
.
outer_primitive
is
not
None
q_bdim
,
k_bdim
,
v_bdim
,
*
_
=
batch_dims
q_bdim
,
k_bdim
,
v_bdim
,
bias_bdim
,
softmax_offset_bdim
,
*
_
=
batch_dims
out_bdims
=
q_bdim
,
k_bdim
,
v_bdim
,
q
_bdim
out_bdims
=
q_bdim
,
k_bdim
,
v_bdim
,
bias_bdim
,
softmax_offset
_bdim
return
(
return
(
FusedAttnBwdPrimitive
.
outer_primitive
.
bind
(
*
batched_args
,
config
=
config
),
FusedAttnBwdPrimitive
.
outer_primitive
.
bind
(
*
batched_args
,
config
=
config
),
out_bdims
,
out_bdims
,
...
@@ -1033,11 +1084,13 @@ class FusedAttnBwdPrimitive(BasePrimitive):
...
@@ -1033,11 +1084,13 @@ class FusedAttnBwdPrimitive(BasePrimitive):
k_spec
=
get_padded_spec
(
arg_infos
[
1
])
k_spec
=
get_padded_spec
(
arg_infos
[
1
])
v_spec
=
get_padded_spec
(
arg_infos
[
2
])
v_spec
=
get_padded_spec
(
arg_infos
[
2
])
bias_spec
=
get_padded_spec
(
arg_infos
[
3
])
bias_spec
=
get_padded_spec
(
arg_infos
[
3
])
softmax_offset_spec
=
get_padded_spec
(
arg_infos
[
4
])
dq_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
q_spec
))
dq_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
q_spec
))
dk_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
k_spec
))
dk_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
k_spec
))
dv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
v_spec
))
dv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
v_spec
))
dbias_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
bias_spec
))
dbias_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
bias_spec
))
return
(
dq_sharding
,
dk_sharding
,
dv_sharding
,
dbias_sharding
)
dsoftmax_offset_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
softmax_offset_spec
))
return
(
dq_sharding
,
dk_sharding
,
dv_sharding
,
dbias_sharding
,
dsoftmax_offset_sharding
)
@
staticmethod
@
staticmethod
def
partition
(
config
,
mesh
,
arg_infos
,
result_infos
):
def
partition
(
config
,
mesh
,
arg_infos
,
result_infos
):
...
@@ -1046,21 +1099,30 @@ class FusedAttnBwdPrimitive(BasePrimitive):
...
@@ -1046,21 +1099,30 @@ class FusedAttnBwdPrimitive(BasePrimitive):
k_spec
=
get_padded_spec
(
arg_infos
[
1
])
k_spec
=
get_padded_spec
(
arg_infos
[
1
])
v_spec
=
get_padded_spec
(
arg_infos
[
2
])
v_spec
=
get_padded_spec
(
arg_infos
[
2
])
bias_spec
=
get_padded_spec
(
arg_infos
[
3
])
bias_spec
=
get_padded_spec
(
arg_infos
[
3
])
softmax_offset_spec
=
get_padded_spec
(
arg_infos
[
4
])
dq_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
q_spec
))
dq_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
q_spec
))
dk_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
k_spec
))
dk_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
k_spec
))
dv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
v_spec
))
dv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
v_spec
))
dbias_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
bias_spec
))
dbias_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
bias_spec
))
dsoftmax_offset_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
softmax_offset_spec
))
arg_shardings
=
[
arg_i
.
sharding
for
arg_i
in
arg_infos
]
arg_shardings
=
[
arg_i
.
sharding
for
arg_i
in
arg_infos
]
arg_shardings
[
-
1
]
=
arg_shardings
[
-
3
]
arg_shardings
[
-
1
]
=
arg_shardings
[
-
3
]
arg_shardings
[
-
2
]
=
arg_shardings
[
-
4
]
arg_shardings
[
-
2
]
=
arg_shardings
[
-
4
]
arg_shardings
=
tuple
(
arg_shardings
)
arg_shardings
=
tuple
(
arg_shardings
)
out_shardings
=
(
dq_sharding
,
dk_sharding
,
dv_sharding
,
dbias_sharding
)
out_shardings
=
(
dq_sharding
,
dk_sharding
,
dv_sharding
,
dbias_sharding
,
dsoftmax_offset_sharding
,
)
def
sharded_impl
(
def
sharded_impl
(
q
,
q
,
k
,
k
,
v
,
v
,
bias
,
bias
,
softmax_offset
,
softmax_aux
,
softmax_aux
,
rng_state
,
rng_state
,
output
,
output
,
...
@@ -1074,36 +1136,43 @@ class FusedAttnBwdPrimitive(BasePrimitive):
...
@@ -1074,36 +1136,43 @@ class FusedAttnBwdPrimitive(BasePrimitive):
_q_segment_pos
,
_q_segment_pos
,
_kv_segment_pos
,
_kv_segment_pos
,
):
):
local_dq
,
local_dk
,
local_dv
,
local_dbias
=
FusedAttnBwdPrimitive
.
impl
(
local_dq
,
local_dk
,
local_dv
,
local_dbias
,
local_dsoftmax_offset
=
(
q
,
FusedAttnBwdPrimitive
.
impl
(
k
,
q
,
v
,
k
,
bias
,
v
,
softmax_aux
,
bias
,
rng_state
,
softmax_offset
,
output
,
softmax_aux
,
doutput
,
rng_state
,
q_cu_seqlen
,
output
,
kv_cu_seqlen
,
doutput
,
q_seq_offsets
,
q_cu_seqlen
,
k_seq_offsets
,
kv_cu_seqlen
,
_q_segment_ids
,
q_seq_offsets
,
_kv_segment_ids
,
k_seq_offsets
,
_q_segment_pos
,
_q_segment_ids
,
_kv_segment_pos
,
_kv_segment_ids
,
config
=
config
,
_q_segment_pos
,
_kv_segment_pos
,
config
=
config
,
)
)
)
global_dbias
=
local_dbias
global_dbias
=
local_dbias
if
config
.
attn_bias_type
is
not
AttnBiasType
.
NO_BIAS
:
if
config
.
attn_bias_type
is
not
AttnBiasType
.
NO_BIAS
:
global_dbias
=
all_reduce_sum_along_dp_fsdp
(
local_dbias
,
mesh
)
global_dbias
=
all_reduce_sum_along_dp_fsdp
(
local_dbias
,
mesh
)
return
local_dq
,
local_dk
,
local_dv
,
global_dbias
global_dsoftmax_offset
=
local_dsoftmax_offset
if
config
.
softmax_type
==
AttnSoftmaxType
.
LEARNABLE_SOFTMAX
:
global_dsoftmax_offset
=
all_reduce_sum_along_dp_fsdp
(
local_dsoftmax_offset
,
mesh
)
return
local_dq
,
local_dk
,
local_dv
,
global_dbias
,
global_dsoftmax_offset
return
mesh
,
sharded_impl
,
out_shardings
,
arg_shardings
return
mesh
,
sharded_impl
,
out_shardings
,
arg_shardings
@
staticmethod
@
staticmethod
def
shardy_sharding_rule
(
config
,
mesh
,
value_types
,
result_types
):
def
shardy_sharding_rule
(
config
,
mesh
,
value_types
,
result_types
):
del
config
,
mesh
del
config
,
mesh
# We only care about the four first arguments.
# Keep in sync with `infer_sharding_from_operands`.
# Keep in sync with `infer_sharding_from_operands`.
input_spec
=
tuple
((
f
"…
{
x
}
"
,)
for
x
in
range
(
len
(
value_types
)))
input_spec
=
tuple
((
f
"…
{
x
}
"
,)
for
x
in
range
(
len
(
value_types
)))
output_spec
=
tuple
((
f
"…
{
x
}
"
,)
for
x
in
range
(
len
(
result_types
)))
output_spec
=
tuple
((
f
"…
{
x
}
"
,)
for
x
in
range
(
len
(
result_types
)))
...
@@ -1229,6 +1298,11 @@ class _FusedAttnCPWithAllGatherHelper:
...
@@ -1229,6 +1298,11 @@ class _FusedAttnCPWithAllGatherHelper:
if
self
.
config
.
dropout_probability
!=
0.0
:
if
self
.
config
.
dropout_probability
!=
0.0
:
raise
ValueError
(
f
"
{
header
}
does not support dropout"
)
raise
ValueError
(
f
"
{
header
}
does not support dropout"
)
if
self
.
config
.
softmax_type
!=
AttnSoftmaxType
.
VANILLA_SOFTMAX
:
raise
ValueError
(
f
"
{
header
}
only supports VANILLA_SOFTMAX, got:
{
self
.
config
.
softmax_type
}
"
)
def
get_adjusted_mask
(
self
):
def
get_adjusted_mask
(
self
):
"""Converts the mask for context parallelism."""
"""Converts the mask for context parallelism."""
if
self
.
config
.
attn_mask_type
==
AttnMaskType
.
CAUSAL_MASK
:
if
self
.
config
.
attn_mask_type
==
AttnMaskType
.
CAUSAL_MASK
:
...
@@ -1240,6 +1314,7 @@ class _FusedAttnCPWithAllGatherHelper:
...
@@ -1240,6 +1314,7 @@ class _FusedAttnCPWithAllGatherHelper:
return
_FusedAttnConfig
(
return
_FusedAttnConfig
(
attn_bias_type
=
self
.
config
.
attn_bias_type
,
attn_bias_type
=
self
.
config
.
attn_bias_type
,
attn_mask_type
=
self
.
get_adjusted_mask
(),
attn_mask_type
=
self
.
get_adjusted_mask
(),
softmax_type
=
self
.
config
.
softmax_type
,
qkv_layout
=
self
.
config
.
qkv_layout
,
qkv_layout
=
self
.
config
.
qkv_layout
,
scaling_factor
=
self
.
config
.
scaling_factor
,
scaling_factor
=
self
.
config
.
scaling_factor
,
dropout_probability
=
self
.
config
.
dropout_probability
,
dropout_probability
=
self
.
config
.
dropout_probability
,
...
@@ -1376,7 +1451,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
...
@@ -1376,7 +1451,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
mesh
,
PartitionSpec
(
get_all_mesh_axes
(),
None
)
mesh
,
PartitionSpec
(
get_all_mesh_axes
(),
None
)
)
)
arg_shardings
=
[
arg_i
.
sharding
for
arg_i
in
arg_infos
]
arg_shardings
=
[
arg_i
.
sharding
for
arg_i
in
arg_infos
]
arg_shardings
[
4
]
=
seed_sharding
arg_shardings
[
5
]
=
seed_sharding
arg_shardings
=
tuple
(
arg_shardings
)
arg_shardings
=
tuple
(
arg_shardings
)
out_shardings
=
(
out_sharding
,
softmax_aux_sharding
,
rng_state_sharding
)
out_shardings
=
(
out_sharding
,
softmax_aux_sharding
,
rng_state_sharding
)
...
@@ -1385,6 +1460,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
...
@@ -1385,6 +1460,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
k
,
k
,
v
,
v
,
bias
,
bias
,
softmax_offset
,
seed
,
seed
,
q_seqlen
,
q_seqlen
,
kv_seqlen
,
kv_seqlen
,
...
@@ -1404,7 +1480,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
...
@@ -1404,7 +1480,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
# meeting the expectation of the SPMD model.
# meeting the expectation of the SPMD model.
# TODO(mgoldfarb-nvidia): When cuDNN supports we should be able to make use of a padding
# TODO(mgoldfarb-nvidia): When cuDNN supports we should be able to make use of a padding
# mask/sequence length tensor to avoid this unrolled loop.
# mask/sequence length tensor to avoid this unrolled loop.
def
_cross_attn
(
idx
,
q
,
k
,
v
,
bias
,
q_seqlen
,
kv_seqlen
,
seed
):
def
_cross_attn
(
idx
,
q
,
k
,
v
,
bias
,
softmax_offset
,
q_seqlen
,
kv_seqlen
,
seed
):
kv_max_seqlen
=
k
.
shape
[
1
]
kv_max_seqlen
=
k
.
shape
[
1
]
kv_seqlen_per_subrank
=
kv_max_seqlen
//
(
cp_size
*
2
)
kv_seqlen_per_subrank
=
kv_max_seqlen
//
(
cp_size
*
2
)
assert
kv_max_seqlen
%
cp_size
==
0
,
"sequence length must evenly divide cp size"
assert
kv_max_seqlen
%
cp_size
==
0
,
"sequence length must evenly divide cp size"
...
@@ -1431,6 +1507,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
...
@@ -1431,6 +1507,7 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
k_unmasked
,
k_unmasked
,
v_unmasked
,
v_unmasked
,
bias
,
bias
,
softmax_offset
,
seed
,
seed
,
q_seqlen_for_step
,
q_seqlen_for_step
,
kv_seqlen_for_step
,
kv_seqlen_for_step
,
...
@@ -1453,7 +1530,9 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
...
@@ -1453,7 +1530,9 @@ class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
k_ag
,
v_ag
=
helper
.
all_gather_kv
(
k
,
v
)
k_ag
,
v_ag
=
helper
.
all_gather_kv
(
k
,
v
)
functions
=
[
functions
=
[
partial
(
_cross_attn
,
idx
,
q
,
k_ag
,
v_ag
,
bias
,
q_seqlen
,
kv_seqlen
,
seed
)
partial
(
_cross_attn
,
idx
,
q
,
k_ag
,
v_ag
,
bias
,
softmax_offset
,
q_seqlen
,
kv_seqlen
,
seed
)
for
idx
in
range
(
cp_size
)
for
idx
in
range
(
cp_size
)
]
]
...
@@ -1492,18 +1571,27 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
...
@@ -1492,18 +1571,27 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
k_spec
=
get_padded_spec
(
arg_infos
[
1
])
k_spec
=
get_padded_spec
(
arg_infos
[
1
])
v_spec
=
get_padded_spec
(
arg_infos
[
2
])
v_spec
=
get_padded_spec
(
arg_infos
[
2
])
bias_spec
=
get_padded_spec
(
arg_infos
[
3
])
bias_spec
=
get_padded_spec
(
arg_infos
[
3
])
softmax_offset_spec
=
get_padded_spec
(
arg_infos
[
4
])
dq_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
q_spec
))
dq_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
q_spec
))
dk_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
k_spec
))
dk_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
k_spec
))
dv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
v_spec
))
dv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
v_spec
))
dbias_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
bias_spec
))
dbias_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
bias_spec
))
dsoftmax_offset_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
softmax_offset_spec
))
arg_shardings
=
tuple
(
arg_i
.
sharding
for
arg_i
in
arg_infos
)
arg_shardings
=
tuple
(
arg_i
.
sharding
for
arg_i
in
arg_infos
)
out_shardings
=
(
dq_sharding
,
dk_sharding
,
dv_sharding
,
dbias_sharding
)
out_shardings
=
(
dq_sharding
,
dk_sharding
,
dv_sharding
,
dbias_sharding
,
dsoftmax_offset_sharding
,
)
def
impl
(
def
impl
(
q
,
q
,
k
,
k
,
v
,
v
,
bias
,
bias
,
softmax_offset
,
softmax_aux
,
softmax_aux
,
rng_state
,
rng_state
,
output
,
output
,
...
@@ -1527,6 +1615,7 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
...
@@ -1527,6 +1615,7 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
k
,
k
,
v
,
v
,
bias
,
bias
,
softmax_offset
,
softmax_aux
,
softmax_aux
,
rng_state
,
rng_state
,
output
,
output
,
...
@@ -1562,11 +1651,12 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
...
@@ -1562,11 +1651,12 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
num_kv_chunks
=
kv_max_seqlen
//
kv_seqlens_for_rank
[
sub_idx
]
num_kv_chunks
=
kv_max_seqlen
//
kv_seqlens_for_rank
[
sub_idx
]
kv_seqlen_for_step
=
(
kv_seqlen
//
(
cp_size
*
2
))
*
num_kv_chunks
kv_seqlen_for_step
=
(
kv_seqlen
//
(
cp_size
*
2
))
*
num_kv_chunks
dq_local
,
dk_local
,
dv_local
,
dbias_local
=
FusedAttnBwdPrimitive
.
impl
(
dq_local
,
dk_local
,
dv_local
,
dbias_local
,
_
=
FusedAttnBwdPrimitive
.
impl
(
q_split
[
sub_idx
],
q_split
[
sub_idx
],
k_unmasked
,
k_unmasked
,
v_unmasked
,
v_unmasked
,
bias
,
bias
,
softmax_offset
,
softmax_aux_split
[
sub_idx
],
softmax_aux_split
[
sub_idx
],
rng_state
,
rng_state
,
output_split
[
sub_idx
],
output_split
[
sub_idx
],
...
@@ -1604,6 +1694,7 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
...
@@ -1604,6 +1694,7 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
k_ag
,
k_ag
,
v_ag
,
v_ag
,
bias
,
bias
,
softmax_offset
,
softmax_aux
,
softmax_aux
,
rng_state
,
rng_state
,
output
,
output
,
...
@@ -1621,7 +1712,9 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
...
@@ -1621,7 +1712,9 @@ class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
dq
,
dk_local
,
dv_local
,
dbias
=
lax
.
switch
(
cp_rank
,
functions
)
dq
,
dk_local
,
dv_local
,
dbias
=
lax
.
switch
(
cp_rank
,
functions
)
dk
,
dv
=
helper
.
reduce_scatter_dkv
(
dk_local
,
dv_local
)
dk
,
dv
=
helper
.
reduce_scatter_dkv
(
dk_local
,
dv_local
)
return
dq
,
dk
,
dv
,
dbias
# Return dummy dsoftmax_offset for arity matching (all-gather CP doesn't use it)
dummy_dsoftmax_offset
=
jnp
.
empty_like
(
softmax_offset
)
return
dq
,
dk
,
dv
,
dbias
,
dummy_dsoftmax_offset
return
mesh
,
impl
,
out_shardings
,
arg_shardings
return
mesh
,
impl
,
out_shardings
,
arg_shardings
...
@@ -1679,6 +1772,11 @@ class _FusedAttnCPWithP2PHelper:
...
@@ -1679,6 +1772,11 @@ class _FusedAttnCPWithP2PHelper:
if
self
.
config
.
dropout_probability
!=
0.0
:
if
self
.
config
.
dropout_probability
!=
0.0
:
raise
ValueError
(
f
"
{
header
}
does not support dropout"
)
raise
ValueError
(
f
"
{
header
}
does not support dropout"
)
if
self
.
config
.
softmax_type
!=
AttnSoftmaxType
.
VANILLA_SOFTMAX
:
raise
ValueError
(
f
"
{
header
}
only supports VANILLA_SOFTMAX, got:
{
self
.
config
.
softmax_type
}
"
)
# We want to encourage use of scan loop to minimize unrolling and ensure more
# We want to encourage use of scan loop to minimize unrolling and ensure more
# predictable scheduling from XLA. The unrolled flavor will be supported but
# predictable scheduling from XLA. The unrolled flavor will be supported but
# not the prefered implementation.
# not the prefered implementation.
...
@@ -1703,6 +1801,7 @@ class _FusedAttnCPWithP2PHelper:
...
@@ -1703,6 +1801,7 @@ class _FusedAttnCPWithP2PHelper:
return
_FusedAttnConfig
(
return
_FusedAttnConfig
(
attn_bias_type
=
self
.
config
.
attn_bias_type
,
attn_bias_type
=
self
.
config
.
attn_bias_type
,
attn_mask_type
=
attn_mask_type
,
attn_mask_type
=
attn_mask_type
,
softmax_type
=
self
.
config
.
softmax_type
,
qkv_layout
=
QKVLayout
.
BSHD_BS2HD
,
qkv_layout
=
QKVLayout
.
BSHD_BS2HD
,
scaling_factor
=
self
.
config
.
scaling_factor
,
scaling_factor
=
self
.
config
.
scaling_factor
,
dropout_probability
=
self
.
config
.
dropout_probability
,
dropout_probability
=
self
.
config
.
dropout_probability
,
...
@@ -1783,7 +1882,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
...
@@ -1783,7 +1882,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
mesh
,
PartitionSpec
(
get_all_mesh_axes
(),
None
)
mesh
,
PartitionSpec
(
get_all_mesh_axes
(),
None
)
)
)
arg_shardings
=
[
arg_i
.
sharding
for
arg_i
in
arg_infos
]
arg_shardings
=
[
arg_i
.
sharding
for
arg_i
in
arg_infos
]
arg_shardings
[
4
]
=
seed_sharding
arg_shardings
[
5
]
=
seed_sharding
# Ensure segment_pos gets same sharding as ID.
# Ensure segment_pos gets same sharding as ID.
arg_shardings
[
-
1
]
=
arg_shardings
[
-
3
]
arg_shardings
[
-
1
]
=
arg_shardings
[
-
3
]
arg_shardings
[
-
2
]
=
arg_shardings
[
-
4
]
arg_shardings
[
-
2
]
=
arg_shardings
[
-
4
]
...
@@ -1795,6 +1894,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
...
@@ -1795,6 +1894,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
k
,
k
,
v
,
v
,
bias
,
bias
,
_softmax_offset
,
seed
,
seed
,
q_seqlen
,
q_seqlen
,
kv_seqlen
,
kv_seqlen
,
...
@@ -1840,6 +1940,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
...
@@ -1840,6 +1940,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
kv
,
kv
,
_not_used
,
_not_used
,
bias
,
bias
,
_softmax_offset
,
seed
,
seed
,
q_seqlen_per_step
,
q_seqlen_per_step
,
kv_seqlen_per_step
,
kv_seqlen_per_step
,
...
@@ -1865,6 +1966,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
...
@@ -1865,6 +1966,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
kv_part
,
kv_part
,
_not_used
,
_not_used
,
bias
,
bias
,
_softmax_offset
,
seed
,
seed
,
q_seqlen_per_step
,
q_seqlen_per_step
,
kv_seqlen_per_step
,
kv_seqlen_per_step
,
...
@@ -1887,6 +1989,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
...
@@ -1887,6 +1989,7 @@ class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
kv
,
kv
,
_not_used
,
_not_used
,
bias
,
bias
,
_softmax_offset
,
seed
,
seed
,
q_seqlen_per_step
,
q_seqlen_per_step
,
kv_seqlen_per_step
,
kv_seqlen_per_step
,
...
@@ -1990,18 +2093,24 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
...
@@ -1990,18 +2093,24 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
k_spec
=
get_padded_spec
(
arg_infos
[
1
])
k_spec
=
get_padded_spec
(
arg_infos
[
1
])
v_spec
=
get_padded_spec
(
arg_infos
[
2
])
v_spec
=
get_padded_spec
(
arg_infos
[
2
])
bias_spec
=
get_padded_spec
(
arg_infos
[
3
])
bias_spec
=
get_padded_spec
(
arg_infos
[
3
])
softmax_offset_spec
=
get_padded_spec
(
arg_infos
[
4
])
dq_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
q_spec
))
dq_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
q_spec
))
dk_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
k_spec
))
dk_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
k_spec
))
dv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
v_spec
))
dv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
v_spec
))
dbias_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
bias_spec
))
dbias_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
bias_spec
))
# Ring attention doesn't use dsoftmax_offset, but we need to return it for arity matching
dsoftmax_offset_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
softmax_offset_spec
))
arg_shardings
=
[
arg_i
.
sharding
for
arg_i
in
arg_infos
]
arg_shardings
=
[
arg_i
.
sharding
for
arg_i
in
arg_infos
]
# Ensure segment_pos gets same sharding as ID.
arg_shardings
[
-
1
]
=
arg_shardings
[
-
3
]
arg_shardings
[
-
1
]
=
arg_shardings
[
-
3
]
arg_shardings
[
-
2
]
=
arg_shardings
[
-
4
]
arg_shardings
[
-
2
]
=
arg_shardings
[
-
4
]
arg_shardings
=
tuple
(
arg_shardings
)
arg_shardings
=
tuple
(
arg_shardings
)
out_shardings
=
(
out_shardings
=
(
dq_sharding
,
dk_sharding
,
dv_sharding
,
dbias_sharding
)
dq_sharding
,
dk_sharding
,
dv_sharding
,
dbias_sharding
,
dsoftmax_offset_sharding
,
)
helper
=
_FusedAttnCPWithP2PHelper
(
mesh
,
config
)
helper
=
_FusedAttnCPWithP2PHelper
(
mesh
,
config
)
helper
.
check_supported
()
helper
.
check_supported
()
...
@@ -2011,6 +2120,7 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
...
@@ -2011,6 +2120,7 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
k
,
k
,
v
,
v
,
bias
,
bias
,
_softmax_offset
,
softmax_aux
,
softmax_aux
,
rng_state
,
rng_state
,
output
,
output
,
...
@@ -2054,11 +2164,12 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
...
@@ -2054,11 +2164,12 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
def
mask_compute
(
attn_mask_type
):
def
mask_compute
(
attn_mask_type
):
q_seqlen_per_step
=
helper
.
adjust_seqlen
(
q_seqlen
,
q_max_seqlen
,
idx
)
q_seqlen_per_step
=
helper
.
adjust_seqlen
(
q_seqlen
,
q_max_seqlen
,
idx
)
kv_seqlen_per_step
=
helper
.
adjust_seqlen
(
kv_seqlen
,
kv_max_seqlen
,
idx
)
kv_seqlen_per_step
=
helper
.
adjust_seqlen
(
kv_seqlen
,
kv_max_seqlen
,
idx
)
dq_per_step
,
dk_dv_per_step
,
_
,
dbias_per_step
=
FusedAttnBwdPrimitive
.
impl
(
dq_per_step
,
dk_dv_per_step
,
_
,
dbias_per_step
,
_
=
FusedAttnBwdPrimitive
.
impl
(
q
,
q
,
kv
,
kv
,
_not_used
,
_not_used
,
bias
,
bias
,
_softmax_offset
,
softmax_aux
,
softmax_aux
,
rng_state
,
rng_state
,
output
,
output
,
...
@@ -2082,11 +2193,12 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
...
@@ -2082,11 +2193,12 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
q_seqlen_per_step
=
helper
.
adjust_seqlen
(
q_seqlen
,
q_max_seqlen
,
idx
)
q_seqlen_per_step
=
helper
.
adjust_seqlen
(
q_seqlen
,
q_max_seqlen
,
idx
)
kv_seqlen_per_step
=
helper
.
adjust_seqlen
(
kv_seqlen
,
kv_max_seqlen
,
idx
)
//
2
kv_seqlen_per_step
=
helper
.
adjust_seqlen
(
kv_seqlen
,
kv_max_seqlen
,
idx
)
//
2
kv_part
=
lax
.
slice_in_dim
(
kv
,
0
,
kv_max_seqlen
//
2
,
axis
=
1
)
kv_part
=
lax
.
slice_in_dim
(
kv
,
0
,
kv_max_seqlen
//
2
,
axis
=
1
)
dq_per_step
,
dk_dv_per_step
,
_
,
dbias_per_step
=
FusedAttnBwdPrimitive
.
impl
(
dq_per_step
,
dk_dv_per_step
,
_
,
dbias_per_step
,
_
=
FusedAttnBwdPrimitive
.
impl
(
q
,
q
,
kv_part
,
kv_part
,
_not_used
,
_not_used
,
bias
,
bias
,
_softmax_offset
,
softmax_aux
,
softmax_aux
,
rng_state
,
rng_state
,
output
,
output
,
...
@@ -2120,11 +2232,12 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
...
@@ -2120,11 +2232,12 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
softmax_aux
,
q_max_seqlen
//
2
,
q_max_seqlen
,
axis
=
2
softmax_aux
,
q_max_seqlen
//
2
,
q_max_seqlen
,
axis
=
2
)
)
dq_per_step
,
dk_dv_per_step
,
_
,
dbias_per_step
=
FusedAttnBwdPrimitive
.
impl
(
dq_per_step
,
dk_dv_per_step
,
_
,
dbias_per_step
,
_
=
FusedAttnBwdPrimitive
.
impl
(
q_part
,
q_part
,
kv
,
kv
,
_not_used
,
_not_used
,
bias
,
bias
,
_softmax_offset
,
softmax_aux_part
,
softmax_aux_part
,
rng_state
,
rng_state
,
output_part
,
output_part
,
...
@@ -2184,7 +2297,9 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
...
@@ -2184,7 +2297,9 @@ class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
global_dbias
=
all_reduce_sum_along_dp_fsdp
(
dbias
,
mesh
)
global_dbias
=
all_reduce_sum_along_dp_fsdp
(
dbias
,
mesh
)
dk
,
dv
=
helper
.
unstack_kv
(
dk_dv
)
dk
,
dv
=
helper
.
unstack_kv
(
dk_dv
)
return
dq
,
dk
,
dv
,
global_dbias
# Return dummy dsoftmax_offset for arity matching (ring attention doesn't use it)
dummy_dsoftmax_offset
=
jnp
.
empty_like
(
_softmax_offset
)
return
dq
,
dk
,
dv
,
global_dbias
,
dummy_dsoftmax_offset
return
mesh
,
ring_attn_bwd_impl
,
out_shardings
,
arg_shardings
return
mesh
,
ring_attn_bwd_impl
,
out_shardings
,
arg_shardings
...
@@ -2273,7 +2388,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
...
@@ -2273,7 +2388,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
mesh
,
PartitionSpec
(
get_all_mesh_axes
(),
None
)
mesh
,
PartitionSpec
(
get_all_mesh_axes
(),
None
)
)
)
arg_shardings
=
[
arg_i
.
sharding
for
arg_i
in
arg_infos
]
arg_shardings
=
[
arg_i
.
sharding
for
arg_i
in
arg_infos
]
arg_shardings
[
4
]
=
seed_sharding
arg_shardings
[
5
]
=
seed_sharding
# Ensure segment_pos gets same sharding as ID.
# Ensure segment_pos gets same sharding as ID.
arg_shardings
[
-
1
]
=
arg_shardings
[
-
3
]
arg_shardings
[
-
1
]
=
arg_shardings
[
-
3
]
arg_shardings
[
-
2
]
=
arg_shardings
[
-
4
]
arg_shardings
[
-
2
]
=
arg_shardings
[
-
4
]
...
@@ -2285,6 +2400,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
...
@@ -2285,6 +2400,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
k
,
k
,
v
,
v
,
bias
,
bias
,
_softmax_offset
,
seed
,
seed
,
q_seqlen
,
q_seqlen
,
kv_seqlen
,
kv_seqlen
,
...
@@ -2336,6 +2452,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
...
@@ -2336,6 +2452,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
kv
,
kv
,
_not_used
,
_not_used
,
bias
,
bias
,
_softmax_offset
,
seed
,
seed
,
q_seqlen
,
q_seqlen
,
kv_seqlen
,
kv_seqlen
,
...
@@ -2345,7 +2462,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
...
@@ -2345,7 +2462,7 @@ class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
kv_segment_ids
,
kv_segment_ids
,
q_segment_pos
,
q_segment_pos
,
kv_segment_pos
,
kv_segment_pos
,
config
,
config
=
config
,
)
)
if
config
.
window_size
!=
(
-
1
,
-
1
):
if
config
.
window_size
!=
(
-
1
,
-
1
):
...
@@ -2420,8 +2537,8 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
...
@@ -2420,8 +2537,8 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
arg_shardings
[
-
1
]
=
arg_shardings
[
-
3
]
arg_shardings
[
-
1
]
=
arg_shardings
[
-
3
]
arg_shardings
[
-
2
]
=
arg_shardings
[
-
4
]
arg_shardings
[
-
2
]
=
arg_shardings
[
-
4
]
arg_shardings
=
tuple
(
arg_shardings
)
arg_shardings
=
tuple
(
arg_shardings
)
# dq, dk, dv, dbias sharding = q, k, v, bias sharding
# dq, dk, dv, dbias
, dsoftmax_offset
sharding = q, k, v, bias
, softmax_offset
sharding
out_shardings
=
tuple
(
arg
.
sharding
for
arg
in
arg_infos
[:
4
])
out_shardings
=
tuple
(
arg
.
sharding
for
arg
in
arg_infos
[:
5
])
helper
=
_FusedAttnCPWithP2PHelper
(
mesh
,
config
)
helper
=
_FusedAttnCPWithP2PHelper
(
mesh
,
config
)
helper
.
check_supported
()
helper
.
check_supported
()
...
@@ -2431,6 +2548,7 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
...
@@ -2431,6 +2548,7 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
k
,
k
,
v
,
v
,
bias
,
bias
,
_softmax_offset
,
softmax_aux
,
softmax_aux
,
rng_state
,
rng_state
,
output
,
output
,
...
@@ -2478,11 +2596,12 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
...
@@ -2478,11 +2596,12 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
kv_segment_pos_next
=
helper
.
permute_kv
(
kv_segment_pos
,
cp_perm
)
kv_segment_pos_next
=
helper
.
permute_kv
(
kv_segment_pos
,
cp_perm
)
def
compute
(
config
):
def
compute
(
config
):
dq_per_step
,
dkv_per_step
,
_
,
dbias_per_step
=
FusedAttnBwdPrimitive
.
impl
(
dq_per_step
,
dkv_per_step
,
_
,
dbias_per_step
,
_
=
FusedAttnBwdPrimitive
.
impl
(
q
,
q
,
kv
,
kv
,
_not_used
,
_not_used
,
bias
,
bias
,
_softmax_offset
,
softmax_aux
,
softmax_aux
,
rng_state
,
rng_state
,
output
,
output
,
...
@@ -2536,7 +2655,9 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
...
@@ -2536,7 +2655,9 @@ class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
global_dbias
=
all_reduce_sum_along_dp_fsdp
(
dbias
,
mesh
)
global_dbias
=
all_reduce_sum_along_dp_fsdp
(
dbias
,
mesh
)
dk
,
dv
=
helper
.
unstack_kv
(
dkv
)
dk
,
dv
=
helper
.
unstack_kv
(
dkv
)
return
dq
,
dk
,
dv
,
global_dbias
# Return dummy dsoftmax_offset for arity matching (ring attention doesn't use it)
dummy_dsoftmax_offset
=
jnp
.
empty_like
(
_softmax_offset
)
return
dq
,
dk
,
dv
,
global_dbias
,
dummy_dsoftmax_offset
return
mesh
,
bwd_impl
,
out_shardings
,
arg_shardings
return
mesh
,
bwd_impl
,
out_shardings
,
arg_shardings
...
@@ -2557,10 +2678,12 @@ def _maybe_context_parallel_axis(cp_axis: str):
...
@@ -2557,10 +2678,12 @@ def _maybe_context_parallel_axis(cp_axis: str):
def
fused_attn_fwd
(
def
fused_attn_fwd
(
qkv
:
Tuple
[
jnp
.
ndarray
,
...],
qkv
:
Tuple
[
jnp
.
ndarray
,
...],
bias
:
Optional
[
jnp
.
ndarray
],
bias
:
Optional
[
jnp
.
ndarray
],
softmax_offset
:
Optional
[
jnp
.
ndarray
],
sequence_descriptor
:
SequenceDescriptor
,
sequence_descriptor
:
SequenceDescriptor
,
seed
:
Optional
[
jnp
.
ndarray
],
seed
:
Optional
[
jnp
.
ndarray
],
attn_bias_type
:
AttnBiasType
,
attn_bias_type
:
AttnBiasType
,
attn_mask_type
:
AttnMaskType
,
attn_mask_type
:
AttnMaskType
,
softmax_type
:
AttnSoftmaxType
,
qkv_layout
:
QKVLayout
,
qkv_layout
:
QKVLayout
,
scaling_factor
:
float
,
scaling_factor
:
float
,
dropout_probability
:
float
,
dropout_probability
:
float
,
...
@@ -2585,6 +2708,7 @@ def fused_attn_fwd(
...
@@ -2585,6 +2708,7 @@ def fused_attn_fwd(
query has a different shape (e.g., cross-attention).
query has a different shape (e.g., cross-attention).
- `(query, key, value)`: For separate query, key, and value tensors.
- `(query, key, value)`: For separate query, key, and value tensors.
bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
softmax_offset (Optional[jnp.ndarray]): An optional softmax offset tensor.
q_seqlen (jnp.ndarray): Sequence lengths for the query, with shape [batch,].
q_seqlen (jnp.ndarray): Sequence lengths for the query, with shape [batch,].
kv_seqlen (jnp.ndarray): Sequence lengths for the key and value, with shape [batch,].
kv_seqlen (jnp.ndarray): Sequence lengths for the key and value, with shape [batch,].
q_seq_offsets (jnp.ndarray):
q_seq_offsets (jnp.ndarray):
...
@@ -2594,6 +2718,7 @@ def fused_attn_fwd(
...
@@ -2594,6 +2718,7 @@ def fused_attn_fwd(
seed (Optional[jnp.ndarray]): Optional random seed for dropout.
seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (AttnBiasType): Type of attention bias.
attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (AttnMaskType): Type of attention mask.
attn_mask_type (AttnMaskType): Type of attention mask.
softmax_type (AttnSoftmaxType): Type of softmax.
qkv_layout (QKVLayout): Layout of the QKV tensors.
qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores.
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
dropout_probability (float): Dropout probability to apply during attention.
...
@@ -2633,10 +2758,36 @@ def fused_attn_fwd(
...
@@ -2633,10 +2758,36 @@ def fused_attn_fwd(
assert
bias
is
None
assert
bias
is
None
bias
=
jnp
.
zeros
(
0
,
dtype
=
qkv
[
0
].
dtype
)
bias
=
jnp
.
zeros
(
0
,
dtype
=
qkv
[
0
].
dtype
)
if
softmax_offset
is
None
:
assert
(
softmax_type
!=
AttnSoftmaxType
.
LEARNABLE_SOFTMAX
),
f
"Softmax type
{
softmax_type
}
is not supported when softmax_offset is None"
if
softmax_type
==
AttnSoftmaxType
.
OFF_BY_ONE_SOFTMAX
:
num_heads
=
qkv
[
0
].
shape
[
-
2
]
# Create tensor [1, h, 1, 1] filled with zeros (logit value = 0)
# This adds exp(0 - x_max) = exp(-x_max) to the denominator,
# which contributes exactly 1 after normalization, giving: exp(x_i) / (sum(exp(x_j)) + 1)
softmax_offset
=
jnp
.
zeros
((
1
,
num_heads
,
1
,
1
),
dtype
=
jnp
.
float32
)
# Shard by heads dimension
softmax_offset
=
with_sharding_constraint_by_logical_axes
(
softmax_offset
,
(
None
,
HEAD_AXES
,
None
,
None
)
)
else
:
assert
softmax_type
==
AttnSoftmaxType
.
VANILLA_SOFTMAX
softmax_offset
=
jnp
.
zeros
(
0
,
dtype
=
jnp
.
float32
)
else
:
assert
softmax_offset
.
dtype
==
jnp
.
float32
# Shard by heads dimension if not VANILLA_SOFTMAX
if
softmax_type
!=
AttnSoftmaxType
.
VANILLA_SOFTMAX
:
softmax_offset
=
with_sharding_constraint_by_logical_axes
(
softmax_offset
,
(
None
,
HEAD_AXES
,
None
,
None
)
)
fused_config
=
_FusedAttnConfig
(
fused_config
=
_FusedAttnConfig
(
attn_bias_type
=
attn_bias_type
,
attn_bias_type
=
attn_bias_type
,
attn_mask_type
=
attn_mask_type
,
attn_mask_type
=
attn_mask_type
,
qkv_layout
=
qkv_layout
,
qkv_layout
=
qkv_layout
,
softmax_type
=
softmax_type
,
scaling_factor
=
scaling_factor
,
scaling_factor
=
scaling_factor
,
dropout_probability
=
dropout_probability
,
dropout_probability
=
dropout_probability
,
is_training
=
is_training
,
is_training
=
is_training
,
...
@@ -2662,6 +2813,7 @@ def fused_attn_fwd(
...
@@ -2662,6 +2813,7 @@ def fused_attn_fwd(
output
,
softmax_aux
,
rng_state
=
primitive
.
bind
(
output
,
softmax_aux
,
rng_state
=
primitive
.
bind
(
*
qkv_for_primitive
,
*
qkv_for_primitive
,
bias
,
bias
,
softmax_offset
,
seed
,
seed
,
*
seq_desc_flatten
,
*
seq_desc_flatten
,
config
=
fused_config
,
config
=
fused_config
,
...
@@ -2673,6 +2825,7 @@ def fused_attn_fwd(
...
@@ -2673,6 +2825,7 @@ def fused_attn_fwd(
def
fused_attn_bwd
(
def
fused_attn_bwd
(
qkv
:
Tuple
[
jnp
.
ndarray
,
...],
qkv
:
Tuple
[
jnp
.
ndarray
,
...],
bias
:
Optional
[
jnp
.
ndarray
],
bias
:
Optional
[
jnp
.
ndarray
],
softmax_offset
:
Optional
[
jnp
.
ndarray
],
softmax_aux
:
jnp
.
ndarray
,
softmax_aux
:
jnp
.
ndarray
,
rng_state
:
jnp
.
ndarray
,
rng_state
:
jnp
.
ndarray
,
output
:
jnp
.
ndarray
,
output
:
jnp
.
ndarray
,
...
@@ -2681,6 +2834,7 @@ def fused_attn_bwd(
...
@@ -2681,6 +2834,7 @@ def fused_attn_bwd(
attn_bias_type
:
AttnBiasType
,
attn_bias_type
:
AttnBiasType
,
attn_mask_type
:
AttnMaskType
,
attn_mask_type
:
AttnMaskType
,
qkv_layout
:
QKVLayout
,
qkv_layout
:
QKVLayout
,
softmax_type
:
AttnSoftmaxType
,
scaling_factor
:
float
,
scaling_factor
:
float
,
dropout_probability
:
float
,
dropout_probability
:
float
,
is_training
:
bool
,
is_training
:
bool
,
...
@@ -2702,6 +2856,7 @@ def fused_attn_bwd(
...
@@ -2702,6 +2856,7 @@ def fused_attn_bwd(
query has a different shape (e.g., cross-attention).
query has a different shape (e.g., cross-attention).
- `(query, key, value)`: For separate query, key, and value tensors.
- `(query, key, value)`: For separate query, key, and value tensors.
bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
softmax_offset (Optional[jnp.ndarray]): An optional softmax offset tensor.
softmax_aux (jnp.ndarray): Auxiliary tensors from the softmax step used in the forward pass.
softmax_aux (jnp.ndarray): Auxiliary tensors from the softmax step used in the forward pass.
rng_state (jnp.ndarray): Auxiliary tensors to save the random state in the forward pass.
rng_state (jnp.ndarray): Auxiliary tensors to save the random state in the forward pass.
output (jnp.ndarray): The output tensor from the forward pass.
output (jnp.ndarray): The output tensor from the forward pass.
...
@@ -2714,6 +2869,7 @@ def fused_attn_bwd(
...
@@ -2714,6 +2869,7 @@ def fused_attn_bwd(
The offsets in the sequence dim for the query, with shape [batch + 1,].
The offsets in the sequence dim for the query, with shape [batch + 1,].
attn_bias_type (AttnBiasType): Type of attention bias.
attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (AttnMaskType): Type of attention mask.
attn_mask_type (AttnMaskType): Type of attention mask.
softmax_type (AttnSoftmaxType): Type of softmax.
qkv_layout (QKVLayout): Layout of the QKV tensors.
qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores.
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
dropout_probability (float): Dropout probability to apply during attention.
...
@@ -2755,6 +2911,28 @@ def fused_attn_bwd(
...
@@ -2755,6 +2911,28 @@ def fused_attn_bwd(
assert
bias
is
None
assert
bias
is
None
bias
=
jnp
.
zeros
(
0
,
dtype
=
qkv
[
0
].
dtype
)
bias
=
jnp
.
zeros
(
0
,
dtype
=
qkv
[
0
].
dtype
)
if
softmax_offset
is
None
:
assert
softmax_type
!=
AttnSoftmaxType
.
LEARNABLE_SOFTMAX
,
f
"Unknown
{
softmax_type
=
}
"
if
softmax_type
==
AttnSoftmaxType
.
OFF_BY_ONE_SOFTMAX
:
num_heads
=
qkv
[
0
].
shape
[
-
2
]
# Create tensor [1, h, 1, 1] filled with zeros
softmax_offset
=
jnp
.
zeros
((
1
,
num_heads
,
1
,
1
),
dtype
=
jnp
.
float32
)
# Shard by heads dimension
softmax_offset
=
with_sharding_constraint_by_logical_axes
(
softmax_offset
,
(
None
,
HEAD_AXES
,
None
,
None
)
)
elif
softmax_type
==
AttnSoftmaxType
.
VANILLA_SOFTMAX
:
softmax_offset
=
jnp
.
zeros
(
0
,
dtype
=
jnp
.
float32
)
else
:
raise
NotImplementedError
(
f
"Unknown
{
softmax_type
=
}
"
)
else
:
softmax_offset
=
softmax_offset
.
astype
(
jnp
.
float32
)
# Shard by heads dimension if not VANILLA_SOFTMAX
if
softmax_type
!=
AttnSoftmaxType
.
VANILLA_SOFTMAX
:
softmax_offset
=
with_sharding_constraint_by_logical_axes
(
softmax_offset
,
(
None
,
HEAD_AXES
,
None
,
None
)
)
# TODO(KshitijLakhani): Add a check for cuDNN version when determinism does get supported on
# TODO(KshitijLakhani): Add a check for cuDNN version when determinism does get supported on
# sm100+
# sm100+
compute_capabilities
=
get_all_device_compute_capability
()
compute_capabilities
=
get_all_device_compute_capability
()
...
@@ -2767,6 +2945,7 @@ def fused_attn_bwd(
...
@@ -2767,6 +2945,7 @@ def fused_attn_bwd(
attn_bias_type
=
attn_bias_type
,
attn_bias_type
=
attn_bias_type
,
attn_mask_type
=
attn_mask_type
,
attn_mask_type
=
attn_mask_type
,
qkv_layout
=
qkv_layout
,
qkv_layout
=
qkv_layout
,
softmax_type
=
softmax_type
,
scaling_factor
=
scaling_factor
,
scaling_factor
=
scaling_factor
,
dropout_probability
=
dropout_probability
,
dropout_probability
=
dropout_probability
,
is_training
=
is_training
,
is_training
=
is_training
,
...
@@ -2788,9 +2967,10 @@ def fused_attn_bwd(
...
@@ -2788,9 +2967,10 @@ def fused_attn_bwd(
primitive
=
FusedRingAttnBwdPrimitive
.
outer_primitive
primitive
=
FusedRingAttnBwdPrimitive
.
outer_primitive
seq_desc_flatten
,
_
=
jax
.
tree
.
flatten
(
sequence_descriptor
)
seq_desc_flatten
,
_
=
jax
.
tree
.
flatten
(
sequence_descriptor
)
*
qkv_grads
,
bias_grad
=
primitive
.
bind
(
*
qkv_grads
,
bias_grad
,
softmax_offset_grad
=
primitive
.
bind
(
*
qkv_for_primitive
,
*
qkv_for_primitive
,
bias
,
bias
,
softmax_offset
,
softmax_aux
,
softmax_aux
,
rng_state
,
rng_state
,
output
,
output
,
...
@@ -2798,4 +2978,4 @@ def fused_attn_bwd(
...
@@ -2798,4 +2978,4 @@ def fused_attn_bwd(
*
seq_desc_flatten
,
*
seq_desc_flatten
,
config
=
fused_config
,
config
=
fused_config
,
)
)
return
tuple
(
qkv_grads
[:
len
(
qkv
)]),
bias_grad
return
tuple
(
qkv_grads
[:
len
(
qkv
)]),
bias_grad
,
softmax_offset_grad
transformer_engine/jax/cpp_extensions/gemm.py
View file @
970620a5
...
@@ -39,12 +39,12 @@ from ..quantize import (
...
@@ -39,12 +39,12 @@ from ..quantize import (
Quantizer
,
Quantizer
,
GroupedQuantizer
,
GroupedQuantizer
,
QuantizerSet
,
QuantizerSet
,
QuantizeLayout
,
noop_quantizer_set
,
noop_quantizer_set
,
is_fp8_gemm_with_all_layouts_supported
,
is_fp8_gemm_with_all_layouts_supported
,
apply_padding_to_scale_inv
,
apply_padding_to_scale_inv
,
get_quantize_config_with_recipe
,
get_quantize_config_with_recipe
,
get_global_quantize_recipe
,
get_global_quantize_recipe
,
QuantizeLayout
,
)
)
from
.misc
import
get_padded_spec
,
is_all_reduce_in_float32
from
.misc
import
get_padded_spec
,
is_all_reduce_in_float32
from
..sharding
import
(
from
..sharding
import
(
...
...
transformer_engine/jax/cpp_extensions/misc.py
View file @
970620a5
...
@@ -116,7 +116,7 @@ def multidim_transpose(shape, static_axis_boundary=-1, transpose_axis=-1):
...
@@ -116,7 +116,7 @@ def multidim_transpose(shape, static_axis_boundary=-1, transpose_axis=-1):
transpose. Note, transpose_axis should be greater than static_axis_boundary
transpose. Note, transpose_axis should be greater than static_axis_boundary
examples:
examples:
X
in
shape (dim0, dim1, dim2, dim3, dim4)
X
of
shape (dim0, dim1, dim2, dim3, dim4)
static_axis_boundary == -1, transpose_axis == 2
static_axis_boundary == -1, transpose_axis == 2
Xt = (dim2, dim3, dim4, dim0, dim1)
Xt = (dim2, dim3, dim4, dim0, dim1)
...
...
transformer_engine/jax/cpp_extensions/normalization.py
View file @
970620a5
...
@@ -35,9 +35,9 @@ from ..sharding import (
...
@@ -35,9 +35,9 @@ from ..sharding import (
from
..quantize
import
ScaledTensor
,
ScaledTensorFactory
,
NoScaleTensor
from
..quantize
import
ScaledTensor
,
ScaledTensorFactory
,
NoScaleTensor
from
..quantize
import
(
from
..quantize
import
(
Quantizer
,
Quantizer
,
QuantizeLayout
,
DelayedScaleQuantizer
,
DelayedScaleQuantizer
,
ScalingMode
,
ScalingMode
,
QuantizeLayout
,
)
)
...
...
transformer_engine/jax/cpp_extensions/quantization.py
View file @
970620a5
...
@@ -40,11 +40,11 @@ from ..quantize import (
...
@@ -40,11 +40,11 @@ from ..quantize import (
GroupedScaledTensor1x
,
GroupedScaledTensor1x
,
Quantizer
,
Quantizer
,
GroupedQuantizer
,
GroupedQuantizer
,
QuantizeLayout
,
ScalingMode
,
ScalingMode
,
compute_scale_from_amax
,
compute_scale_from_amax
,
NoScaleTensor
,
NoScaleTensor
,
get_rht_matrix
,
get_rht_matrix
,
QuantizeLayout
,
)
)
...
@@ -497,6 +497,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
...
@@ -497,6 +497,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
x_spec
=
get_padded_spec
(
arg_infos
[
0
])
x_spec
=
get_padded_spec
(
arg_infos
[
0
])
amax_spec
=
get_padded_spec
(
arg_infos
[
2
])
amax_spec
=
get_padded_spec
(
arg_infos
[
2
])
sr_rng_state_spec
=
get_padded_spec
(
arg_infos
[
3
])
out_sharding
=
NamedSharding
(
out_sharding
=
NamedSharding
(
mesh
,
mesh
,
PartitionSpec
(
*
x_spec
),
PartitionSpec
(
*
x_spec
),
...
@@ -551,11 +552,14 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
...
@@ -551,11 +552,14 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
)
)
arg_shardings
=
list
(
arg_i
.
sharding
for
arg_i
in
arg_infos
)
arg_shardings
=
list
(
arg_i
.
sharding
for
arg_i
in
arg_infos
)
arg_shardings
[
3
]
=
NamedSharding
(
if
len
(
sr_rng_state_spec
)
>
1
:
mesh
,
# sr_rng_state shape [n_devices, state_per_device]
PartitionSpec
(
tuple
(
x
for
x
in
x_spec
if
x
is
not
None
),
None
),
sr_rng_state_spec
=
(
*
tuple
(
x
for
x
in
x_spec
if
x
is
not
None
),
None
)
desc
=
"BaseDBiasQuantizePrimitive.sr_rng_state"
,
arg_shardings
[
3
]
=
NamedSharding
(
)
mesh
,
PartitionSpec
(
*
sr_rng_state_spec
),
desc
=
"BaseDBiasQuantizePrimitive.sr_rng_state"
,
)
arg_shardings
=
tuple
(
arg_shardings
)
arg_shardings
=
tuple
(
arg_shardings
)
out_shardings
=
(
out_shardings
=
(
out_sharding
,
out_sharding
,
...
@@ -654,10 +658,12 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
...
@@ -654,10 +658,12 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
dbias
=
input_spec
[
flatten_axis
:]
if
is_dbias
else
(
prefix
+
"_dbias"
,)
dbias
=
input_spec
[
flatten_axis
:]
if
is_dbias
else
(
prefix
+
"_dbias"
,)
amax
=
(
BATCHING
+
prefix
+
"_amax"
,)
amax
=
(
BATCHING
+
prefix
+
"_amax"
,)
scale
=
(
BATCHING
+
prefix
+
"_scale"
,)
scale
=
(
BATCHING
+
prefix
+
"_scale"
,)
sr_rng_state
=
(
sr_rng_state
=
(
BATCHING
+
prefix
+
"_sr_rng_state"
,)
BATCHING
+
prefix
+
"_sr_rng_state_partition_axis"
,
if
value_types
[
3
].
shape
!=
[
0
]:
BATCHING
+
prefix
+
"sr_rng_state_data_axis"
,
sr_rng_state
=
(
)
BATCHING
+
prefix
+
"_sr_rng_state_devices"
,
prefix
+
"sr_rng_state_data"
,
)
post_rht_amax
=
(
BATCHING
+
prefix
+
"_post_rht_amax"
,)
post_rht_amax
=
(
BATCHING
+
prefix
+
"_post_rht_amax"
,)
rht_matrix
=
(
BATCHING
+
prefix
+
"_rht_matrix_1"
,
BATCHING
+
prefix
+
"_rht_matrix_2"
)
rht_matrix
=
(
BATCHING
+
prefix
+
"_rht_matrix_1"
,
BATCHING
+
prefix
+
"_rht_matrix_2"
)
...
@@ -849,7 +855,7 @@ def _quantize_dbias_impl(
...
@@ -849,7 +855,7 @@ def _quantize_dbias_impl(
if
force_1x_quantization
:
if
force_1x_quantization
:
q_layout
=
QuantizeLayout
.
ROWWISE
q_layout
=
QuantizeLayout
.
ROWWISE
sr_rng_state
=
None
sr_rng_state
=
jnp
.
empty
((
0
,),
jnp
.
uint32
)
if
quantizer
.
scaling_mode
.
is_nvfp4_scaling
:
if
quantizer
.
scaling_mode
.
is_nvfp4_scaling
:
# Only NVFP4 scaling modes support stochastic rounding
# Only NVFP4 scaling modes support stochastic rounding
if
quantizer
.
stochastic_rounding_rng_state
is
not
None
:
if
quantizer
.
stochastic_rounding_rng_state
is
not
None
:
...
@@ -866,11 +872,7 @@ def _quantize_dbias_impl(
...
@@ -866,11 +872,7 @@ def _quantize_dbias_impl(
x
.
data
,
x
.
data
,
scale
,
scale
,
amax
,
amax
,
(
sr_rng_state
,
sr_rng_state
if
sr_rng_state
is
not
None
else
jnp
.
empty
((
get_num_devices_in_mesh
(),
1
),
jnp
.
uint32
)
),
post_rht_amax
if
post_rht_amax
is
not
None
else
jnp
.
zeros
((
1
,),
jnp
.
float32
),
post_rht_amax
if
post_rht_amax
is
not
None
else
jnp
.
zeros
((
1
,),
jnp
.
float32
),
rht_matrix
,
rht_matrix
,
out_dtype
=
quantizer
.
q_dtype
,
out_dtype
=
quantizer
.
q_dtype
,
...
@@ -880,7 +882,7 @@ def _quantize_dbias_impl(
...
@@ -880,7 +882,7 @@ def _quantize_dbias_impl(
scale_dtype
=
quantizer
.
get_scale_dtype
(),
scale_dtype
=
quantizer
.
get_scale_dtype
(),
is_dbias
=
is_dbias
if
not
quantizer
.
scaling_mode
.
is_nvfp4_scaling
else
False
,
is_dbias
=
is_dbias
if
not
quantizer
.
scaling_mode
.
is_nvfp4_scaling
else
False
,
is_outer
=
True
,
is_outer
=
True
,
stochastic_rounding
=
sr_rng_state
is
not
None
,
stochastic_rounding
=
sr_rng_state
.
size
!=
0
,
use_rht
=
use_rht
,
use_rht
=
use_rht
,
)
)
# For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
# For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
...
...
transformer_engine/jax/cpp_extensions/softmax.py
View file @
970620a5
...
@@ -11,10 +11,11 @@ import jax
...
@@ -11,10 +11,11 @@ import jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
jax
import
dtypes
,
ffi
from
jax
import
dtypes
,
ffi
from
jax.sharding
import
PartitionSpec
,
NamedSharding
from
jax.sharding
import
PartitionSpec
,
NamedSharding
from
.attention
import
AttnSoftmaxType
from
.base
import
BasePrimitive
,
register_primitive
from
.base
import
BasePrimitive
,
register_primitive
from
.misc
import
get_padded_spec
,
check_valid_batch_dims
from
.misc
import
get_padded_spec
,
check_valid_batch_dims
from
..softmax
import
SoftmaxType
from
..softmax
import
Softmax
Fusion
Type
__all__
=
[
__all__
=
[
...
@@ -32,7 +33,8 @@ __all__ = [
...
@@ -32,7 +33,8 @@ __all__ = [
def
is_softmax_kernel_available
(
def
is_softmax_kernel_available
(
softmax_type
:
SoftmaxType
,
softmax_fusion_type
:
SoftmaxFusionType
,
softmax_type
:
AttnSoftmaxType
,
batch
:
int
,
batch
:
int
,
heads
:
int
,
heads
:
int
,
q_seqlen
:
int
,
q_seqlen
:
int
,
...
@@ -40,15 +42,18 @@ def is_softmax_kernel_available(
...
@@ -40,15 +42,18 @@ def is_softmax_kernel_available(
dtype
:
jnp
.
dtype
,
dtype
:
jnp
.
dtype
,
):
):
"""check softmax available"""
"""check softmax available"""
if
softmax_type
is
SoftmaxType
.
SCALED
:
if
softmax_type
!=
AttnSoftmaxType
.
VANILLA_SOFTMAX
:
return
False
if
softmax_fusion_type
is
SoftmaxFusionType
.
SCALED
:
return
ScaledSoftmaxFwdPrimitive
.
is_kernel_available
(
return
ScaledSoftmaxFwdPrimitive
.
is_kernel_available
(
batch
,
heads
,
q_seqlen
,
k_seqlen
,
dtype
batch
,
heads
,
q_seqlen
,
k_seqlen
,
dtype
)
)
if
softmax_type
is
SoftmaxType
.
SCALED_MASKED
:
if
softmax_
fusion_
type
is
Softmax
Fusion
Type
.
SCALED_MASKED
:
return
ScaledMaskedSoftmaxFwdPrimitive
.
is_kernel_available
(
return
ScaledMaskedSoftmaxFwdPrimitive
.
is_kernel_available
(
batch
,
heads
,
q_seqlen
,
k_seqlen
,
dtype
batch
,
heads
,
q_seqlen
,
k_seqlen
,
dtype
)
)
if
softmax_type
is
SoftmaxType
.
SCALED_UPPER_TRIANG_MASKED
:
if
softmax_
fusion_
type
is
Softmax
Fusion
Type
.
SCALED_UPPER_TRIANG_MASKED
:
return
ScaledUpperTriangMaskedSoftmaxFwdPrimitive
.
is_kernel_available
(
return
ScaledUpperTriangMaskedSoftmaxFwdPrimitive
.
is_kernel_available
(
batch
,
heads
,
q_seqlen
,
k_seqlen
,
dtype
batch
,
heads
,
q_seqlen
,
k_seqlen
,
dtype
)
)
...
@@ -792,26 +797,77 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
...
@@ -792,26 +797,77 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
register_primitive
(
ScaledUpperTriangMaskedSoftmaxBwdPrimitive
)
register_primitive
(
ScaledUpperTriangMaskedSoftmaxBwdPrimitive
)
def
jax_scaled_softmax
(
logits
:
jnp
.
ndarray
,
scale_factor
:
float
):
def
jax_scaled_softmax
(
logits
:
jnp
.
ndarray
,
scale_factor
:
float
,
softmax_offset
:
jnp
.
ndarray
|
float
|
None
=
None
):
"""
"""
JAX based implementation of scaled softmax
JAX based implementation of scaled softmax
"""
"""
if
softmax_offset
is
not
None
:
return
jax_general_softmax
(
scale_factor
*
logits
,
offset
=
softmax_offset
)
return
jax
.
nn
.
softmax
(
scale_factor
*
logits
)
return
jax
.
nn
.
softmax
(
scale_factor
*
logits
)
def
jax_scaled_masked_softmax
(
logits
:
jnp
.
ndarray
,
mask
:
jnp
.
ndarray
,
scale_factor
:
float
):
def
jax_scaled_masked_softmax
(
logits
:
jnp
.
ndarray
,
mask
:
jnp
.
ndarray
,
scale_factor
:
float
,
softmax_offset
:
jnp
.
ndarray
|
float
|
None
=
None
,
):
"""
"""
JAX based implementation of scaled and masked softmax
JAX based implementation of scaled and masked softmax
"""
"""
if
softmax_offset
is
not
None
:
return
jax_general_softmax
(
logits
*
scale_factor
,
offset
=
softmax_offset
,
where
=
mask
!=
1
)
return
jax
.
nn
.
softmax
(
logits
*
scale_factor
,
where
=
mask
!=
1
)
return
jax
.
nn
.
softmax
(
logits
*
scale_factor
,
where
=
mask
!=
1
)
def
jax_scaled_upper_triang_masked_softmax
(
logits
:
jnp
.
ndarray
,
scale_factor
:
float
):
def
jax_scaled_upper_triang_masked_softmax
(
logits
:
jnp
.
ndarray
,
scale_factor
:
float
,
softmax_offset
:
jnp
.
ndarray
|
float
|
None
=
None
):
"""
"""
JAX based implementation of scaled and upper triangle masked softmax
JAX based implementation of scaled and upper triangle masked softmax
"""
"""
mask
=
1
-
jnp
.
tril
(
jnp
.
ones_like
(
logits
))
mask
=
1
-
jnp
.
tril
(
jnp
.
ones_like
(
logits
))
return
jax_scaled_masked_softmax
(
logits
,
mask
,
scale_factor
)
return
jax_scaled_masked_softmax
(
logits
,
mask
,
scale_factor
,
softmax_offset
)
def
jax_general_softmax
(
x
:
jnp
.
ndarray
,
axis
:
int
=
-
1
,
where
:
jnp
.
ndarray
|
None
=
None
,
initial
:
jnp
.
ndarray
=
-
jnp
.
inf
,
offset
:
jnp
.
ndarray
|
float
|
None
=
None
,
)
->
jnp
.
ndarray
:
"""
JAX based implementation of general softmax with optional masking and offset.
"""
# Compute max of x
x_max
=
jnp
.
max
(
x
,
axis
,
where
=
where
,
initial
=
initial
,
keepdims
=
True
)
if
offset
is
not
None
:
# Cast offset to x.dtype to prevent type promotion
if
isinstance
(
offset
,
(
int
,
float
)):
offset
=
jnp
.
array
(
offset
,
dtype
=
x
.
dtype
)
else
:
offset
=
offset
.
astype
(
x
.
dtype
)
# Include offset in max: x_max = max(x_max, offset)
# This is equivalent to computing max over [x..., offset]
x_max
=
jnp
.
maximum
(
x_max
,
offset
)
x_safe
=
x
if
where
is
None
else
jnp
.
where
(
where
,
x
,
initial
)
unnormalized
=
jnp
.
exp
(
x_safe
-
x_max
)
denominator
=
jnp
.
sum
(
unnormalized
,
axis
,
where
=
where
,
keepdims
=
True
)
if
offset
is
not
None
:
# Add exp(offset - x_max) to denominator
denominator
=
denominator
+
jnp
.
exp
(
offset
-
x_max
)
result
=
unnormalized
/
denominator
if
where
is
not
None
:
result
=
jnp
.
where
(
where
,
result
,
0
)
return
result
def
scaled_softmax_fwd
(
logits
:
jnp
.
ndarray
,
scale_factor
:
float
)
->
jnp
.
ndarray
:
def
scaled_softmax_fwd
(
logits
:
jnp
.
ndarray
,
scale_factor
:
float
)
->
jnp
.
ndarray
:
...
...
transformer_engine/jax/csrc/extensions.h
View file @
970620a5
...
@@ -108,28 +108,28 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler);
...
@@ -108,28 +108,28 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
FusedAttnBackwardHandler
);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
FusedAttnBackwardHandler
);
NVTE_Fused_Attn_Backend
GetFusedAttnBackend
(
bool
is_training
,
DType
q_dtype
,
DType
kv_dtype
,
NVTE_Fused_Attn_Backend
GetFusedAttnBackend
(
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
bool
is_training
,
DType
q_dtype
,
DType
kv_dtype
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Mask_Type
mask_type
,
float
dropout_probability
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
size_t
q_num_heads
,
size_t
kv_num_heads
,
float
dropout_probability
,
size_t
q_attn_heads
,
size_t
kv_attn_heads
,
size_t
q_max_seqlen
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
qk_head_dim
,
size_t
v_head_dim
,
int64_t
window_size_left
,
size_t
qk_head_dim
,
size_t
v_head_dim
,
int64_t
window_size_right
);
int64_t
window_size_left
,
int64_t
window_size_right
);
pybind11
::
tuple
GetFusedAttnForwardWorkspaceSizes
(
pybind11
::
tuple
GetFusedAttnForwardWorkspaceSizes
(
size_t
input_batch
,
size_t
bias_batch
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
input_batch
,
size_t
bias_batch
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
attn_heads
,
size_t
num_gqa_groups
,
size_t
bias_heads
,
size_t
qk_head_dim
,
size_t
attn_heads
,
size_t
num_gqa_groups
,
size_t
bias_heads
,
size_t
qk_head_dim
,
size_t
v_head_dim
,
float
scaling_factor
,
float
dropout_probability
,
NVTE_Bias_Type
bias_type
,
size_t
v_head_dim
,
float
scaling_factor
,
float
dropout_probability
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_QKV_Layout
qkv_layout
,
DType
dtype
,
bool
is_training
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
NVTE_QKV_Layout
qkv_layout
,
size_t
max_segments_per_seq
,
int64_t
window_size_left
,
int64_t
window_size_right
);
DType
dtype
,
bool
is_training
,
size_t
max_segments_per_seq
,
int64_t
window_size_left
,
int64_t
window_size_right
);
pybind11
::
tuple
GetFusedAttnBackwardWorkspaceSizes
(
pybind11
::
tuple
GetFusedAttnBackwardWorkspaceSizes
(
size_t
input_batch
,
size_t
bias_batch
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
input_batch
,
size_t
bias_batch
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
attn_heads
,
size_t
num_gqa_groups
,
size_t
bias_heads
,
size_t
qk_head_dim
,
size_t
attn_heads
,
size_t
num_gqa_groups
,
size_t
bias_heads
,
size_t
qk_head_dim
,
size_t
v_head_dim
,
float
scaling_factor
,
float
dropout_probability
,
NVTE_Bias_Type
bias_type
,
size_t
v_head_dim
,
float
scaling_factor
,
float
dropout_probability
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_
QKV_Layout
qkv_layout
,
DType
dtype
,
bool
is_training
,
NVTE_Mask_Type
mask_type
,
NVTE_
Softmax_Type
softmax_type
,
NVTE_QKV_Layout
qkv_layout
,
bool
deterministic
,
size_t
max_segments_per_seq
,
int64_t
window_size_left
,
DType
dtype
,
bool
is_training
,
bool
deterministic
,
size_t
max_segments_per_seq
,
int64_t
window_size_right
);
int64_t
window_size_left
,
int64_t
window_size_right
);
// GEMM
// GEMM
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
GemmHandler
);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
GemmHandler
);
...
...
transformer_engine/jax/csrc/extensions/attention.cpp
View file @
970620a5
...
@@ -11,14 +11,12 @@
...
@@ -11,14 +11,12 @@
namespace
transformer_engine
{
namespace
transformer_engine
{
namespace
jax
{
namespace
jax
{
NVTE_Fused_Attn_Backend
GetFusedAttnBackend
(
bool
is_training
,
DType
q_dtype
,
DType
kv_dtype
,
NVTE_Fused_Attn_Backend
GetFusedAttnBackend
(
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
bool
is_training
,
DType
q_dtype
,
DType
kv_dtype
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Mask_Type
mask_type
,
float
dropout_probability
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
size_t
q_attn_heads
,
size_t
kv_attn_heads
,
float
dropout_probability
,
size_t
q_attn_heads
,
size_t
kv_attn_heads
,
size_t
q_max_seqlen
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
qk_head_dim
,
size_t
v_head_dim
,
int64_t
window_size_left
,
size_t
qk_head_dim
,
size_t
v_head_dim
,
int64_t
window_size_right
)
{
int64_t
window_size_left
,
int64_t
window_size_right
)
{
NVTE_Softmax_Type
softmax_type
=
NVTE_Softmax_Type
::
NVTE_VANILLA_SOFTMAX
;
auto
backend
=
nvte_get_fused_attn_backend
(
auto
backend
=
nvte_get_fused_attn_backend
(
is_training
,
static_cast
<
NVTEDType
>
(
q_dtype
),
static_cast
<
NVTEDType
>
(
kv_dtype
),
qkv_layout
,
is_training
,
static_cast
<
NVTEDType
>
(
q_dtype
),
static_cast
<
NVTEDType
>
(
kv_dtype
),
qkv_layout
,
bias_type
,
mask_type
,
softmax_type
,
dropout_probability
,
q_attn_heads
,
kv_attn_heads
,
bias_type
,
mask_type
,
softmax_type
,
dropout_probability
,
q_attn_heads
,
kv_attn_heads
,
...
@@ -39,7 +37,8 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t
...
@@ -39,7 +37,8 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t
const
size_t
kv_max_seqlen
,
DType
dtype
,
const
size_t
kv_max_seqlen
,
DType
dtype
,
NVTE_Bias_Type
bias_type
,
NVTE_Fused_Attn_Backend
backend
,
NVTE_Bias_Type
bias_type
,
NVTE_Fused_Attn_Backend
backend
,
void
*
softmax_buf
,
void
*
rng_state_buf
=
nullptr
,
void
*
softmax_buf
,
void
*
rng_state_buf
=
nullptr
,
void
*
bias_buf
=
nullptr
)
{
void
*
bias_buf
=
nullptr
,
void
*
softmax_offset_buf
=
nullptr
)
{
// all backends need softmax but expect different shapes/dtypes
// all backends need softmax but expect different shapes/dtypes
// start with the max512 sequence length softmax shape/dtype and correct later
// start with the max512 sequence length softmax shape/dtype and correct later
tensor_pack
->
size
=
1
;
tensor_pack
->
size
=
1
;
...
@@ -67,10 +66,12 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t
...
@@ -67,10 +66,12 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t
softmax_aux_data
.
shape
.
data
[
3
]
=
1
;
// {B,H,Qs,Ks} -> {B,H,Qs,1}
softmax_aux_data
.
shape
.
data
[
3
]
=
1
;
// {B,H,Qs,Ks} -> {B,H,Qs,1}
softmax_aux_data
.
dtype
=
static_cast
<
NVTEDType
>
(
DType
::
kFloat32
);
softmax_aux_data
.
dtype
=
static_cast
<
NVTEDType
>
(
DType
::
kFloat32
);
int
size
=
2
;
// Start at 2 (we have softmax and rng_state at indices 0, 1)
// include bias if enabled
// include bias if enabled
if
(
bias_type
!=
NVTE_Bias_Type
::
NVTE_NO_BIAS
&&
bias_type
!=
NVTE_Bias_Type
::
NVTE_ALIBI
)
{
if
(
bias_type
!=
NVTE_Bias_Type
::
NVTE_NO_BIAS
&&
bias_type
!=
NVTE_Bias_Type
::
NVTE_ALIBI
)
{
tensor_pack
->
size
=
3
;
NVTETensor
&
bias_aux
=
tensor_pack
->
tensors
[
size
]
;
NVTETensor
&
bias_aux
=
tensor_pack
->
tensors
[
2
]
;
size
++
;
NVTEBasicTensor
bias_aux_data
;
NVTEBasicTensor
bias_aux_data
;
bias_aux_data
.
data_ptr
=
bias_buf
;
bias_aux_data
.
data_ptr
=
bias_buf
;
bias_aux_data
.
shape
.
ndim
=
4
;
bias_aux_data
.
shape
.
ndim
=
4
;
...
@@ -81,6 +82,24 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t
...
@@ -81,6 +82,24 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t
bias_aux_data
.
dtype
=
static_cast
<
NVTEDType
>
(
dtype
);
bias_aux_data
.
dtype
=
static_cast
<
NVTEDType
>
(
dtype
);
nvte_set_tensor_param
(
&
bias_aux
,
kNVTERowwiseData
,
&
bias_aux_data
);
nvte_set_tensor_param
(
&
bias_aux
,
kNVTERowwiseData
,
&
bias_aux_data
);
}
}
// include softmax_offset if provided
if
(
softmax_offset_buf
!=
nullptr
)
{
NVTETensor
&
softmax_offset_aux
=
tensor_pack
->
tensors
[
size
];
size
++
;
NVTEBasicTensor
softmax_offset_aux_data
;
softmax_offset_aux_data
.
data_ptr
=
softmax_offset_buf
;
softmax_offset_aux_data
.
shape
.
ndim
=
4
;
softmax_offset_aux_data
.
shape
.
data
[
0
]
=
1
;
softmax_offset_aux_data
.
shape
.
data
[
1
]
=
attn_heads
;
softmax_offset_aux_data
.
shape
.
data
[
2
]
=
1
;
softmax_offset_aux_data
.
shape
.
data
[
3
]
=
1
;
softmax_offset_aux_data
.
dtype
=
static_cast
<
NVTEDType
>
(
DType
::
kFloat32
);
nvte_set_tensor_param
(
&
softmax_offset_aux
,
kNVTERowwiseData
,
&
softmax_offset_aux_data
);
}
// Set final size
tensor_pack
->
size
=
size
;
}
}
nvte_set_tensor_param
(
&
softmax_aux
,
kNVTERowwiseData
,
&
softmax_aux_data
);
nvte_set_tensor_param
(
&
softmax_aux
,
kNVTERowwiseData
,
&
softmax_aux_data
);
}
}
...
@@ -98,14 +117,16 @@ void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack, const size_
...
@@ -98,14 +117,16 @@ void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack, const size_
const
size_t
bias_heads
,
const
size_t
q_max_seqlen
,
const
size_t
bias_heads
,
const
size_t
q_max_seqlen
,
const
size_t
kv_max_seqlen
,
DType
dtype
,
const
size_t
kv_max_seqlen
,
DType
dtype
,
NVTE_Fused_Attn_Backend
backend
,
void
*
softmax_buf
,
NVTE_Fused_Attn_Backend
backend
,
void
*
softmax_buf
,
void
*
rng_state_buf
,
void
*
bias_buf
)
{
void
*
rng_state_buf
,
void
*
bias_buf
,
void
*
softmax_offset_buf
=
nullptr
)
{
// Backward calls put everything into the tensor pack for every backend
// Backward calls put everything into the tensor pack for every backend
// so we set dummy bias_type and backend choices here to follow the correct code path
// so we set dummy bias_type and backend choices here to follow the correct code path
auto
dummy_bias_type
=
NVTE_Bias_Type
::
NVTE_POST_SCALE_BIAS
;
auto
dummy_bias_type
=
NVTE_Bias_Type
::
NVTE_POST_SCALE_BIAS
;
auto
dummy_backend
=
NVTE_Fused_Attn_Backend
::
NVTE_F16_arbitrary_seqlen
;
auto
dummy_backend
=
NVTE_Fused_Attn_Backend
::
NVTE_F16_arbitrary_seqlen
;
PrepareFusedAttnForwardAuxTensors
(
tensor_pack
,
input_batch
,
bias_batch
,
attn_heads
,
bias_heads
,
PrepareFusedAttnForwardAuxTensors
(
tensor_pack
,
input_batch
,
bias_batch
,
attn_heads
,
bias_heads
,
q_max_seqlen
,
kv_max_seqlen
,
dtype
,
dummy_bias_type
,
q_max_seqlen
,
kv_max_seqlen
,
dtype
,
dummy_bias_type
,
dummy_backend
,
softmax_buf
,
rng_state_buf
,
bias_buf
);
dummy_backend
,
softmax_buf
,
rng_state_buf
,
bias_buf
,
softmax_offset_buf
);
// correct softmax shape for max512 sequence length kernel
// correct softmax shape for max512 sequence length kernel
if
(
backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
if
(
backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
...
@@ -121,8 +142,9 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
...
@@ -121,8 +142,9 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t
input_batch
,
size_t
bias_batch
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
input_batch
,
size_t
bias_batch
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
attn_heads
,
size_t
num_gqa_groups
,
size_t
bias_heads
,
size_t
qk_head_dim
,
size_t
attn_heads
,
size_t
num_gqa_groups
,
size_t
bias_heads
,
size_t
qk_head_dim
,
size_t
v_head_dim
,
float
scaling_factor
,
float
dropout_probability
,
NVTE_Bias_Type
bias_type
,
size_t
v_head_dim
,
float
scaling_factor
,
float
dropout_probability
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_QKV_Layout
qkv_layout
,
DType
dtype
,
bool
is_training
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
NVTE_QKV_Layout
qkv_layout
,
size_t
max_segments_per_seq
,
int64_t
window_size_left
,
int64_t
window_size_right
)
{
DType
dtype
,
bool
is_training
,
size_t
max_segments_per_seq
,
int64_t
window_size_left
,
int64_t
window_size_right
)
{
auto
q_shape
=
std
::
vector
<
size_t
>
{
input_batch
*
q_max_seqlen
,
attn_heads
,
qk_head_dim
};
auto
q_shape
=
std
::
vector
<
size_t
>
{
input_batch
*
q_max_seqlen
,
attn_heads
,
qk_head_dim
};
auto
q_tensor
=
TensorWrapper
(
nullptr
,
q_shape
,
dtype
);
auto
q_tensor
=
TensorWrapper
(
nullptr
,
q_shape
,
dtype
);
auto
k_shape
=
std
::
vector
<
size_t
>
{
input_batch
*
kv_max_seqlen
,
num_gqa_groups
,
qk_head_dim
};
auto
k_shape
=
std
::
vector
<
size_t
>
{
input_batch
*
kv_max_seqlen
,
num_gqa_groups
,
qk_head_dim
};
...
@@ -141,7 +163,6 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
...
@@ -141,7 +163,6 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
auto
dummy_page_table_tensor
=
TensorWrapper
(
nullptr
,
std
::
vector
<
size_t
>
{
1
},
DType
::
kInt32
);
auto
dummy_page_table_tensor
=
TensorWrapper
(
nullptr
,
std
::
vector
<
size_t
>
{
1
},
DType
::
kInt32
);
auto
dummy_softmax_offset_tensor
=
auto
dummy_softmax_offset_tensor
=
TensorWrapper
(
nullptr
,
std
::
vector
<
size_t
>
{
1
},
DType
::
kFloat32
);
TensorWrapper
(
nullptr
,
std
::
vector
<
size_t
>
{
1
},
DType
::
kFloat32
);
NVTE_Softmax_Type
softmax_type
=
NVTE_Softmax_Type
::
NVTE_VANILLA_SOFTMAX
;
NVTETensorPack
aux_output_tensors
;
NVTETensorPack
aux_output_tensors
;
nvte_tensor_pack_create
(
&
aux_output_tensors
);
nvte_tensor_pack_create
(
&
aux_output_tensors
);
...
@@ -208,18 +229,21 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
...
@@ -208,18 +229,21 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
static
void
FusedAttnForwardImpl
(
static
void
FusedAttnForwardImpl
(
cudaStream_t
stream
,
void
*
q
,
void
*
k
,
void
*
v
,
void
*
bias
,
void
*
seed
,
void
*
q_cu_seqlens
,
cudaStream_t
stream
,
void
*
q
,
void
*
k
,
void
*
v
,
void
*
bias
,
void
*
softmax_offset
,
void
*
seed
,
void
*
kv_cu_seqlens
,
void
*
q_seq_offsets
,
void
*
k_seq_offsets
,
void
*
output
,
void
*
softmax_aux
,
void
*
q_cu_seqlens
,
void
*
kv_cu_seqlens
,
void
*
q_seq_offsets
,
void
*
k_seq_offsets
,
void
*
output
,
void
*
rng_state
,
void
*
workspace
,
size_t
input_batch
,
size_t
bias_batch
,
size_t
q_max_seqlen
,
void
*
softmax_aux
,
void
*
rng_state
,
void
*
workspace
,
size_t
input_batch
,
size_t
bias_batch
,
size_t
kv_max_seqlen
,
size_t
attn_heads
,
size_t
num_gqa_groups
,
size_t
bias_heads
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
attn_heads
,
size_t
num_gqa_groups
,
size_t
qk_head_dim
,
size_t
v_head_dim
,
size_t
max_segments_per_seq
,
size_t
wkspace_size
,
size_t
bias_heads
,
size_t
qk_head_dim
,
size_t
v_head_dim
,
size_t
max_segments_per_seq
,
float
scaling_factor
,
float
dropout_probability
,
NVTE_Bias_Type
bias_type
,
size_t
wkspace_size
,
float
scaling_factor
,
float
dropout_probability
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_QKV_Layout
qkv_layout
,
DType
dtype
,
DType
wkspace_dtype
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
NVTE_QKV_Layout
qkv_layout
,
bool
is_training
,
bool
deterministic
,
int64_t
window_size_left
,
int64_t
window_size_right
)
{
DType
dtype
,
DType
wkspace_dtype
,
bool
is_training
,
bool
deterministic
,
int64_t
window_size_left
,
int64_t
window_size_right
)
{
FUSED_ATTN_IMPL_COMMON_BLOCK
;
FUSED_ATTN_IMPL_COMMON_BLOCK
;
/* Input tensors */
/* Input tensors */
auto
bias_tensor
=
TensorWrapper
(
bias
,
bias_shape
,
dtype
);
auto
bias_tensor
=
TensorWrapper
(
bias
,
bias_shape
,
dtype
);
auto
softmax_offset_tensor
=
TensorWrapper
(
softmax_offset
,
std
::
vector
<
size_t
>
{
1
,
attn_heads
,
1
,
1
},
DType
::
kFloat32
);
if
(
is_ragged
)
{
if
(
is_ragged
)
{
auto
output_size
=
input_batch
*
q_max_seqlen
*
attn_heads
*
v_head_dim
;
auto
output_size
=
input_batch
*
q_max_seqlen
*
attn_heads
*
v_head_dim
;
...
@@ -238,10 +262,6 @@ static void FusedAttnForwardImpl(
...
@@ -238,10 +262,6 @@ static void FusedAttnForwardImpl(
/* Prepare RNG state */
/* Prepare RNG state */
auto
rng_state_tensor
=
TensorWrapper
(
rng_state
,
std
::
vector
<
size_t
>
{
2
},
DType
::
kInt64
);
auto
rng_state_tensor
=
TensorWrapper
(
rng_state
,
std
::
vector
<
size_t
>
{
2
},
DType
::
kInt64
);
auto
dummy_softmax_offset_tensor
=
TensorWrapper
(
nullptr
,
std
::
vector
<
size_t
>
{
1
},
DType
::
kFloat32
);
NVTE_Softmax_Type
softmax_type
=
NVTE_Softmax_Type
::
NVTE_VANILLA_SOFTMAX
;
auto
backend
=
nvte_get_fused_attn_backend
(
auto
backend
=
nvte_get_fused_attn_backend
(
is_training
,
static_cast
<
NVTEDType
>
(
dtype
),
static_cast
<
NVTEDType
>
(
dtype
),
qkv_layout
,
is_training
,
static_cast
<
NVTEDType
>
(
dtype
),
static_cast
<
NVTEDType
>
(
dtype
),
qkv_layout
,
bias_type
,
mask_type
,
softmax_type
,
dropout_probability
,
attn_heads
,
num_gqa_groups
,
bias_type
,
mask_type
,
softmax_type
,
dropout_probability
,
attn_heads
,
num_gqa_groups
,
...
@@ -254,7 +274,7 @@ static void FusedAttnForwardImpl(
...
@@ -254,7 +274,7 @@ static void FusedAttnForwardImpl(
nvte_tensor_pack_create
(
&
aux_output_tensors
);
nvte_tensor_pack_create
(
&
aux_output_tensors
);
PrepareFusedAttnForwardAuxTensors
(
&
aux_output_tensors
,
input_batch
,
bias_batch
,
attn_heads
,
PrepareFusedAttnForwardAuxTensors
(
&
aux_output_tensors
,
input_batch
,
bias_batch
,
attn_heads
,
bias_heads
,
q_max_seqlen
,
kv_max_seqlen
,
dtype
,
bias_type
,
bias_heads
,
q_max_seqlen
,
kv_max_seqlen
,
dtype
,
bias_type
,
backend
,
softmax_aux
);
backend
,
softmax_aux
,
softmax_offset
);
/* Call the underlying NVTE API */
/* Call the underlying NVTE API */
auto
dummy_page_table_tensor
=
TensorWrapper
(
nullptr
,
std
::
vector
<
size_t
>
{
1
},
DType
::
kInt32
);
auto
dummy_page_table_tensor
=
TensorWrapper
(
nullptr
,
std
::
vector
<
size_t
>
{
1
},
DType
::
kInt32
);
...
@@ -303,7 +323,7 @@ static void FusedAttnForwardImpl(
...
@@ -303,7 +323,7 @@ static void FusedAttnForwardImpl(
nvte_fused_attn_fwd
(
nvte_fused_attn_fwd
(
q_tensor
.
data
(),
k_tensor
.
data
(),
v_tensor
.
data
(),
bias_tensor
.
data
(),
q_tensor
.
data
(),
k_tensor
.
data
(),
v_tensor
.
data
(),
bias_tensor
.
data
(),
dummy_
softmax_offset_tensor
.
data
(),
s_tensor
.
data
(),
o_tensor
.
data
(),
&
aux_output_tensors
,
softmax_offset_tensor
.
data
(),
s_tensor
.
data
(),
o_tensor
.
data
(),
&
aux_output_tensors
,
q_cu_seqlens_tensor
.
data
(),
kv_cu_seqlens_tensor
.
data
(),
q_seq_offsets_tensor
.
data
(),
q_cu_seqlens_tensor
.
data
(),
kv_cu_seqlens_tensor
.
data
(),
q_seq_offsets_tensor
.
data
(),
k_seq_offsets_tensor
.
data
(),
dummy_page_table_tensor
.
data
(),
dummy_page_table_tensor
.
data
(),
k_seq_offsets_tensor
.
data
(),
dummy_page_table_tensor
.
data
(),
dummy_page_table_tensor
.
data
(),
rng_state_tensor
.
data
(),
q_max_seqlen
,
kv_max_seqlen
,
is_training
,
false
,
false
,
rng_state_tensor
.
data
(),
q_max_seqlen
,
kv_max_seqlen
,
is_training
,
false
,
false
,
...
@@ -332,6 +352,8 @@ static void FusedAttnForwardImpl(
...
@@ -332,6 +352,8 @@ static void FusedAttnForwardImpl(
static_cast<NVTE_Bias_Type>(get_attr_value<int64_t>(attrs, "bias_type")); \
static_cast<NVTE_Bias_Type>(get_attr_value<int64_t>(attrs, "bias_type")); \
NVTE_Mask_Type mask_type = \
NVTE_Mask_Type mask_type = \
static_cast<NVTE_Mask_Type>(get_attr_value<int64_t>(attrs, "mask_type")); \
static_cast<NVTE_Mask_Type>(get_attr_value<int64_t>(attrs, "mask_type")); \
NVTE_Softmax_Type softmax_type = \
static_cast<NVTE_Softmax_Type>(get_attr_value<int64_t>(attrs, "softmax_type")); \
NVTE_QKV_Layout qkv_layout = \
NVTE_QKV_Layout qkv_layout = \
static_cast<NVTE_QKV_Layout>(get_attr_value<int64_t>(attrs, "qkv_layout")); \
static_cast<NVTE_QKV_Layout>(get_attr_value<int64_t>(attrs, "qkv_layout")); \
bool is_training = get_attr_value<bool>(attrs, "is_training"); \
bool is_training = get_attr_value<bool>(attrs, "is_training"); \
...
@@ -342,7 +364,8 @@ static void FusedAttnForwardImpl(
...
@@ -342,7 +364,8 @@ static void FusedAttnForwardImpl(
DType wkspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
DType wkspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());
Error_Type
FusedAttnForwardFFI
(
cudaStream_t
stream
,
Buffer_Type
q_buf
,
Buffer_Type
k_buf
,
Error_Type
FusedAttnForwardFFI
(
cudaStream_t
stream
,
Buffer_Type
q_buf
,
Buffer_Type
k_buf
,
Buffer_Type
v_buf
,
Buffer_Type
bias_buf
,
Buffer_Type
seed_buf
,
Buffer_Type
v_buf
,
Buffer_Type
bias_buf
,
Buffer_Type
softmax_offset_buf
,
Buffer_Type
seed_buf
,
Buffer_Type
q_cu_seqlens_buf
,
Buffer_Type
kv_cu_seqlens_buf
,
Buffer_Type
q_cu_seqlens_buf
,
Buffer_Type
kv_cu_seqlens_buf
,
Buffer_Type
q_seq_offsets_buf
,
Buffer_Type
k_seq_offsets_buf
,
Buffer_Type
q_seq_offsets_buf
,
Buffer_Type
k_seq_offsets_buf
,
Variadic_Buffer_Type
_unused_args
,
Result_Type
output_buf
,
Variadic_Buffer_Type
_unused_args
,
Result_Type
output_buf
,
...
@@ -352,15 +375,15 @@ Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Ty
...
@@ -352,15 +375,15 @@ Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Ty
FusedAttnForwardImpl
(
FusedAttnForwardImpl
(
stream
,
q_buf
.
untyped_data
(),
k_buf
.
untyped_data
(),
v_buf
.
untyped_data
(),
stream
,
q_buf
.
untyped_data
(),
k_buf
.
untyped_data
(),
v_buf
.
untyped_data
(),
bias_buf
.
untyped_data
(),
seed_buf
.
untyped_data
(),
q_cu_seqlens_buf
.
untyped_data
(),
bias_buf
.
untyped_data
(),
softmax_offset_buf
.
untyped_data
(),
seed_buf
.
untyped_data
(),
kv_cu_seqlens_buf
.
untyped_data
(),
is_ragged
?
q_seq_offsets_buf
.
untyped_data
()
:
nullptr
,
q_cu_seqlens_buf
.
untyped_data
(),
kv_cu_seqlens_buf
.
untyped_data
(),
is_ragged
?
q_seq_offsets_buf
.
untyped_data
()
:
nullptr
,
is_ragged
?
k_seq_offsets_buf
.
untyped_data
()
:
nullptr
,
output_buf
->
untyped_data
(),
is_ragged
?
k_seq_offsets_buf
.
untyped_data
()
:
nullptr
,
output_buf
->
untyped_data
(),
softmax_aux_buf
->
untyped_data
(),
rng_state_buf
->
untyped_data
(),
workspace_buf
->
untyped_data
(),
softmax_aux_buf
->
untyped_data
(),
rng_state_buf
->
untyped_data
(),
workspace_buf
->
untyped_data
(),
input_batch
,
bias_batch
,
q_max_seqlen
,
kv_max_seqlen
,
attn_heads
,
num_gqa_groups
,
bias_heads
,
input_batch
,
bias_batch
,
q_max_seqlen
,
kv_max_seqlen
,
attn_heads
,
num_gqa_groups
,
bias_heads
,
qk_head_dim
,
v_head_dim
,
max_segments_per_seq
,
wkspace_size
,
scaling_factor
,
qk_head_dim
,
v_head_dim
,
max_segments_per_seq
,
wkspace_size
,
scaling_factor
,
dropout_probability
,
bias_type
,
mask_type
,
qkv_layout
,
dtype
,
wkspace_dtype
,
is_training
,
dropout_probability
,
bias_type
,
mask_type
,
softmax_type
,
qkv_layout
,
dtype
,
wkspace_dtype
,
deterministic
,
window_size_left
,
window_size_right
);
is_training
,
deterministic
,
window_size_left
,
window_size_right
);
return
ffi_with_cuda_error_check
();
return
ffi_with_cuda_error_check
();
}
}
...
@@ -371,6 +394,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnForwardHandler, FusedAttnForwardFFI,
...
@@ -371,6 +394,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnForwardHandler, FusedAttnForwardFFI,
.
Arg
<
Buffer_Type
>
()
// k
.
Arg
<
Buffer_Type
>
()
// k
.
Arg
<
Buffer_Type
>
()
// v
.
Arg
<
Buffer_Type
>
()
// v
.
Arg
<
Buffer_Type
>
()
// bias
.
Arg
<
Buffer_Type
>
()
// bias
.
Arg
<
Buffer_Type
>
()
// softmax_offset
.
Arg
<
Buffer_Type
>
()
// seed_buf
.
Arg
<
Buffer_Type
>
()
// seed_buf
.
Arg
<
Buffer_Type
>
()
// q_cu_seqlens
.
Arg
<
Buffer_Type
>
()
// q_cu_seqlens
.
Arg
<
Buffer_Type
>
()
// kv_cu_seqlens
.
Arg
<
Buffer_Type
>
()
// kv_cu_seqlens
...
@@ -388,9 +412,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
...
@@ -388,9 +412,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t
input_batch
,
size_t
bias_batch
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
input_batch
,
size_t
bias_batch
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
attn_heads
,
size_t
num_gqa_groups
,
size_t
bias_heads
,
size_t
qk_head_dim
,
size_t
attn_heads
,
size_t
num_gqa_groups
,
size_t
bias_heads
,
size_t
qk_head_dim
,
size_t
v_head_dim
,
float
scaling_factor
,
float
dropout_probability
,
NVTE_Bias_Type
bias_type
,
size_t
v_head_dim
,
float
scaling_factor
,
float
dropout_probability
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_
QKV_Layout
qkv_layout
,
DType
dtype
,
bool
is_training
,
NVTE_Mask_Type
mask_type
,
NVTE_
Softmax_Type
softmax_type
,
NVTE_QKV_Layout
qkv_layout
,
bool
deterministic
,
size_t
max_segments_per_seq
,
int64_t
window_size_left
,
DType
dtype
,
bool
is_training
,
bool
deterministic
,
size_t
max_segments_per_seq
,
int64_t
window_size_right
)
{
int64_t
window_size_left
,
int64_t
window_size_right
)
{
auto
q_shape
=
std
::
vector
<
size_t
>
{
input_batch
*
q_max_seqlen
,
attn_heads
,
qk_head_dim
};
auto
q_shape
=
std
::
vector
<
size_t
>
{
input_batch
*
q_max_seqlen
,
attn_heads
,
qk_head_dim
};
auto
q_tensor
=
TensorWrapper
(
nullptr
,
q_shape
,
dtype
);
auto
q_tensor
=
TensorWrapper
(
nullptr
,
q_shape
,
dtype
);
auto
dq_tensor
=
TensorWrapper
(
nullptr
,
q_shape
,
dtype
);
auto
dq_tensor
=
TensorWrapper
(
nullptr
,
q_shape
,
dtype
);
...
@@ -425,9 +449,14 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
...
@@ -425,9 +449,14 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
// For cuDNN < 9.3.0, it requires to run all possible seqlens to address act_seqlen = 0
// For cuDNN < 9.3.0, it requires to run all possible seqlens to address act_seqlen = 0
min_num_segments
=
input_batch
*
max_segments_per_seq
;
min_num_segments
=
input_batch
*
max_segments_per_seq
;
}
}
auto
dummy_d_softmax_offset_tensor
=
TensorWrapper
(
nullptr
,
std
::
vector
<
size_t
>
{
1
},
DType
::
kFloat32
);
TensorWrapper
dummy_d_softmax_offset_tensor
;
NVTE_Softmax_Type
softmax_type
=
NVTE_Softmax_Type
::
NVTE_VANILLA_SOFTMAX
;
if
(
softmax_type
==
NVTE_Softmax_Type
::
NVTE_OFF_BY_ONE_SOFTMAX
||
softmax_type
==
NVTE_Softmax_Type
::
NVTE_LEARNABLE_SOFTMAX
)
{
dummy_d_softmax_offset_tensor
=
TensorWrapper
(
nullptr
,
std
::
vector
<
size_t
>
{
1
,
attn_heads
,
1
,
1
},
DType
::
kFloat32
);
}
for
(
auto
num_segments
=
min_num_segments
;
num_segments
<=
max_num_segments
;
++
num_segments
)
{
for
(
auto
num_segments
=
min_num_segments
;
num_segments
<=
max_num_segments
;
++
num_segments
)
{
// the last one is the largest which will be the returned workspace size
// the last one is the largest which will be the returned workspace size
auto
q_cu_seqlens_tensor
=
auto
q_cu_seqlens_tensor
=
...
@@ -457,15 +486,16 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
...
@@ -457,15 +486,16 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
}
}
static
void
FusedAttnBackwardImpl
(
static
void
FusedAttnBackwardImpl
(
cudaStream_t
stream
,
void
*
q
,
void
*
k
,
void
*
v
,
void
*
bias
,
void
*
softmax_aux
,
void
*
rng_state
,
cudaStream_t
stream
,
void
*
q
,
void
*
k
,
void
*
v
,
void
*
bias
,
void
*
softmax_offset
,
void
*
output
,
void
*
doutput
,
void
*
q_cu_seqlens
,
void
*
kv_cu_seqlens
,
void
*
q_seq_offsets
,
void
*
softmax_aux
,
void
*
rng_state
,
void
*
output
,
void
*
doutput
,
void
*
q_cu_seqlens
,
void
*
k_seq_offsets
,
void
*
dq
,
void
*
dk
,
void
*
dv
,
void
*
dbias
,
void
*
workspace
,
void
*
kv_cu_seqlens
,
void
*
q_seq_offsets
,
void
*
k_seq_offsets
,
void
*
dq
,
void
*
dk
,
void
*
dv
,
size_t
input_batch
,
size_t
bias_batch
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
void
*
dbias
,
void
*
dsoftmax_offset
,
void
*
workspace
,
size_t
input_batch
,
size_t
bias_batch
,
size_t
attn_heads
,
size_t
num_gqa_groups
,
size_t
bias_heads
,
size_t
qk_head_dim
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
attn_heads
,
size_t
num_gqa_groups
,
size_t
v_head_dim
,
size_t
max_segments_per_seq
,
size_t
wkspace_size
,
float
scaling_factor
,
size_t
bias_heads
,
size_t
qk_head_dim
,
size_t
v_head_dim
,
size_t
max_segments_per_seq
,
float
dropout_probability
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
size_t
wkspace_size
,
float
scaling_factor
,
float
dropout_probability
,
NVTE_Bias_Type
bias_type
,
NVTE_QKV_Layout
qkv_layout
,
DType
dtype
,
DType
wkspace_dtype
,
bool
is_training
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
NVTE_QKV_Layout
qkv_layout
,
bool
deterministic
,
int64_t
window_size_left
,
int64_t
window_size_right
)
{
DType
dtype
,
DType
wkspace_dtype
,
bool
is_training
,
bool
deterministic
,
int64_t
window_size_left
,
int64_t
window_size_right
)
{
FUSED_ATTN_IMPL_COMMON_BLOCK
;
FUSED_ATTN_IMPL_COMMON_BLOCK
;
/* Input tensors */
/* Input tensors */
...
@@ -476,9 +506,13 @@ static void FusedAttnBackwardImpl(
...
@@ -476,9 +506,13 @@ static void FusedAttnBackwardImpl(
/* Output tensors */
/* Output tensors */
auto
s_tensor
=
TensorWrapper
(
nullptr
,
std
::
vector
<
size_t
>
{
1
},
dtype
);
// not used in F16
auto
s_tensor
=
TensorWrapper
(
nullptr
,
std
::
vector
<
size_t
>
{
1
},
dtype
);
// not used in F16
auto
dbias_tensor
=
TensorWrapper
(
dbias
,
bias_shape
,
dtype
);
auto
dbias_tensor
=
TensorWrapper
(
dbias
,
bias_shape
,
dtype
);
auto
dummy_d_softmax_offset_tensor
=
TensorWrapper
(
nullptr
,
std
::
vector
<
size_t
>
{
1
},
DType
::
kFloat32
);
TensorWrapper
dsoftmax_offset_tensor
;
NVTE_Softmax_Type
softmax_type
=
NVTE_Softmax_Type
::
NVTE_VANILLA_SOFTMAX
;
if
(
softmax_type
==
NVTE_Softmax_Type
::
NVTE_OFF_BY_ONE_SOFTMAX
||
softmax_type
==
NVTE_Softmax_Type
::
NVTE_LEARNABLE_SOFTMAX
)
{
dsoftmax_offset_tensor
=
TensorWrapper
(
dsoftmax_offset
,
std
::
vector
<
size_t
>
{
1
,
attn_heads
,
1
,
1
},
DType
::
kFloat32
);
}
/* Auxiliary tensors (propagated from the forward pass) */
/* Auxiliary tensors (propagated from the forward pass) */
NVTETensorPack
aux_input_tensors
;
NVTETensorPack
aux_input_tensors
;
...
@@ -490,7 +524,7 @@ static void FusedAttnBackwardImpl(
...
@@ -490,7 +524,7 @@ static void FusedAttnBackwardImpl(
false
,
false
);
false
,
false
);
PrepareFusedAttnBackwardAuxTensors
(
&
aux_input_tensors
,
input_batch
,
bias_batch
,
attn_heads
,
PrepareFusedAttnBackwardAuxTensors
(
&
aux_input_tensors
,
input_batch
,
bias_batch
,
attn_heads
,
bias_heads
,
q_max_seqlen
,
kv_max_seqlen
,
dtype
,
backend
,
bias_heads
,
q_max_seqlen
,
kv_max_seqlen
,
dtype
,
backend
,
softmax_aux
,
rng_state
,
bias
);
softmax_aux
,
rng_state
,
bias
,
softmax_offset
);
/* Call the underly NVTE API */
/* Call the underly NVTE API */
// Prepare Q, K, V pointers and shapes based on layout
// Prepare Q, K, V pointers and shapes based on layout
...
@@ -564,7 +598,7 @@ static void FusedAttnBackwardImpl(
...
@@ -564,7 +598,7 @@ static void FusedAttnBackwardImpl(
s_tensor
.
data
(),
// not used for F16
s_tensor
.
data
(),
// not used for F16
s_tensor
.
data
(),
// not used for F16
s_tensor
.
data
(),
// not used for F16
&
aux_input_tensors
,
dq_tensor
.
data
(),
dk_tensor
.
data
(),
dv_tensor
.
data
(),
dbias_tensor
.
data
(),
&
aux_input_tensors
,
dq_tensor
.
data
(),
dk_tensor
.
data
(),
dv_tensor
.
data
(),
dbias_tensor
.
data
(),
d
ummy_d_
softmax_offset_tensor
.
data
(),
q_cu_seqlens_tensor
.
data
(),
kv_cu_seqlens_tensor
.
data
(),
dsoftmax_offset_tensor
.
data
(),
q_cu_seqlens_tensor
.
data
(),
kv_cu_seqlens_tensor
.
data
(),
q_seq_offsets_tensor
.
data
(),
k_seq_offsets_tensor
.
data
(),
q_max_seqlen
,
kv_max_seqlen
,
q_seq_offsets_tensor
.
data
(),
k_seq_offsets_tensor
.
data
(),
q_max_seqlen
,
kv_max_seqlen
,
scaling_factor
,
dropout_probability
,
qkv_layout
,
bias_type
,
mask_type
,
softmax_type
,
scaling_factor
,
dropout_probability
,
qkv_layout
,
bias_type
,
mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
deterministic
,
false
,
workspace_tensor
.
data
(),
stream
);
window_size_left
,
window_size_right
,
deterministic
,
false
,
workspace_tensor
.
data
(),
stream
);
...
@@ -574,26 +608,29 @@ static void FusedAttnBackwardImpl(
...
@@ -574,26 +608,29 @@ static void FusedAttnBackwardImpl(
Error_Type
FusedAttnBackwardFFI
(
cudaStream_t
stream
,
Buffer_Type
q_buf
,
Buffer_Type
k_buf
,
Error_Type
FusedAttnBackwardFFI
(
cudaStream_t
stream
,
Buffer_Type
q_buf
,
Buffer_Type
k_buf
,
Buffer_Type
v_buf
,
Buffer_Type
bias_buf
,
Buffer_Type
v_buf
,
Buffer_Type
bias_buf
,
Buffer_Type
softmax_aux_buf
,
Buffer_Type
rng_state_buf
,
Buffer_Type
softmax_offset_buf
,
Buffer_Type
softmax_aux_buf
,
Buffer_Type
output_buf
,
Buffer_Type
doutput_buf
,
Buffer_Type
rng_state_buf
,
Buffer_Type
output_buf
,
Buffer_Type
q_cu_seqlens_buf
,
Buffer_Type
kv_cu_seqlens_buf
,
Buffer_Type
doutput_buf
,
Buffer_Type
q_cu_seqlens_buf
,
Buffer_Type
q_seq_offsets_buf
,
Buffer_Type
k_seq_offsets_buf
,
Buffer_Type
kv_cu_seqlens_buf
,
Buffer_Type
q_seq_offsets_buf
,
Variadic_Buffer_Type
_unused_args
,
Result_Type
dq_buf
,
Buffer_Type
k_seq_offsets_buf
,
Variadic_Buffer_Type
_unused_args
,
Result_Type
dk_buf
,
Result_Type
dv_buf
,
Result_Type
dbias_buf
,
Result_Type
dq_buf
,
Result_Type
dk_buf
,
Result_Type
dv_buf
,
Result_Type
dbias_buf
,
Result_Type
dsoftmax_offset_buf
,
Result_Type
workspace_buf
,
Dictionary
attrs
)
{
Result_Type
workspace_buf
,
Dictionary
attrs
)
{
FUSED_ATTN_FFI_GET_ATTRS
;
FUSED_ATTN_FFI_GET_ATTRS
;
FusedAttnBackwardImpl
(
FusedAttnBackwardImpl
(
stream
,
q_buf
.
untyped_data
(),
k_buf
.
untyped_data
(),
v_buf
.
untyped_data
(),
stream
,
q_buf
.
untyped_data
(),
k_buf
.
untyped_data
(),
v_buf
.
untyped_data
(),
bias_buf
.
untyped_data
(),
softmax_aux_buf
.
untyped_data
(),
rng_state_buf
.
untyped_data
(),
bias_buf
.
untyped_data
(),
softmax_offset_buf
.
untyped_data
(),
softmax_aux_buf
.
untyped_data
(),
output_buf
.
untyped_data
(),
doutput_buf
.
untyped_data
(),
q_cu_seqlens_buf
.
untyped_data
(),
rng_state_buf
.
untyped_data
(),
output_buf
.
untyped_data
(),
doutput_buf
.
untyped_data
(),
kv_cu_seqlens_buf
.
untyped_data
(),
is_ragged
?
q_seq_offsets_buf
.
untyped_data
()
:
nullptr
,
q_cu_seqlens_buf
.
untyped_data
(),
kv_cu_seqlens_buf
.
untyped_data
(),
is_ragged
?
q_seq_offsets_buf
.
untyped_data
()
:
nullptr
,
is_ragged
?
k_seq_offsets_buf
.
untyped_data
()
:
nullptr
,
dq_buf
->
untyped_data
(),
is_ragged
?
k_seq_offsets_buf
.
untyped_data
()
:
nullptr
,
dq_buf
->
untyped_data
(),
dk_buf
->
untyped_data
(),
dv_buf
->
untyped_data
(),
dbias_buf
->
untyped_data
(),
dk_buf
->
untyped_data
(),
dv_buf
->
untyped_data
(),
dbias_buf
->
untyped_data
(),
workspace_buf
->
untyped_data
(),
input_batch
,
bias_batch
,
q_max_seqlen
,
kv_max_seqlen
,
dsoftmax_offset_buf
->
untyped_data
(),
workspace_buf
->
untyped_data
(),
input_batch
,
bias_batch
,
attn_heads
,
num_gqa_groups
,
bias_heads
,
qk_head_dim
,
v_head_dim
,
max_segments_per_seq
,
q_max_seqlen
,
kv_max_seqlen
,
attn_heads
,
num_gqa_groups
,
bias_heads
,
qk_head_dim
,
v_head_dim
,
wkspace_size
,
scaling_factor
,
dropout_probability
,
bias_type
,
mask_type
,
qkv_layout
,
dtype
,
max_segments_per_seq
,
wkspace_size
,
scaling_factor
,
dropout_probability
,
bias_type
,
mask_type
,
wkspace_dtype
,
is_training
,
deterministic
,
window_size_left
,
window_size_right
);
softmax_type
,
qkv_layout
,
dtype
,
wkspace_dtype
,
is_training
,
deterministic
,
window_size_left
,
window_size_right
);
return
ffi_with_cuda_error_check
();
return
ffi_with_cuda_error_check
();
}
}
...
@@ -605,6 +642,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnBackwardHandler, FusedAttnBackwardFFI,
...
@@ -605,6 +642,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnBackwardHandler, FusedAttnBackwardFFI,
.
Arg
<
Buffer_Type
>
()
// k
.
Arg
<
Buffer_Type
>
()
// k
.
Arg
<
Buffer_Type
>
()
// v
.
Arg
<
Buffer_Type
>
()
// v
.
Arg
<
Buffer_Type
>
()
// bias
.
Arg
<
Buffer_Type
>
()
// bias
.
Arg
<
Buffer_Type
>
()
// softmax_offset
.
Arg
<
Buffer_Type
>
()
// softmax_aux
.
Arg
<
Buffer_Type
>
()
// softmax_aux
.
Arg
<
Buffer_Type
>
()
// rng_state
.
Arg
<
Buffer_Type
>
()
// rng_state
.
Arg
<
Buffer_Type
>
()
// output
.
Arg
<
Buffer_Type
>
()
// output
...
@@ -618,6 +656,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnBackwardHandler, FusedAttnBackwardFFI,
...
@@ -618,6 +656,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnBackwardHandler, FusedAttnBackwardFFI,
.
Ret
<
Buffer_Type
>
()
// dk
.
Ret
<
Buffer_Type
>
()
// dk
.
Ret
<
Buffer_Type
>
()
// dv
.
Ret
<
Buffer_Type
>
()
// dv
.
Ret
<
Buffer_Type
>
()
// dbias
.
Ret
<
Buffer_Type
>
()
// dbias
.
Ret
<
Buffer_Type
>
()
// dsoftmax_offset
.
Ret
<
Buffer_Type
>
()
// workspace
.
Ret
<
Buffer_Type
>
()
// workspace
.
Attrs
(),
.
Attrs
(),
FFI_CudaGraph_Traits
);
FFI_CudaGraph_Traits
);
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment