Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
970620a5
Commit
970620a5
authored
Dec 27, 2025
by
wenjh
Browse files
merge nv_release_v2.10 to release_v2.10
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parents
c1a1c04e
769ed778
Changes
135
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
648 additions
and
240 deletions
+648
-240
tests/pytorch/distributed/test_comm_gemm_overlap.py
tests/pytorch/distributed/test_comm_gemm_overlap.py
+6
-0
tests/pytorch/distributed/test_sanity.py
tests/pytorch/distributed/test_sanity.py
+73
-13
tests/pytorch/test_fusible_ops.py
tests/pytorch/test_fusible_ops.py
+5
-5
tests/pytorch/test_numerics.py
tests/pytorch/test_numerics.py
+0
-7
tests/pytorch/test_sanity.py
tests/pytorch/test_sanity.py
+1
-3
tests/pytorch/utils.py
tests/pytorch/utils.py
+10
-0
transformer_engine/common/fused_attn/kv_cache.cu
transformer_engine/common/fused_attn/kv_cache.cu
+1
-1
transformer_engine/common/include/transformer_engine/fused_attn.h
...mer_engine/common/include/transformer_engine/fused_attn.h
+1
-1
transformer_engine/common/recipe/__init__.py
transformer_engine/common/recipe/__init__.py
+20
-18
transformer_engine/common/transformer_engine.cpp
transformer_engine/common/transformer_engine.cpp
+12
-7
transformer_engine/jax/attention.py
transformer_engine/jax/attention.py
+75
-18
transformer_engine/jax/cpp_extensions/activation.py
transformer_engine/jax/cpp_extensions/activation.py
+1
-1
transformer_engine/jax/cpp_extensions/attention.py
transformer_engine/jax/cpp_extensions/attention.py
+236
-56
transformer_engine/jax/cpp_extensions/gemm.py
transformer_engine/jax/cpp_extensions/gemm.py
+1
-1
transformer_engine/jax/cpp_extensions/misc.py
transformer_engine/jax/cpp_extensions/misc.py
+1
-1
transformer_engine/jax/cpp_extensions/normalization.py
transformer_engine/jax/cpp_extensions/normalization.py
+1
-1
transformer_engine/jax/cpp_extensions/quantization.py
transformer_engine/jax/cpp_extensions/quantization.py
+19
-17
transformer_engine/jax/cpp_extensions/softmax.py
transformer_engine/jax/cpp_extensions/softmax.py
+65
-9
transformer_engine/jax/csrc/extensions.h
transformer_engine/jax/csrc/extensions.h
+12
-12
transformer_engine/jax/csrc/extensions/attention.cpp
transformer_engine/jax/csrc/extensions/attention.cpp
+108
-69
No files found.
tests/pytorch/distributed/test_comm_gemm_overlap.py
View file @
970620a5
...
@@ -127,12 +127,18 @@ def _run_layer_with_overlap(
...
@@ -127,12 +127,18 @@ def _run_layer_with_overlap(
os
.
environ
[
"PYTORCH_JIT"
]
=
"0"
os
.
environ
[
"PYTORCH_JIT"
]
=
"0"
os
.
environ
[
"NVTE_TORCH_COMPILE"
]
=
"0"
os
.
environ
[
"NVTE_TORCH_COMPILE"
]
=
"0"
os
.
environ
[
"NVTE_ALLOW_NONDETERMINISTIC_ALGO"
]
=
"0"
os
.
environ
[
"NVTE_ALLOW_NONDETERMINISTIC_ALGO"
]
=
"0"
if
te
.
get_device_compute_capability
()
<=
(
8
,
0
):
# We've experienced numerical discrepancies in Flash Attention
# backward when running with Userbuffers on A100s. This does
# not show up in more recent GPUs.
os
.
environ
[
"NVTE_FLASH_ATTN"
]
=
"0"
result
=
subprocess
.
run
(
test_cmd
,
env
=
os
.
environ
,
capture_output
=
True
,
check
=
False
)
result
=
subprocess
.
run
(
test_cmd
,
env
=
os
.
environ
,
capture_output
=
True
,
check
=
False
)
os
.
unsetenv
(
"PYTORCH_JIT"
)
os
.
unsetenv
(
"PYTORCH_JIT"
)
os
.
unsetenv
(
"NVTE_TORCH_COMPILE"
)
os
.
unsetenv
(
"NVTE_TORCH_COMPILE"
)
os
.
unsetenv
(
"NVTE_ALLOW_NONDETERMINISTIC_ALGO"
)
os
.
unsetenv
(
"NVTE_ALLOW_NONDETERMINISTIC_ALGO"
)
os
.
unsetenv
(
"NVTE_FLASH_ATTN"
)
if
(
if
(
result
.
returncode
!=
0
result
.
returncode
!=
0
...
...
tests/pytorch/distributed/test_sanity.py
View file @
970620a5
...
@@ -7,7 +7,7 @@ import sys
...
@@ -7,7 +7,7 @@ import sys
import
pytest
import
pytest
import
torch
import
torch
import
transformer_engine
import
transformer_engine
from
transformer_engine.pytorch
import
DotProductAttention
,
TransformerLayer
,
Linear
from
transformer_engine.pytorch
import
DotProductAttention
,
TransformerLayer
,
Linear
,
GroupedLinear
_current_file
=
pathlib
.
Path
(
__file__
).
resolve
()
_current_file
=
pathlib
.
Path
(
__file__
).
resolve
()
sys
.
path
.
append
(
str
(
_current_file
.
parent
.
parent
))
sys
.
path
.
append
(
str
(
_current_file
.
parent
.
parent
))
...
@@ -19,7 +19,9 @@ model_configs = {
...
@@ -19,7 +19,9 @@ model_configs = {
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"small"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"small"
])
@
pytest
.
mark
.
parametrize
(
"module"
,
[
"TransformerLayer"
,
"DotProductAttention"
,
"Linear"
])
@
pytest
.
mark
.
parametrize
(
"module"
,
[
"TransformerLayer"
,
"DotProductAttention"
,
"Linear"
,
"GroupedLinear"
]
)
def
test_current_device
(
model
,
module
):
def
test_current_device
(
model
,
module
):
"""Test cases where current device is different from tensor device"""
"""Test cases where current device is different from tensor device"""
...
@@ -42,7 +44,29 @@ def test_current_device(model, module):
...
@@ -42,7 +44,29 @@ def test_current_device(model, module):
self_attn_mask_type
=
"padding"
,
self_attn_mask_type
=
"padding"
,
device
=
f
"cuda:
{
tensor_device
}
"
,
device
=
f
"cuda:
{
tensor_device
}
"
,
)
)
num_tokens
=
torch
.
randint
(
0
,
config
.
max_seqlen_q
,
(
1
,)).
item
()
seqlens_q
=
torch
.
randint
(
1
,
config
.
max_seqlen_q
,
[
config
.
batch_size
],
dtype
=
torch
.
int32
,
device
=
f
"cuda:
{
tensor_device
}
"
,
)
cu_seqlens_q
=
torch
.
zeros
(
config
.
batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
f
"cuda:
{
tensor_device
}
"
)
cu_seqlens_q
[
1
:]
=
torch
.
cumsum
(
seqlens_q
,
dim
=
0
)
seqlens_kv
=
torch
.
randint
(
1
,
config
.
max_seqlen_kv
,
[
config
.
batch_size
],
dtype
=
torch
.
int32
,
device
=
f
"cuda:
{
tensor_device
}
"
,
)
cu_seqlens_kv
=
torch
.
zeros
(
config
.
batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
f
"cuda:
{
tensor_device
}
"
)
cu_seqlens_kv
[
1
:]
=
torch
.
cumsum
(
seqlens_kv
,
dim
=
0
)
num_tokens
=
cu_seqlens_q
[
-
1
]
args
=
[
args
=
[
torch
.
randn
(
torch
.
randn
(
(
num_tokens
,
config
.
hidden_size
),
(
num_tokens
,
config
.
hidden_size
),
...
@@ -51,37 +75,55 @@ def test_current_device(model, module):
...
@@ -51,37 +75,55 @@ def test_current_device(model, module):
requires_grad
=
True
,
requires_grad
=
True
,
)
)
]
]
cu_seqlens_q
,
cu_seqlens_kv
=
[
torch
.
Tensor
([
0
,
2
,
3
]).
to
(
dtype
=
torch
.
int32
,
device
=
tensor_device
)
for
_
in
range
(
2
)
]
kwargs
[
"cu_seqlens_q"
]
=
cu_seqlens_q
kwargs
[
"cu_seqlens_q"
]
=
cu_seqlens_q
kwargs
[
"cu_seqlens_kv"
]
=
cu_seqlens_kv
kwargs
[
"cu_seqlens_kv"
]
=
cu_seqlens_kv
kwargs
[
"max_seqlen_q"
]
=
config
.
max_seqlen_q
kwargs
[
"max_seqlen_q"
]
=
config
.
max_seqlen_q
kwargs
[
"max_seqlen_kv"
]
=
config
.
max_seqlen_kv
kwargs
[
"max_seqlen_kv"
]
=
config
.
max_seqlen_kv
if
module
==
"DotProductAttention"
:
el
if
module
==
"DotProductAttention"
:
model
=
DotProductAttention
(
model
=
DotProductAttention
(
config
.
num_heads
,
config
.
head_dim_qk
,
qkv_format
=
"thd"
,
attn_mask_type
=
"padding"
config
.
num_heads
,
config
.
head_dim_qk
,
qkv_format
=
"thd"
,
attn_mask_type
=
"padding"
)
)
num_tokens
=
torch
.
randint
(
0
,
config
.
max_seqlen_q
,
(
1
,)).
item
()
seqlens_q
=
torch
.
randint
(
1
,
config
.
max_seqlen_q
,
[
config
.
batch_size
],
dtype
=
torch
.
int32
,
device
=
f
"cuda:
{
tensor_device
}
"
,
)
cu_seqlens_q
=
torch
.
zeros
(
config
.
batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
f
"cuda:
{
tensor_device
}
"
)
cu_seqlens_q
[
1
:]
=
torch
.
cumsum
(
seqlens_q
,
dim
=
0
)
seqlens_kv
=
torch
.
randint
(
1
,
config
.
max_seqlen_kv
,
[
config
.
batch_size
],
dtype
=
torch
.
int32
,
device
=
f
"cuda:
{
tensor_device
}
"
,
)
cu_seqlens_kv
=
torch
.
zeros
(
config
.
batch_size
+
1
,
dtype
=
torch
.
int32
,
device
=
f
"cuda:
{
tensor_device
}
"
)
cu_seqlens_kv
[
1
:]
=
torch
.
cumsum
(
seqlens_kv
,
dim
=
0
)
num_tokens
=
cu_seqlens_q
[
-
1
]
args
=
[
args
=
[
torch
.
randn
(
torch
.
randn
(
num_tokens
,
num_tokens
,
config
.
num_heads
,
config
.
num_heads
,
config
.
head_dim_qk
,
config
.
head_dim_qk
,
dtype
=
dtype
,
dtype
=
dtype
,
device
=
tensor_device
,
device
=
f
"cuda:
{
tensor_device
}
"
,
requires_grad
=
True
,
requires_grad
=
True
,
)
)
for
_
in
range
(
3
)
for
_
in
range
(
3
)
]
]
cu_seqlens_q
,
cu_seqlens_kv
=
[
torch
.
Tensor
([
0
,
2
,
3
]).
to
(
dtype
=
torch
.
int32
,
device
=
tensor_device
)
for
_
in
range
(
2
)
]
kwargs
[
"cu_seqlens_q"
]
=
cu_seqlens_q
kwargs
[
"cu_seqlens_q"
]
=
cu_seqlens_q
kwargs
[
"cu_seqlens_kv"
]
=
cu_seqlens_kv
kwargs
[
"cu_seqlens_kv"
]
=
cu_seqlens_kv
kwargs
[
"max_seqlen_q"
]
=
config
.
max_seqlen_q
kwargs
[
"max_seqlen_q"
]
=
config
.
max_seqlen_q
kwargs
[
"max_seqlen_kv"
]
=
config
.
max_seqlen_kv
kwargs
[
"max_seqlen_kv"
]
=
config
.
max_seqlen_kv
bwd_args
=
[
torch
.
randn
(
num_tokens
,
config
.
hidden_size
,
dtype
=
dtype
,
device
=
tensor_device
)]
bwd_args
=
[
torch
.
randn
(
num_tokens
,
config
.
hidden_size
,
dtype
=
dtype
,
device
=
f
"cuda:
{
tensor_device
}
"
)
]
elif
module
==
"Linear"
:
elif
module
==
"Linear"
:
model
=
Linear
(
model
=
Linear
(
config
.
hidden_size
,
config
.
hidden_size
,
...
@@ -97,6 +139,24 @@ def test_current_device(model, module):
...
@@ -97,6 +139,24 @@ def test_current_device(model, module):
requires_grad
=
True
,
requires_grad
=
True
,
)
)
]
]
elif
module
==
"GroupedLinear"
:
num_gemms
=
4
model
=
GroupedLinear
(
num_gemms
,
config
.
hidden_size
,
4
*
config
.
hidden_size
,
params_dtype
=
dtype
,
device
=
f
"cuda:
{
tensor_device
}
"
,
)
args
=
[
torch
.
randn
(
(
config
.
max_seqlen_q
*
config
.
batch_size
*
(
num_gemms
-
1
),
config
.
hidden_size
),
dtype
=
dtype
,
device
=
f
"cuda:
{
tensor_device
}
"
,
requires_grad
=
True
,
),
[
0
]
+
[
config
.
max_seqlen_q
*
config
.
batch_size
]
*
(
num_gemms
-
1
),
# Empty first split.
]
current_device_before
=
torch
.
cuda
.
current_device
()
current_device_before
=
torch
.
cuda
.
current_device
()
out
=
model
(
*
args
,
**
kwargs
)
out
=
model
(
*
args
,
**
kwargs
)
...
...
tests/pytorch/test_fusible_ops.py
View file @
970620a5
...
@@ -913,15 +913,15 @@ class TestBasicOps:
...
@@ -913,15 +913,15 @@ class TestBasicOps:
dtype
=
dtype
,
dtype
=
dtype
,
accumulate_into_main_grad
=
accumulate_into_main_grad
,
accumulate_into_main_grad
=
accumulate_into_main_grad
,
)
)
with
torch
.
no_grad
():
op
.
weight
.
copy_
(
w_test
)
del
w_test
op
.
weight
.
main_grad
=
torch
.
full_like
(
op
.
weight
,
0.5
,
dtype
=
torch
.
float32
)
forward
=
te_ops
.
Sequential
(
forward
=
te_ops
.
Sequential
(
te_ops
.
Quantize
(
forward
=
quantized_input
,
backward
=
quantized_grad_input
),
te_ops
.
Quantize
(
forward
=
quantized_input
,
backward
=
quantized_grad_input
),
op
,
op
,
te_ops
.
Quantize
(
forward
=
quantized_output
,
backward
=
quantized_grad_output
),
te_ops
.
Quantize
(
forward
=
quantized_output
,
backward
=
quantized_grad_output
),
)
)
with
torch
.
no_grad
():
op
.
weight
.
copy_
(
w_test
)
del
w_test
op
.
weight
.
main_grad
=
torch
.
full_like
(
op
.
weight
,
0.5
,
dtype
=
torch
.
float32
)
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
y_test
=
forward
(
x_test
)
y_test
=
forward
(
x_test
)
y_test
.
backward
(
dy_test
)
y_test
.
backward
(
dy_test
)
...
...
tests/pytorch/test_numerics.py
View file @
970620a5
...
@@ -46,7 +46,6 @@ from transformer_engine.pytorch import (
...
@@ -46,7 +46,6 @@ from transformer_engine.pytorch import (
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch
import
checkpoint
as
te_checkpoint
from
transformer_engine.pytorch
import
checkpoint
as
te_checkpoint
from
transformer_engine.pytorch.cpp_extensions
import
general_gemm
,
general_grouped_gemm
from
transformer_engine.pytorch.cpp_extensions
import
general_gemm
,
general_grouped_gemm
from
transformer_engine.pytorch.module.base
import
get_multi_stream_cublas_workspace
,
get_workspace
from
transformer_engine.common
import
recipe
from
transformer_engine.common
import
recipe
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
utils
import
ModelConfig
,
reset_rng_states
from
utils
import
ModelConfig
,
reset_rng_states
...
@@ -2757,7 +2756,6 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
...
@@ -2757,7 +2756,6 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
general_gemm
(
general_gemm
(
A
[
i
],
A
[
i
],
B
[
i
],
B
[
i
],
get_workspace
(),
dtype
,
dtype
,
grad
=
grad
,
grad
=
grad
,
accumulate
=
accumulate
,
accumulate
=
accumulate
,
...
@@ -2772,7 +2770,6 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
...
@@ -2772,7 +2770,6 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
B
,
B
,
out
,
out
,
dtype
,
dtype
,
get_multi_stream_cublas_workspace
(),
m_splits
=
m_splits
,
m_splits
=
m_splits
,
grad
=
grad
,
grad
=
grad
,
accumulate
=
accumulate
,
accumulate
=
accumulate
,
...
@@ -2832,7 +2829,6 @@ def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_qua
...
@@ -2832,7 +2829,6 @@ def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_qua
quantized_out
,
*
_
=
general_gemm
(
quantized_out
,
*
_
=
general_gemm
(
weight_fp8
,
weight_fp8
,
inp_fp8
,
inp_fp8
,
get_workspace
(),
outp_type
,
outp_type
,
quantization_params
=
out_quantizer
,
quantization_params
=
out_quantizer
,
bias
=
None
,
bias
=
None
,
...
@@ -2842,7 +2838,6 @@ def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_qua
...
@@ -2842,7 +2838,6 @@ def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_qua
out
,
*
_
=
general_gemm
(
out
,
*
_
=
general_gemm
(
weight_fp8
,
weight_fp8
,
inp_fp8
,
inp_fp8
,
get_workspace
(),
outp_type
,
outp_type
,
quantization_params
=
None
,
quantization_params
=
None
,
bias
=
None
,
bias
=
None
,
...
@@ -2918,7 +2913,6 @@ def test_fp8_grouped_gemm(shape, accumulate):
...
@@ -2918,7 +2913,6 @@ def test_fp8_grouped_gemm(shape, accumulate):
general_gemm
(
general_gemm
(
A_fp8
[
i
],
A_fp8
[
i
],
B_fp8
[
i
],
B_fp8
[
i
],
get_workspace
(),
dtype
,
dtype
,
out
=
out_ref
[
i
],
out
=
out_ref
[
i
],
accumulate
=
accumulate
,
accumulate
=
accumulate
,
...
@@ -2928,7 +2922,6 @@ def test_fp8_grouped_gemm(shape, accumulate):
...
@@ -2928,7 +2922,6 @@ def test_fp8_grouped_gemm(shape, accumulate):
B_fp8
,
B_fp8
,
out
,
out
,
dtype
,
dtype
,
get_multi_stream_cublas_workspace
(),
m_splits
=
m_splits
,
m_splits
=
m_splits
,
accumulate
=
accumulate
,
accumulate
=
accumulate
,
)
)
...
...
tests/pytorch/test_sanity.py
View file @
970620a5
...
@@ -37,7 +37,6 @@ from transformer_engine.pytorch import (
...
@@ -37,7 +37,6 @@ from transformer_engine.pytorch import (
from
transformer_engine.common
import
recipe
from
transformer_engine.common
import
recipe
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.cpp_extensions
import
general_gemm
from
transformer_engine.pytorch.cpp_extensions
import
general_gemm
from
transformer_engine.pytorch.module.base
import
get_workspace
from
transformer_engine.pytorch.tensor.utils
import
replace_raw_data
from
transformer_engine.pytorch.tensor.utils
import
replace_raw_data
from
utils
import
ModelConfig
from
utils
import
ModelConfig
...
@@ -961,7 +960,7 @@ def test_sanity_gemm_with_unalignment(N, offset, datatype):
...
@@ -961,7 +960,7 @@ def test_sanity_gemm_with_unalignment(N, offset, datatype):
inp
=
torch
.
reshape
(
scratchpad
[
offset
:
-
offset
],
(
N
,
N
))
inp
=
torch
.
reshape
(
scratchpad
[
offset
:
-
offset
],
(
N
,
N
))
weight
=
torch
.
reshape
(
scratchpad
[
offset
*
2
:],
(
N
,
N
))
weight
=
torch
.
reshape
(
scratchpad
[
offset
*
2
:],
(
N
,
N
))
_
=
general_gemm
(
A
=
weight
,
B
=
inp
,
workspace
=
get_workspace
()
)
_
=
general_gemm
(
A
=
weight
,
B
=
inp
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
...
@@ -985,7 +984,6 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype):
...
@@ -985,7 +984,6 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype):
general_gemm
(
general_gemm
(
weight_fp8
,
weight_fp8
,
inp_fp8
,
inp_fp8
,
get_workspace
(),
outp_type
,
outp_type
,
bias
=
None
,
bias
=
None
,
use_split_accumulator
=
False
,
use_split_accumulator
=
False
,
...
...
tests/pytorch/utils.py
View file @
970620a5
...
@@ -8,6 +8,7 @@ import logging
...
@@ -8,6 +8,7 @@ import logging
import
os
import
os
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
Optional
,
Tuple
,
Dict
,
Any
,
List
from
typing
import
Optional
,
Tuple
,
Dict
,
Any
,
List
from
packaging.version
import
Version
as
PkgVersion
import
torch
import
torch
...
@@ -210,6 +211,7 @@ class ModelConfig:
...
@@ -210,6 +211,7 @@ class ModelConfig:
max_ctx_len
:
int
=
None
,
max_ctx_len
:
int
=
None
,
num_layers
:
int
=
1
,
num_layers
:
int
=
1
,
eps
:
float
=
1e-5
,
eps
:
float
=
1e-5
,
num_splits
=
1
,
):
):
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
self
.
max_seqlen_q
=
max_seqlen_q
self
.
max_seqlen_q
=
max_seqlen_q
...
@@ -239,6 +241,7 @@ class ModelConfig:
...
@@ -239,6 +241,7 @@ class ModelConfig:
self
.
max_ctx_len
=
max_ctx_len
self
.
max_ctx_len
=
max_ctx_len
self
.
num_layers
=
num_layers
self
.
num_layers
=
num_layers
self
.
eps
=
eps
self
.
eps
=
eps
self
.
num_splits
=
num_splits
@
contextmanager
@
contextmanager
...
@@ -321,6 +324,9 @@ def get_available_attention_backends(
...
@@ -321,6 +324,9 @@ def get_available_attention_backends(
inference_params
=
inference_params
,
inference_params
=
inference_params
,
softmax_type
=
config
.
softmax_type
,
softmax_type
=
config
.
softmax_type
,
return_max_logit
=
config
.
return_max_logit
,
return_max_logit
=
config
.
return_max_logit
,
# allow all backends to pass so they can be used for testing;
# check for FA3 availability later
num_splits
=
1
,
)
)
(
(
use_flash_attention
,
use_flash_attention
,
...
@@ -330,6 +336,10 @@ def get_available_attention_backends(
...
@@ -330,6 +336,10 @@ def get_available_attention_backends(
use_unfused_attention
,
use_unfused_attention
,
available_backends
,
available_backends
,
)
=
get_attention_backend
(
attention_params
)
)
=
get_attention_backend
(
attention_params
)
# Check if FA3 is an available backend when num_splits != 1
if
available_backends
[
0
]:
if
config
.
num_splits
!=
1
and
not
flash_attention_backend
>
PkgVersion
(
"3.0.0b"
):
available_backends
[
0
]
=
False
# Set attention.py _attention_backends var using return value
# Set attention.py _attention_backends var using return value
# from get_attention_backend()
# from get_attention_backend()
_attention_backends
[
"use_flash_attention"
]
=
use_flash_attention
_attention_backends
[
"use_flash_attention"
]
=
use_flash_attention
...
...
transformer_engine/common/fused_attn/kv_cache.cu
View file @
970620a5
...
@@ -278,7 +278,7 @@ void convert_bshd_to_thd(Tensor tensor, Tensor cu_seqlens, Tensor new_tensor, in
...
@@ -278,7 +278,7 @@ void convert_bshd_to_thd(Tensor tensor, Tensor cu_seqlens, Tensor new_tensor, in
/***************************************************************************************************
/***************************************************************************************************
* KV Cache: Copy new KV tokens to the KV cache
* KV Cache: Copy new KV tokens to the KV cache
* 1. new_k and new_v are in qkv_format; k_cache and v_cache are in 'bshd' format
* 1. new_k and new_v are in qkv_format; k_cache and v_cache are in 'bshd' format
* 2. cu_new_lens and cu_cached_lens are
in
shape [b + 1]; cu_cached_lens include the added lens
* 2. cu_new_lens and cu_cached_lens are
of
shape [b + 1]; cu_cached_lens include the added lens
* in current step
* in current step
* 3. Non-paged KV cache is a special case of paged KV cache, with page_table = [b, 1] and
* 3. Non-paged KV cache is a special case of paged KV cache, with page_table = [b, 1] and
* max_pages_per_seq = 1. We use the same underlying kernel for both non-paged and paged.
* max_pages_per_seq = 1. We use the same underlying kernel for both non-paged and paged.
...
...
transformer_engine/common/include/transformer_engine/fused_attn.h
View file @
970620a5
...
@@ -131,7 +131,7 @@ enum NVTE_Mask_Type {
...
@@ -131,7 +131,7 @@ enum NVTE_Mask_Type {
* NVTE_VANILLA_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
* NVTE_VANILLA_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
* NVTE_OFF_BY_ONE_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
* NVTE_OFF_BY_ONE_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
* NVTE_LEARNABLE_SOFTMAX: S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
* NVTE_LEARNABLE_SOFTMAX: S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
* where alpha is a learnable parameter
in
shape [H].
* where alpha is a learnable parameter
of
shape [H].
*/
*/
enum
NVTE_Softmax_Type
{
enum
NVTE_Softmax_Type
{
/*! Vanilla softmax */
/*! Vanilla softmax */
...
...
transformer_engine/common/recipe/__init__.py
View file @
970620a5
...
@@ -50,7 +50,7 @@ class MMParams:
...
@@ -50,7 +50,7 @@ class MMParams:
Parameters
Parameters
----------
----------
use_split_accumulator : bool, default =
`
True
`
use_split_accumulator : bool, default = True
Use FP8 fast accumulation on Hopper or Ada. For more details,
Use FP8 fast accumulation on Hopper or Ada. For more details,
see CUBLASLT_MATMUL_DESC_FAST_ACCUM option for cublasLtMatmul.
see CUBLASLT_MATMUL_DESC_FAST_ACCUM option for cublasLtMatmul.
"""
"""
...
@@ -159,7 +159,7 @@ class DelayedScaling(Recipe):
...
@@ -159,7 +159,7 @@ class DelayedScaling(Recipe):
recipe: DelayedScaling) -> Tensor
recipe: DelayedScaling) -> Tensor
where `Tensor` is a framework tensor type.
where `Tensor` is a framework tensor type.
reduce_amax: bool, default =
`
True
`
reduce_amax: bool, default = True
By default, if `torch.distributed` is initialized, the `amax` value for FP8
By default, if `torch.distributed` is initialized, the `amax` value for FP8
tensors is reduced across the `amax_reduction_group` (specified in the `autocast`
tensors is reduced across the `amax_reduction_group` (specified in the `autocast`
call). This keeps the amaxes and scaling factors synced across the given
call). This keeps the amaxes and scaling factors synced across the given
...
@@ -167,13 +167,13 @@ class DelayedScaling(Recipe):
...
@@ -167,13 +167,13 @@ class DelayedScaling(Recipe):
GPU maintains local amaxes and scaling factors. To ensure results are
GPU maintains local amaxes and scaling factors. To ensure results are
numerically identical across checkpointing boundaries in this case, all
numerically identical across checkpointing boundaries in this case, all
ranks must checkpoint in order to store the local tensors.
ranks must checkpoint in order to store the local tensors.
fp8_dpa: bool, default =
`
False
`
fp8_dpa: bool, default = False
Whether to enable FP8 dot product attention (DPA). When the model is placed in an
Whether to enable FP8 dot product attention (DPA). When the model is placed in an
`autocast(enabled=True)` region and `fp8_dpa` is set to `True`, DPA casts the
`autocast(enabled=True)` region and `fp8_dpa` is set to `True`, DPA casts the
inputs from higher precision to FP8, performs attention in FP8, and casts tensors
inputs from higher precision to FP8, performs attention in FP8, and casts tensors
back to higher precision as outputs. FP8 DPA currently is only supported in the
back to higher precision as outputs. FP8 DPA currently is only supported in the
`FusedAttention` backend.
`FusedAttention` backend.
fp8_mha: bool, default =
`
False
`
fp8_mha: bool, default = False
Whether to enable FP8 multi-head attention (MHA). When `True`, it removes the casting
Whether to enable FP8 multi-head attention (MHA). When `True`, it removes the casting
operations mentioned above at the DPA boundaries. Currently only standard MHA modules
operations mentioned above at the DPA boundaries. Currently only standard MHA modules
i.e. `LayerNormLinear/Linear + DPA + Linear`, are supported for this feature. When
i.e. `LayerNormLinear/Linear + DPA + Linear`, are supported for this feature. When
...
@@ -422,11 +422,11 @@ class NVFP4BlockScaling(Recipe):
...
@@ -422,11 +422,11 @@ class NVFP4BlockScaling(Recipe):
----------
----------
fp4_format : {Format.E2M1}, default = Format.E2M1
fp4_format : {Format.E2M1}, default = Format.E2M1
FP4 data type.
FP4 data type.
disable_rht : bool, default =
`
False
`
disable_rht : bool, default = False
If set to `True`, random Hadamard transforms are not applied to any tensor.
If set to `True`, random Hadamard transforms are not applied to any tensor.
disable_stochastic_rounding : bool, default =
`
False
`
disable_stochastic_rounding : bool, default = False
If set to `True`, stochastic rounding is disabled during quantization for all tensors.
If set to `True`, stochastic rounding is disabled during quantization for all tensors.
disable_2d_quantization : bool, default =
`
False
`
disable_2d_quantization : bool, default = False
If set to `True`, 1D block scaling with block size 16 is used for all tensors.
If set to `True`, 1D block scaling with block size 16 is used for all tensors.
"""
"""
...
@@ -494,13 +494,15 @@ class CustomRecipe(Recipe):
...
@@ -494,13 +494,15 @@ class CustomRecipe(Recipe):
qfactory : Callable
qfactory : Callable
Factory callable that returns a quantizer instance for a
Factory callable that returns a quantizer instance for a
given semantic tensor role.
given semantic tensor role.
The callable is typically invoked as:
The callable is typically invoked as::
qfactory(
qfactory(
role: str,
role: str,
)
)
Where `role` is one of the following strings for e.g. te.Linear
Where `role` is one of the following strings for e.g. te.Linear
(stable public contract):
(stable public contract):
- forward: "linear_input", "linear_weight", "linear_output"
- forward: "linear_input", "linear_weight", "linear_output"
- backward: "linear_grad_output", "linear_grad_input"
- backward: "linear_grad_output", "linear_grad_input"
"""
"""
...
...
transformer_engine/common/transformer_engine.cpp
View file @
970620a5
...
@@ -736,12 +736,17 @@ int nvte_is_non_tn_fp8_gemm_supported() {
...
@@ -736,12 +736,17 @@ int nvte_is_non_tn_fp8_gemm_supported() {
#if USE_ROCM
#if USE_ROCM
return
true
;
return
true
;
#else
#else
int
deviceComputeCapability
=
int
num_devices
=
transformer_engine
::
cuda
::
num_devices
();
transformer_engine
::
cuda
::
sm_arch
(
transformer_engine
::
cuda
::
current_device
());
static
std
::
vector
<
int
>
cache
(
num_devices
,
-
1
);
static
std
::
vector
<
std
::
once_flag
>
flags
(
num_devices
);
int
device_id
=
transformer_engine
::
cuda
::
current_device
();
std
::
call_once
(
flags
[
device_id
],
[
&
]()
{
int
deviceComputeCapability
=
transformer_engine
::
cuda
::
sm_arch
(
device_id
);
// Note: this is temporary restriction and should be lifted in the future.
// Note: this is temporary restriction and should be lifted in the future.
// (remove the note once it's done.)
// (remove the note once it's done.)
return
(
deviceComputeCapability
>=
100
&&
deviceComputeCapability
<
120
)
||
cache
[
device_id
]
=
(
deviceComputeCapability
>=
100
&&
deviceComputeCapability
<
120
)
||
deviceComputeCapability
>=
130
;
deviceComputeCapability
>=
130
;
});
return
cache
[
device_id
];
#endif
#endif
}
}
transformer_engine/jax/attention.py
View file @
970620a5
...
@@ -18,6 +18,7 @@ from transformer_engine_jax import NVTE_Mask_Type
...
@@ -18,6 +18,7 @@ from transformer_engine_jax import NVTE_Mask_Type
from
transformer_engine_jax
import
NVTE_QKV_Layout
from
transformer_engine_jax
import
NVTE_QKV_Layout
from
transformer_engine_jax
import
NVTE_QKV_Format
from
transformer_engine_jax
import
NVTE_QKV_Format
from
transformer_engine_jax
import
nvte_get_qkv_format
from
transformer_engine_jax
import
nvte_get_qkv_format
from
transformer_engine_jax
import
NVTE_Softmax_Type
from
.
import
cpp_extensions
as
tex
from
.
import
cpp_extensions
as
tex
...
@@ -74,6 +75,35 @@ class AttnMaskType(Enum):
...
@@ -74,6 +75,35 @@ class AttnMaskType(Enum):
]
]
class
AttnSoftmaxType
(
Enum
):
"""
VANILLA_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
OFF_BY_ONE_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)),
LEARNABLE_SOFTMAX: S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
where alpha is a learnable parameter in shape [H].
"""
VANILLA_SOFTMAX
=
NVTE_Softmax_Type
.
NVTE_VANILLA_SOFTMAX
OFF_BY_ONE_SOFTMAX
=
NVTE_Softmax_Type
.
NVTE_OFF_BY_ONE_SOFTMAX
LEARNABLE_SOFTMAX
=
NVTE_Softmax_Type
.
NVTE_LEARNABLE_SOFTMAX
@
classmethod
def
from_str
(
cls
,
softmax_type
:
str
)
->
"AttnSoftmaxType"
:
"""Convert string to AttnSoftmaxType: 'vanilla', 'off_by_one', or 'learnable'."""
softmax_type_map
=
{
"vanilla"
:
cls
.
VANILLA_SOFTMAX
,
"off_by_one"
:
cls
.
OFF_BY_ONE_SOFTMAX
,
"learnable"
:
cls
.
LEARNABLE_SOFTMAX
,
}
result
=
softmax_type_map
.
get
(
softmax_type
)
if
result
is
None
:
raise
ValueError
(
f
"Unknown softmax_type:
{
softmax_type
}
. "
"Valid options: 'vanilla', 'off_by_one', 'learnable'"
)
return
result
class
QKVFormat
(
Enum
):
class
QKVFormat
(
Enum
):
"""
"""
SBHD: q,k,v memory layout with [s, b, ..., h, d]
SBHD: q,k,v memory layout with [s, b, ..., h, d]
...
@@ -301,6 +331,7 @@ def is_fused_attn_kernel_available(
...
@@ -301,6 +331,7 @@ def is_fused_attn_kernel_available(
qkv_layout
,
qkv_layout
,
attn_bias_type
,
attn_bias_type
,
attn_mask_type
,
attn_mask_type
,
softmax_type
,
dropout_probability
,
dropout_probability
,
q_num_heads
,
q_num_heads
,
kv_num_heads
,
kv_num_heads
,
...
@@ -313,6 +344,7 @@ def is_fused_attn_kernel_available(
...
@@ -313,6 +344,7 @@ def is_fused_attn_kernel_available(
"""
"""
To check whether the fused attention kernel is supported
To check whether the fused attention kernel is supported
"""
"""
window_size_tuple
=
(
-
1
,
-
1
)
if
window_size
is
None
else
window_size
def
make_helper
(
attn_mask_type
):
def
make_helper
(
attn_mask_type
):
return
tex
.
FusedAttnHelper
(
return
tex
.
FusedAttnHelper
(
...
@@ -322,6 +354,7 @@ def is_fused_attn_kernel_available(
...
@@ -322,6 +354,7 @@ def is_fused_attn_kernel_available(
qkv_layout
,
qkv_layout
,
attn_bias_type
,
attn_bias_type
,
attn_mask_type
,
attn_mask_type
,
softmax_type
,
dropout_probability
,
dropout_probability
,
q_num_heads
,
q_num_heads
,
kv_num_heads
,
kv_num_heads
,
...
@@ -329,7 +362,7 @@ def is_fused_attn_kernel_available(
...
@@ -329,7 +362,7 @@ def is_fused_attn_kernel_available(
kv_max_seqlen
,
kv_max_seqlen
,
head_dim_qk
,
head_dim_qk
,
head_dim_v
,
head_dim_v
,
(
-
1
,
-
1
)
if
window_size
is
None
else
window_size
,
window_siz
e_tupl
e
,
)
)
return
make_helper
(
attn_mask_type
).
is_fused_attn_kernel_available
()
return
make_helper
(
attn_mask_type
).
is_fused_attn_kernel_available
()
...
@@ -497,6 +530,11 @@ def _segment_ids_pos_to_seqlens_offsets(
...
@@ -497,6 +530,11 @@ def _segment_ids_pos_to_seqlens_offsets(
#
#
# This fast path avoids expanding the mask to Q * KV matrix and instead allows us to
# This fast path avoids expanding the mask to Q * KV matrix and instead allows us to
# examine only O(Q+KV) elements.
# examine only O(Q+KV) elements.
# For seqlens and seqoffsets calculations, the intermediate(temp) attn_mask creation
# using the segment ids and pos along with mask type (causal or brcm) is sufficient.
# It does not need to involve SW for this mask's creation
# TODO(KshitijLakhani): Try exercising the fast path for BRCM as well
# TODO(KshitijLakhani): Try exercising the fast path for BRCM as well
if
(
attn_mask_type
.
is_causal
()
and
window_size
is
None
)
or
(
if
(
attn_mask_type
.
is_causal
()
and
window_size
is
None
)
or
(
window_size
==
(
-
1
,
-
1
)
and
not
attn_mask_type
.
is_bottom_right
()
window_size
==
(
-
1
,
-
1
)
and
not
attn_mask_type
.
is_bottom_right
()
...
@@ -558,21 +596,6 @@ def _segment_ids_pos_to_seqlens_offsets(
...
@@ -558,21 +596,6 @@ def _segment_ids_pos_to_seqlens_offsets(
)
)
attn_mask
=
jnp
.
logical_and
(
segment_mask
,
causal_mask
)
attn_mask
=
jnp
.
logical_and
(
segment_mask
,
causal_mask
)
# TODO(KshitijLakhani): Evaluate if swa_mask is needed to procure seqlen and offsets
swa_mask
=
(
make_swa_mask
(
segment_pos_q
,
segment_pos_kv
,
window_size
,
dtype
=
jnp
.
bool
,
segment_ids_q
=
segment_ids_q
,
segment_ids_kv
=
segment_ids_kv
,
)
if
attn_mask_type
.
is_bottom_right
()
else
make_swa_mask
(
segment_pos_q
,
segment_pos_kv
,
window_size
,
dtype
=
jnp
.
bool
)
)
attn_mask
=
jnp
.
logical_and
(
attn_mask
,
swa_mask
)
attn_mask_with_id
=
jnp
.
where
(
attn_mask
,
segment_mask_with_id
,
0
)
attn_mask_with_id
=
jnp
.
where
(
attn_mask
,
segment_mask_with_id
,
0
)
q_seqlen
,
q_offset
,
kv_seqlen
,
kv_offset
=
_mask_to_seqlens_offset
(
q_seqlen
,
q_offset
,
kv_seqlen
,
kv_offset
=
_mask_to_seqlens_offset
(
attn_mask_with_id
,
max_segments_per_seq
attn_mask_with_id
,
max_segments_per_seq
...
@@ -786,6 +809,7 @@ def _legacy_fused_attn(
...
@@ -786,6 +809,7 @@ def _legacy_fused_attn(
attn_bias_type
:
AttnBiasType
,
attn_bias_type
:
AttnBiasType
,
attn_mask_type
:
AttnMaskType
,
attn_mask_type
:
AttnMaskType
,
qkv_layout
:
QKVLayout
,
qkv_layout
:
QKVLayout
,
softmax_type
:
AttnSoftmaxType
,
scaling_factor
:
float
,
scaling_factor
:
float
,
dropout_probability
:
float
,
dropout_probability
:
float
,
is_training
:
bool
,
is_training
:
bool
,
...
@@ -793,6 +817,7 @@ def _legacy_fused_attn(
...
@@ -793,6 +817,7 @@ def _legacy_fused_attn(
context_parallel_strategy
:
CPStrategy
=
CPStrategy
.
DEFAULT
,
context_parallel_strategy
:
CPStrategy
=
CPStrategy
.
DEFAULT
,
context_parallel_causal_load_balanced
:
bool
=
False
,
context_parallel_causal_load_balanced
:
bool
=
False
,
context_parallel_axis
:
str
=
""
,
context_parallel_axis
:
str
=
""
,
softmax_offset
:
Optional
[
jnp
.
ndarray
]
=
None
,
):
):
"""
"""
Perform non-THD (non-packed) cuDNN fused attention.
Perform non-THD (non-packed) cuDNN fused attention.
...
@@ -815,6 +840,7 @@ def _legacy_fused_attn(
...
@@ -815,6 +840,7 @@ def _legacy_fused_attn(
seed (Optional[jnp.ndarray]): Optional random seed for dropout.
seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (AttnBiasType): Type of attention bias.
attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (AttnMaskType): Type of attention mask.
attn_mask_type (AttnMaskType): Type of attention mask.
softmax_type (AttnSoftmaxType): Type of attention softmax.
qkv_layout (QKVLayout): Layout of the QKV tensors.
qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores.
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
dropout_probability (float): Dropout probability to apply during attention.
...
@@ -863,10 +889,12 @@ def _legacy_fused_attn(
...
@@ -863,10 +889,12 @@ def _legacy_fused_attn(
output
=
_fused_attn
(
output
=
_fused_attn
(
qkv
,
qkv
,
bias
,
bias
,
softmax_offset
,
SequenceDescriptor
.
from_seqlens
((
q_seq_lens
,
kv_seq_lens
)),
SequenceDescriptor
.
from_seqlens
((
q_seq_lens
,
kv_seq_lens
)),
seed
,
seed
,
attn_bias_type
=
attn_bias_type
,
attn_bias_type
=
attn_bias_type
,
attn_mask_type
=
attn_mask_type
,
attn_mask_type
=
attn_mask_type
,
softmax_type
=
softmax_type
,
qkv_layout
=
qkv_layout
,
qkv_layout
=
qkv_layout
,
scaling_factor
=
scaling_factor
,
scaling_factor
=
scaling_factor
,
dropout_probability
=
dropout_probability
,
dropout_probability
=
dropout_probability
,
...
@@ -900,6 +928,7 @@ def fused_attn_thd(
...
@@ -900,6 +928,7 @@ def fused_attn_thd(
context_parallel_strategy
:
CPStrategy
=
CPStrategy
.
DEFAULT
,
context_parallel_strategy
:
CPStrategy
=
CPStrategy
.
DEFAULT
,
context_parallel_causal_load_balanced
:
bool
=
False
,
context_parallel_causal_load_balanced
:
bool
=
False
,
context_parallel_axis
:
str
=
""
,
context_parallel_axis
:
str
=
""
,
softmax_offset
:
Optional
[
jnp
.
ndarray
]
=
None
,
):
):
"""
"""
Deprecated THD fused attn, please use fusd_attn with SequenceDescriptor
Deprecated THD fused attn, please use fusd_attn with SequenceDescriptor
...
@@ -937,6 +966,7 @@ def fused_attn_thd(
...
@@ -937,6 +966,7 @@ def fused_attn_thd(
output
=
_fused_attn
(
output
=
_fused_attn
(
qkv
,
qkv
,
bias
,
bias
,
softmax_offset
,
SequenceDescriptor
.
from_seqlens_and_offsets
(
SequenceDescriptor
.
from_seqlens_and_offsets
(
(
q_seq_lens
,
kv_seq_lens
),
(
q_seq_offsets
,
kv_seq_offsets
)
(
q_seq_lens
,
kv_seq_lens
),
(
q_seq_offsets
,
kv_seq_offsets
)
),
),
...
@@ -945,6 +975,7 @@ def fused_attn_thd(
...
@@ -945,6 +975,7 @@ def fused_attn_thd(
attn_mask_type
=
attn_mask_type
,
attn_mask_type
=
attn_mask_type
,
qkv_layout
=
qkv_layout
,
qkv_layout
=
qkv_layout
,
scaling_factor
=
scaling_factor
,
scaling_factor
=
scaling_factor
,
softmax_type
=
AttnSoftmaxType
.
VANILLA_SOFTMAX
,
dropout_probability
=
dropout_probability
,
dropout_probability
=
dropout_probability
,
is_training
=
is_training
,
is_training
=
is_training
,
max_segments_per_seq
=
max_segments_per_seq
,
max_segments_per_seq
=
max_segments_per_seq
,
...
@@ -957,15 +988,17 @@ def fused_attn_thd(
...
@@ -957,15 +988,17 @@ def fused_attn_thd(
return
output
return
output
@
partial
(
jax
.
custom_vjp
,
nondiff_argnums
=
(
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
))
@
partial
(
jax
.
custom_vjp
,
nondiff_argnums
=
(
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
))
def
_fused_attn
(
def
_fused_attn
(
qkv
:
Tuple
[
jnp
.
ndarray
,
...],
qkv
:
Tuple
[
jnp
.
ndarray
,
...],
bias
:
Optional
[
jnp
.
ndarray
],
bias
:
Optional
[
jnp
.
ndarray
],
softmax_offset
:
Optional
[
jnp
.
ndarray
],
sequence_descriptor
:
SequenceDescriptor
,
sequence_descriptor
:
SequenceDescriptor
,
seed
:
Optional
[
jnp
.
ndarray
],
seed
:
Optional
[
jnp
.
ndarray
],
attn_bias_type
:
AttnBiasType
,
attn_bias_type
:
AttnBiasType
,
attn_mask_type
:
AttnMaskType
,
attn_mask_type
:
AttnMaskType
,
qkv_layout
:
QKVLayout
,
qkv_layout
:
QKVLayout
,
softmax_type
:
AttnSoftmaxType
,
scaling_factor
:
float
,
scaling_factor
:
float
,
dropout_probability
:
float
,
dropout_probability
:
float
,
is_training
:
bool
,
is_training
:
bool
,
...
@@ -979,11 +1012,13 @@ def _fused_attn(
...
@@ -979,11 +1012,13 @@ def _fused_attn(
output
,
_
=
_fused_attn_fwd_rule
(
output
,
_
=
_fused_attn_fwd_rule
(
qkv
,
qkv
,
bias
,
bias
,
softmax_offset
,
sequence_descriptor
,
sequence_descriptor
,
seed
,
seed
,
attn_bias_type
,
attn_bias_type
,
attn_mask_type
,
attn_mask_type
,
qkv_layout
,
qkv_layout
,
softmax_type
,
scaling_factor
,
scaling_factor
,
dropout_probability
,
dropout_probability
,
is_training
,
is_training
,
...
@@ -1000,11 +1035,13 @@ def _fused_attn(
...
@@ -1000,11 +1035,13 @@ def _fused_attn(
def
_fused_attn_fwd_rule
(
def
_fused_attn_fwd_rule
(
qkv
,
qkv
,
bias
,
bias
,
softmax_offset
,
sequence_descriptor
,
sequence_descriptor
,
seed
,
seed
,
attn_bias_type
,
attn_bias_type
,
attn_mask_type
,
attn_mask_type
,
qkv_layout
,
qkv_layout
,
softmax_type
,
scaling_factor
,
scaling_factor
,
dropout_probability
,
dropout_probability
,
is_training
,
is_training
,
...
@@ -1018,10 +1055,12 @@ def _fused_attn_fwd_rule(
...
@@ -1018,10 +1055,12 @@ def _fused_attn_fwd_rule(
output
,
softmax_aux
,
rng_state
=
tex
.
fused_attn_fwd
(
output
,
softmax_aux
,
rng_state
=
tex
.
fused_attn_fwd
(
qkv
,
qkv
,
bias
,
bias
,
softmax_offset
,
sequence_descriptor
,
sequence_descriptor
,
seed
,
seed
,
attn_bias_type
=
attn_bias_type
,
attn_bias_type
=
attn_bias_type
,
attn_mask_type
=
attn_mask_type
,
attn_mask_type
=
attn_mask_type
,
softmax_type
=
softmax_type
,
qkv_layout
=
qkv_layout
,
qkv_layout
=
qkv_layout
,
scaling_factor
=
scaling_factor
,
scaling_factor
=
scaling_factor
,
dropout_probability
=
dropout_probability
,
dropout_probability
=
dropout_probability
,
...
@@ -1041,6 +1080,7 @@ def _fused_attn_fwd_rule(
...
@@ -1041,6 +1080,7 @@ def _fused_attn_fwd_rule(
sequence_descriptor
,
sequence_descriptor
,
softmax_aux
,
softmax_aux
,
rng_state
,
rng_state
,
softmax_offset
,
output
,
output
,
)
)
...
@@ -1049,6 +1089,7 @@ def _fused_attn_bwd_rule(
...
@@ -1049,6 +1089,7 @@ def _fused_attn_bwd_rule(
attn_bias_type
,
attn_bias_type
,
attn_mask_type
,
attn_mask_type
,
qkv_layout
,
qkv_layout
,
softmax_type
,
scaling_factor
,
scaling_factor
,
dropout_probability
,
dropout_probability
,
is_training
,
is_training
,
...
@@ -1068,11 +1109,13 @@ def _fused_attn_bwd_rule(
...
@@ -1068,11 +1109,13 @@ def _fused_attn_bwd_rule(
sequence_descriptor
,
sequence_descriptor
,
softmax_aux
,
softmax_aux
,
rng_state
,
rng_state
,
softmax_offset
,
output
,
output
,
)
=
ctx
)
=
ctx
grad_qkv
,
grad_bias
=
tex
.
fused_attn_bwd
(
grad_qkv
,
grad_bias
,
grad_softmax_offset
=
tex
.
fused_attn_bwd
(
qkv
,
qkv
,
bias
,
bias
,
softmax_offset
,
softmax_aux
,
softmax_aux
,
rng_state
,
rng_state
,
output
,
output
,
...
@@ -1080,6 +1123,7 @@ def _fused_attn_bwd_rule(
...
@@ -1080,6 +1123,7 @@ def _fused_attn_bwd_rule(
sequence_descriptor
,
sequence_descriptor
,
attn_bias_type
=
attn_bias_type
,
attn_bias_type
=
attn_bias_type
,
attn_mask_type
=
attn_mask_type
,
attn_mask_type
=
attn_mask_type
,
softmax_type
=
softmax_type
,
qkv_layout
=
qkv_layout
,
qkv_layout
=
qkv_layout
,
scaling_factor
=
scaling_factor
,
scaling_factor
=
scaling_factor
,
dropout_probability
=
dropout_probability
,
dropout_probability
=
dropout_probability
,
...
@@ -1092,9 +1136,12 @@ def _fused_attn_bwd_rule(
...
@@ -1092,9 +1136,12 @@ def _fused_attn_bwd_rule(
)
)
if
attn_bias_type
==
AttnBiasType
.
NO_BIAS
:
if
attn_bias_type
==
AttnBiasType
.
NO_BIAS
:
grad_bias
=
None
grad_bias
=
None
if
softmax_type
!=
AttnSoftmaxType
.
LEARNABLE_SOFTMAX
:
grad_softmax_offset
=
None
return
(
return
(
grad_qkv
,
grad_qkv
,
grad_bias
,
grad_bias
,
grad_softmax_offset
,
None
,
None
,
None
,
None
,
)
)
...
@@ -1111,6 +1158,7 @@ def fused_attn(
...
@@ -1111,6 +1158,7 @@ def fused_attn(
attn_bias_type
:
AttnBiasType
,
attn_bias_type
:
AttnBiasType
,
attn_mask_type
:
AttnMaskType
,
attn_mask_type
:
AttnMaskType
,
qkv_layout
:
QKVLayout
,
qkv_layout
:
QKVLayout
,
softmax_type
:
AttnSoftmaxType
,
scaling_factor
:
float
,
scaling_factor
:
float
,
dropout_probability
:
float
,
dropout_probability
:
float
,
is_training
:
bool
,
is_training
:
bool
,
...
@@ -1120,6 +1168,7 @@ def fused_attn(
...
@@ -1120,6 +1168,7 @@ def fused_attn(
context_parallel_causal_load_balanced
:
bool
=
False
,
context_parallel_causal_load_balanced
:
bool
=
False
,
context_parallel_axis
:
str
=
""
,
context_parallel_axis
:
str
=
""
,
context_checkpoint_name
:
str
=
"context"
,
context_checkpoint_name
:
str
=
"context"
,
softmax_offset
:
Optional
[
jnp
.
ndarray
]
=
None
,
):
):
"""
"""
Perform cuDNN fused attention.
Perform cuDNN fused attention.
...
@@ -1139,6 +1188,7 @@ def fused_attn(
...
@@ -1139,6 +1188,7 @@ def fused_attn(
seed (Optional[jnp.ndarray]): Optional random seed for dropout.
seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (AttnBiasType): Type of attention bias.
attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (AttnMaskType): Type of attention mask.
attn_mask_type (AttnMaskType): Type of attention mask.
softmax_type (AttnSoftmaxType): Type of attention softmax.
qkv_layout (QKVLayout): Layout of the QKV tensors.
qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores.
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
dropout_probability (float): Dropout probability to apply during attention.
...
@@ -1153,6 +1203,9 @@ def fused_attn(
...
@@ -1153,6 +1203,9 @@ def fused_attn(
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.
context_parallel_axis (str): The name of the context parallel axis.
context_checkpoint_name (str): The name of the context checkpoint for the custom VJP forward pass.
context_checkpoint_name (str): The name of the context checkpoint for the custom VJP forward pass.
softmax_offset (Optional[jnp.ndarray]): An optional learnable softmax offset tensor with shape
[1, num_heads, 1, 1]. Used when softmax_type is AttnSoftmaxType.LEARNABLE_SOFTMAX.
If provided, this parameter will receive gradients during backpropagation.
Returns:
Returns:
(jnp.ndarray): The output tensor from the fused attention.
(jnp.ndarray): The output tensor from the fused attention.
...
@@ -1200,6 +1253,7 @@ def fused_attn(
...
@@ -1200,6 +1253,7 @@ def fused_attn(
seed
,
seed
,
attn_bias_type
=
attn_bias_type
,
attn_bias_type
=
attn_bias_type
,
attn_mask_type
=
attn_mask_type
,
attn_mask_type
=
attn_mask_type
,
softmax_type
=
softmax_type
,
qkv_layout
=
qkv_layout
,
qkv_layout
=
qkv_layout
,
scaling_factor
=
scaling_factor
,
scaling_factor
=
scaling_factor
,
dropout_probability
=
dropout_probability
,
dropout_probability
=
dropout_probability
,
...
@@ -1208,15 +1262,18 @@ def fused_attn(
...
@@ -1208,15 +1262,18 @@ def fused_attn(
context_parallel_strategy
=
context_parallel_strategy
,
context_parallel_strategy
=
context_parallel_strategy
,
context_parallel_causal_load_balanced
=
context_parallel_causal_load_balanced
,
context_parallel_causal_load_balanced
=
context_parallel_causal_load_balanced
,
context_parallel_axis
=
context_parallel_axis
,
context_parallel_axis
=
context_parallel_axis
,
softmax_offset
=
softmax_offset
,
)
)
output
=
_fused_attn
(
output
=
_fused_attn
(
qkv
,
qkv
,
bias
,
bias
,
softmax_offset
,
sequence_descriptor
,
sequence_descriptor
,
seed
,
seed
,
attn_bias_type
=
attn_bias_type
,
attn_bias_type
=
attn_bias_type
,
attn_mask_type
=
attn_mask_type
,
attn_mask_type
=
attn_mask_type
,
qkv_layout
=
qkv_layout
,
qkv_layout
=
qkv_layout
,
softmax_type
=
softmax_type
,
scaling_factor
=
scaling_factor
,
scaling_factor
=
scaling_factor
,
dropout_probability
=
dropout_probability
,
dropout_probability
=
dropout_probability
,
is_training
=
is_training
,
is_training
=
is_training
,
...
...
transformer_engine/jax/cpp_extensions/activation.py
View file @
970620a5
...
@@ -32,9 +32,9 @@ from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_a
...
@@ -32,9 +32,9 @@ from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_a
from
..quantize
import
ScaledTensor
,
ScaledTensorFactory
,
NoScaleTensor
from
..quantize
import
ScaledTensor
,
ScaledTensorFactory
,
NoScaleTensor
from
..quantize
import
(
from
..quantize
import
(
Quantizer
,
Quantizer
,
QuantizeLayout
,
DelayedScaleQuantizer
,
DelayedScaleQuantizer
,
ScalingMode
,
ScalingMode
,
QuantizeLayout
,
)
)
...
...
transformer_engine/jax/cpp_extensions/attention.py
View file @
970620a5
This diff is collapsed.
Click to expand it.
transformer_engine/jax/cpp_extensions/gemm.py
View file @
970620a5
...
@@ -39,12 +39,12 @@ from ..quantize import (
...
@@ -39,12 +39,12 @@ from ..quantize import (
Quantizer
,
Quantizer
,
GroupedQuantizer
,
GroupedQuantizer
,
QuantizerSet
,
QuantizerSet
,
QuantizeLayout
,
noop_quantizer_set
,
noop_quantizer_set
,
is_fp8_gemm_with_all_layouts_supported
,
is_fp8_gemm_with_all_layouts_supported
,
apply_padding_to_scale_inv
,
apply_padding_to_scale_inv
,
get_quantize_config_with_recipe
,
get_quantize_config_with_recipe
,
get_global_quantize_recipe
,
get_global_quantize_recipe
,
QuantizeLayout
,
)
)
from
.misc
import
get_padded_spec
,
is_all_reduce_in_float32
from
.misc
import
get_padded_spec
,
is_all_reduce_in_float32
from
..sharding
import
(
from
..sharding
import
(
...
...
transformer_engine/jax/cpp_extensions/misc.py
View file @
970620a5
...
@@ -116,7 +116,7 @@ def multidim_transpose(shape, static_axis_boundary=-1, transpose_axis=-1):
...
@@ -116,7 +116,7 @@ def multidim_transpose(shape, static_axis_boundary=-1, transpose_axis=-1):
transpose. Note, transpose_axis should be greater than static_axis_boundary
transpose. Note, transpose_axis should be greater than static_axis_boundary
examples:
examples:
X
in
shape (dim0, dim1, dim2, dim3, dim4)
X
of
shape (dim0, dim1, dim2, dim3, dim4)
static_axis_boundary == -1, transpose_axis == 2
static_axis_boundary == -1, transpose_axis == 2
Xt = (dim2, dim3, dim4, dim0, dim1)
Xt = (dim2, dim3, dim4, dim0, dim1)
...
...
transformer_engine/jax/cpp_extensions/normalization.py
View file @
970620a5
...
@@ -35,9 +35,9 @@ from ..sharding import (
...
@@ -35,9 +35,9 @@ from ..sharding import (
from
..quantize
import
ScaledTensor
,
ScaledTensorFactory
,
NoScaleTensor
from
..quantize
import
ScaledTensor
,
ScaledTensorFactory
,
NoScaleTensor
from
..quantize
import
(
from
..quantize
import
(
Quantizer
,
Quantizer
,
QuantizeLayout
,
DelayedScaleQuantizer
,
DelayedScaleQuantizer
,
ScalingMode
,
ScalingMode
,
QuantizeLayout
,
)
)
...
...
transformer_engine/jax/cpp_extensions/quantization.py
View file @
970620a5
...
@@ -40,11 +40,11 @@ from ..quantize import (
...
@@ -40,11 +40,11 @@ from ..quantize import (
GroupedScaledTensor1x
,
GroupedScaledTensor1x
,
Quantizer
,
Quantizer
,
GroupedQuantizer
,
GroupedQuantizer
,
QuantizeLayout
,
ScalingMode
,
ScalingMode
,
compute_scale_from_amax
,
compute_scale_from_amax
,
NoScaleTensor
,
NoScaleTensor
,
get_rht_matrix
,
get_rht_matrix
,
QuantizeLayout
,
)
)
...
@@ -497,6 +497,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
...
@@ -497,6 +497,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
x_spec
=
get_padded_spec
(
arg_infos
[
0
])
x_spec
=
get_padded_spec
(
arg_infos
[
0
])
amax_spec
=
get_padded_spec
(
arg_infos
[
2
])
amax_spec
=
get_padded_spec
(
arg_infos
[
2
])
sr_rng_state_spec
=
get_padded_spec
(
arg_infos
[
3
])
out_sharding
=
NamedSharding
(
out_sharding
=
NamedSharding
(
mesh
,
mesh
,
PartitionSpec
(
*
x_spec
),
PartitionSpec
(
*
x_spec
),
...
@@ -551,9 +552,12 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
...
@@ -551,9 +552,12 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
)
)
arg_shardings
=
list
(
arg_i
.
sharding
for
arg_i
in
arg_infos
)
arg_shardings
=
list
(
arg_i
.
sharding
for
arg_i
in
arg_infos
)
if
len
(
sr_rng_state_spec
)
>
1
:
# sr_rng_state shape [n_devices, state_per_device]
sr_rng_state_spec
=
(
*
tuple
(
x
for
x
in
x_spec
if
x
is
not
None
),
None
)
arg_shardings
[
3
]
=
NamedSharding
(
arg_shardings
[
3
]
=
NamedSharding
(
mesh
,
mesh
,
PartitionSpec
(
tuple
(
x
for
x
in
x_spec
if
x
is
not
None
),
None
),
PartitionSpec
(
*
sr_rng_state_spec
),
desc
=
"BaseDBiasQuantizePrimitive.sr_rng_state"
,
desc
=
"BaseDBiasQuantizePrimitive.sr_rng_state"
,
)
)
arg_shardings
=
tuple
(
arg_shardings
)
arg_shardings
=
tuple
(
arg_shardings
)
...
@@ -654,9 +658,11 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
...
@@ -654,9 +658,11 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
dbias
=
input_spec
[
flatten_axis
:]
if
is_dbias
else
(
prefix
+
"_dbias"
,)
dbias
=
input_spec
[
flatten_axis
:]
if
is_dbias
else
(
prefix
+
"_dbias"
,)
amax
=
(
BATCHING
+
prefix
+
"_amax"
,)
amax
=
(
BATCHING
+
prefix
+
"_amax"
,)
scale
=
(
BATCHING
+
prefix
+
"_scale"
,)
scale
=
(
BATCHING
+
prefix
+
"_scale"
,)
sr_rng_state
=
(
BATCHING
+
prefix
+
"_sr_rng_state"
,)
if
value_types
[
3
].
shape
!=
[
0
]:
sr_rng_state
=
(
sr_rng_state
=
(
BATCHING
+
prefix
+
"_sr_rng_state_
partition_axi
s"
,
BATCHING
+
prefix
+
"_sr_rng_state_
device
s"
,
BATCHING
+
prefix
+
"sr_rng_state_data
_axis
"
,
prefix
+
"sr_rng_state_data"
,
)
)
post_rht_amax
=
(
BATCHING
+
prefix
+
"_post_rht_amax"
,)
post_rht_amax
=
(
BATCHING
+
prefix
+
"_post_rht_amax"
,)
...
@@ -849,7 +855,7 @@ def _quantize_dbias_impl(
...
@@ -849,7 +855,7 @@ def _quantize_dbias_impl(
if
force_1x_quantization
:
if
force_1x_quantization
:
q_layout
=
QuantizeLayout
.
ROWWISE
q_layout
=
QuantizeLayout
.
ROWWISE
sr_rng_state
=
None
sr_rng_state
=
jnp
.
empty
((
0
,),
jnp
.
uint32
)
if
quantizer
.
scaling_mode
.
is_nvfp4_scaling
:
if
quantizer
.
scaling_mode
.
is_nvfp4_scaling
:
# Only NVFP4 scaling modes support stochastic rounding
# Only NVFP4 scaling modes support stochastic rounding
if
quantizer
.
stochastic_rounding_rng_state
is
not
None
:
if
quantizer
.
stochastic_rounding_rng_state
is
not
None
:
...
@@ -866,11 +872,7 @@ def _quantize_dbias_impl(
...
@@ -866,11 +872,7 @@ def _quantize_dbias_impl(
x
.
data
,
x
.
data
,
scale
,
scale
,
amax
,
amax
,
(
sr_rng_state
,
sr_rng_state
if
sr_rng_state
is
not
None
else
jnp
.
empty
((
get_num_devices_in_mesh
(),
1
),
jnp
.
uint32
)
),
post_rht_amax
if
post_rht_amax
is
not
None
else
jnp
.
zeros
((
1
,),
jnp
.
float32
),
post_rht_amax
if
post_rht_amax
is
not
None
else
jnp
.
zeros
((
1
,),
jnp
.
float32
),
rht_matrix
,
rht_matrix
,
out_dtype
=
quantizer
.
q_dtype
,
out_dtype
=
quantizer
.
q_dtype
,
...
@@ -880,7 +882,7 @@ def _quantize_dbias_impl(
...
@@ -880,7 +882,7 @@ def _quantize_dbias_impl(
scale_dtype
=
quantizer
.
get_scale_dtype
(),
scale_dtype
=
quantizer
.
get_scale_dtype
(),
is_dbias
=
is_dbias
if
not
quantizer
.
scaling_mode
.
is_nvfp4_scaling
else
False
,
is_dbias
=
is_dbias
if
not
quantizer
.
scaling_mode
.
is_nvfp4_scaling
else
False
,
is_outer
=
True
,
is_outer
=
True
,
stochastic_rounding
=
sr_rng_state
is
not
None
,
stochastic_rounding
=
sr_rng_state
.
size
!=
0
,
use_rht
=
use_rht
,
use_rht
=
use_rht
,
)
)
# For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
# For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
...
...
transformer_engine/jax/cpp_extensions/softmax.py
View file @
970620a5
...
@@ -11,10 +11,11 @@ import jax
...
@@ -11,10 +11,11 @@ import jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
jax
import
dtypes
,
ffi
from
jax
import
dtypes
,
ffi
from
jax.sharding
import
PartitionSpec
,
NamedSharding
from
jax.sharding
import
PartitionSpec
,
NamedSharding
from
.attention
import
AttnSoftmaxType
from
.base
import
BasePrimitive
,
register_primitive
from
.base
import
BasePrimitive
,
register_primitive
from
.misc
import
get_padded_spec
,
check_valid_batch_dims
from
.misc
import
get_padded_spec
,
check_valid_batch_dims
from
..softmax
import
SoftmaxType
from
..softmax
import
Softmax
Fusion
Type
__all__
=
[
__all__
=
[
...
@@ -32,7 +33,8 @@ __all__ = [
...
@@ -32,7 +33,8 @@ __all__ = [
def
is_softmax_kernel_available
(
def
is_softmax_kernel_available
(
softmax_type
:
SoftmaxType
,
softmax_fusion_type
:
SoftmaxFusionType
,
softmax_type
:
AttnSoftmaxType
,
batch
:
int
,
batch
:
int
,
heads
:
int
,
heads
:
int
,
q_seqlen
:
int
,
q_seqlen
:
int
,
...
@@ -40,15 +42,18 @@ def is_softmax_kernel_available(
...
@@ -40,15 +42,18 @@ def is_softmax_kernel_available(
dtype
:
jnp
.
dtype
,
dtype
:
jnp
.
dtype
,
):
):
"""check softmax available"""
"""check softmax available"""
if
softmax_type
is
SoftmaxType
.
SCALED
:
if
softmax_type
!=
AttnSoftmaxType
.
VANILLA_SOFTMAX
:
return
False
if
softmax_fusion_type
is
SoftmaxFusionType
.
SCALED
:
return
ScaledSoftmaxFwdPrimitive
.
is_kernel_available
(
return
ScaledSoftmaxFwdPrimitive
.
is_kernel_available
(
batch
,
heads
,
q_seqlen
,
k_seqlen
,
dtype
batch
,
heads
,
q_seqlen
,
k_seqlen
,
dtype
)
)
if
softmax_type
is
SoftmaxType
.
SCALED_MASKED
:
if
softmax_
fusion_
type
is
Softmax
Fusion
Type
.
SCALED_MASKED
:
return
ScaledMaskedSoftmaxFwdPrimitive
.
is_kernel_available
(
return
ScaledMaskedSoftmaxFwdPrimitive
.
is_kernel_available
(
batch
,
heads
,
q_seqlen
,
k_seqlen
,
dtype
batch
,
heads
,
q_seqlen
,
k_seqlen
,
dtype
)
)
if
softmax_type
is
SoftmaxType
.
SCALED_UPPER_TRIANG_MASKED
:
if
softmax_
fusion_
type
is
Softmax
Fusion
Type
.
SCALED_UPPER_TRIANG_MASKED
:
return
ScaledUpperTriangMaskedSoftmaxFwdPrimitive
.
is_kernel_available
(
return
ScaledUpperTriangMaskedSoftmaxFwdPrimitive
.
is_kernel_available
(
batch
,
heads
,
q_seqlen
,
k_seqlen
,
dtype
batch
,
heads
,
q_seqlen
,
k_seqlen
,
dtype
)
)
...
@@ -792,26 +797,77 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
...
@@ -792,26 +797,77 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
register_primitive
(
ScaledUpperTriangMaskedSoftmaxBwdPrimitive
)
register_primitive
(
ScaledUpperTriangMaskedSoftmaxBwdPrimitive
)
def
jax_scaled_softmax
(
logits
:
jnp
.
ndarray
,
scale_factor
:
float
):
def
jax_scaled_softmax
(
logits
:
jnp
.
ndarray
,
scale_factor
:
float
,
softmax_offset
:
jnp
.
ndarray
|
float
|
None
=
None
):
"""
"""
JAX based implementation of scaled softmax
JAX based implementation of scaled softmax
"""
"""
if
softmax_offset
is
not
None
:
return
jax_general_softmax
(
scale_factor
*
logits
,
offset
=
softmax_offset
)
return
jax
.
nn
.
softmax
(
scale_factor
*
logits
)
return
jax
.
nn
.
softmax
(
scale_factor
*
logits
)
def
jax_scaled_masked_softmax
(
logits
:
jnp
.
ndarray
,
mask
:
jnp
.
ndarray
,
scale_factor
:
float
):
def
jax_scaled_masked_softmax
(
logits
:
jnp
.
ndarray
,
mask
:
jnp
.
ndarray
,
scale_factor
:
float
,
softmax_offset
:
jnp
.
ndarray
|
float
|
None
=
None
,
):
"""
"""
JAX based implementation of scaled and masked softmax
JAX based implementation of scaled and masked softmax
"""
"""
if
softmax_offset
is
not
None
:
return
jax_general_softmax
(
logits
*
scale_factor
,
offset
=
softmax_offset
,
where
=
mask
!=
1
)
return
jax
.
nn
.
softmax
(
logits
*
scale_factor
,
where
=
mask
!=
1
)
return
jax
.
nn
.
softmax
(
logits
*
scale_factor
,
where
=
mask
!=
1
)
def
jax_scaled_upper_triang_masked_softmax
(
logits
:
jnp
.
ndarray
,
scale_factor
:
float
):
def
jax_scaled_upper_triang_masked_softmax
(
logits
:
jnp
.
ndarray
,
scale_factor
:
float
,
softmax_offset
:
jnp
.
ndarray
|
float
|
None
=
None
):
"""
"""
JAX based implementation of scaled and upper triangle masked softmax
JAX based implementation of scaled and upper triangle masked softmax
"""
"""
mask
=
1
-
jnp
.
tril
(
jnp
.
ones_like
(
logits
))
mask
=
1
-
jnp
.
tril
(
jnp
.
ones_like
(
logits
))
return
jax_scaled_masked_softmax
(
logits
,
mask
,
scale_factor
)
return
jax_scaled_masked_softmax
(
logits
,
mask
,
scale_factor
,
softmax_offset
)
def
jax_general_softmax
(
x
:
jnp
.
ndarray
,
axis
:
int
=
-
1
,
where
:
jnp
.
ndarray
|
None
=
None
,
initial
:
jnp
.
ndarray
=
-
jnp
.
inf
,
offset
:
jnp
.
ndarray
|
float
|
None
=
None
,
)
->
jnp
.
ndarray
:
"""
JAX based implementation of general softmax with optional masking and offset.
"""
# Compute max of x
x_max
=
jnp
.
max
(
x
,
axis
,
where
=
where
,
initial
=
initial
,
keepdims
=
True
)
if
offset
is
not
None
:
# Cast offset to x.dtype to prevent type promotion
if
isinstance
(
offset
,
(
int
,
float
)):
offset
=
jnp
.
array
(
offset
,
dtype
=
x
.
dtype
)
else
:
offset
=
offset
.
astype
(
x
.
dtype
)
# Include offset in max: x_max = max(x_max, offset)
# This is equivalent to computing max over [x..., offset]
x_max
=
jnp
.
maximum
(
x_max
,
offset
)
x_safe
=
x
if
where
is
None
else
jnp
.
where
(
where
,
x
,
initial
)
unnormalized
=
jnp
.
exp
(
x_safe
-
x_max
)
denominator
=
jnp
.
sum
(
unnormalized
,
axis
,
where
=
where
,
keepdims
=
True
)
if
offset
is
not
None
:
# Add exp(offset - x_max) to denominator
denominator
=
denominator
+
jnp
.
exp
(
offset
-
x_max
)
result
=
unnormalized
/
denominator
if
where
is
not
None
:
result
=
jnp
.
where
(
where
,
result
,
0
)
return
result
def
scaled_softmax_fwd
(
logits
:
jnp
.
ndarray
,
scale_factor
:
float
)
->
jnp
.
ndarray
:
def
scaled_softmax_fwd
(
logits
:
jnp
.
ndarray
,
scale_factor
:
float
)
->
jnp
.
ndarray
:
...
...
transformer_engine/jax/csrc/extensions.h
View file @
970620a5
...
@@ -108,28 +108,28 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler);
...
@@ -108,28 +108,28 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
FusedAttnBackwardHandler
);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
FusedAttnBackwardHandler
);
NVTE_Fused_Attn_Backend
GetFusedAttnBackend
(
bool
is_training
,
DType
q_dtype
,
DType
kv_dtype
,
NVTE_Fused_Attn_Backend
GetFusedAttnBackend
(
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
bool
is_training
,
DType
q_dtype
,
DType
kv_dtype
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Mask_Type
mask_type
,
float
dropout_probability
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
size_t
q_num_heads
,
size_t
kv_num_heads
,
float
dropout_probability
,
size_t
q_attn_heads
,
size_t
kv_attn_heads
,
size_t
q_max_seqlen
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
qk_head_dim
,
size_t
v_head_dim
,
int64_t
window_size_left
,
size_t
qk_head_dim
,
size_t
v_head_dim
,
int64_t
window_size_right
);
int64_t
window_size_left
,
int64_t
window_size_right
);
pybind11
::
tuple
GetFusedAttnForwardWorkspaceSizes
(
pybind11
::
tuple
GetFusedAttnForwardWorkspaceSizes
(
size_t
input_batch
,
size_t
bias_batch
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
input_batch
,
size_t
bias_batch
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
attn_heads
,
size_t
num_gqa_groups
,
size_t
bias_heads
,
size_t
qk_head_dim
,
size_t
attn_heads
,
size_t
num_gqa_groups
,
size_t
bias_heads
,
size_t
qk_head_dim
,
size_t
v_head_dim
,
float
scaling_factor
,
float
dropout_probability
,
NVTE_Bias_Type
bias_type
,
size_t
v_head_dim
,
float
scaling_factor
,
float
dropout_probability
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_QKV_Layout
qkv_layout
,
DType
dtype
,
bool
is_training
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
NVTE_QKV_Layout
qkv_layout
,
size_t
max_segments_per_seq
,
int64_t
window_size_left
,
int64_t
window_size_right
);
DType
dtype
,
bool
is_training
,
size_t
max_segments_per_seq
,
int64_t
window_size_left
,
int64_t
window_size_right
);
pybind11
::
tuple
GetFusedAttnBackwardWorkspaceSizes
(
pybind11
::
tuple
GetFusedAttnBackwardWorkspaceSizes
(
size_t
input_batch
,
size_t
bias_batch
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
input_batch
,
size_t
bias_batch
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
attn_heads
,
size_t
num_gqa_groups
,
size_t
bias_heads
,
size_t
qk_head_dim
,
size_t
attn_heads
,
size_t
num_gqa_groups
,
size_t
bias_heads
,
size_t
qk_head_dim
,
size_t
v_head_dim
,
float
scaling_factor
,
float
dropout_probability
,
NVTE_Bias_Type
bias_type
,
size_t
v_head_dim
,
float
scaling_factor
,
float
dropout_probability
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_
QKV_Layout
qkv_layout
,
DType
dtype
,
bool
is_training
,
NVTE_Mask_Type
mask_type
,
NVTE_
Softmax_Type
softmax_type
,
NVTE_QKV_Layout
qkv_layout
,
bool
deterministic
,
size_t
max_segments_per_seq
,
int64_t
window_size_left
,
DType
dtype
,
bool
is_training
,
bool
deterministic
,
size_t
max_segments_per_seq
,
int64_t
window_size_right
);
int64_t
window_size_left
,
int64_t
window_size_right
);
// GEMM
// GEMM
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
GemmHandler
);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
GemmHandler
);
...
...
transformer_engine/jax/csrc/extensions/attention.cpp
View file @
970620a5
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment