Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
0a5016b1
Commit
0a5016b1
authored
Dec 03, 2025
by
wenjh
Browse files
Merge nv release_v2.9
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parents
063ef88d
70f53666
Changes
61
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1391 additions
and
767 deletions
+1391
-767
tests/jax/test_distributed_softmax.py
tests/jax/test_distributed_softmax.py
+6
-4
tests/jax/test_fused_attn.py
tests/jax/test_fused_attn.py
+3
-3
tests/jax/test_helper.py
tests/jax/test_helper.py
+81
-1
tests/pytorch/attention/run_attention_with_cp.py
tests/pytorch/attention/run_attention_with_cp.py
+11
-4
tests/pytorch/attention/test_attention.py
tests/pytorch/attention/test_attention.py
+70
-20
tests/pytorch/attention/test_attention_with_cp.py
tests/pytorch/attention/test_attention_with_cp.py
+3
-3
tests/pytorch/test_numerics.py
tests/pytorch/test_numerics.py
+1
-29
tests/pytorch/utils.py
tests/pytorch/utils.py
+3
-0
transformer_engine/common/CMakeLists.txt
transformer_engine/common/CMakeLists.txt
+282
-162
transformer_engine/common/__init__.py
transformer_engine/common/__init__.py
+96
-47
transformer_engine/common/fused_attn/fused_attn.cpp
transformer_engine/common/fused_attn/fused_attn.cpp
+40
-40
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu
...gine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu
+268
-142
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h
...ngine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h
+24
-22
transformer_engine/common/fused_attn/fused_attn_fp8.cu
transformer_engine/common/fused_attn/fused_attn_fp8.cu
+4
-2
transformer_engine/common/fused_attn/utils.h
transformer_engine/common/fused_attn/utils.h
+3
-2
transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu
...mmon/hadamard_transform/hadamard_transform_cast_fusion.cu
+14
-13
transformer_engine/common/include/transformer_engine/fused_attn.h
...mer_engine/common/include/transformer_engine/fused_attn.h
+42
-37
transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu
...mmon/transpose/quantize_transpose_vector_blockwise_fp4.cu
+39
-37
transformer_engine/common/util/nvfp4_transpose.cuh
transformer_engine/common/util/nvfp4_transpose.cuh
+142
-148
transformer_engine/common/util/ptx.cuh
transformer_engine/common/util/ptx.cuh
+259
-51
No files found.
tests/jax/test_distributed_softmax.py
View file @
0a5016b1
...
...
@@ -103,8 +103,10 @@ class TestDistributedSoftmax:
devices
=
np
.
asarray
(
jax
.
devices
()[:
device_count
]).
reshape
(
*
mesh_shape
)
mesh
=
Mesh
(
devices
,
mesh_axes
)
with
mesh
,
autocast
(
mesh_resource
=
mesh_resource
):
x_
=
jax
.
device_put
(
x
,
NamedSharding
(
mesh
,
x_pspec
))
mask_
=
jax
.
device_put
(
mask
,
NamedSharding
(
mesh
,
mask_pspec
))
x_named_sharding
=
NamedSharding
(
mesh
,
x_pspec
)
mask_named_sharding
=
NamedSharding
(
mesh
,
mask_pspec
)
x_
=
jax
.
device_put
(
x
,
x_named_sharding
)
mask_
=
jax
.
device_put
(
mask
,
mask_named_sharding
)
with
warnings
.
catch_warnings
(
record
=
True
)
as
warns
:
try
:
...
...
@@ -116,8 +118,8 @@ class TestDistributedSoftmax:
grad_args
=
(
0
,),
metric_fwd_dtype
=
dtype
,
metric_bwd_dtype
=
dtype
,
in_shardings
=
(
x_
pspec
,
mask_pspec
),
out_shardings
=
(
None
,
(
x_
pspec
,)
),
in_shardings
=
(
x_
named_sharding
,
mask_named_sharding
),
out_shardings
=
(
None
,
x_
named_sharding
),
)
except
AssertionError
as
err
:
# Softmax should still produce the correct numerical result with
...
...
tests/jax/test_fused_attn.py
View file @
0a5016b1
...
...
@@ -378,14 +378,14 @@ class FusedAttnRunner:
pytest
.
skip
(
"seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
)
# TODO(KshitijLakhani): Set the upper limit for skipping this test when cuDNN adds support
if
(
get_device_compute_capability
(
0
)
=
=
100
get_device_compute_capability
(
0
)
>
=
100
and
self
.
dropout_prob
==
0.1
and
self
.
attn_bias_type
is
not
AttnBiasType
.
NO_BIAS
):
pytest
.
skip
(
"For sm100, bprop kernel support for dropout + determinism (bias) is not supported"
"For sm100
+
, bprop kernel support for dropout + determinism (bias) is not supported"
)
# Test the MLA case where head dims for qk differ from head dims for v, only if the tensors
# are provided in BSHD_BSHD_BSHD or THD_THD_THD formats
...
...
tests/jax/test_helper.py
View file @
0a5016b1
...
...
@@ -3,11 +3,13 @@
# See LICENSE for license information.
import
unittest
from
functools
import
partial
import
flax
import
jax
import
jax.numpy
as
jnp
import
numpy
as
np
from
flax
import
linen
as
nn
from
utils
import
assert_allclose
from
transformer_engine.common.recipe
import
(
...
...
@@ -24,15 +26,51 @@ from transformer_engine.jax.quantize import (
ScalingMode
,
update_collections
,
TensorSource
,
QuantizerFactory
,
QuantizeLayout
,
)
from
transformer_engine.jax.quantize.helper
import
_format2dtypes
from
transformer_engine.jax.sharding
import
MeshResource
,
global_mesh_resource
from
transformer_engine.jax.flax.module
import
TransformerEngineBase
is_fp8_supported
,
reason
=
is_scaling_mode_supported
(
ScalingMode
.
DELAYED_TENSOR_SCALING
)
is_mxfp8_supported
,
mxfp8_reason
=
is_scaling_mode_supported
(
ScalingMode
.
MXFP8_1D_SCALING
)
is_nvfp4_supported
,
nvfp4_reason
=
is_scaling_mode_supported
(
ScalingMode
.
NVFP4_1D_SCALING
)
def
quantizer_check_vjp
(
outer_quantizer_set
,
assertion_func
,
x
):
"""Check that the quantizers in the quantizer set are as expected and reconstructed correctly from flattened pytree representations across VJP boundaries."""
# Define a function with a custom VJP (vector-Jacobian product)
@
partial
(
jax
.
custom_vjp
,
nondiff_argnums
=
(
1
,))
def
quantizer_check
(
inner_quantizer_set
,
assertion_func
,
x
):
return
quantizer_check_fwd
(
inner_quantizer_set
,
assertion_func
,
x
)
def
quantizer_check_fwd
(
inner_quantizer_set
,
assertion_func
,
x
):
assertion_func
(
inner_quantizer_set
.
x
,
TensorSource
.
X
)
assertion_func
(
inner_quantizer_set
.
kernel
,
TensorSource
.
KERNEL
)
assertion_func
(
inner_quantizer_set
.
dgrad
,
TensorSource
.
DGRAD
)
return
x
def
quantizer_check_bwd
(
ctx
,
g
):
return
(
g
,)
quantizer_check
.
defvjp
(
quantizer_check_fwd
,
quantizer_check_bwd
)
return
quantizer_check
(
outer_quantizer_set
,
assertion_func
,
x
)
class
TestModule
(
TransformerEngineBase
):
"""A simple module to test quantizer creation and reconstruction across VJP boundaries."""
# Signature: (quantizer: Quantizer, tensor_source: TensorSource) -> None
assertion_func
:
callable
@
nn
.
compact
def
__call__
(
self
,
x
):
quantizer_set
=
self
.
generate_quantizer_set
()
return
quantizer_check_vjp
(
quantizer_set
,
self
.
assertion_func
,
x
)
class
TestHelper
(
unittest
.
TestCase
):
@
unittest
.
skipIf
(
not
is_fp8_supported
,
reason
=
reason
)
...
...
@@ -89,12 +127,43 @@ class TestFP8Functions(unittest.TestCase):
for
tensor_source
in
TensorSource
:
target_scaling_mode
=
(
ScalingMode
.
NVFP4_2D_SCALING
if
tensor_source
==
TensorSource
.
KERNEL
if
(
not
test
.
disable_2d_quantization
)
and
tensor_source
==
TensorSource
.
KERNEL
else
ScalingMode
.
NVFP4_1D_SCALING
)
self
.
assertEqual
(
get_quantize_config
().
get_scaling_mode
(
tensor_source
),
target_scaling_mode
)
self
.
assertEqual
(
get_quantize_config
().
DISABLE_STOCHASTIC_ROUNDING
,
test
.
disable_stochastic_rounding
)
self
.
assertEqual
(
get_quantize_config
().
DISABLE_RHT
,
test
.
disable_rht
)
self
.
assertEqual
(
get_quantize_config
().
DISABLE_2D_QUANTIZATION
,
test
.
disable_2d_quantization
)
def
_compare_nvfp4_scaling_quantizers
(
self
,
test
):
"""Check that the quantizers created have the expected stochastic rounding state and the state is preserved across VJP boundaries."""
def
assertion_func
(
quantizer
,
tensor_source
):
if
test
.
disable_stochastic_rounding
or
tensor_source
!=
TensorSource
.
DGRAD
:
self
.
assertIsNone
(
quantizer
.
stochastic_rounding_rng_state
)
else
:
self
.
assertIsNotNone
(
quantizer
.
stochastic_rounding_rng_state
)
expected_rht
=
(
quantizer
.
scaling_mode
==
ScalingMode
.
NVFP4_1D_SCALING
and
quantizer
.
q_layout
in
{
QuantizeLayout
.
ROWWISE_COLWISE
,
QuantizeLayout
.
COLWISE
}
and
not
test
.
disable_rht
)
self
.
assertEqual
(
quantizer
.
use_rht
,
expected_rht
)
x
=
jnp
.
ones
((),
dtype
=
jnp
.
float32
)
test_module
=
TestModule
(
assertion_func
=
assertion_func
)
param_key
,
sr_key
=
jax
.
random
.
split
(
jax
.
random
.
PRNGKey
(
0
))
rngs
=
{
"params"
:
param_key
,
"sr_rng"
:
sr_key
}
variables
=
test_module
.
init
(
rngs
,
x
)
jax
.
jit
(
jax
.
value_and_grad
(
test_module
.
apply
),
static_argnums
=
(
2
,))(
variables
,
x
,
rngs
=
rngs
)
@
unittest
.
skipIf
(
not
is_fp8_supported
,
reason
=
reason
)
def
test_autocast_delayed_scaling
(
self
):
...
...
@@ -171,5 +240,16 @@ class TestFP8Functions(unittest.TestCase):
with
autocast
(
enabled
=
True
,
recipe
=
bs
,
mesh_resource
=
MeshResource
()):
self
.
assertTrue
(
get_quantize_config
().
is_fp8_enabled
())
self
.
_compare_nvfp4_scaling
(
bs
)
self
.
_compare_nvfp4_scaling_quantizers
(
bs
)
bs
=
NVFP4BlockScaling
(
disable_stochastic_rounding
=
True
,
disable_rht
=
True
,
disable_2d_quantization
=
True
,
)
with
autocast
(
enabled
=
True
,
recipe
=
bs
,
mesh_resource
=
MeshResource
()):
self
.
assertTrue
(
get_quantize_config
().
is_fp8_enabled
())
self
.
_compare_nvfp4_scaling
(
bs
)
self
.
_compare_nvfp4_scaling_quantizers
(
bs
)
self
.
_check_default_state
()
tests/pytorch/attention/run_attention_with_cp.py
View file @
0a5016b1
...
...
@@ -248,6 +248,7 @@ def run_dpa_with_cp(
attn_mask_type
=
config
.
attn_mask_type
,
window_size
=
config
.
window_size
,
softmax_type
=
config
.
softmax_type
,
return_max_logit
=
config
.
return_max_logit
,
).
cuda
()
if
config
.
softmax_type
!=
"vanilla"
:
core_attn
.
softmax_offset
.
requires_grad
=
True
...
...
@@ -308,6 +309,7 @@ def run_dpa_with_cp(
fp8_context
=
autocast
(
enabled
=
True
,
recipe
=
fp8_recipe
,
amax_reduction_group
=
cp_comm_group
)
else
:
fp8_context
=
nullcontext
()
max_logit
=
None
with
fp8_context
:
# q, k, v, out in FP8; dout in F16
out
=
core_attn
(
...
...
@@ -322,6 +324,8 @@ def run_dpa_with_cp(
cu_seqlens_kv_padded
=
cu_seqlens_kv_padded
,
fp8_output
=
fp8_mha
,
)
if
config
.
return_max_logit
:
out
,
max_logit
=
out
if
fp8_bwd
and
fp8_mha
:
dout_fp8
=
dout_quantizer
(
dout
)
out
.
backward
(
dout_fp8
)
...
...
@@ -400,6 +404,7 @@ def run_dpa_with_cp(
fp8_context
=
nullcontext
()
# run attention
max_logit_
=
None
with
fp8_context
:
# q, k, v, out in FP8; dout in F16
out_
=
core_attn
(
...
...
@@ -414,6 +419,8 @@ def run_dpa_with_cp(
cu_seqlens_kv_padded
=
cu_seqlens_kv_padded
,
fp8_output
=
fp8_mha
,
)
if
config
.
return_max_logit
:
out_
,
max_logit_
=
out_
if
fp8_bwd
and
fp8_mha
:
dout_fp8_
=
dout_quantizer
(
dout_
)
out_
.
backward
(
dout_fp8_
)
...
...
@@ -495,15 +502,15 @@ def run_dpa_with_cp(
)
atol
,
rtol
,
rmse_tol
=
get_tols
(
config
,
dtype
)
tensors_cp
=
[
out_
,
dq_
,
dk_
,
dv_
,
d_softmax_offset_
]
tensors_no_cp
=
[
out
,
dq
,
dk
,
dv
,
d_softmax_offset
]
names
=
[
"out"
,
"dq"
,
"dk"
,
"dv"
,
"d_softmax_offset"
]
tensors_cp
=
[
out_
,
dq_
,
dk_
,
dv_
,
d_softmax_offset_
,
max_logit_
]
tensors_no_cp
=
[
out
,
dq
,
dk
,
dv
,
d_softmax_offset
,
max_logit
]
names
=
[
"out"
,
"dq"
,
"dk"
,
"dv"
,
"d_softmax_offset"
,
"max_logit"
]
names_cp
=
[
x
+
"_cp"
for
x
in
names
]
names_no_cp
=
[
x
+
"_no_cp"
for
x
in
names
]
is_fp8
=
dtype
==
"fp8"
for
i
,
t
in
enumerate
(
tensors_no_cp
):
if
t
is
not
None
:
if
"softmax_offset"
not
in
names
[
i
]:
if
"softmax_offset"
not
in
names
[
i
]
and
"max_logit"
not
in
names
[
i
]:
if
qkv_format
==
"bshd"
:
compare_and_assert
(
t
[:,
0
],
...
...
tests/pytorch/attention/test_attention.py
View file @
0a5016b1
...
...
@@ -60,8 +60,16 @@ from utils import (
get_available_attention_backends
,
)
# Check if hardware supports FP8
# Check if hardware supports FP8
attention.
fp8_available
,
reason_for_no_fp8
=
is_fp8_available
(
return_reason
=
True
)
fp8_attn_available
,
reason_for_no_fp8_attn
=
fp8_available
,
reason_for_no_fp8
device_compute_capability
=
get_device_compute_capability
()
if
fp8_available
and
(
device_compute_capability
<
(
9
,
0
)
or
device_compute_capability
>=
(
12
,
0
)):
fp8_attn_available
=
False
reason_for_no_fp8_attn
=
(
"FP8 attention is not supported for compute capability ="
f
" sm
{
device_compute_capability
[
0
]
*
10
+
device_compute_capability
[
1
]
}
"
)
# Reset RNG seed and states
seed
=
1234
...
...
@@ -130,6 +138,11 @@ def test_dot_product_attention(
if
config
.
window_size
==
(
-
1
,
-
1
)
and
swa
:
config
.
window_size
=
[
2
,
2
]
config
.
window_size
=
check_set_window_size
(
config
.
attn_mask_type
,
config
.
window_size
)
qkv_format
=
qkv_layout
.
replace
(
"3"
,
""
).
replace
(
"2"
,
""
).
split
(
"_"
)[
0
]
if
qkv_format
==
"thd"
and
"padding"
not
in
config
.
attn_mask_type
:
config
.
attn_mask_type
=
(
"padding_"
+
config
.
attn_mask_type
if
config
.
attn_mask_type
!=
"no_mask"
else
"padding"
)
# Get backends
is_training
=
True
...
...
@@ -171,7 +184,7 @@ def test_dot_product_attention(
# UnfusedDotProductAttention backend
if
unfused_attn_supported
:
unfused_attn_fwd
,
unfused_attn_bwd
=
_run_dot_product_attention
(
unfused_attn_fwd
,
unfused_max_logit
,
unfused_attn_bwd
=
_run_dot_product_attention
(
dtype
,
config
,
"UnfusedDotProductAttention"
,
...
...
@@ -185,7 +198,7 @@ def test_dot_product_attention(
# FusedAttention backend
if
fused_attn_supported
:
if
len
(
fused_attn_backends
)
==
1
:
fused_attn_fwd
,
fused_attn_bwd
=
_run_dot_product_attention
(
fused_attn_fwd
,
fused_max_logit
,
fused_attn_bwd
=
_run_dot_product_attention
(
dtype
,
config
,
"FusedAttention"
,
...
...
@@ -197,7 +210,7 @@ def test_dot_product_attention(
)
if
len
(
fused_attn_backends
)
==
2
:
os
.
environ
[
"NVTE_FUSED_ATTN_BACKEND"
]
=
"0"
fused_attn_fwd
,
fused_attn_bwd
=
_run_dot_product_attention
(
fused_attn_fwd
,
_
,
fused_attn_bwd
=
_run_dot_product_attention
(
dtype
,
config
,
"FusedAttention"
,
...
...
@@ -208,7 +221,7 @@ def test_dot_product_attention(
is_training
,
)
os
.
environ
[
"NVTE_FUSED_ATTN_BACKEND"
]
=
"1"
fused_attn_fwd_1
,
fused_attn_bwd_1
=
_run_dot_product_attention
(
fused_attn_fwd_1
,
_
,
fused_attn_bwd_1
=
_run_dot_product_attention
(
dtype
,
config
,
"FusedAttention"
,
...
...
@@ -221,7 +234,7 @@ def test_dot_product_attention(
# FlashAttention backend
if
flash_attn_supported
:
flash_attn_fwd
,
flash_attn_bwd
=
_run_dot_product_attention
(
flash_attn_fwd
,
_
,
flash_attn_bwd
=
_run_dot_product_attention
(
dtype
,
config
,
"FlashAttention"
,
...
...
@@ -242,6 +255,8 @@ def test_dot_product_attention(
if
unfused_attn_supported
and
fused_attn_supported
:
logging
.
info
(
"[test_dot_product_attention]: unfused attn vs fused attn"
)
torch
.
testing
.
assert_close
(
fused_attn_fwd
,
unfused_attn_fwd
,
**
tols
)
if
config
.
return_max_logit
:
torch
.
testing
.
assert_close
(
fused_max_logit
,
unfused_max_logit
,
**
tols
)
for
i
,
_
in
enumerate
(
unfused_attn_bwd
):
torch
.
testing
.
assert_close
(
fused_attn_bwd
[
i
],
unfused_attn_bwd
[
i
],
**
tols
)
if
fused_attn_supported
and
flash_attn_supported
:
...
...
@@ -265,6 +280,33 @@ def test_dpa_checkpoint(dtype, model_configs, model):
test_dot_product_attention
(
dtype
,
model_configs
,
model
,
True
,
True
,
None
,
False
,
False
)
model_configs_max_logit
=
{
# test: ModelConfig(b, sq, hq, dqk)
"max_logit_1"
:
ModelConfig
(
1
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
),
"max_logit_2"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
attn_mask_type
=
"causal"
),
"max_logit_3"
:
ModelConfig
(
2
,
1
,
16
,
128
,
max_seqlen_kv
=
2048
,
attn_mask_type
=
"padding_causal"
),
"max_logit_4"
:
ModelConfig
(
8
,
128
,
16
,
192
,
max_seqlen_kv
=
2048
,
attn_bias_type
=
"post_scale_bias"
),
"max_logit_5"
:
ModelConfig
(
8
,
128
,
16
,
512
,
max_seqlen_kv
=
2048
,
attn_mask_type
=
"causal"
,
window_size
=
(
20
,
0
)
),
"max_logit_6"
:
ModelConfig
(
8
,
1
,
16
,
1024
,
max_seqlen_kv
=
2048
),
}
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
8
,
9
,
1
),
reason
=
"cuDNN 8.9.1+ is required."
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"model_configs"
,
[
model_configs_max_logit
])
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_max_logit
.
keys
())
@
pytest
.
mark
.
parametrize
(
"qkv_layout"
,
[
"sbhd_sbhd_sbhd"
,
"thd_thd_thd"
])
def
test_dpa_max_logit
(
dtype
,
model_configs
,
model
,
qkv_layout
):
"""Test DotProductAttention module with checkpointing"""
config
=
model_configs
[
model
]
config
.
return_max_logit
=
True
test_dot_product_attention
(
dtype
,
model_configs
,
model
,
False
,
True
,
qkv_layout
,
False
,
False
)
model_configs_softmax
=
{
# test: ModelConfig(b, sq, hq, dqk)
"softmax_1_0"
:
ModelConfig
(
2
,
2048
,
64
,
64
,
num_gqa_groups
=
8
),
...
...
@@ -962,6 +1004,8 @@ def _run_dot_product_attention(
layout
=
layout
.
replace
(
"d"
,
"dqk"
)
tensor_shape
=
[
dim_to_num
[
j
]
for
j
in
layout
.
split
(
"_"
)]
tensor
=
0.1
*
torch
.
randn
(
tensor_shape
,
dtype
=
dtype
,
device
=
"cuda"
)
# tensor: with padding tokens
# tensor_orig: without padding tokens
tensor_orig
=
tensor
if
qkv_format
==
"thd"
and
pad_between_seqs
:
tensor_orig
=
torch
.
Tensor
([]).
to
(
device
=
"cuda"
,
dtype
=
dtype
)
...
...
@@ -1071,6 +1115,7 @@ def _run_dot_product_attention(
layer_number
=
1
,
attention_type
=
config
.
attn_type
,
softmax_type
=
config
.
softmax_type
,
return_max_logit
=
config
.
return_max_logit
,
).
to
(
dtype
=
dtype
,
device
=
"cuda"
)
if
not
is_training
:
block
=
block
.
eval
()
...
...
@@ -1108,16 +1153,21 @@ def _run_dot_product_attention(
alibi_slopes
=
alibi_slopes
,
fast_zero_fill
=
True
,
)
max_logit
=
None
if
config
.
return_max_logit
:
out
,
max_logit
=
out
if
is_training
:
out
.
backward
(
d_out
)
d_softmax_offset
=
None
if
is_training
and
config
.
softmax_type
!=
"vanilla"
:
d_softmax_offset
=
block
.
softmax_offset
.
grad
if
backend
in
[
"FlashAttention"
,
"UnfusedDotProductAttention"
]:
if
is_training
:
return
out
,
(
q
.
grad
,
k
.
grad
,
v
.
grad
,
d_softmax_offset
)
return
out
,
max_logit
,
(
q
.
grad
,
k
.
grad
,
v
.
grad
,
d_softmax_offset
)
else
:
return
out
,
(
None
,
None
,
None
,
d_softmax_offset
)
return
out
,
max_logit
,
(
None
,
None
,
None
,
d_softmax_offset
)
if
backend
==
"FusedAttention"
:
if
qkv_format
==
"thd"
and
pad_between_seqs
:
out_orig
=
torch
.
Tensor
([]).
to
(
device
=
"cuda"
,
dtype
=
dtype
)
...
...
@@ -1146,14 +1196,18 @@ def _run_dot_product_attention(
[
v_grad_orig
,
v
.
grad
[
valid_range_kv
[
0
]
:
valid_range_kv
[
1
]]],
dim
=
0
)
if
is_training
:
return
out_orig
,
(
q_grad_orig
,
k_grad_orig
,
v_grad_orig
,
d_softmax_offset
)
return
(
out_orig
,
max_logit
,
(
q_grad_orig
,
k_grad_orig
,
v_grad_orig
,
d_softmax_offset
),
)
else
:
return
out_orig
,
(
None
,
None
,
None
,
d_softmax_offset
)
return
out_orig
,
max_logit
,
(
None
,
None
,
None
,
d_softmax_offset
)
else
:
if
is_training
:
return
out
,
(
q
.
grad
,
k
.
grad
,
v
.
grad
,
d_softmax_offset
)
return
out
,
max_logit
,
(
q
.
grad
,
k
.
grad
,
v
.
grad
,
d_softmax_offset
)
else
:
return
out
,
(
None
,
None
,
None
,
d_softmax_offset
)
return
out
,
max_logit
,
(
None
,
None
,
None
,
d_softmax_offset
)
model_configs_te_layer
=
{
...
...
@@ -1527,8 +1581,7 @@ model_configs_fp8_extra_state = {
}
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
skipif
(
get_device_compute_capability
()
<
(
9
,
0
),
reason
=
"FP8 tests require Hopper."
)
@
pytest
.
mark
.
skipif
(
not
fp8_attn_available
,
reason
=
reason_for_no_fp8_attn
)
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
9
,
3
,
0
),
reason
=
"cuDNN 9.3.0+ is required."
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"large"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
...
...
@@ -1690,8 +1743,7 @@ qkv_format_fp8_vs_f16 = ["bshd", "sbhd"]
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
9
,
2
,
1
),
reason
=
"cuDNN 9.2.1+ is required."
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
skipif
(
get_device_compute_capability
()
<
(
9
,
0
),
reason
=
"FP8 tests require Hopper+."
)
@
pytest
.
mark
.
skipif
(
not
fp8_attn_available
,
reason
=
reason_for_no_fp8_attn
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types_fp8_vs_f16
)
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_fp8_vs_f16
.
keys
())
@
pytest
.
mark
.
parametrize
(
"qkv_format"
,
qkv_format_fp8_vs_f16
)
...
...
@@ -1927,8 +1979,7 @@ def _run_mha_fp8_vs_f16(
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
9
,
2
,
1
),
reason
=
"cuDNN 9.2.1+ is required."
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
skipif
(
get_device_compute_capability
()
<
(
9
,
0
),
reason
=
"FP8 tests require Hopper+."
)
@
pytest
.
mark
.
skipif
(
not
fp8_attn_available
,
reason
=
reason_for_no_fp8_attn
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types_fp8_vs_f16
)
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_fp8_vs_f16
.
keys
())
@
pytest
.
mark
.
parametrize
(
"qkv_layout"
,
qkv_layout_fp8_vs_f16
)
...
...
@@ -2256,8 +2307,7 @@ models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"]
),
reason
=
f
"""cuDNN
{
"8.9.3"
if
cudnn_frontend_version
==
0
else
"9.2.1"
}
+ is required."""
,
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
skipif
(
get_device_compute_capability
()
<
(
9
,
0
),
reason
=
"FP8 tests require Hopper+."
)
@
pytest
.
mark
.
skipif
(
not
fp8_attn_available
,
reason
=
reason_for_no_fp8_attn
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types_fp8
)
@
pytest
.
mark
.
parametrize
(
"model"
,
models_v1
if
cudnn_frontend_version
==
1
else
models_v0
)
def
test_custom_mha_fp8_vs_f16
(
dtype
,
model
):
...
...
tests/pytorch/attention/test_attention_with_cp.py
View file @
0a5016b1
...
...
@@ -138,8 +138,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
model_configs_fused_attn
=
{
# test: ModelConfig(b, sq, hq, dqk)
"cp_1_0"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
attn_mask_type
=
"causal"
),
# MHA
"cp_1_1"
:
ModelConfig
(
2
,
4096
,
12
,
128
),
# MHA
"cp_1_0"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
attn_mask_type
=
"causal"
,
return_max_logit
=
True
),
# MHA
"cp_1_1"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
return_max_logit
=
True
),
# MHA
"cp_1_2"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"post_scale_bias"
),
# MHA
...
...
@@ -184,7 +184,7 @@ dtypes = ["bf16", "fp16", "fp8"]
qkv_formats
=
[
"bshd"
,
"sbhd"
,
"thd"
]
cp_comm_types
=
[
"p2p"
,
"all_gather"
,
"a2a"
,
"a2a+p2p"
]
if
test_essential
:
configs
=
[
"cp_1_0"
,
"cp_2_0"
,
"cp_2_2"
,
"cp_3_2"
,
"cp_4_2"
]
configs
=
[
"cp_1_0"
,
"cp_1_1"
,
"cp_2_0"
,
"cp_2_2"
,
"cp_3_2"
,
"cp_4_2"
]
model_configs_fused_attn
=
{
k
:
model_configs_fused_attn
[
k
]
for
k
in
configs
}
dtypes
=
[
"bf16"
,
"fp8"
]
qkv_formats
=
[
"sbhd"
,
"thd"
]
...
...
tests/pytorch/test_numerics.py
View file @
0a5016b1
...
...
@@ -45,11 +45,10 @@ from transformer_engine.pytorch import (
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch
import
checkpoint
as
te_checkpoint
from
transformer_engine.pytorch.cpp_extensions
import
general_gemm
,
general_grouped_gemm
from
transformer_engine.pytorch.cpp_extensions.fused_attn
import
FusedAttnBackend
from
transformer_engine.pytorch.module.base
import
get_multi_stream_cublas_workspace
,
get_workspace
from
transformer_engine.common
import
recipe
import
transformer_engine_torch
as
tex
from
utils
import
ModelConfig
,
reset_rng_states
,
get_available_attention_backends
from
utils
import
ModelConfig
,
reset_rng_states
# Only run FP8 tests on supported devices.
...
...
@@ -135,23 +134,6 @@ if torch.cuda.get_device_capability() == (9, 0):
use_cutlass_grouped_gemm
.
append
(
True
)
def
is_fused_attn_available
(
config
:
ModelConfig
,
dtype
:
torch
.
dtype
,
qkv_layout
=
"bshd_bshd_bshd"
,
is_training
=
True
,
deterministic
=
False
,
):
_
,
_
,
fused_attn_backends
=
get_available_attention_backends
(
config
,
qkv_dtype
=
dtype
,
qkv_layout
=
qkv_layout
,
is_training
=
is_training
,
deterministic
=
deterministic
,
)
return
FusedAttnBackend
[
"F16_arbitrary_seqlen"
]
in
fused_attn_backends
def
get_causal_attn_mask
(
sq
:
int
)
->
torch
.
Tensor
:
return
torch
.
triu
(
torch
.
ones
(
sq
,
sq
,
device
=
"cuda"
),
diagonal
=
1
).
bool
()
...
...
@@ -872,8 +854,6 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"126m"
])
def
test_gpt_checkpointing
(
dtype
,
bs
,
model
):
config
=
model_configs
[
model
]
if
not
is_fused_attn_available
(
config
,
dtype
,
deterministic
=
True
):
pytest
.
skip
(
"No attention backend available."
)
outputs
=
_test_e2e_checkpointing
(
bs
,
dtype
,
config
,
checkpoint
=
False
)
outputs_checkpoint
=
_test_e2e_checkpointing
(
bs
,
dtype
,
config
,
checkpoint
=
True
)
...
...
@@ -920,10 +900,6 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):
@
pytest
.
mark
.
parametrize
(
"parallel_attention_mlp"
,
all_boolean
)
def
test_gpt_accuracy
(
dtype
,
bs
,
model
,
parallel_attention_mlp
):
config
=
model_configs
[
model
]
if
not
is_fused_attn_available
(
config
,
dtype
,
qkv_layout
=
"sb3hd"
,
is_training
=
True
,
deterministic
=
True
):
pytest
.
skip
(
"No attention backend available."
)
te_gpt
=
TransformerLayer
(
hidden_size
=
config
.
hidden_size
,
...
...
@@ -1035,10 +1011,6 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
@
pytest
.
mark
.
parametrize
(
"mask_type"
,
mask_types
)
def
test_mha_accuracy
(
dtype
,
bs
,
model
,
mask_type
):
config
=
model_configs
[
model
]
if
not
is_fused_attn_available
(
config
,
dtype
,
qkv_layout
=
"sb3hd"
,
is_training
=
True
,
deterministic
=
True
):
pytest
.
skip
(
"No attention backend available."
)
te_mha
=
MultiheadAttention
(
config
.
hidden_size
,
...
...
tests/pytorch/utils.py
View file @
0a5016b1
...
...
@@ -205,6 +205,7 @@ class ModelConfig:
window_size
:
Tuple
[
int
,
int
]
=
(
-
1
,
-
1
),
context_parallel
:
bool
=
False
,
cp_comm_type
:
str
=
"p2p"
,
return_max_logit
=
False
,
total_requests
:
int
=
None
,
max_ctx_len
:
int
=
None
,
num_layers
:
int
=
1
,
...
...
@@ -233,6 +234,7 @@ class ModelConfig:
self
.
window_size
=
check_set_window_size
(
self
.
attn_mask_type
,
window_size
)
self
.
context_parallel
=
context_parallel
self
.
cp_comm_type
=
cp_comm_type
self
.
return_max_logit
=
return_max_logit
self
.
total_requests
=
total_requests
self
.
max_ctx_len
=
max_ctx_len
self
.
num_layers
=
num_layers
...
...
@@ -318,6 +320,7 @@ def get_available_attention_backends(
is_training
=
is_training
,
inference_params
=
inference_params
,
softmax_type
=
config
.
softmax_type
,
return_max_logit
=
config
.
return_max_logit
,
)
(
use_flash_attention
,
...
...
transformer_engine/common/CMakeLists.txt
View file @
0a5016b1
...
...
@@ -29,15 +29,6 @@ endif()
# Language options
if
(
USE_CUDA
)
if
(
NOT DEFINED CMAKE_CUDA_ARCHITECTURES
)
if
(
CUDAToolkit_VERSION VERSION_GREATER_EQUAL 13.0
)
set
(
CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120
)
elseif
(
CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8
)
set
(
CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120
)
else
()
set
(
CMAKE_CUDA_ARCHITECTURES 70 80 89 90
)
endif
()
endif
()
set
(
CMAKE_CXX_STANDARD 17
)
set
(
CMAKE_CUDA_STANDARD 17
)
set
(
CMAKE_CUDA_STANDARD_REQUIRED ON
)
...
...
@@ -54,8 +45,62 @@ if(USE_CUDA)
# CUDA Toolkit
find_package
(
CUDAToolkit REQUIRED
)
if
(
CUDAToolkit_VERSION VERSION_LESS 12.0
)
message
(
FATAL_ERROR
"CUDA 12.0+ is required, but found CUDA
${
CUDAToolkit_VERSION
}
"
)
if
(
CUDAToolkit_VERSION VERSION_LESS 12.1
)
message
(
FATAL_ERROR
"CUDA 12.1+ is required, but found CUDA
${
CUDAToolkit_VERSION
}
"
)
endif
()
# Process GPU architectures
if
(
NOT DEFINED CMAKE_CUDA_ARCHITECTURES
)
if
(
CUDAToolkit_VERSION VERSION_GREATER_EQUAL 13.0
)
set
(
CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120
)
elseif
(
CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8
)
set
(
CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120
)
else
()
set
(
CMAKE_CUDA_ARCHITECTURES 70 80 89 90
)
endif
()
endif
()
# Process CMAKE_CUDA_ARCHITECTURES to separate generic and specific architectures
set
(
NVTE_GENERIC_ARCHS
)
set
(
NVTE_SPECIFIC_ARCHS
)
# Check for architecture 100
list
(
FIND CMAKE_CUDA_ARCHITECTURES
"100"
arch_100_index
)
if
(
NOT arch_100_index EQUAL -1
)
list
(
REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES
"100"
)
list
(
APPEND NVTE_GENERIC_ARCHS
"100"
)
list
(
APPEND NVTE_SPECIFIC_ARCHS
"100a"
)
if
(
CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.9
)
list
(
APPEND NVTE_SPECIFIC_ARCHS
"103a"
)
endif
()
endif
()
# Check for architecture 101 (if we see this we are in toolkit <= 12.9)
list
(
FIND CMAKE_CUDA_ARCHITECTURES
"101"
arch_101_index
)
if
(
NOT arch_101_index EQUAL -1
)
list
(
REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES
"101"
)
list
(
APPEND NVTE_GENERIC_ARCHS
"101"
)
list
(
APPEND NVTE_SPECIFIC_ARCHS
"101a"
)
endif
()
# Check for architecture 110 (if we see this we are in toolkit >= 13.0)
list
(
FIND CMAKE_CUDA_ARCHITECTURES
"110"
arch_110_index
)
if
(
NOT arch_110_index EQUAL -1
)
list
(
REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES
"110"
)
list
(
APPEND NVTE_GENERIC_ARCHS
"110"
)
list
(
APPEND NVTE_SPECIFIC_ARCHS
"110f"
)
endif
()
# Check for architecture 120
list
(
FIND CMAKE_CUDA_ARCHITECTURES
"120"
arch_120_index
)
if
(
NOT arch_120_index EQUAL -1
)
list
(
REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES
"120"
)
list
(
APPEND NVTE_GENERIC_ARCHS
"120"
)
if
(
CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.9
)
list
(
APPEND NVTE_SPECIFIC_ARCHS
"120f"
)
else
()
list
(
APPEND NVTE_SPECIFIC_ARCHS
"120a"
)
endif
()
endif
()
# cuDNN frontend API
...
...
@@ -135,11 +180,29 @@ endif()
# Configure Transformer Engine library
include_directories
(
${
PROJECT_SOURCE_DIR
}
/..
)
set
(
transformer_engine_SOURCES
)
set
(
transformer_engine_cpp_sources
)
set
(
transformer_engine_cuda_sources
)
set
(
transformer_engine_cuda_arch_specific_sources
)
if
(
USE_CUDA
)
list
(
APPEND transformer_engine_
SOURCES
list
(
APPEND transformer_engine_
cpp_sources
cudnn_utils.cpp
transformer_engine.cpp
fused_attn/fused_attn.cpp
gemm/config.cpp
normalization/common.cpp
normalization/layernorm/ln_api.cpp
normalization/rmsnorm/rmsnorm_api.cpp
util/cuda_driver.cpp
util/cuda_nvml.cpp
util/cuda_runtime.cpp
util/multi_stream.cpp
util/rtc.cpp
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/comm_gemm_overlap.cpp
)
list
(
APPEND transformer_engine_cuda_sources
common.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
...
...
@@ -151,40 +214,23 @@ if(USE_CUDA)
transpose/cast_transpose_fusion.cu
transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu
transpose/swap_first_dims.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu
activation/gelu.cu
dropout/dropout.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
fused_attn/fused_attn_f16_max512_seqlen.cu
fused_attn/fused_attn_f16_arbitrary_seqlen.cu
activation/relu.cu
activation/swiglu.cu
fused_attn/fused_attn_fp8.cu
fused_attn/fused_attn.cpp
fused_attn/utils.cu
gemm/config.cpp
gemm/cublaslt_gemm.cu
gemm/cutlass_grouped_gemm.cu
normalization/common.cpp
normalization/layernorm/ln_api.cpp
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_api.cpp
normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
permutation/permutation.cu
util/cast.cu
util/padding.cu
util/cuda_driver.cpp
util/cuda_nvml.cpp
util/cuda_runtime.cpp
util/multi_stream.cpp
util/rtc.cpp
swizzle/swizzle.cu
swizzle/swizzle_block_scaling.cu
fused_softmax/scaled_masked_softmax.cu
...
...
@@ -198,25 +244,84 @@ if(USE_CUDA)
recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu
recipe/nvfp4.cu
comm_gemm_overlap/userbuffers/userbuffers.cu
)
list
(
APPEND transformer_engine_cuda_arch_specific_sources
gemm/cutlass_grouped_gemm.cu
util/cast.cu
activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu
hadamard_transform/hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
comm_gemm_overlap/comm_gemm_overlap.cpp
)
hadamard_transform/hadamard_transform_cast_fusion.cu
)
# Compiling the files with the worst compilation time first to hopefully overlap
# better with the faster-compiling cpp files
list
(
APPEND transformer_engine_SOURCES
${
transformer_engine_cuda_arch_specific_sources
}
${
transformer_engine_cuda_sources
}
${
transformer_engine_cpp_sources
}
)
# Set compile options for CUDA sources with generic architectures
foreach
(
cuda_source IN LISTS transformer_engine_cuda_sources
)
set
(
arch_compile_options
)
foreach
(
arch IN LISTS NVTE_GENERIC_ARCHS
)
list
(
APPEND arch_compile_options
"--generate-code=arch=compute_
${
arch
}
,code=sm_
${
arch
}
"
)
endforeach
()
if
(
arch_compile_options
)
set_property
(
SOURCE
${
cuda_source
}
APPEND
PROPERTY
COMPILE_OPTIONS
${
arch_compile_options
}
)
endif
()
endforeach
()
# Set compile options for CUDA sources with specific architectures
foreach
(
cuda_source IN LISTS transformer_engine_cuda_arch_specific_sources
)
set
(
arch_compile_options
)
foreach
(
arch IN LISTS NVTE_SPECIFIC_ARCHS
)
list
(
APPEND arch_compile_options
"--generate-code=arch=compute_
${
arch
}
,code=sm_
${
arch
}
"
)
endforeach
()
if
(
arch_compile_options
)
set_property
(
SOURCE
${
cuda_source
}
APPEND
PROPERTY
COMPILE_OPTIONS
${
arch_compile_options
}
)
endif
()
endforeach
()
if
(
NVTE_WITH_CUBLASMP
)
list
(
APPEND transformer_engine_SOURCES
comm_gemm/comm_gemm.cpp
)
endif
()
add_library
(
transformer_engine SHARED
${
transformer_engine_SOURCES
}
)
else
()
list
(
APPEND transformer_engine_
SOURCES
list
(
APPEND transformer_engine_
cpp_sources
cudnn_utils.cpp
transformer_engine.cpp
gemm/config.cpp
normalization/common.cpp
normalization/layernorm/ln_api.cpp
normalization/rmsnorm/rmsnorm_api.cpp
util/cuda_driver.cpp
util/cuda_nvml.cpp
util/cuda_runtime.cpp
util/multi_stream.cpp
util/rtc.cpp
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/comm_gemm_overlap.cpp
)
list
(
APPEND transformer_engine_cuda_sources
common.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu
...
...
@@ -227,31 +332,23 @@ else()
transpose/cast_transpose_fusion.cu
transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu
transpose/swap_first_dims.cu
activation/gelu.cu
dropout/dropout.cu
activation/relu.cu
activation/swiglu.cu
gemm/config.cpp
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
fused_attn/fused_attn_f16_max512_seqlen.cu
fused_attn/fused_attn_f16_arbitrary_seqlen.cu
fused_attn/fused_attn_fp8.cu
fused_attn/utils.cu
gemm/cublaslt_gemm.cu
gemm/hipblas_gemm.cu
normalization/common.cpp
normalization/layernorm/ln_api.cpp
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_api.cpp
normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
permutation/permutation.cu
util/cast.cu
util/padding.cu
util/cuda_driver.cpp
util/cuda_nvml.cpp
util/cuda_runtime.cpp
util/multi_stream.cpp
util/rtc.cpp
swizzle/swizzle.cu
swizzle/swizzle_block_scaling.cu
fused_softmax/scaled_masked_softmax.cu
...
...
@@ -264,10 +361,25 @@ else()
recipe/current_scaling.cu
recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
comm_gemm_overlap/comm_gemm_overlap.cpp
)
recipe/nvfp4.cu
comm_gemm_overlap/userbuffers/userbuffers.cu
)
list
(
APPEND transformer_engine_cuda_arch_specific_sources
gemm/cutlass_grouped_gemm.cu
util/cast.cu
activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu
hadamard_transform/hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu
)
# Compiling the files with the worst compilation time first to hopefully overlap
# better with the faster-compiling cpp files
list
(
APPEND transformer_engine_SOURCES
${
transformer_engine_cuda_arch_specific_sources
}
${
transformer_engine_cuda_sources
}
${
transformer_engine_cpp_sources
}
)
if
(
NVTE_WITH_CUBLASMP
)
list
(
APPEND transformer_engine_SOURCES
comm_gemm/comm_gemm.cpp
)
...
...
@@ -316,8 +428,10 @@ if (USE_CUDA)
CUDA::cublas
CUDA::cudart
CUDNN::cudnn_all
)
target_include_directories
(
transformer_engine PRIVATE
${
CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES
}
)
target_include_directories
(
transformer_engine PRIVATE
${
MATHDX_INCLUDE_DIR
}
)
target_include_directories
(
transformer_engine SYSTEM PRIVATE
${
CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES
}
/cccl
)
target_include_directories
(
transformer_engine PRIVATE
"
${
CUDNN_FRONTEND_INCLUDE_DIR
}
"
)
...
...
@@ -436,7 +550,8 @@ target_include_directories(transformer_engine PRIVATE
"
${
CMAKE_CURRENT_BINARY_DIR
}
/string_headers"
)
# Compiler options
set_source_files_properties
(
fused_softmax/scaled_masked_softmax.cu
set
(
nvte_sources_with_fast_math
)
list
(
APPEND nvte_sources_with_fast_math fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
multi_tensor/adam.cu
...
...
@@ -446,20 +561,25 @@ set_source_files_properties(fused_softmax/scaled_masked_softmax.cu
multi_tensor/sgd.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
PROPERTIES
COMPILE_OPTIONS
"--use_fast_math"
)
fused_attn/kv_cache.cu
)
option
(
NVTE_BUILD_ACTIVATION_WITH_FAST_MATH
"Compile activation kernels with --use_fast_math option"
OFF
)
if
(
NVTE_BUILD_ACTIVATION_WITH_FAST_MATH
)
set_source_files_properties
(
activation/gelu.cu
list
(
APPEND nvte_sources_with_fast_math
activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
util/cast.cu
PROPERTIES
COMPILE_OPTIONS
"--use_fast_math"
)
util/cast.cu
)
endif
()
if
(
USE_CUDA
)
foreach
(
cuda_source IN LISTS nvte_sources_with_fast_math
)
set_property
(
SOURCE
${
cuda_source
}
APPEND
PROPERTY
COMPILE_OPTIONS
"--use_fast_math"
)
endforeach
()
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
--expt-relaxed-constexpr"
)
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
-O3"
)
else
()
...
...
transformer_engine/common/__init__.py
View file @
0a5016b1
...
...
@@ -8,22 +8,18 @@ import ctypes
import
functools
import
glob
import
importlib
from
importlib.metadata
import
version
,
metadata
,
PackageNotFoundError
import
logging
from
importlib.metadata
import
version
,
distribution
,
PackageNotFoundError
import
os
from
pathlib
import
Path
import
platform
import
subprocess
import
sys
import
sysconfig
from
typing
import
Optional
from
typing
import
Optional
,
Tuple
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
_logger
=
logging
.
getLogger
(
__name__
)
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_is_
pip_
package_installed
(
package
)
->
bool
:
def
_is_package_installed
(
package
)
->
bool
:
"""Check if the given package is installed via pip."""
# This is needed because we only want to return true
...
...
@@ -31,12 +27,34 @@ def _is_pip_package_installed(package) -> bool:
# if it's importable in the current directory due to
# the presence of the shared library module.
try
:
metadata
(
package
)
distribution
(
package
)
except
PackageNotFoundError
:
return
False
return
True
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_is_package_installed_from_wheel
(
package
)
->
bool
:
"""Check if the given package is installed via PyPI."""
if
not
_is_package_installed
(
package
):
return
False
te_dist
=
distribution
(
package
)
te_wheel_file
=
""
for
file_path
in
te_dist
.
files
:
if
file_path
.
name
==
"WHEEL"
:
te_wheel_file
=
te_dist
.
locate_file
(
""
)
/
file_path
if
not
te_wheel_file
:
return
False
with
te_wheel_file
.
open
(
"r"
)
as
f
:
for
line
in
f
:
if
line
.
startswith
(
"Root-Is-Purelib:"
):
return
line
.
strip
().
split
(
":"
)[
1
].
strip
().
lower
()
==
"true"
return
False
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_find_shared_object_in_te_dir
(
te_path
:
Path
,
prefix
:
str
)
->
Optional
[
Path
]:
"""
...
...
@@ -112,6 +130,19 @@ def _get_shared_object_file(library: str) -> Path:
)
def
get_te_core_package_info
()
->
Tuple
[
bool
,
str
,
str
]:
"""
Check if Tranformer Engine core package is installed.
Returns the module name and version if found.
"""
te_core_packages
=
(
"transformer-engine-cu12"
,
"transformer-engine-cu13"
)
for
package
in
te_core_packages
:
if
_is_package_installed
(
package
):
return
True
,
package
,
version
(
package
)
return
False
,
""
,
""
@
functools
.
lru_cache
(
maxsize
=
None
)
def
load_framework_extension
(
framework
:
str
)
->
None
:
"""
...
...
@@ -130,37 +161,28 @@ def load_framework_extension(framework: str) -> None:
if
framework
==
"torch"
:
extra_dep_name
=
"pytorch"
# Find the TE packages. The core and framework packages can only be installed via PyPI.
# For the `transformer-engine` package, we need to check explicity.
te_core_installed
,
te_core_package_name
,
te_core_version
=
get_te_core_package_info
()
te_framework_installed
=
_is_package_installed
(
module_name
)
te_installed
=
_is_package_installed
(
"transformer_engine"
)
te_installed_via_pypi
=
_is_package_installed_from_wheel
(
"transformer_engine"
)
assert
te_installed
,
"Could not find `transformer_engine`."
# If the framework extension pip package is installed, it means that TE is installed via
# PyPI. For this case we need to make sure that the metapackage, the core lib, and framework
# extension are all installed via PyPI and have matching version.
if
_is_pip_package_installed
(
module_name
):
assert
_is_pip_package_installed
(
"transformer_engine"
),
"Could not find `transformer-engine`."
assert
_is_pip_package_installed
(
"transformer_engine_cu12"
),
"Could not find `transformer-engine-cu12`."
assert
(
version
(
module_name
)
==
version
(
"transformer-engine"
)
==
version
(
"transformer-engine-cu12"
)
),
(
"TransformerEngine package version mismatch. Found"
f
"
{
module_name
}
v
{
version
(
module_name
)
}
, transformer-engine"
f
" v
{
version
(
'transformer-engine'
)
}
, and transformer-engine-cu12"
f
" v
{
version
(
'transformer-engine-cu12'
)
}
. Install transformer-engine using "
f
"'pip3 install transformer-engine[
{
extra_dep_name
}
]==VERSION'"
)
# extension are all installed via PyPI and have matching versions.
if
te_framework_installed
:
assert
te_installed_via_pypi
,
"Could not find `transformer-engine` PyPI package."
assert
te_core_installed
,
"Could not find TE core package `transformer-engine-cu*`."
# If the core package is installed via PyPI, log if
# the framework extension is not found from PyPI.
# Note: Should we error? This is a rare use case.
if
_is_pip_package_installed
(
"transformer-engine-cu12"
):
if
not
_is_pip_package_installed
(
module_name
):
_logger
.
info
(
"Could not find package %s. Install transformer-engine using "
f
"'pip3 install transformer-engine[
{
extra_dep_name
}
]==VERSION'"
,
module_name
,
assert
version
(
module_name
)
==
version
(
"transformer-engine"
)
==
te_core_version
,
(
"Transformer Engine package version mismatch. Found"
f
"
{
module_name
}
v
{
version
(
module_name
)
}
, transformer-engine"
f
" v
{
version
(
'transformer-engine'
)
}
, and
{
te_core_package_name
}
"
f
" v
{
te_core_version
}
. Install transformer-engine using "
f
"'pip3 install --no-build-isolation transformer-engine[
{
extra_dep_name
}
]==VERSION'"
)
# After all checks are completed, load the shared object file.
...
...
@@ -170,6 +192,35 @@ def load_framework_extension(framework: str) -> None:
spec
.
loader
.
exec_module
(
solib
)
def
sanity_checks_for_pypi_installation
()
->
None
:
"""Ensure that package is installed correctly if using PyPI."""
te_core_installed
,
te_core_package_name
,
te_core_version
=
get_te_core_package_info
()
te_installed
=
_is_package_installed
(
"transformer_engine"
)
te_installed_via_pypi
=
_is_package_installed_from_wheel
(
"transformer_engine"
)
assert
te_installed
,
"Could not find `transformer-engine`."
# If the core package is installed via PyPI.
if
te_core_installed
:
assert
te_installed_via_pypi
,
"Could not find `transformer-engine` PyPI package."
assert
version
(
"transformer-engine"
)
==
te_core_version
,
(
"Transformer Engine package version mismatch. Found "
f
"transformer-engine v
{
version
(
'transformer-engine'
)
}
"
f
"and
{
te_core_package_name
}
v
{
te_core_version
}
."
)
# Only the metapackage is found, invalid usecase.
elif
te_installed_via_pypi
:
raise
RuntimeError
(
"Found empty `transformer-engine` meta package installed. "
"Install `transformer-engine` with framework extensions via"
"'pip3 install --no-build-isolation transformer-engine[pytorch,jax]==VERSION'"
" or 'pip3 install transformer-engine[core]` for the TE core lib only. The `core_cu12`"
" or `core_cu13` extra deps can be used to specify CUDA version for the TE core lib."
)
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_get_sys_extension
()
->
str
:
"""File extension for shared objects."""
...
...
@@ -339,16 +390,14 @@ def _load_core_library():
if
"NVTE_PROJECT_BUILDING"
not
in
os
.
environ
or
bool
(
int
(
os
.
getenv
(
"NVTE_RELEASE_BUILD"
,
"0"
))):
try
:
sanity_checks_for_pypi_installation
()
_CUDNN_LIB_CTYPES
=
_load_cudnn
()
_NVRTC_LIB_CTYPES
=
_load_nvrtc
()
_CURAND_LIB_CTYPES
=
_load_curand
()
_CUBLAS_LIB_CTYPES
=
_load_nvidia_cuda_library
(
"cublas"
)
_CUDART_LIB_CTYPES
=
_load_nvidia_cuda_library
(
"cuda_runtime"
)
_TE_LIB_CTYPES
=
_load_core_library
()
# Needed to find the correct headers for NVRTC kernels.
if
not
os
.
getenv
(
"NVTE_CUDA_INCLUDE_DIR"
)
and
_nvidia_cudart_include_dir
():
os
.
environ
[
"NVTE_CUDA_INCLUDE_DIR"
]
=
_nvidia_cudart_include_dir
()
except
OSError
:
pass
_TE_LIB_CTYPES
=
_load_core_library
()
transformer_engine/common/fused_attn/fused_attn.cpp
View file @
0a5016b1
...
...
@@ -138,7 +138,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
float
dropout
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
int64_t
window_size_left
,
int64_t
window_size_right
)
{
int64_t
window_size_right
,
bool
return_max_logit
)
{
using
namespace
transformer_engine
;
NVTE_Fused_Attn_Backend
backend
=
NVTE_Fused_Attn_Backend
::
NVTE_No_Backend
;
const
int
device_id
=
cuda
::
current_device
();
...
...
@@ -187,7 +187,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_BSHD
||
qkv_format
==
NVTE_QKV_Format
::
NVTE_SBHD
)
&&
!
requires_64bit_ragged_offset
&&
(
softmax_type
==
NVTE_Softmax_Type
::
NVTE_VANILLA_SOFTMAX
)
&&
// 9.10.0: known bugs with SDPA FP8
(
cudnn_runtime_version
!=
91000
))
{
(
cudnn_runtime_version
!=
91000
)
&&
!
return_max_logit
)
{
if
(
cudnn_runtime_version
>=
8900
)
{
backend
=
NVTE_Fused_Attn_Backend
::
NVTE_FP8
;
}
else
{
...
...
@@ -216,7 +216,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(
qkv_layout
==
NVTE_QKV_Layout
::
NVTE_BSHD_BSHD_BSHD
))
&&
((
window_size_left
==
-
1
)
&&
(
window_size_right
==
-
1
||
window_size_right
==
0
))
&&
!
requires_64bit_ragged_offset
&&
(
softmax_type
==
NVTE_Softmax_Type
::
NVTE_VANILLA_SOFTMAX
))
{
(
softmax_type
==
NVTE_Softmax_Type
::
NVTE_VANILLA_SOFTMAX
)
&&
!
return_max_logit
)
{
flag_m512
=
true
;
}
if
(
...
...
@@ -418,8 +418,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens
,
const
NVTETensor
cu_seqlens_padded
,
const
NVTETensor
rng_state
,
size_t
max_seqlen
,
bool
is_training
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
size_t
max_seqlen
,
bool
is_training
,
bool
return_max_logit
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
...
...
@@ -460,7 +460,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
is_training
,
QKV_type
,
QKV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
h
,
h
,
max_seqlen
,
max_seqlen
,
d
,
d
,
window_size_left
,
window_size_right
);
h
,
h
,
max_seqlen
,
max_seqlen
,
d
,
d
,
window_size_left
,
window_size_right
,
return_max_logit
);
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
#if (CUDNN_VERSION >= 8901)
...
...
@@ -474,10 +474,10 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
}
else
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_arbitrary_seqlen
)
{
#if (CUDNN_VERSION >= 8900)
fused_attn_arbitrary_seqlen_fwd_qkvpacked
(
b
,
h
,
max_seqlen
,
d
,
t
,
is_training
,
attn_scale
,
dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
input_QKV
,
input_Bias
,
input_SoftmaxOffset
,
output_O
,
Aux_CTX_Tensors
,
input_cu_seqlens
,
input_cu_seqlens_padded
,
input_rng_state
,
wkspace
,
stream
,
handle
);
b
,
h
,
max_seqlen
,
d
,
t
,
is_training
,
return_max_logit
,
attn_scale
,
dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
input_QKV
,
input_Bias
,
input_SoftmaxOffset
,
output_O
,
Aux_CTX_Tensors
,
input_cu_seqlens
,
input_cu_seqlens_padded
,
input_rng_state
,
wkspace
,
stream
,
handle
);
#else
NVTE_ERROR
(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length.
\n
"
);
...
...
@@ -544,7 +544,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
true
,
QKV_type
,
QKV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
h
,
h
,
max_seqlen
,
max_seqlen
,
d
,
d
,
window_size_left
,
window_size_right
);
max_seqlen
,
max_seqlen
,
d
,
d
,
window_size_left
,
window_size_right
,
false
);
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
#if (CUDNN_VERSION >= 8901)
...
...
@@ -602,7 +602,7 @@ void nvte_fused_attn_fwd_kvpacked(
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
page_table_k
,
const
NVTETensor
page_table_v
,
const
NVTETensor
rng_state
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
float
attn_scale
,
float
dropout
,
size_t
max_seqlen_kv
,
bool
is_training
,
bool
return_max_logit
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
...
...
@@ -680,7 +680,8 @@ void nvte_fused_attn_fwd_kvpacked(
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
is_training
,
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d
,
d
,
window_size_left
,
window_size_right
);
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d
,
d
,
window_size_left
,
window_size_right
,
return_max_logit
);
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
#if (CUDNN_VERSION >= 8901)
...
...
@@ -695,12 +696,12 @@ void nvte_fused_attn_fwd_kvpacked(
#if (CUDNN_VERSION >= 8903)
fused_attn_arbitrary_seqlen_fwd_kvpacked
(
b
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d
,
t_q
,
t_kv
,
num_pages_k
,
num_pages_v
,
page_size_k
,
page_size_v
,
max_pages_per_seq_k
,
max_pages_per_seq_v
,
is_training
,
attn_scale
,
dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
input_Q
,
input_KV
,
input_Bias
,
input_SoftmaxOffset
,
output_O
,
Aux_CTX_Tensors
,
input_cu_seqlens_q
,
input_cu_seqlens_kv
,
input_cu_seqlens_q_padded
,
input_cu_seqlens_
kv
_padded
,
input_
page_table_k
,
input_page_table_
v
,
input_rng_state
,
wkspace
,
stream
,
handle
);
page_size_k
,
page_size_v
,
max_pages_per_seq_k
,
max_pages_per_seq_v
,
is_training
,
return_max_logit
,
attn_scale
,
dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
input_Q
,
input_KV
,
input_Bias
,
input_SoftmaxOffset
,
output_O
,
Aux_CTX_Tensors
,
input_cu_seqlens_q
,
input_cu_seqlens_kv
,
input_cu_seqlens_
q
_padded
,
input_
cu_seqlens_kv_padded
,
input_page_table_
k
,
input_page_table_v
,
input_rng_state
,
wkspace
,
stream
,
handle
);
#else
NVTE_ERROR
(
"cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length.
\n
"
);
...
...
@@ -777,7 +778,7 @@ void nvte_fused_attn_bwd_kvpacked(
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
true
,
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d
,
d
,
window_size_left
,
window_size_right
);
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d
,
d
,
window_size_left
,
window_size_right
,
false
);
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
#if (CUDNN_VERSION >= 8901)
...
...
@@ -832,17 +833,15 @@ void nvte_fused_attn_bwd_kvpacked(
}
}
// NVTE fused attention FWD with separate Q, K and V
void
nvte_fused_attn_fwd
(
const
NVTETensor
Q
,
const
NVTETensor
K
,
const
NVTETensor
V
,
const
NVTETensor
Bias
,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
void
nvte_fused_attn_fwd
(
const
NVTETensor
Q
,
const
NVTETensor
K
,
const
NVTETensor
V
,
const
NVTETensor
Bias
,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
page_table_k
,
const
NVTETensor
page_table_v
,
const
NVTETensor
rng_state
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
page_table_k
,
const
NVTETensor
page_table_v
,
const
NVTETensor
rng_state
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
bool
return_max_logit
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_flash_attn_fwd
);
using
namespace
transformer_engine
;
...
...
@@ -913,7 +912,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
is_training
,
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d_qk
,
d_v
,
window_size_left
,
window_size_right
);
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d_qk
,
d_v
,
window_size_left
,
window_size_right
,
return_max_logit
);
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
#if (CUDNN_VERSION >= 8901)
...
...
@@ -928,12 +928,12 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
#if (CUDNN_VERSION >= 8900)
fused_attn_arbitrary_seqlen_fwd
(
b
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d_qk
,
d_v
,
t_q
,
t_kv
,
num_pages_k
,
num_pages_v
,
page_size_k
,
page_size_v
,
max_pages_per_seq_k
,
max_pages_per_seq_v
,
is_training
,
attn_scale
,
dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
input_Q
,
input_K
,
input_V
,
input_Bias
,
input_SoftmaxOffset
,
output_O
,
Aux_CTX_Tensors
,
input_cu_seqlens_q
,
input_cu_seqlens_kv
,
input_cu_seqlens_q_padded
,
input_cu_seqlens_
kv
_padded
,
input_
page_table_k
,
input_page_table_
v
,
input_rng_state
,
wkspace
,
stream
,
handle
);
page_size_k
,
page_size_v
,
max_pages_per_seq_k
,
max_pages_per_seq_v
,
is_training
,
return_max_logit
,
attn_scale
,
dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
input_Q
,
input_K
,
input_V
,
input_Bias
,
input_SoftmaxOffset
,
output_O
,
Aux_CTX_Tensors
,
input_cu_seqlens_q
,
input_cu_seqlens_kv
,
input_cu_seqlens_
q
_padded
,
input_
cu_seqlens_kv_padded
,
input_page_table_
k
,
input_page_table_v
,
input_rng_state
,
wkspace
,
stream
,
handle
);
#else
NVTE_ERROR
(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length.
\n
"
);
...
...
@@ -1008,7 +1008,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
true
,
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d_qk
,
d_v
,
window_size_left
,
window_size_right
);
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d_qk
,
d_v
,
window_size_left
,
window_size_right
,
false
);
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
#if (CUDNN_VERSION >= 8901)
...
...
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu
View file @
0a5016b1
This diff is collapsed.
Click to expand it.
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h
View file @
0a5016b1
...
...
@@ -20,12 +20,13 @@ namespace transformer_engine {
#if (CUDNN_VERSION >= 8900)
void
fused_attn_arbitrary_seqlen_fwd_qkvpacked
(
size_t
batch
,
size_t
num_attn_heads
,
size_t
max_seqlen
,
size_t
head_dim
,
size_t
num_tokens
,
bool
is_training
,
float
attn_scale
,
float
p_dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
const
Tensor
*
input_QKV
,
const
Tensor
*
input_Bias
,
const
Tensor
*
input_SoftmaxOffset
,
Tensor
*
output_O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
Tensor
*
cu_seqlens
,
const
Tensor
*
cu_seqlens_padded
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
);
bool
is_training
,
bool
return_max_logit
,
float
attn_scale
,
float
p_dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
const
Tensor
*
input_QKV
,
const
Tensor
*
input_Bias
,
const
Tensor
*
input_SoftmaxOffset
,
Tensor
*
output_O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
Tensor
*
cu_seqlens
,
const
Tensor
*
cu_seqlens_padded
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
);
void
fused_attn_arbitrary_seqlen_bwd_qkvpacked
(
size_t
batch
,
size_t
num_attn_heads
,
size_t
max_seqlen
,
size_t
head_dim
,
size_t
num_tokens
,
...
...
@@ -41,14 +42,15 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t
batch
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim
,
size_t
num_tokens_q
,
size_t
num_tokens_kv
,
size_t
num_pages_k
,
size_t
num_pages_v
,
size_t
page_size_k
,
size_t
page_size_v
,
size_t
max_pages_per_seq_k
,
size_t
max_pages_per_seq_v
,
bool
is_training
,
float
attn_scale
,
float
p_dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
const
Tensor
*
input_Q
,
const
Tensor
*
input_KV
,
const
Tensor
*
input_Bias
,
const
Tensor
*
input_SoftmaxOffset
,
Tensor
*
output_O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
Tensor
*
cu_seqlens_q
,
const
Tensor
*
cu_seqlens_kv
,
const
Tensor
*
cu_seqlens_q_padded
,
const
Tensor
*
cu_seqlens_kv_padded
,
const
Tensor
*
page_table_k
,
const
Tensor
*
page_table_v
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
);
size_t
max_pages_per_seq_k
,
size_t
max_pages_per_seq_v
,
bool
is_training
,
bool
return_max_logit
,
float
attn_scale
,
float
p_dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
const
Tensor
*
input_Q
,
const
Tensor
*
input_KV
,
const
Tensor
*
input_Bias
,
const
Tensor
*
input_SoftmaxOffset
,
Tensor
*
output_O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
Tensor
*
cu_seqlens_q
,
const
Tensor
*
cu_seqlens_kv
,
const
Tensor
*
cu_seqlens_q_padded
,
const
Tensor
*
cu_seqlens_kv_padded
,
const
Tensor
*
page_table_k
,
const
Tensor
*
page_table_v
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
);
void
fused_attn_arbitrary_seqlen_bwd_kvpacked
(
size_t
batch
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
...
...
@@ -68,14 +70,14 @@ void fused_attn_arbitrary_seqlen_fwd(
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
size_t
num_tokens_q
,
size_t
num_tokens_kv
,
size_t
num_pages_k
,
size_t
num_pages_v
,
size_t
page_size_k
,
size_t
page_size_v
,
size_t
max_pages_per_seq_k
,
size_t
max_pages_per_seq_v
,
bool
is_training
,
float
attn_scale
,
float
p_dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_
right
,
const
Tensor
*
input_Q
,
const
Tensor
*
input_K
,
const
Tensor
*
input_
V
,
const
Tensor
*
input_
Bias
,
const
Tensor
*
input_
SoftmaxOffset
,
Tensor
*
out
put_
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
Tensor
*
cu_seqlens_q
,
const
Tensor
*
cu_seqlens_kv
,
const
Tensor
*
cu_seqlens_q
_padded
,
const
Tensor
*
cu_seqlens_kv_padded
,
const
Tensor
*
page_table_k
,
const
Tensor
*
page_table_v
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
);
bool
return_max_logit
,
float
attn_scale
,
float
p_dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_
left
,
int64_t
window_size_right
,
const
Tensor
*
input_
Q
,
const
Tensor
*
input_
K
,
const
Tensor
*
input_
V
,
const
Tensor
*
in
put_
Bias
,
const
Tensor
*
input_SoftmaxOffset
,
Tensor
*
output_O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
Tensor
*
cu_seqlens_q
,
const
Tensor
*
cu_seqlens_kv
,
const
Tensor
*
cu_seqlens_q
_padded
,
const
Tensor
*
cu_seqlens_kv_padded
,
const
Tensor
*
page_table_k
,
const
Tensor
*
page_table_v
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
);
void
fused_attn_arbitrary_seqlen_bwd
(
size_t
batch
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
...
...
transformer_engine/common/fused_attn/fused_attn_fp8.cu
View file @
0a5016b1
...
...
@@ -1710,7 +1710,8 @@ void fused_attn_fp8_fwd_impl_v1(
qkv_tensor_type
,
o_tensor_type
,
cudnn_frontend
::
DataType_t
::
NOT_SET
,
cudnn_frontend
::
DataType_t
::
NOT_SET
};
cudnn_frontend
::
DataType_t
::
NOT_SET
,
false
};
namespace
fe
=
cudnn_frontend
;
using
graph_and_tensors
=
...
...
@@ -2038,7 +2039,8 @@ void fused_attn_fp8_bwd_impl_v1(
qkv_tensor_type
,
o_tensor_type
,
do_tensor_type
,
dqkv_tensor_type
};
dqkv_tensor_type
,
false
};
namespace
fe
=
cudnn_frontend
;
using
graph_and_tensors
=
...
...
transformer_engine/common/fused_attn/utils.h
View file @
0a5016b1
...
...
@@ -115,20 +115,21 @@ struct FADescriptor_v1 {
cudnn_frontend
::
DataType_t
o_tensor_type
;
cudnn_frontend
::
DataType_t
do_tensor_type
;
cudnn_frontend
::
DataType_t
dqkv_tensor_type
;
bool
generate_max_sum_exp
;
bool
operator
<
(
const
FADescriptor_v1
&
rhs
)
const
{
return
std
::
tie
(
b
,
h
,
hg
,
s_q
,
s_kv
,
d_qk
,
d_v
,
num_pages_k
,
num_pages_v
,
page_size_k
,
page_size_v
,
max_pages_per_seq_k
,
max_pages_per_seq_v
,
bias_b
,
bias_h
,
attnScale
,
isTraining
,
dropoutProbability
,
layout
,
mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
deterministic
,
bias_type
,
qkv_tensor_type
,
o_tensor_type
,
do_tensor_type
,
dqkv_tensor_type
)
<
o_tensor_type
,
do_tensor_type
,
dqkv_tensor_type
,
generate_max_sum_exp
)
<
std
::
tie
(
rhs
.
b
,
rhs
.
h
,
rhs
.
hg
,
rhs
.
s_q
,
rhs
.
s_kv
,
rhs
.
d_qk
,
rhs
.
d_v
,
rhs
.
num_pages_k
,
rhs
.
num_pages_v
,
rhs
.
page_size_k
,
rhs
.
page_size_v
,
rhs
.
max_pages_per_seq_k
,
rhs
.
max_pages_per_seq_v
,
rhs
.
bias_b
,
rhs
.
bias_h
,
rhs
.
attnScale
,
rhs
.
isTraining
,
rhs
.
dropoutProbability
,
rhs
.
layout
,
rhs
.
mask_type
,
rhs
.
softmax_type
,
rhs
.
window_size_left
,
rhs
.
window_size_right
,
rhs
.
deterministic
,
rhs
.
bias_type
,
rhs
.
qkv_tensor_type
,
rhs
.
o_tensor_type
,
rhs
.
do_tensor_type
,
rhs
.
dqkv_tensor_type
);
rhs
.
dqkv_tensor_type
,
rhs
.
generate_max_sum_exp
);
}
};
...
...
transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu
View file @
0a5016b1
...
...
@@ -97,7 +97,8 @@ cutlass::Array<cutlass::float_e2m1_t, 8>
StochasticNumericConverterBase
(
cutlass
::
Array
<
float
,
8
>
const
&
input
,
cutlass
::
Array
<
uint32_t
,
2
>
const
&
rbits
)
{
using
result_type
=
cutlass
::
Array
<
cutlass
::
float_e2m1_t
,
8
>
;
result_type
output
;
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
constexpr
bool
has_rs
=
ARCH_HAS_STOCHASTIC_ROUNDING
;
if
constexpr
(
has_rs
)
{
auto
output_ptr
=
reinterpret_cast
<
uint16_t
*>
(
&
output
);
asm
volatile
(
\
"{
\n
"
\
...
...
@@ -109,10 +110,10 @@ StochasticNumericConverterBase(cutlass::Array<float, 8> const &input, cutlass::A
:
"f"
(
input
[
0
]),
"f"
(
input
[
1
]),
"f"
(
input
[
2
]),
"f"
(
input
[
3
]),
"f"
(
input
[
4
]),
"f"
(
input
[
5
]),
"f"
(
input
[
6
]),
"f"
(
input
[
7
]),
"r"
(
rbits
[
0
]),
"r"
(
rbits
[
1
]));
#
else
}
else
{
NVTE_DEVICE_ERROR
(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
}
return
output
;
}
...
...
transformer_engine/common/include/transformer_engine/fused_attn.h
View file @
0a5016b1
...
...
@@ -206,13 +206,14 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
* \param[in] head_dim_v The head dimension of V.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
*/
NVTE_Fused_Attn_Backend
nvte_get_fused_attn_backend
(
bool
is_training
,
NVTEDType
q_dtype
,
NVTEDType
kv_dtype
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
float
dropout
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
int64_t
window_size_left
,
int64_t
window_size_right
);
int64_t
window_size_right
,
bool
return_max_logit
);
/*! \brief Compute dot product attention with packed QKV input.
*
...
...
@@ -255,6 +256,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] max_seqlen Max sequence length used for computing,
* it may be >= max(seqlen_i) for i=0,...batch_size-1.
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout.
...
...
@@ -266,13 +268,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void
nvte_fused_attn_fwd_qkvpacked
(
const
NVTETensor
QKV
,
const
NVTETensor
Bias
,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens
,
const
NVTETensor
cu_seqlens_padded
,
const
NVTETensor
rng_state
,
size_t
max_seqlen
,
bool
is_training
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
);
void
nvte_fused_attn_fwd_qkvpacked
(
const
NVTETensor
QKV
,
const
NVTETensor
Bias
,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens
,
const
NVTETensor
cu_seqlens_padded
,
const
NVTETensor
rng_state
,
size_t
max_seqlen
,
bool
is_training
,
bool
return_max_logit
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Compute the backward of the dot product attention with packed QKV input.
*
...
...
@@ -381,6 +386,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] max_seqlen_kv Max sequence length used for computing for KV.
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout.
...
...
@@ -399,7 +405,7 @@ void nvte_fused_attn_fwd_kvpacked(
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
page_table_k
,
const
NVTETensor
page_table_v
,
const
NVTETensor
rng_state
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
float
attn_scale
,
float
dropout
,
size_t
max_seqlen_kv
,
bool
is_training
,
bool
return_max_logit
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
);
...
...
@@ -520,6 +526,7 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] max_seqlen_kv Max sequence length used for computing for K and V.
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensors' layout.
...
...
@@ -531,17 +538,15 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void
nvte_fused_attn_fwd
(
const
NVTETensor
Q
,
const
NVTETensor
K
,
const
NVTETensor
V
,
const
NVTETensor
Bias
,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
void
nvte_fused_attn_fwd
(
const
NVTETensor
Q
,
const
NVTETensor
K
,
const
NVTETensor
V
,
const
NVTETensor
Bias
,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
page_table_k
,
const
NVTETensor
page_table_v
,
const
NVTETensor
rng_state
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
page_table_k
,
const
NVTETensor
page_table_v
,
const
NVTETensor
rng_state
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
bool
return_max_logit
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Compute the backward of the dot product attention with separate Q, K and V.
...
...
transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu
View file @
0a5016b1
...
...
@@ -264,7 +264,8 @@ __device__ __forceinline__ size_t scale_factor_swizzled_offset(size_t row_idx, s
__device__
__forceinline__
__nv_fp4x4_e2m1
cvt_fp32_to_fp4_4x_with_stochastic_rounding
(
const
float2
in01
,
const
float2
in23
,
const
uint32_t
rbits
)
{
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
constexpr
bool
has_rs
=
ARCH_HAS_STOCHASTIC_ROUNDING
;
if
constexpr
(
has_rs
)
{
uint16_t
out_4x
;
asm
volatile
(
"{
\n
"
...
...
@@ -273,19 +274,20 @@ __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_stochastic_ro
:
"=h"
(
out_4x
)
:
"f"
(
in01
.
y
),
"f"
(
in01
.
x
),
"f"
(
in23
.
y
),
"f"
(
in23
.
x
),
"r"
(
rbits
));
return
*
reinterpret_cast
<
__nv_fp4x4_e2m1
*>
(
&
out_4x
);
#
else
}
else
{
NVTE_DEVICE_ERROR
(
"FP4 cvt PTX instructions are architecture-specific. "
"FP4 cvt
.rs
PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
uint16_t
dummy
=
0
;
return
*
reinterpret_cast
<
__nv_fp4x4_e2m1
*>
(
&
dummy
);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
}
}
__device__
__forceinline__
__nv_fp4x4_e2m1
cvt_fp32_to_fp4_4x_with_rn
(
const
float2
in01
,
const
float2
in23
,
const
uint32_t
rbits
)
{
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
constexpr
bool
has_fp4
=
ARCH_BLACKWELL_FAMILY
;
if
constexpr
(
has_fp4
)
{
// NOTE: rbits unused for rn.
uint32_t
out_4x
;
// Only need 16 bit. Using 32 bit container for packing.
asm
volatile
(
...
...
@@ -299,13 +301,13 @@ __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_rn(const floa
:
"=r"
(
out_4x
)
:
"f"
(
in01
.
y
),
"f"
(
in01
.
x
),
"f"
(
in23
.
y
),
"f"
(
in23
.
x
));
return
reinterpret_cast
<
__nv_fp4x4_e2m1
*>
(
&
out_4x
)[
0
];
#
else
}
else
{
NVTE_DEVICE_ERROR
(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
uint16_t
dummy
=
0
;
return
*
reinterpret_cast
<
__nv_fp4x4_e2m1
*>
(
&
dummy
);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
}
}
template
<
bool
kApplyStochasticRounding
>
...
...
transformer_engine/common/util/nvfp4_transpose.cuh
View file @
0a5016b1
...
...
@@ -15,10 +15,9 @@
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#if
CUDA_VERSION > 12080
#if
FP4_TYPE_SUPPORTED
#include <cuda_fp4.h>
#endif // CUDA_VERSION > 12080
#endif // FP4_TYPE_SUPPORTED
#include <cfloat>
#include "../common.h"
...
...
@@ -30,7 +29,7 @@
namespace
transformer_engine
{
#if
CUDA_VERSION > 12080
#if
FP4_TYPE_SUPPORTED
namespace
nvfp4_transpose
{
using
RNG
=
decltype
(
curanddx
::
Generator
<
curanddx
::
philox4_32
>
()
+
curanddx
::
PhiloxRounds
<
10
>
()
+
...
...
@@ -152,12 +151,11 @@ __device__ __forceinline__ uint32_t get_rbits(RNG &rng, uint4 &random_uint4, int
return
rbits
;
}
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__device__
__forceinline__
fp4e2m1x4
mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding
(
const
uint64_t
in_4x
,
const
float2
scale
,
const
uint32_t
rbits
)
{
uint16_t
out_4x
=
0
;
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
constexpr
bool
has_rs
=
ARCH_HAS_STOCHASTIC_ROUNDING
;
if
constexpr
(
has_rs
)
{
asm
volatile
(
"{
\n
"
".reg.b64 v01;
\n\t
"
...
...
@@ -185,20 +183,21 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_stochastic_roun
"}"
:
"=h"
(
out_4x
)
:
"l"
(
in_4x
),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)),
"r"
(
rbits
));
#
else
}
else
{
NVTE_DEVICE_ERROR
(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
}
return
*
reinterpret_cast
<
fp4e2m1x4
*>
(
&
out_4x
);
}
__device__
__forceinline__
fp4e2m1x4
mul_cvt_bf16_to_fp4_4x_with_rn
(
const
uint64_t
in_4x
,
const
float2
scale
,
const
uint32_t
rbits
)
{
// NOTE: rbits unused for rn.
constexpr
bool
is_blackwell
=
ARCH_BLACKWELL_FAMILY
;
uint32_t
out_4x
=
0
;
// Only need 16 bit. Using 32 bit container for packing.
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
if
constexpr
(
is_blackwell
)
{
// NOTE: rbits unused for rn.
asm
volatile
(
"{
\n
"
".reg.b64 v01;
\n\t
"
...
...
@@ -230,11 +229,11 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_rn(const uint64
"}"
:
"=r"
(
out_4x
)
:
"l"
(
in_4x
),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)));
#
else
}
else
{
NVTE_DEVICE_ERROR
(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
}
return
reinterpret_cast
<
fp4e2m1x4
*>
(
&
out_4x
)[
0
];
}
...
...
@@ -252,7 +251,8 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x(const uint64_t in_4x
__device__
__forceinline__
fp4e2m1x4
mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding
(
const
float2
in01
,
const
float2
in23
,
const
float2
scale
,
const
uint32_t
rbits
)
{
uint16_t
out_4x
=
0
;
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
constexpr
bool
has_rs
=
ARCH_HAS_STOCHASTIC_ROUNDING
;
if
constexpr
(
has_rs
)
{
asm
volatile
(
"{
\n
"
".reg.b64 v01;
\n\t
"
...
...
@@ -275,11 +275,11 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_stochastic_roun
:
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
in01
)),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
in23
)),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)),
"r"
(
rbits
));
#
else
}
else
{
NVTE_DEVICE_ERROR
(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
}
return
*
reinterpret_cast
<
fp4e2m1x4
*>
(
&
out_4x
);
}
...
...
@@ -287,9 +287,10 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_rn(const float2
const
float2
in23
,
const
float2
scale
,
const
uint32_t
rbits
)
{
// NOTE: rbits unused for rn.
constexpr
bool
is_blackwell
=
ARCH_BLACKWELL_FAMILY
;
uint32_t
out_4x
=
0
;
// Only need 16 bit. Using 32 bit container for packing.
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
if
constexpr
(
is_blackwell
)
{
// NOTE: rbits unused for rn.
asm
volatile
(
"{
\n
"
".reg.b64 v01;
\n\t
"
...
...
@@ -316,11 +317,11 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_rn(const float2
:
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
in01
)),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
in23
)),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)));
#
else
}
else
{
NVTE_DEVICE_ERROR
(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
}
return
reinterpret_cast
<
fp4e2m1x4
*>
(
&
out_4x
)[
0
];
}
...
...
@@ -335,8 +336,6 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, c
}
}
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
template
<
bool
COMPUTE_ACTIVATIONS
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
),
typename
IType
,
bool
USE_STOCHASTIC_ROUNDING
,
bool
RETURN_TRANSPOSE
>
__global__
void
__launch_bounds__
(
THREADS_NUM
)
...
...
@@ -1380,18 +1379,13 @@ __global__ void __launch_bounds__(THREADS_NUM)
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
}
// namespace nvfp4_transpose
#endif // CUDA_VERSION > 12080
// Compile-time flag to choose kernel variant
#ifndef USE_2D_NVFP4_KERNEL
#define USE_2D_NVFP4_KERNEL 0
#endif
#endif // FP4_TYPE_SUPPORTED
template
<
bool
COMPUTE_ACTIVATIONS
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
),
bool
use_2d_quantization
>
void
nvfp4_quantize_transpose
(
const
Tensor
&
input
,
const
Tensor
*
noop
,
Tensor
*
output
,
const
QuantizationConfig
*
quant_config
,
cudaStream_t
stream
)
{
#if
CUDA_VERSION > 12080
#if
FP4_TYPE_SUPPORTED
bool
use_stochastic_rounding
=
quant_config
?
quant_config
->
stochastic_rounding
:
false
;
// If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to
...
...
@@ -1509,7 +1503,7 @@ void nvfp4_quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *o
}););
#else
NVTE_ERROR
(
"FP4 support requires CUDA 12.8+, but compile-time CUDA version is "
,
CUDA_VERSION
);
#endif //
CUDA_VERSION > 12080
#endif //
FP4_TYPE_SUPPORTED
}
}
// namespace transformer_engine
...
...
transformer_engine/common/util/ptx.cuh
View file @
0a5016b1
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment