Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
0a5016b1
Commit
0a5016b1
authored
Dec 03, 2025
by
wenjh
Browse files
Merge nv release_v2.9
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parents
063ef88d
70f53666
Changes
61
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1391 additions
and
767 deletions
+1391
-767
tests/jax/test_distributed_softmax.py
tests/jax/test_distributed_softmax.py
+6
-4
tests/jax/test_fused_attn.py
tests/jax/test_fused_attn.py
+3
-3
tests/jax/test_helper.py
tests/jax/test_helper.py
+81
-1
tests/pytorch/attention/run_attention_with_cp.py
tests/pytorch/attention/run_attention_with_cp.py
+11
-4
tests/pytorch/attention/test_attention.py
tests/pytorch/attention/test_attention.py
+70
-20
tests/pytorch/attention/test_attention_with_cp.py
tests/pytorch/attention/test_attention_with_cp.py
+3
-3
tests/pytorch/test_numerics.py
tests/pytorch/test_numerics.py
+1
-29
tests/pytorch/utils.py
tests/pytorch/utils.py
+3
-0
transformer_engine/common/CMakeLists.txt
transformer_engine/common/CMakeLists.txt
+282
-162
transformer_engine/common/__init__.py
transformer_engine/common/__init__.py
+96
-47
transformer_engine/common/fused_attn/fused_attn.cpp
transformer_engine/common/fused_attn/fused_attn.cpp
+40
-40
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu
...gine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu
+268
-142
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h
...ngine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h
+24
-22
transformer_engine/common/fused_attn/fused_attn_fp8.cu
transformer_engine/common/fused_attn/fused_attn_fp8.cu
+4
-2
transformer_engine/common/fused_attn/utils.h
transformer_engine/common/fused_attn/utils.h
+3
-2
transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu
...mmon/hadamard_transform/hadamard_transform_cast_fusion.cu
+14
-13
transformer_engine/common/include/transformer_engine/fused_attn.h
...mer_engine/common/include/transformer_engine/fused_attn.h
+42
-37
transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu
...mmon/transpose/quantize_transpose_vector_blockwise_fp4.cu
+39
-37
transformer_engine/common/util/nvfp4_transpose.cuh
transformer_engine/common/util/nvfp4_transpose.cuh
+142
-148
transformer_engine/common/util/ptx.cuh
transformer_engine/common/util/ptx.cuh
+259
-51
No files found.
tests/jax/test_distributed_softmax.py
View file @
0a5016b1
...
@@ -103,8 +103,10 @@ class TestDistributedSoftmax:
...
@@ -103,8 +103,10 @@ class TestDistributedSoftmax:
devices
=
np
.
asarray
(
jax
.
devices
()[:
device_count
]).
reshape
(
*
mesh_shape
)
devices
=
np
.
asarray
(
jax
.
devices
()[:
device_count
]).
reshape
(
*
mesh_shape
)
mesh
=
Mesh
(
devices
,
mesh_axes
)
mesh
=
Mesh
(
devices
,
mesh_axes
)
with
mesh
,
autocast
(
mesh_resource
=
mesh_resource
):
with
mesh
,
autocast
(
mesh_resource
=
mesh_resource
):
x_
=
jax
.
device_put
(
x
,
NamedSharding
(
mesh
,
x_pspec
))
x_named_sharding
=
NamedSharding
(
mesh
,
x_pspec
)
mask_
=
jax
.
device_put
(
mask
,
NamedSharding
(
mesh
,
mask_pspec
))
mask_named_sharding
=
NamedSharding
(
mesh
,
mask_pspec
)
x_
=
jax
.
device_put
(
x
,
x_named_sharding
)
mask_
=
jax
.
device_put
(
mask
,
mask_named_sharding
)
with
warnings
.
catch_warnings
(
record
=
True
)
as
warns
:
with
warnings
.
catch_warnings
(
record
=
True
)
as
warns
:
try
:
try
:
...
@@ -116,8 +118,8 @@ class TestDistributedSoftmax:
...
@@ -116,8 +118,8 @@ class TestDistributedSoftmax:
grad_args
=
(
0
,),
grad_args
=
(
0
,),
metric_fwd_dtype
=
dtype
,
metric_fwd_dtype
=
dtype
,
metric_bwd_dtype
=
dtype
,
metric_bwd_dtype
=
dtype
,
in_shardings
=
(
x_
pspec
,
mask_pspec
),
in_shardings
=
(
x_
named_sharding
,
mask_named_sharding
),
out_shardings
=
(
None
,
(
x_
pspec
,)
),
out_shardings
=
(
None
,
x_
named_sharding
),
)
)
except
AssertionError
as
err
:
except
AssertionError
as
err
:
# Softmax should still produce the correct numerical result with
# Softmax should still produce the correct numerical result with
...
...
tests/jax/test_fused_attn.py
View file @
0a5016b1
...
@@ -378,14 +378,14 @@ class FusedAttnRunner:
...
@@ -378,14 +378,14 @@ class FusedAttnRunner:
pytest
.
skip
(
pytest
.
skip
(
"seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
"seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
)
)
# TODO(KshitijLakhani): Set the upper limit for skipping this test when cuDNN adds support
if
(
if
(
get_device_compute_capability
(
0
)
=
=
100
get_device_compute_capability
(
0
)
>
=
100
and
self
.
dropout_prob
==
0.1
and
self
.
dropout_prob
==
0.1
and
self
.
attn_bias_type
is
not
AttnBiasType
.
NO_BIAS
and
self
.
attn_bias_type
is
not
AttnBiasType
.
NO_BIAS
):
):
pytest
.
skip
(
pytest
.
skip
(
"For sm100, bprop kernel support for dropout + determinism (bias) is not supported"
"For sm100
+
, bprop kernel support for dropout + determinism (bias) is not supported"
)
)
# Test the MLA case where head dims for qk differ from head dims for v, only if the tensors
# Test the MLA case where head dims for qk differ from head dims for v, only if the tensors
# are provided in BSHD_BSHD_BSHD or THD_THD_THD formats
# are provided in BSHD_BSHD_BSHD or THD_THD_THD formats
...
...
tests/jax/test_helper.py
View file @
0a5016b1
...
@@ -3,11 +3,13 @@
...
@@ -3,11 +3,13 @@
# See LICENSE for license information.
# See LICENSE for license information.
import
unittest
import
unittest
from
functools
import
partial
import
flax
import
flax
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
import
numpy
as
np
import
numpy
as
np
from
flax
import
linen
as
nn
from
utils
import
assert_allclose
from
utils
import
assert_allclose
from
transformer_engine.common.recipe
import
(
from
transformer_engine.common.recipe
import
(
...
@@ -24,15 +26,51 @@ from transformer_engine.jax.quantize import (
...
@@ -24,15 +26,51 @@ from transformer_engine.jax.quantize import (
ScalingMode
,
ScalingMode
,
update_collections
,
update_collections
,
TensorSource
,
TensorSource
,
QuantizerFactory
,
QuantizeLayout
,
)
)
from
transformer_engine.jax.quantize.helper
import
_format2dtypes
from
transformer_engine.jax.quantize.helper
import
_format2dtypes
from
transformer_engine.jax.sharding
import
MeshResource
,
global_mesh_resource
from
transformer_engine.jax.sharding
import
MeshResource
,
global_mesh_resource
from
transformer_engine.jax.flax.module
import
TransformerEngineBase
is_fp8_supported
,
reason
=
is_scaling_mode_supported
(
ScalingMode
.
DELAYED_TENSOR_SCALING
)
is_fp8_supported
,
reason
=
is_scaling_mode_supported
(
ScalingMode
.
DELAYED_TENSOR_SCALING
)
is_mxfp8_supported
,
mxfp8_reason
=
is_scaling_mode_supported
(
ScalingMode
.
MXFP8_1D_SCALING
)
is_mxfp8_supported
,
mxfp8_reason
=
is_scaling_mode_supported
(
ScalingMode
.
MXFP8_1D_SCALING
)
is_nvfp4_supported
,
nvfp4_reason
=
is_scaling_mode_supported
(
ScalingMode
.
NVFP4_1D_SCALING
)
is_nvfp4_supported
,
nvfp4_reason
=
is_scaling_mode_supported
(
ScalingMode
.
NVFP4_1D_SCALING
)
def
quantizer_check_vjp
(
outer_quantizer_set
,
assertion_func
,
x
):
"""Check that the quantizers in the quantizer set are as expected and reconstructed correctly from flattened pytree representations across VJP boundaries."""
# Define a function with a custom VJP (vector-Jacobian product)
@
partial
(
jax
.
custom_vjp
,
nondiff_argnums
=
(
1
,))
def
quantizer_check
(
inner_quantizer_set
,
assertion_func
,
x
):
return
quantizer_check_fwd
(
inner_quantizer_set
,
assertion_func
,
x
)
def
quantizer_check_fwd
(
inner_quantizer_set
,
assertion_func
,
x
):
assertion_func
(
inner_quantizer_set
.
x
,
TensorSource
.
X
)
assertion_func
(
inner_quantizer_set
.
kernel
,
TensorSource
.
KERNEL
)
assertion_func
(
inner_quantizer_set
.
dgrad
,
TensorSource
.
DGRAD
)
return
x
def
quantizer_check_bwd
(
ctx
,
g
):
return
(
g
,)
quantizer_check
.
defvjp
(
quantizer_check_fwd
,
quantizer_check_bwd
)
return
quantizer_check
(
outer_quantizer_set
,
assertion_func
,
x
)
class
TestModule
(
TransformerEngineBase
):
"""A simple module to test quantizer creation and reconstruction across VJP boundaries."""
# Signature: (quantizer: Quantizer, tensor_source: TensorSource) -> None
assertion_func
:
callable
@
nn
.
compact
def
__call__
(
self
,
x
):
quantizer_set
=
self
.
generate_quantizer_set
()
return
quantizer_check_vjp
(
quantizer_set
,
self
.
assertion_func
,
x
)
class
TestHelper
(
unittest
.
TestCase
):
class
TestHelper
(
unittest
.
TestCase
):
@
unittest
.
skipIf
(
not
is_fp8_supported
,
reason
=
reason
)
@
unittest
.
skipIf
(
not
is_fp8_supported
,
reason
=
reason
)
...
@@ -89,12 +127,43 @@ class TestFP8Functions(unittest.TestCase):
...
@@ -89,12 +127,43 @@ class TestFP8Functions(unittest.TestCase):
for
tensor_source
in
TensorSource
:
for
tensor_source
in
TensorSource
:
target_scaling_mode
=
(
target_scaling_mode
=
(
ScalingMode
.
NVFP4_2D_SCALING
ScalingMode
.
NVFP4_2D_SCALING
if
tensor_source
==
TensorSource
.
KERNEL
if
(
not
test
.
disable_2d_quantization
)
and
tensor_source
==
TensorSource
.
KERNEL
else
ScalingMode
.
NVFP4_1D_SCALING
else
ScalingMode
.
NVFP4_1D_SCALING
)
)
self
.
assertEqual
(
self
.
assertEqual
(
get_quantize_config
().
get_scaling_mode
(
tensor_source
),
target_scaling_mode
get_quantize_config
().
get_scaling_mode
(
tensor_source
),
target_scaling_mode
)
)
self
.
assertEqual
(
get_quantize_config
().
DISABLE_STOCHASTIC_ROUNDING
,
test
.
disable_stochastic_rounding
)
self
.
assertEqual
(
get_quantize_config
().
DISABLE_RHT
,
test
.
disable_rht
)
self
.
assertEqual
(
get_quantize_config
().
DISABLE_2D_QUANTIZATION
,
test
.
disable_2d_quantization
)
def
_compare_nvfp4_scaling_quantizers
(
self
,
test
):
"""Check that the quantizers created have the expected stochastic rounding state and the state is preserved across VJP boundaries."""
def
assertion_func
(
quantizer
,
tensor_source
):
if
test
.
disable_stochastic_rounding
or
tensor_source
!=
TensorSource
.
DGRAD
:
self
.
assertIsNone
(
quantizer
.
stochastic_rounding_rng_state
)
else
:
self
.
assertIsNotNone
(
quantizer
.
stochastic_rounding_rng_state
)
expected_rht
=
(
quantizer
.
scaling_mode
==
ScalingMode
.
NVFP4_1D_SCALING
and
quantizer
.
q_layout
in
{
QuantizeLayout
.
ROWWISE_COLWISE
,
QuantizeLayout
.
COLWISE
}
and
not
test
.
disable_rht
)
self
.
assertEqual
(
quantizer
.
use_rht
,
expected_rht
)
x
=
jnp
.
ones
((),
dtype
=
jnp
.
float32
)
test_module
=
TestModule
(
assertion_func
=
assertion_func
)
param_key
,
sr_key
=
jax
.
random
.
split
(
jax
.
random
.
PRNGKey
(
0
))
rngs
=
{
"params"
:
param_key
,
"sr_rng"
:
sr_key
}
variables
=
test_module
.
init
(
rngs
,
x
)
jax
.
jit
(
jax
.
value_and_grad
(
test_module
.
apply
),
static_argnums
=
(
2
,))(
variables
,
x
,
rngs
=
rngs
)
@
unittest
.
skipIf
(
not
is_fp8_supported
,
reason
=
reason
)
@
unittest
.
skipIf
(
not
is_fp8_supported
,
reason
=
reason
)
def
test_autocast_delayed_scaling
(
self
):
def
test_autocast_delayed_scaling
(
self
):
...
@@ -171,5 +240,16 @@ class TestFP8Functions(unittest.TestCase):
...
@@ -171,5 +240,16 @@ class TestFP8Functions(unittest.TestCase):
with
autocast
(
enabled
=
True
,
recipe
=
bs
,
mesh_resource
=
MeshResource
()):
with
autocast
(
enabled
=
True
,
recipe
=
bs
,
mesh_resource
=
MeshResource
()):
self
.
assertTrue
(
get_quantize_config
().
is_fp8_enabled
())
self
.
assertTrue
(
get_quantize_config
().
is_fp8_enabled
())
self
.
_compare_nvfp4_scaling
(
bs
)
self
.
_compare_nvfp4_scaling
(
bs
)
self
.
_compare_nvfp4_scaling_quantizers
(
bs
)
bs
=
NVFP4BlockScaling
(
disable_stochastic_rounding
=
True
,
disable_rht
=
True
,
disable_2d_quantization
=
True
,
)
with
autocast
(
enabled
=
True
,
recipe
=
bs
,
mesh_resource
=
MeshResource
()):
self
.
assertTrue
(
get_quantize_config
().
is_fp8_enabled
())
self
.
_compare_nvfp4_scaling
(
bs
)
self
.
_compare_nvfp4_scaling_quantizers
(
bs
)
self
.
_check_default_state
()
self
.
_check_default_state
()
tests/pytorch/attention/run_attention_with_cp.py
View file @
0a5016b1
...
@@ -248,6 +248,7 @@ def run_dpa_with_cp(
...
@@ -248,6 +248,7 @@ def run_dpa_with_cp(
attn_mask_type
=
config
.
attn_mask_type
,
attn_mask_type
=
config
.
attn_mask_type
,
window_size
=
config
.
window_size
,
window_size
=
config
.
window_size
,
softmax_type
=
config
.
softmax_type
,
softmax_type
=
config
.
softmax_type
,
return_max_logit
=
config
.
return_max_logit
,
).
cuda
()
).
cuda
()
if
config
.
softmax_type
!=
"vanilla"
:
if
config
.
softmax_type
!=
"vanilla"
:
core_attn
.
softmax_offset
.
requires_grad
=
True
core_attn
.
softmax_offset
.
requires_grad
=
True
...
@@ -308,6 +309,7 @@ def run_dpa_with_cp(
...
@@ -308,6 +309,7 @@ def run_dpa_with_cp(
fp8_context
=
autocast
(
enabled
=
True
,
recipe
=
fp8_recipe
,
amax_reduction_group
=
cp_comm_group
)
fp8_context
=
autocast
(
enabled
=
True
,
recipe
=
fp8_recipe
,
amax_reduction_group
=
cp_comm_group
)
else
:
else
:
fp8_context
=
nullcontext
()
fp8_context
=
nullcontext
()
max_logit
=
None
with
fp8_context
:
with
fp8_context
:
# q, k, v, out in FP8; dout in F16
# q, k, v, out in FP8; dout in F16
out
=
core_attn
(
out
=
core_attn
(
...
@@ -322,6 +324,8 @@ def run_dpa_with_cp(
...
@@ -322,6 +324,8 @@ def run_dpa_with_cp(
cu_seqlens_kv_padded
=
cu_seqlens_kv_padded
,
cu_seqlens_kv_padded
=
cu_seqlens_kv_padded
,
fp8_output
=
fp8_mha
,
fp8_output
=
fp8_mha
,
)
)
if
config
.
return_max_logit
:
out
,
max_logit
=
out
if
fp8_bwd
and
fp8_mha
:
if
fp8_bwd
and
fp8_mha
:
dout_fp8
=
dout_quantizer
(
dout
)
dout_fp8
=
dout_quantizer
(
dout
)
out
.
backward
(
dout_fp8
)
out
.
backward
(
dout_fp8
)
...
@@ -400,6 +404,7 @@ def run_dpa_with_cp(
...
@@ -400,6 +404,7 @@ def run_dpa_with_cp(
fp8_context
=
nullcontext
()
fp8_context
=
nullcontext
()
# run attention
# run attention
max_logit_
=
None
with
fp8_context
:
with
fp8_context
:
# q, k, v, out in FP8; dout in F16
# q, k, v, out in FP8; dout in F16
out_
=
core_attn
(
out_
=
core_attn
(
...
@@ -414,6 +419,8 @@ def run_dpa_with_cp(
...
@@ -414,6 +419,8 @@ def run_dpa_with_cp(
cu_seqlens_kv_padded
=
cu_seqlens_kv_padded
,
cu_seqlens_kv_padded
=
cu_seqlens_kv_padded
,
fp8_output
=
fp8_mha
,
fp8_output
=
fp8_mha
,
)
)
if
config
.
return_max_logit
:
out_
,
max_logit_
=
out_
if
fp8_bwd
and
fp8_mha
:
if
fp8_bwd
and
fp8_mha
:
dout_fp8_
=
dout_quantizer
(
dout_
)
dout_fp8_
=
dout_quantizer
(
dout_
)
out_
.
backward
(
dout_fp8_
)
out_
.
backward
(
dout_fp8_
)
...
@@ -495,15 +502,15 @@ def run_dpa_with_cp(
...
@@ -495,15 +502,15 @@ def run_dpa_with_cp(
)
)
atol
,
rtol
,
rmse_tol
=
get_tols
(
config
,
dtype
)
atol
,
rtol
,
rmse_tol
=
get_tols
(
config
,
dtype
)
tensors_cp
=
[
out_
,
dq_
,
dk_
,
dv_
,
d_softmax_offset_
]
tensors_cp
=
[
out_
,
dq_
,
dk_
,
dv_
,
d_softmax_offset_
,
max_logit_
]
tensors_no_cp
=
[
out
,
dq
,
dk
,
dv
,
d_softmax_offset
]
tensors_no_cp
=
[
out
,
dq
,
dk
,
dv
,
d_softmax_offset
,
max_logit
]
names
=
[
"out"
,
"dq"
,
"dk"
,
"dv"
,
"d_softmax_offset"
]
names
=
[
"out"
,
"dq"
,
"dk"
,
"dv"
,
"d_softmax_offset"
,
"max_logit"
]
names_cp
=
[
x
+
"_cp"
for
x
in
names
]
names_cp
=
[
x
+
"_cp"
for
x
in
names
]
names_no_cp
=
[
x
+
"_no_cp"
for
x
in
names
]
names_no_cp
=
[
x
+
"_no_cp"
for
x
in
names
]
is_fp8
=
dtype
==
"fp8"
is_fp8
=
dtype
==
"fp8"
for
i
,
t
in
enumerate
(
tensors_no_cp
):
for
i
,
t
in
enumerate
(
tensors_no_cp
):
if
t
is
not
None
:
if
t
is
not
None
:
if
"softmax_offset"
not
in
names
[
i
]:
if
"softmax_offset"
not
in
names
[
i
]
and
"max_logit"
not
in
names
[
i
]:
if
qkv_format
==
"bshd"
:
if
qkv_format
==
"bshd"
:
compare_and_assert
(
compare_and_assert
(
t
[:,
0
],
t
[:,
0
],
...
...
tests/pytorch/attention/test_attention.py
View file @
0a5016b1
...
@@ -60,8 +60,16 @@ from utils import (
...
@@ -60,8 +60,16 @@ from utils import (
get_available_attention_backends
,
get_available_attention_backends
,
)
)
# Check if hardware supports FP8
# Check if hardware supports FP8
attention.
fp8_available
,
reason_for_no_fp8
=
is_fp8_available
(
return_reason
=
True
)
fp8_available
,
reason_for_no_fp8
=
is_fp8_available
(
return_reason
=
True
)
fp8_attn_available
,
reason_for_no_fp8_attn
=
fp8_available
,
reason_for_no_fp8
device_compute_capability
=
get_device_compute_capability
()
if
fp8_available
and
(
device_compute_capability
<
(
9
,
0
)
or
device_compute_capability
>=
(
12
,
0
)):
fp8_attn_available
=
False
reason_for_no_fp8_attn
=
(
"FP8 attention is not supported for compute capability ="
f
" sm
{
device_compute_capability
[
0
]
*
10
+
device_compute_capability
[
1
]
}
"
)
# Reset RNG seed and states
# Reset RNG seed and states
seed
=
1234
seed
=
1234
...
@@ -130,6 +138,11 @@ def test_dot_product_attention(
...
@@ -130,6 +138,11 @@ def test_dot_product_attention(
if
config
.
window_size
==
(
-
1
,
-
1
)
and
swa
:
if
config
.
window_size
==
(
-
1
,
-
1
)
and
swa
:
config
.
window_size
=
[
2
,
2
]
config
.
window_size
=
[
2
,
2
]
config
.
window_size
=
check_set_window_size
(
config
.
attn_mask_type
,
config
.
window_size
)
config
.
window_size
=
check_set_window_size
(
config
.
attn_mask_type
,
config
.
window_size
)
qkv_format
=
qkv_layout
.
replace
(
"3"
,
""
).
replace
(
"2"
,
""
).
split
(
"_"
)[
0
]
if
qkv_format
==
"thd"
and
"padding"
not
in
config
.
attn_mask_type
:
config
.
attn_mask_type
=
(
"padding_"
+
config
.
attn_mask_type
if
config
.
attn_mask_type
!=
"no_mask"
else
"padding"
)
# Get backends
# Get backends
is_training
=
True
is_training
=
True
...
@@ -171,7 +184,7 @@ def test_dot_product_attention(
...
@@ -171,7 +184,7 @@ def test_dot_product_attention(
# UnfusedDotProductAttention backend
# UnfusedDotProductAttention backend
if
unfused_attn_supported
:
if
unfused_attn_supported
:
unfused_attn_fwd
,
unfused_attn_bwd
=
_run_dot_product_attention
(
unfused_attn_fwd
,
unfused_max_logit
,
unfused_attn_bwd
=
_run_dot_product_attention
(
dtype
,
dtype
,
config
,
config
,
"UnfusedDotProductAttention"
,
"UnfusedDotProductAttention"
,
...
@@ -185,7 +198,7 @@ def test_dot_product_attention(
...
@@ -185,7 +198,7 @@ def test_dot_product_attention(
# FusedAttention backend
# FusedAttention backend
if
fused_attn_supported
:
if
fused_attn_supported
:
if
len
(
fused_attn_backends
)
==
1
:
if
len
(
fused_attn_backends
)
==
1
:
fused_attn_fwd
,
fused_attn_bwd
=
_run_dot_product_attention
(
fused_attn_fwd
,
fused_max_logit
,
fused_attn_bwd
=
_run_dot_product_attention
(
dtype
,
dtype
,
config
,
config
,
"FusedAttention"
,
"FusedAttention"
,
...
@@ -197,7 +210,7 @@ def test_dot_product_attention(
...
@@ -197,7 +210,7 @@ def test_dot_product_attention(
)
)
if
len
(
fused_attn_backends
)
==
2
:
if
len
(
fused_attn_backends
)
==
2
:
os
.
environ
[
"NVTE_FUSED_ATTN_BACKEND"
]
=
"0"
os
.
environ
[
"NVTE_FUSED_ATTN_BACKEND"
]
=
"0"
fused_attn_fwd
,
fused_attn_bwd
=
_run_dot_product_attention
(
fused_attn_fwd
,
_
,
fused_attn_bwd
=
_run_dot_product_attention
(
dtype
,
dtype
,
config
,
config
,
"FusedAttention"
,
"FusedAttention"
,
...
@@ -208,7 +221,7 @@ def test_dot_product_attention(
...
@@ -208,7 +221,7 @@ def test_dot_product_attention(
is_training
,
is_training
,
)
)
os
.
environ
[
"NVTE_FUSED_ATTN_BACKEND"
]
=
"1"
os
.
environ
[
"NVTE_FUSED_ATTN_BACKEND"
]
=
"1"
fused_attn_fwd_1
,
fused_attn_bwd_1
=
_run_dot_product_attention
(
fused_attn_fwd_1
,
_
,
fused_attn_bwd_1
=
_run_dot_product_attention
(
dtype
,
dtype
,
config
,
config
,
"FusedAttention"
,
"FusedAttention"
,
...
@@ -221,7 +234,7 @@ def test_dot_product_attention(
...
@@ -221,7 +234,7 @@ def test_dot_product_attention(
# FlashAttention backend
# FlashAttention backend
if
flash_attn_supported
:
if
flash_attn_supported
:
flash_attn_fwd
,
flash_attn_bwd
=
_run_dot_product_attention
(
flash_attn_fwd
,
_
,
flash_attn_bwd
=
_run_dot_product_attention
(
dtype
,
dtype
,
config
,
config
,
"FlashAttention"
,
"FlashAttention"
,
...
@@ -242,6 +255,8 @@ def test_dot_product_attention(
...
@@ -242,6 +255,8 @@ def test_dot_product_attention(
if
unfused_attn_supported
and
fused_attn_supported
:
if
unfused_attn_supported
and
fused_attn_supported
:
logging
.
info
(
"[test_dot_product_attention]: unfused attn vs fused attn"
)
logging
.
info
(
"[test_dot_product_attention]: unfused attn vs fused attn"
)
torch
.
testing
.
assert_close
(
fused_attn_fwd
,
unfused_attn_fwd
,
**
tols
)
torch
.
testing
.
assert_close
(
fused_attn_fwd
,
unfused_attn_fwd
,
**
tols
)
if
config
.
return_max_logit
:
torch
.
testing
.
assert_close
(
fused_max_logit
,
unfused_max_logit
,
**
tols
)
for
i
,
_
in
enumerate
(
unfused_attn_bwd
):
for
i
,
_
in
enumerate
(
unfused_attn_bwd
):
torch
.
testing
.
assert_close
(
fused_attn_bwd
[
i
],
unfused_attn_bwd
[
i
],
**
tols
)
torch
.
testing
.
assert_close
(
fused_attn_bwd
[
i
],
unfused_attn_bwd
[
i
],
**
tols
)
if
fused_attn_supported
and
flash_attn_supported
:
if
fused_attn_supported
and
flash_attn_supported
:
...
@@ -265,6 +280,33 @@ def test_dpa_checkpoint(dtype, model_configs, model):
...
@@ -265,6 +280,33 @@ def test_dpa_checkpoint(dtype, model_configs, model):
test_dot_product_attention
(
dtype
,
model_configs
,
model
,
True
,
True
,
None
,
False
,
False
)
test_dot_product_attention
(
dtype
,
model_configs
,
model
,
True
,
True
,
None
,
False
,
False
)
model_configs_max_logit
=
{
# test: ModelConfig(b, sq, hq, dqk)
"max_logit_1"
:
ModelConfig
(
1
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
),
"max_logit_2"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
attn_mask_type
=
"causal"
),
"max_logit_3"
:
ModelConfig
(
2
,
1
,
16
,
128
,
max_seqlen_kv
=
2048
,
attn_mask_type
=
"padding_causal"
),
"max_logit_4"
:
ModelConfig
(
8
,
128
,
16
,
192
,
max_seqlen_kv
=
2048
,
attn_bias_type
=
"post_scale_bias"
),
"max_logit_5"
:
ModelConfig
(
8
,
128
,
16
,
512
,
max_seqlen_kv
=
2048
,
attn_mask_type
=
"causal"
,
window_size
=
(
20
,
0
)
),
"max_logit_6"
:
ModelConfig
(
8
,
1
,
16
,
1024
,
max_seqlen_kv
=
2048
),
}
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
8
,
9
,
1
),
reason
=
"cuDNN 8.9.1+ is required."
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"model_configs"
,
[
model_configs_max_logit
])
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_max_logit
.
keys
())
@
pytest
.
mark
.
parametrize
(
"qkv_layout"
,
[
"sbhd_sbhd_sbhd"
,
"thd_thd_thd"
])
def
test_dpa_max_logit
(
dtype
,
model_configs
,
model
,
qkv_layout
):
"""Test DotProductAttention module with checkpointing"""
config
=
model_configs
[
model
]
config
.
return_max_logit
=
True
test_dot_product_attention
(
dtype
,
model_configs
,
model
,
False
,
True
,
qkv_layout
,
False
,
False
)
model_configs_softmax
=
{
model_configs_softmax
=
{
# test: ModelConfig(b, sq, hq, dqk)
# test: ModelConfig(b, sq, hq, dqk)
"softmax_1_0"
:
ModelConfig
(
2
,
2048
,
64
,
64
,
num_gqa_groups
=
8
),
"softmax_1_0"
:
ModelConfig
(
2
,
2048
,
64
,
64
,
num_gqa_groups
=
8
),
...
@@ -962,6 +1004,8 @@ def _run_dot_product_attention(
...
@@ -962,6 +1004,8 @@ def _run_dot_product_attention(
layout
=
layout
.
replace
(
"d"
,
"dqk"
)
layout
=
layout
.
replace
(
"d"
,
"dqk"
)
tensor_shape
=
[
dim_to_num
[
j
]
for
j
in
layout
.
split
(
"_"
)]
tensor_shape
=
[
dim_to_num
[
j
]
for
j
in
layout
.
split
(
"_"
)]
tensor
=
0.1
*
torch
.
randn
(
tensor_shape
,
dtype
=
dtype
,
device
=
"cuda"
)
tensor
=
0.1
*
torch
.
randn
(
tensor_shape
,
dtype
=
dtype
,
device
=
"cuda"
)
# tensor: with padding tokens
# tensor_orig: without padding tokens
tensor_orig
=
tensor
tensor_orig
=
tensor
if
qkv_format
==
"thd"
and
pad_between_seqs
:
if
qkv_format
==
"thd"
and
pad_between_seqs
:
tensor_orig
=
torch
.
Tensor
([]).
to
(
device
=
"cuda"
,
dtype
=
dtype
)
tensor_orig
=
torch
.
Tensor
([]).
to
(
device
=
"cuda"
,
dtype
=
dtype
)
...
@@ -1071,6 +1115,7 @@ def _run_dot_product_attention(
...
@@ -1071,6 +1115,7 @@ def _run_dot_product_attention(
layer_number
=
1
,
layer_number
=
1
,
attention_type
=
config
.
attn_type
,
attention_type
=
config
.
attn_type
,
softmax_type
=
config
.
softmax_type
,
softmax_type
=
config
.
softmax_type
,
return_max_logit
=
config
.
return_max_logit
,
).
to
(
dtype
=
dtype
,
device
=
"cuda"
)
).
to
(
dtype
=
dtype
,
device
=
"cuda"
)
if
not
is_training
:
if
not
is_training
:
block
=
block
.
eval
()
block
=
block
.
eval
()
...
@@ -1108,16 +1153,21 @@ def _run_dot_product_attention(
...
@@ -1108,16 +1153,21 @@ def _run_dot_product_attention(
alibi_slopes
=
alibi_slopes
,
alibi_slopes
=
alibi_slopes
,
fast_zero_fill
=
True
,
fast_zero_fill
=
True
,
)
)
max_logit
=
None
if
config
.
return_max_logit
:
out
,
max_logit
=
out
if
is_training
:
if
is_training
:
out
.
backward
(
d_out
)
out
.
backward
(
d_out
)
d_softmax_offset
=
None
d_softmax_offset
=
None
if
is_training
and
config
.
softmax_type
!=
"vanilla"
:
if
is_training
and
config
.
softmax_type
!=
"vanilla"
:
d_softmax_offset
=
block
.
softmax_offset
.
grad
d_softmax_offset
=
block
.
softmax_offset
.
grad
if
backend
in
[
"FlashAttention"
,
"UnfusedDotProductAttention"
]:
if
backend
in
[
"FlashAttention"
,
"UnfusedDotProductAttention"
]:
if
is_training
:
if
is_training
:
return
out
,
(
q
.
grad
,
k
.
grad
,
v
.
grad
,
d_softmax_offset
)
return
out
,
max_logit
,
(
q
.
grad
,
k
.
grad
,
v
.
grad
,
d_softmax_offset
)
else
:
else
:
return
out
,
(
None
,
None
,
None
,
d_softmax_offset
)
return
out
,
max_logit
,
(
None
,
None
,
None
,
d_softmax_offset
)
if
backend
==
"FusedAttention"
:
if
backend
==
"FusedAttention"
:
if
qkv_format
==
"thd"
and
pad_between_seqs
:
if
qkv_format
==
"thd"
and
pad_between_seqs
:
out_orig
=
torch
.
Tensor
([]).
to
(
device
=
"cuda"
,
dtype
=
dtype
)
out_orig
=
torch
.
Tensor
([]).
to
(
device
=
"cuda"
,
dtype
=
dtype
)
...
@@ -1146,14 +1196,18 @@ def _run_dot_product_attention(
...
@@ -1146,14 +1196,18 @@ def _run_dot_product_attention(
[
v_grad_orig
,
v
.
grad
[
valid_range_kv
[
0
]
:
valid_range_kv
[
1
]]],
dim
=
0
[
v_grad_orig
,
v
.
grad
[
valid_range_kv
[
0
]
:
valid_range_kv
[
1
]]],
dim
=
0
)
)
if
is_training
:
if
is_training
:
return
out_orig
,
(
q_grad_orig
,
k_grad_orig
,
v_grad_orig
,
d_softmax_offset
)
return
(
out_orig
,
max_logit
,
(
q_grad_orig
,
k_grad_orig
,
v_grad_orig
,
d_softmax_offset
),
)
else
:
else
:
return
out_orig
,
(
None
,
None
,
None
,
d_softmax_offset
)
return
out_orig
,
max_logit
,
(
None
,
None
,
None
,
d_softmax_offset
)
else
:
else
:
if
is_training
:
if
is_training
:
return
out
,
(
q
.
grad
,
k
.
grad
,
v
.
grad
,
d_softmax_offset
)
return
out
,
max_logit
,
(
q
.
grad
,
k
.
grad
,
v
.
grad
,
d_softmax_offset
)
else
:
else
:
return
out
,
(
None
,
None
,
None
,
d_softmax_offset
)
return
out
,
max_logit
,
(
None
,
None
,
None
,
d_softmax_offset
)
model_configs_te_layer
=
{
model_configs_te_layer
=
{
...
@@ -1527,8 +1581,7 @@ model_configs_fp8_extra_state = {
...
@@ -1527,8 +1581,7 @@ model_configs_fp8_extra_state = {
}
}
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
skipif
(
not
fp8_attn_available
,
reason
=
reason_for_no_fp8_attn
)
@
pytest
.
mark
.
skipif
(
get_device_compute_capability
()
<
(
9
,
0
),
reason
=
"FP8 tests require Hopper."
)
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
9
,
3
,
0
),
reason
=
"cuDNN 9.3.0+ is required."
)
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
9
,
3
,
0
),
reason
=
"cuDNN 9.3.0+ is required."
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"large"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"large"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
...
@@ -1690,8 +1743,7 @@ qkv_format_fp8_vs_f16 = ["bshd", "sbhd"]
...
@@ -1690,8 +1743,7 @@ qkv_format_fp8_vs_f16 = ["bshd", "sbhd"]
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
9
,
2
,
1
),
reason
=
"cuDNN 9.2.1+ is required."
)
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
9
,
2
,
1
),
reason
=
"cuDNN 9.2.1+ is required."
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
skipif
(
not
fp8_attn_available
,
reason
=
reason_for_no_fp8_attn
)
@
pytest
.
mark
.
skipif
(
get_device_compute_capability
()
<
(
9
,
0
),
reason
=
"FP8 tests require Hopper+."
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types_fp8_vs_f16
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types_fp8_vs_f16
)
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_fp8_vs_f16
.
keys
())
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_fp8_vs_f16
.
keys
())
@
pytest
.
mark
.
parametrize
(
"qkv_format"
,
qkv_format_fp8_vs_f16
)
@
pytest
.
mark
.
parametrize
(
"qkv_format"
,
qkv_format_fp8_vs_f16
)
...
@@ -1927,8 +1979,7 @@ def _run_mha_fp8_vs_f16(
...
@@ -1927,8 +1979,7 @@ def _run_mha_fp8_vs_f16(
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
9
,
2
,
1
),
reason
=
"cuDNN 9.2.1+ is required."
)
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
9
,
2
,
1
),
reason
=
"cuDNN 9.2.1+ is required."
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
skipif
(
not
fp8_attn_available
,
reason
=
reason_for_no_fp8_attn
)
@
pytest
.
mark
.
skipif
(
get_device_compute_capability
()
<
(
9
,
0
),
reason
=
"FP8 tests require Hopper+."
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types_fp8_vs_f16
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types_fp8_vs_f16
)
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_fp8_vs_f16
.
keys
())
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_fp8_vs_f16
.
keys
())
@
pytest
.
mark
.
parametrize
(
"qkv_layout"
,
qkv_layout_fp8_vs_f16
)
@
pytest
.
mark
.
parametrize
(
"qkv_layout"
,
qkv_layout_fp8_vs_f16
)
...
@@ -2256,8 +2307,7 @@ models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"]
...
@@ -2256,8 +2307,7 @@ models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"]
),
),
reason
=
f
"""cuDNN
{
"8.9.3"
if
cudnn_frontend_version
==
0
else
"9.2.1"
}
+ is required."""
,
reason
=
f
"""cuDNN
{
"8.9.3"
if
cudnn_frontend_version
==
0
else
"9.2.1"
}
+ is required."""
,
)
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
skipif
(
not
fp8_attn_available
,
reason
=
reason_for_no_fp8_attn
)
@
pytest
.
mark
.
skipif
(
get_device_compute_capability
()
<
(
9
,
0
),
reason
=
"FP8 tests require Hopper+."
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types_fp8
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types_fp8
)
@
pytest
.
mark
.
parametrize
(
"model"
,
models_v1
if
cudnn_frontend_version
==
1
else
models_v0
)
@
pytest
.
mark
.
parametrize
(
"model"
,
models_v1
if
cudnn_frontend_version
==
1
else
models_v0
)
def
test_custom_mha_fp8_vs_f16
(
dtype
,
model
):
def
test_custom_mha_fp8_vs_f16
(
dtype
,
model
):
...
...
tests/pytorch/attention/test_attention_with_cp.py
View file @
0a5016b1
...
@@ -138,8 +138,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
...
@@ -138,8 +138,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
model_configs_fused_attn
=
{
model_configs_fused_attn
=
{
# test: ModelConfig(b, sq, hq, dqk)
# test: ModelConfig(b, sq, hq, dqk)
"cp_1_0"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
attn_mask_type
=
"causal"
),
# MHA
"cp_1_0"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
attn_mask_type
=
"causal"
,
return_max_logit
=
True
),
# MHA
"cp_1_1"
:
ModelConfig
(
2
,
4096
,
12
,
128
),
# MHA
"cp_1_1"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
return_max_logit
=
True
),
# MHA
"cp_1_2"
:
ModelConfig
(
"cp_1_2"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"post_scale_bias"
2
,
4096
,
12
,
128
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"post_scale_bias"
),
# MHA
),
# MHA
...
@@ -184,7 +184,7 @@ dtypes = ["bf16", "fp16", "fp8"]
...
@@ -184,7 +184,7 @@ dtypes = ["bf16", "fp16", "fp8"]
qkv_formats
=
[
"bshd"
,
"sbhd"
,
"thd"
]
qkv_formats
=
[
"bshd"
,
"sbhd"
,
"thd"
]
cp_comm_types
=
[
"p2p"
,
"all_gather"
,
"a2a"
,
"a2a+p2p"
]
cp_comm_types
=
[
"p2p"
,
"all_gather"
,
"a2a"
,
"a2a+p2p"
]
if
test_essential
:
if
test_essential
:
configs
=
[
"cp_1_0"
,
"cp_2_0"
,
"cp_2_2"
,
"cp_3_2"
,
"cp_4_2"
]
configs
=
[
"cp_1_0"
,
"cp_1_1"
,
"cp_2_0"
,
"cp_2_2"
,
"cp_3_2"
,
"cp_4_2"
]
model_configs_fused_attn
=
{
k
:
model_configs_fused_attn
[
k
]
for
k
in
configs
}
model_configs_fused_attn
=
{
k
:
model_configs_fused_attn
[
k
]
for
k
in
configs
}
dtypes
=
[
"bf16"
,
"fp8"
]
dtypes
=
[
"bf16"
,
"fp8"
]
qkv_formats
=
[
"sbhd"
,
"thd"
]
qkv_formats
=
[
"sbhd"
,
"thd"
]
...
...
tests/pytorch/test_numerics.py
View file @
0a5016b1
...
@@ -45,11 +45,10 @@ from transformer_engine.pytorch import (
...
@@ -45,11 +45,10 @@ from transformer_engine.pytorch import (
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch
import
checkpoint
as
te_checkpoint
from
transformer_engine.pytorch
import
checkpoint
as
te_checkpoint
from
transformer_engine.pytorch.cpp_extensions
import
general_gemm
,
general_grouped_gemm
from
transformer_engine.pytorch.cpp_extensions
import
general_gemm
,
general_grouped_gemm
from
transformer_engine.pytorch.cpp_extensions.fused_attn
import
FusedAttnBackend
from
transformer_engine.pytorch.module.base
import
get_multi_stream_cublas_workspace
,
get_workspace
from
transformer_engine.pytorch.module.base
import
get_multi_stream_cublas_workspace
,
get_workspace
from
transformer_engine.common
import
recipe
from
transformer_engine.common
import
recipe
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
utils
import
ModelConfig
,
reset_rng_states
,
get_available_attention_backends
from
utils
import
ModelConfig
,
reset_rng_states
# Only run FP8 tests on supported devices.
# Only run FP8 tests on supported devices.
...
@@ -135,23 +134,6 @@ if torch.cuda.get_device_capability() == (9, 0):
...
@@ -135,23 +134,6 @@ if torch.cuda.get_device_capability() == (9, 0):
use_cutlass_grouped_gemm
.
append
(
True
)
use_cutlass_grouped_gemm
.
append
(
True
)
def
is_fused_attn_available
(
config
:
ModelConfig
,
dtype
:
torch
.
dtype
,
qkv_layout
=
"bshd_bshd_bshd"
,
is_training
=
True
,
deterministic
=
False
,
):
_
,
_
,
fused_attn_backends
=
get_available_attention_backends
(
config
,
qkv_dtype
=
dtype
,
qkv_layout
=
qkv_layout
,
is_training
=
is_training
,
deterministic
=
deterministic
,
)
return
FusedAttnBackend
[
"F16_arbitrary_seqlen"
]
in
fused_attn_backends
def
get_causal_attn_mask
(
sq
:
int
)
->
torch
.
Tensor
:
def
get_causal_attn_mask
(
sq
:
int
)
->
torch
.
Tensor
:
return
torch
.
triu
(
torch
.
ones
(
sq
,
sq
,
device
=
"cuda"
),
diagonal
=
1
).
bool
()
return
torch
.
triu
(
torch
.
ones
(
sq
,
sq
,
device
=
"cuda"
),
diagonal
=
1
).
bool
()
...
@@ -872,8 +854,6 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
...
@@ -872,8 +854,6 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"126m"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"126m"
])
def
test_gpt_checkpointing
(
dtype
,
bs
,
model
):
def
test_gpt_checkpointing
(
dtype
,
bs
,
model
):
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
not
is_fused_attn_available
(
config
,
dtype
,
deterministic
=
True
):
pytest
.
skip
(
"No attention backend available."
)
outputs
=
_test_e2e_checkpointing
(
bs
,
dtype
,
config
,
checkpoint
=
False
)
outputs
=
_test_e2e_checkpointing
(
bs
,
dtype
,
config
,
checkpoint
=
False
)
outputs_checkpoint
=
_test_e2e_checkpointing
(
bs
,
dtype
,
config
,
checkpoint
=
True
)
outputs_checkpoint
=
_test_e2e_checkpointing
(
bs
,
dtype
,
config
,
checkpoint
=
True
)
...
@@ -920,10 +900,6 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):
...
@@ -920,10 +900,6 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):
@
pytest
.
mark
.
parametrize
(
"parallel_attention_mlp"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"parallel_attention_mlp"
,
all_boolean
)
def
test_gpt_accuracy
(
dtype
,
bs
,
model
,
parallel_attention_mlp
):
def
test_gpt_accuracy
(
dtype
,
bs
,
model
,
parallel_attention_mlp
):
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
not
is_fused_attn_available
(
config
,
dtype
,
qkv_layout
=
"sb3hd"
,
is_training
=
True
,
deterministic
=
True
):
pytest
.
skip
(
"No attention backend available."
)
te_gpt
=
TransformerLayer
(
te_gpt
=
TransformerLayer
(
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
...
@@ -1035,10 +1011,6 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
...
@@ -1035,10 +1011,6 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
@
pytest
.
mark
.
parametrize
(
"mask_type"
,
mask_types
)
@
pytest
.
mark
.
parametrize
(
"mask_type"
,
mask_types
)
def
test_mha_accuracy
(
dtype
,
bs
,
model
,
mask_type
):
def
test_mha_accuracy
(
dtype
,
bs
,
model
,
mask_type
):
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
not
is_fused_attn_available
(
config
,
dtype
,
qkv_layout
=
"sb3hd"
,
is_training
=
True
,
deterministic
=
True
):
pytest
.
skip
(
"No attention backend available."
)
te_mha
=
MultiheadAttention
(
te_mha
=
MultiheadAttention
(
config
.
hidden_size
,
config
.
hidden_size
,
...
...
tests/pytorch/utils.py
View file @
0a5016b1
...
@@ -205,6 +205,7 @@ class ModelConfig:
...
@@ -205,6 +205,7 @@ class ModelConfig:
window_size
:
Tuple
[
int
,
int
]
=
(
-
1
,
-
1
),
window_size
:
Tuple
[
int
,
int
]
=
(
-
1
,
-
1
),
context_parallel
:
bool
=
False
,
context_parallel
:
bool
=
False
,
cp_comm_type
:
str
=
"p2p"
,
cp_comm_type
:
str
=
"p2p"
,
return_max_logit
=
False
,
total_requests
:
int
=
None
,
total_requests
:
int
=
None
,
max_ctx_len
:
int
=
None
,
max_ctx_len
:
int
=
None
,
num_layers
:
int
=
1
,
num_layers
:
int
=
1
,
...
@@ -233,6 +234,7 @@ class ModelConfig:
...
@@ -233,6 +234,7 @@ class ModelConfig:
self
.
window_size
=
check_set_window_size
(
self
.
attn_mask_type
,
window_size
)
self
.
window_size
=
check_set_window_size
(
self
.
attn_mask_type
,
window_size
)
self
.
context_parallel
=
context_parallel
self
.
context_parallel
=
context_parallel
self
.
cp_comm_type
=
cp_comm_type
self
.
cp_comm_type
=
cp_comm_type
self
.
return_max_logit
=
return_max_logit
self
.
total_requests
=
total_requests
self
.
total_requests
=
total_requests
self
.
max_ctx_len
=
max_ctx_len
self
.
max_ctx_len
=
max_ctx_len
self
.
num_layers
=
num_layers
self
.
num_layers
=
num_layers
...
@@ -318,6 +320,7 @@ def get_available_attention_backends(
...
@@ -318,6 +320,7 @@ def get_available_attention_backends(
is_training
=
is_training
,
is_training
=
is_training
,
inference_params
=
inference_params
,
inference_params
=
inference_params
,
softmax_type
=
config
.
softmax_type
,
softmax_type
=
config
.
softmax_type
,
return_max_logit
=
config
.
return_max_logit
,
)
)
(
(
use_flash_attention
,
use_flash_attention
,
...
...
transformer_engine/common/CMakeLists.txt
View file @
0a5016b1
...
@@ -29,35 +29,80 @@ endif()
...
@@ -29,35 +29,80 @@ endif()
# Language options
# Language options
if
(
USE_CUDA
)
if
(
USE_CUDA
)
if
(
NOT DEFINED CMAKE_CUDA_ARCHITECTURES
)
if
(
CUDAToolkit_VERSION VERSION_GREATER_EQUAL 13.0
)
set
(
CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120
)
elseif
(
CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8
)
set
(
CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120
)
else
()
set
(
CMAKE_CUDA_ARCHITECTURES 70 80 89 90
)
endif
()
endif
()
set
(
CMAKE_CXX_STANDARD 17
)
set
(
CMAKE_CXX_STANDARD 17
)
set
(
CMAKE_CUDA_STANDARD 17
)
set
(
CMAKE_CUDA_STANDARD 17
)
set
(
CMAKE_CUDA_STANDARD_REQUIRED ON
)
set
(
CMAKE_CUDA_STANDARD_REQUIRED ON
)
if
(
CMAKE_BUILD_TYPE STREQUAL
"Debug"
)
if
(
CMAKE_BUILD_TYPE STREQUAL
"Debug"
)
set
(
CMAKE_CUDA_FLAGS_DEBUG
"
${
CMAKE_CUDA_FLAGS_DEBUG
}
-g -G"
)
set
(
CMAKE_CUDA_FLAGS_DEBUG
"
${
CMAKE_CUDA_FLAGS_DEBUG
}
-g -G"
)
endif
()
endif
()
# Hide non-necessary symbols in shared object.
# Hide non-necessary symbols in shared object.
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-Wl,--version-script=
${
CMAKE_CURRENT_SOURCE_DIR
}
/libtransformer_engine.version"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-Wl,--version-script=
${
CMAKE_CURRENT_SOURCE_DIR
}
/libtransformer_engine.version"
)
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
-Wl,--version-script=
${
CMAKE_CURRENT_SOURCE_DIR
}
/libtransformer_engine.version"
)
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
-Wl,--version-script=
${
CMAKE_CURRENT_SOURCE_DIR
}
/libtransformer_engine.version"
)
# Transformer Engine library
# Transformer Engine library
project
(
transformer_engine LANGUAGES CUDA CXX
)
project
(
transformer_engine LANGUAGES CUDA CXX
)
# CUDA Toolkit
# CUDA Toolkit
find_package
(
CUDAToolkit REQUIRED
)
find_package
(
CUDAToolkit REQUIRED
)
if
(
CUDAToolkit_VERSION VERSION_LESS 12.
0
)
if
(
CUDAToolkit_VERSION VERSION_LESS 12.
1
)
message
(
FATAL_ERROR
"CUDA 12.
0
+ is required, but found CUDA
${
CUDAToolkit_VERSION
}
"
)
message
(
FATAL_ERROR
"CUDA 12.
1
+ is required, but found CUDA
${
CUDAToolkit_VERSION
}
"
)
endif
()
endif
()
# Process GPU architectures
if
(
NOT DEFINED CMAKE_CUDA_ARCHITECTURES
)
if
(
CUDAToolkit_VERSION VERSION_GREATER_EQUAL 13.0
)
set
(
CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120
)
elseif
(
CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8
)
set
(
CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120
)
else
()
set
(
CMAKE_CUDA_ARCHITECTURES 70 80 89 90
)
endif
()
endif
()
# Process CMAKE_CUDA_ARCHITECTURES to separate generic and specific architectures
set
(
NVTE_GENERIC_ARCHS
)
set
(
NVTE_SPECIFIC_ARCHS
)
# Check for architecture 100
list
(
FIND CMAKE_CUDA_ARCHITECTURES
"100"
arch_100_index
)
if
(
NOT arch_100_index EQUAL -1
)
list
(
REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES
"100"
)
list
(
APPEND NVTE_GENERIC_ARCHS
"100"
)
list
(
APPEND NVTE_SPECIFIC_ARCHS
"100a"
)
if
(
CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.9
)
list
(
APPEND NVTE_SPECIFIC_ARCHS
"103a"
)
endif
()
endif
()
# Check for architecture 101 (if we see this we are in toolkit <= 12.9)
list
(
FIND CMAKE_CUDA_ARCHITECTURES
"101"
arch_101_index
)
if
(
NOT arch_101_index EQUAL -1
)
list
(
REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES
"101"
)
list
(
APPEND NVTE_GENERIC_ARCHS
"101"
)
list
(
APPEND NVTE_SPECIFIC_ARCHS
"101a"
)
endif
()
# Check for architecture 110 (if we see this we are in toolkit >= 13.0)
list
(
FIND CMAKE_CUDA_ARCHITECTURES
"110"
arch_110_index
)
if
(
NOT arch_110_index EQUAL -1
)
list
(
REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES
"110"
)
list
(
APPEND NVTE_GENERIC_ARCHS
"110"
)
list
(
APPEND NVTE_SPECIFIC_ARCHS
"110f"
)
endif
()
# Check for architecture 120
list
(
FIND CMAKE_CUDA_ARCHITECTURES
"120"
arch_120_index
)
if
(
NOT arch_120_index EQUAL -1
)
list
(
REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES
"120"
)
list
(
APPEND NVTE_GENERIC_ARCHS
"120"
)
if
(
CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.9
)
list
(
APPEND NVTE_SPECIFIC_ARCHS
"120f"
)
else
()
list
(
APPEND NVTE_SPECIFIC_ARCHS
"120a"
)
endif
()
endif
()
# cuDNN frontend API
# cuDNN frontend API
set
(
CUDNN_FRONTEND_INCLUDE_DIR
set
(
CUDNN_FRONTEND_INCLUDE_DIR
"
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../3rdparty/cudnn-frontend/include"
)
"
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../3rdparty/cudnn-frontend/include"
)
...
@@ -135,139 +180,206 @@ endif()
...
@@ -135,139 +180,206 @@ endif()
# Configure Transformer Engine library
# Configure Transformer Engine library
include_directories
(
${
PROJECT_SOURCE_DIR
}
/..
)
include_directories
(
${
PROJECT_SOURCE_DIR
}
/..
)
set
(
transformer_engine_SOURCES
)
set
(
transformer_engine_SOURCES
)
set
(
transformer_engine_cpp_sources
)
set
(
transformer_engine_cuda_sources
)
set
(
transformer_engine_cuda_arch_specific_sources
)
if
(
USE_CUDA
)
if
(
USE_CUDA
)
list
(
APPEND transformer_engine_SOURCES
list
(
APPEND transformer_engine_cpp_sources
cudnn_utils.cpp
cudnn_utils.cpp
transformer_engine.cpp
transformer_engine.cpp
common.cu
fused_attn/fused_attn.cpp
multi_tensor/adam.cu
gemm/config.cpp
multi_tensor/compute_scale.cu
normalization/common.cpp
multi_tensor/l2norm.cu
normalization/layernorm/ln_api.cpp
multi_tensor/scale.cu
normalization/rmsnorm/rmsnorm_api.cpp
multi_tensor/sgd.cu
util/cuda_driver.cpp
transpose/cast_transpose.cu
util/cuda_nvml.cpp
transpose/transpose.cu
util/cuda_runtime.cpp
transpose/cast_transpose_fusion.cu
util/multi_stream.cpp
transpose/transpose_fusion.cu
util/rtc.cpp
transpose/multi_cast_transpose.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc
transpose/quantize_transpose_square_blockwise.cu
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
transpose/quantize_transpose_vector_blockwise.cu
comm_gemm_overlap/comm_gemm_overlap.cpp
)
transpose/swap_first_dims.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu
list
(
APPEND transformer_engine_cuda_sources
activation/gelu.cu
common.cu
dropout/dropout.cu
multi_tensor/adam.cu
fused_attn/flash_attn.cu
multi_tensor/compute_scale.cu
fused_attn/context_parallel.cu
multi_tensor/l2norm.cu
fused_attn/kv_cache.cu
multi_tensor/scale.cu
fused_attn/fused_attn_f16_max512_seqlen.cu
multi_tensor/sgd.cu
fused_attn/fused_attn_f16_arbitrary_seqlen.cu
transpose/cast_transpose.cu
activation/relu.cu
transpose/transpose.cu
activation/swiglu.cu
transpose/cast_transpose_fusion.cu
fused_attn/fused_attn_fp8.cu
transpose/transpose_fusion.cu
fused_attn/fused_attn.cpp
transpose/multi_cast_transpose.cu
fused_attn/utils.cu
transpose/quantize_transpose_vector_blockwise.cu
gemm/config.cpp
transpose/swap_first_dims.cu
gemm/cublaslt_gemm.cu
dropout/dropout.cu
gemm/cutlass_grouped_gemm.cu
fused_attn/flash_attn.cu
normalization/common.cpp
fused_attn/context_parallel.cu
normalization/layernorm/ln_api.cpp
fused_attn/kv_cache.cu
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
fused_attn/fused_attn_f16_max512_seqlen.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu
fused_attn/fused_attn_f16_arbitrary_seqlen.cu
normalization/rmsnorm/rmsnorm_api.cpp
fused_attn/fused_attn_fp8.cu
normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
fused_attn/utils.cu
normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
gemm/cublaslt_gemm.cu
permutation/permutation.cu
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
util/cast.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu
util/padding.cu
normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
util/cuda_driver.cpp
normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
util/cuda_nvml.cpp
permutation/permutation.cu
util/cuda_runtime.cpp
util/padding.cu
util/multi_stream.cpp
swizzle/swizzle.cu
util/rtc.cpp
swizzle/swizzle_block_scaling.cu
swizzle/swizzle.cu
fused_softmax/scaled_masked_softmax.cu
swizzle/swizzle_block_scaling.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_rope/fused_rope.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
fused_router/fused_moe_aux_loss.cu
fused_rope/fused_rope.cu
fused_router/fused_score_for_moe_aux_loss.cu
fused_router/fused_moe_aux_loss.cu
fused_router/fused_topk_with_score_function.cu
fused_router/fused_score_for_moe_aux_loss.cu
recipe/current_scaling.cu
fused_router/fused_topk_with_score_function.cu
recipe/delayed_scaling.cu
recipe/current_scaling.cu
recipe/fp8_block_scaling.cu
recipe/delayed_scaling.cu
recipe/nvfp4.cu
recipe/fp8_block_scaling.cu
comm_gemm_overlap/userbuffers/userbuffers.cu
)
recipe/nvfp4.cu
hadamard_transform/hadamard_transform.cu
list
(
APPEND transformer_engine_cuda_arch_specific_sources
hadamard_transform/hadamard_transform_cast_fusion.cu
gemm/cutlass_grouped_gemm.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc
util/cast.cu
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
activation/gelu.cu
comm_gemm_overlap/userbuffers/userbuffers.cu
activation/relu.cu
comm_gemm_overlap/comm_gemm_overlap.cpp
)
activation/swiglu.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu
hadamard_transform/hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu
)
# Compiling the files with the worst compilation time first to hopefully overlap
# better with the faster-compiling cpp files
list
(
APPEND transformer_engine_SOURCES
${
transformer_engine_cuda_arch_specific_sources
}
${
transformer_engine_cuda_sources
}
${
transformer_engine_cpp_sources
}
)
# Set compile options for CUDA sources with generic architectures
foreach
(
cuda_source IN LISTS transformer_engine_cuda_sources
)
set
(
arch_compile_options
)
foreach
(
arch IN LISTS NVTE_GENERIC_ARCHS
)
list
(
APPEND arch_compile_options
"--generate-code=arch=compute_
${
arch
}
,code=sm_
${
arch
}
"
)
endforeach
()
if
(
arch_compile_options
)
set_property
(
SOURCE
${
cuda_source
}
APPEND
PROPERTY
COMPILE_OPTIONS
${
arch_compile_options
}
)
endif
()
endforeach
()
# Set compile options for CUDA sources with specific architectures
foreach
(
cuda_source IN LISTS transformer_engine_cuda_arch_specific_sources
)
set
(
arch_compile_options
)
foreach
(
arch IN LISTS NVTE_SPECIFIC_ARCHS
)
list
(
APPEND arch_compile_options
"--generate-code=arch=compute_
${
arch
}
,code=sm_
${
arch
}
"
)
endforeach
()
if
(
arch_compile_options
)
set_property
(
SOURCE
${
cuda_source
}
APPEND
PROPERTY
COMPILE_OPTIONS
${
arch_compile_options
}
)
endif
()
endforeach
()
if
(
NVTE_WITH_CUBLASMP
)
if
(
NVTE_WITH_CUBLASMP
)
list
(
APPEND transformer_engine_SOURCES
list
(
APPEND transformer_engine_SOURCES
comm_gemm/comm_gemm.cpp
)
comm_gemm/comm_gemm.cpp
)
endif
()
endif
()
add_library
(
transformer_engine SHARED
${
transformer_engine_SOURCES
}
)
add_library
(
transformer_engine SHARED
${
transformer_engine_SOURCES
}
)
else
()
else
()
list
(
APPEND transformer_engine_SOURCES
list
(
APPEND transformer_engine_cpp_sources
cudnn_utils.cpp
cudnn_utils.cpp
transformer_engine.cpp
transformer_engine.cpp
common.cu
gemm/config.cpp
fused_attn/flash_attn.cu
normalization/common.cpp
fused_attn/context_parallel.cu
normalization/layernorm/ln_api.cpp
fused_attn/kv_cache.cu
normalization/rmsnorm/rmsnorm_api.cpp
multi_tensor/adam.cu
util/cuda_driver.cpp
multi_tensor/compute_scale.cu
util/cuda_nvml.cpp
multi_tensor/l2norm.cu
util/cuda_runtime.cpp
multi_tensor/scale.cu
util/multi_stream.cpp
multi_tensor/sgd.cu
util/rtc.cpp
transpose/cast_transpose.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc
transpose/transpose.cu
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
transpose/cast_transpose_fusion.cu
comm_gemm_overlap/comm_gemm_overlap.cpp
)
transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu
list
(
APPEND transformer_engine_cuda_sources
transpose/quantize_transpose_square_blockwise.cu
common.cu
transpose/quantize_transpose_vector_blockwise.cu
multi_tensor/adam.cu
transpose/swap_first_dims.cu
multi_tensor/compute_scale.cu
activation/gelu.cu
multi_tensor/l2norm.cu
dropout/dropout.cu
multi_tensor/scale.cu
activation/relu.cu
multi_tensor/sgd.cu
activation/swiglu.cu
transpose/cast_transpose.cu
gemm/config.cpp
transpose/transpose.cu
gemm/cublaslt_gemm.cu
transpose/cast_transpose_fusion.cu
gemm/hipblas_gemm.cu
transpose/transpose_fusion.cu
normalization/common.cpp
transpose/multi_cast_transpose.cu
normalization/layernorm/ln_api.cpp
transpose/quantize_transpose_vector_blockwise.cu
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
transpose/swap_first_dims.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu
dropout/dropout.cu
normalization/rmsnorm/rmsnorm_api.cpp
fused_attn/flash_attn.cu
normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
fused_attn/context_parallel.cu
normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
fused_attn/kv_cache.cu
permutation/permutation.cu
fused_attn/fused_attn_f16_max512_seqlen.cu
util/cast.cu
fused_attn/fused_attn_f16_arbitrary_seqlen.cu
util/padding.cu
fused_attn/fused_attn_fp8.cu
util/cuda_driver.cpp
fused_attn/utils.cu
util/cuda_nvml.cpp
gemm/cublaslt_gemm.cu
util/cuda_runtime.cpp
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
util/multi_stream.cpp
normalization/layernorm/ln_fwd_cuda_kernel.cu
util/rtc.cpp
normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
swizzle/swizzle.cu
normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
swizzle/swizzle_block_scaling.cu
permutation/permutation.cu
fused_softmax/scaled_masked_softmax.cu
util/padding.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
swizzle/swizzle.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
swizzle/swizzle_block_scaling.cu
fused_rope/fused_rope.cu
fused_softmax/scaled_masked_softmax.cu
fused_router/fused_moe_aux_loss.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_router/fused_score_for_moe_aux_loss.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
fused_router/fused_topk_with_score_function.cu
fused_rope/fused_rope.cu
recipe/current_scaling.cu
fused_router/fused_moe_aux_loss.cu
recipe/delayed_scaling.cu
fused_router/fused_score_for_moe_aux_loss.cu
recipe/fp8_block_scaling.cu
fused_router/fused_topk_with_score_function.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc
recipe/current_scaling.cu
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
recipe/delayed_scaling.cu
comm_gemm_overlap/userbuffers/userbuffers.cu
recipe/fp8_block_scaling.cu
comm_gemm_overlap/comm_gemm_overlap.cpp
)
recipe/nvfp4.cu
comm_gemm_overlap/userbuffers/userbuffers.cu
)
list
(
APPEND transformer_engine_cuda_arch_specific_sources
gemm/cutlass_grouped_gemm.cu
util/cast.cu
activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu
hadamard_transform/hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu
)
# Compiling the files with the worst compilation time first to hopefully overlap
# better with the faster-compiling cpp files
list
(
APPEND transformer_engine_SOURCES
${
transformer_engine_cuda_arch_specific_sources
}
${
transformer_engine_cuda_sources
}
${
transformer_engine_cpp_sources
}
)
if
(
NVTE_WITH_CUBLASMP
)
if
(
NVTE_WITH_CUBLASMP
)
list
(
APPEND transformer_engine_SOURCES
list
(
APPEND transformer_engine_SOURCES
comm_gemm/comm_gemm.cpp
)
comm_gemm/comm_gemm.cpp
)
...
@@ -316,10 +428,12 @@ if (USE_CUDA)
...
@@ -316,10 +428,12 @@ if (USE_CUDA)
CUDA::cublas
CUDA::cublas
CUDA::cudart
CUDA::cudart
CUDNN::cudnn_all
)
CUDNN::cudnn_all
)
target_include_directories
(
transformer_engine PRIVATE
target_include_directories
(
transformer_engine PRIVATE
${
CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES
}
)
${
CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES
}
)
target_include_directories
(
transformer_engine PRIVATE
${
MATHDX_INCLUDE_DIR
}
)
target_include_directories
(
transformer_engine SYSTEM PRIVATE
target_include_directories
(
transformer_engine SYSTEM PRIVATE
${
CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES
}
/cccl
)
${
CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES
}
/cccl
)
target_include_directories
(
transformer_engine PRIVATE
"
${
CUDNN_FRONTEND_INCLUDE_DIR
}
"
)
target_include_directories
(
transformer_engine PRIVATE
"
${
CUDNN_FRONTEND_INCLUDE_DIR
}
"
)
target_include_directories
(
transformer_engine PRIVATE
target_include_directories
(
transformer_engine PRIVATE
${
CUTLASS_INCLUDE_DIR
}
${
CUTLASS_INCLUDE_DIR
}
...
@@ -436,30 +550,36 @@ target_include_directories(transformer_engine PRIVATE
...
@@ -436,30 +550,36 @@ target_include_directories(transformer_engine PRIVATE
"
${
CMAKE_CURRENT_BINARY_DIR
}
/string_headers"
)
"
${
CMAKE_CURRENT_BINARY_DIR
}
/string_headers"
)
# Compiler options
# Compiler options
set_source
_files_properties
(
fused_softmax/scaled_masked_softmax.cu
set
(
nvte
_source
s_with_fast_math
)
fused_softmax/scaled_
upper_triang_
masked_softmax.cu
list
(
APPEND nvte_sources_with_fast_math
fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_
aligned_causal
_masked_softmax.cu
fused_softmax/scaled_
upper_triang
_masked_softmax.cu
multi_tensor/adam
.cu
fused_softmax/scaled_aligned_causal_masked_softmax
.cu
multi_tensor/
compute_scale
.cu
multi_tensor/
adam
.cu
multi_tensor/
l2norm
.cu
multi_tensor/
compute_scale
.cu
multi_tensor/
scale
.cu
multi_tensor/
l2norm
.cu
multi_tensor/s
gd
.cu
multi_tensor/s
cale
.cu
fused_attn/flash_attn
.cu
multi_tensor/sgd
.cu
fused_attn/context_parallel
.cu
fused_attn/flash_attn
.cu
fused_attn/kv_cache
.cu
fused_attn/context_parallel
.cu
PROPERTIES
fused_attn/kv_cache.cu
)
COMPILE_OPTIONS
"--use_fast_math"
)
option
(
NVTE_BUILD_ACTIVATION_WITH_FAST_MATH
"Compile activation kernels with --use_fast_math option"
OFF
)
option
(
NVTE_BUILD_ACTIVATION_WITH_FAST_MATH
"Compile activation kernels with --use_fast_math option"
OFF
)
if
(
NVTE_BUILD_ACTIVATION_WITH_FAST_MATH
)
if
(
NVTE_BUILD_ACTIVATION_WITH_FAST_MATH
)
set_source_files_properties
(
activation/gelu.cu
list
(
APPEND nvte_sources_with_fast_math activation/gelu.cu
activation/relu.cu
activation/relu.cu
activation/swiglu.cu
activation/swiglu.cu
util/cast.cu
util/cast.cu
)
PROPERTIES
COMPILE_OPTIONS
"--use_fast_math"
)
endif
()
endif
()
if
(
USE_CUDA
)
if
(
USE_CUDA
)
foreach
(
cuda_source IN LISTS nvte_sources_with_fast_math
)
set_property
(
SOURCE
${
cuda_source
}
APPEND
PROPERTY
COMPILE_OPTIONS
"--use_fast_math"
)
endforeach
()
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
--expt-relaxed-constexpr"
)
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
--expt-relaxed-constexpr"
)
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
-O3"
)
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
-O3"
)
else
()
else
()
...
...
transformer_engine/common/__init__.py
View file @
0a5016b1
...
@@ -8,22 +8,18 @@ import ctypes
...
@@ -8,22 +8,18 @@ import ctypes
import
functools
import
functools
import
glob
import
glob
import
importlib
import
importlib
from
importlib.metadata
import
version
,
metadata
,
PackageNotFoundError
from
importlib.metadata
import
version
,
distribution
,
PackageNotFoundError
import
logging
import
os
import
os
from
pathlib
import
Path
from
pathlib
import
Path
import
platform
import
platform
import
subprocess
import
subprocess
import
sys
import
sys
import
sysconfig
import
sysconfig
from
typing
import
Optional
from
typing
import
Optional
,
Tuple
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
_logger
=
logging
.
getLogger
(
__name__
)
@
functools
.
lru_cache
(
maxsize
=
None
)
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_is_
pip_
package_installed
(
package
)
->
bool
:
def
_is_package_installed
(
package
)
->
bool
:
"""Check if the given package is installed via pip."""
"""Check if the given package is installed via pip."""
# This is needed because we only want to return true
# This is needed because we only want to return true
...
@@ -31,12 +27,34 @@ def _is_pip_package_installed(package) -> bool:
...
@@ -31,12 +27,34 @@ def _is_pip_package_installed(package) -> bool:
# if it's importable in the current directory due to
# if it's importable in the current directory due to
# the presence of the shared library module.
# the presence of the shared library module.
try
:
try
:
metadata
(
package
)
distribution
(
package
)
except
PackageNotFoundError
:
except
PackageNotFoundError
:
return
False
return
False
return
True
return
True
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_is_package_installed_from_wheel
(
package
)
->
bool
:
"""Check if the given package is installed via PyPI."""
if
not
_is_package_installed
(
package
):
return
False
te_dist
=
distribution
(
package
)
te_wheel_file
=
""
for
file_path
in
te_dist
.
files
:
if
file_path
.
name
==
"WHEEL"
:
te_wheel_file
=
te_dist
.
locate_file
(
""
)
/
file_path
if
not
te_wheel_file
:
return
False
with
te_wheel_file
.
open
(
"r"
)
as
f
:
for
line
in
f
:
if
line
.
startswith
(
"Root-Is-Purelib:"
):
return
line
.
strip
().
split
(
":"
)[
1
].
strip
().
lower
()
==
"true"
return
False
@
functools
.
lru_cache
(
maxsize
=
None
)
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_find_shared_object_in_te_dir
(
te_path
:
Path
,
prefix
:
str
)
->
Optional
[
Path
]:
def
_find_shared_object_in_te_dir
(
te_path
:
Path
,
prefix
:
str
)
->
Optional
[
Path
]:
"""
"""
...
@@ -112,6 +130,19 @@ def _get_shared_object_file(library: str) -> Path:
...
@@ -112,6 +130,19 @@ def _get_shared_object_file(library: str) -> Path:
)
)
def
get_te_core_package_info
()
->
Tuple
[
bool
,
str
,
str
]:
"""
Check if Tranformer Engine core package is installed.
Returns the module name and version if found.
"""
te_core_packages
=
(
"transformer-engine-cu12"
,
"transformer-engine-cu13"
)
for
package
in
te_core_packages
:
if
_is_package_installed
(
package
):
return
True
,
package
,
version
(
package
)
return
False
,
""
,
""
@
functools
.
lru_cache
(
maxsize
=
None
)
@
functools
.
lru_cache
(
maxsize
=
None
)
def
load_framework_extension
(
framework
:
str
)
->
None
:
def
load_framework_extension
(
framework
:
str
)
->
None
:
"""
"""
...
@@ -130,39 +161,30 @@ def load_framework_extension(framework: str) -> None:
...
@@ -130,39 +161,30 @@ def load_framework_extension(framework: str) -> None:
if
framework
==
"torch"
:
if
framework
==
"torch"
:
extra_dep_name
=
"pytorch"
extra_dep_name
=
"pytorch"
# Find the TE packages. The core and framework packages can only be installed via PyPI.
# For the `transformer-engine` package, we need to check explicity.
te_core_installed
,
te_core_package_name
,
te_core_version
=
get_te_core_package_info
()
te_framework_installed
=
_is_package_installed
(
module_name
)
te_installed
=
_is_package_installed
(
"transformer_engine"
)
te_installed_via_pypi
=
_is_package_installed_from_wheel
(
"transformer_engine"
)
assert
te_installed
,
"Could not find `transformer_engine`."
# If the framework extension pip package is installed, it means that TE is installed via
# If the framework extension pip package is installed, it means that TE is installed via
# PyPI. For this case we need to make sure that the metapackage, the core lib, and framework
# PyPI. For this case we need to make sure that the metapackage, the core lib, and framework
# extension are all installed via PyPI and have matching version.
# extension are all installed via PyPI and have matching versions.
if
_is_pip_package_installed
(
module_name
):
if
te_framework_installed
:
assert
_is_pip_package_installed
(
assert
te_installed_via_pypi
,
"Could not find `transformer-engine` PyPI package."
"transformer_engine"
assert
te_core_installed
,
"Could not find TE core package `transformer-engine-cu*`."
),
"Could not find `transformer-engine`."
assert
_is_pip_package_installed
(
assert
version
(
module_name
)
==
version
(
"transformer-engine"
)
==
te_core_version
,
(
"transformer_engine_cu12"
"Transformer Engine package version mismatch. Found"
),
"Could not find `transformer-engine-cu12`."
assert
(
version
(
module_name
)
==
version
(
"transformer-engine"
)
==
version
(
"transformer-engine-cu12"
)
),
(
"TransformerEngine package version mismatch. Found"
f
"
{
module_name
}
v
{
version
(
module_name
)
}
, transformer-engine"
f
"
{
module_name
}
v
{
version
(
module_name
)
}
, transformer-engine"
f
" v
{
version
(
'transformer-engine'
)
}
, and
transformer-engine-cu12
"
f
" v
{
version
(
'transformer-engine'
)
}
, and
{
te_core_package_name
}
"
f
" v
{
version
(
'transformer-engine-cu12'
)
}
. Install transformer-engine using "
f
" v
{
te_core_version
}
. Install transformer-engine using "
f
"'pip3 install transformer-engine[
{
extra_dep_name
}
]==VERSION'"
f
"'pip3 install
--no-build-isolation
transformer-engine[
{
extra_dep_name
}
]==VERSION'"
)
)
# If the core package is installed via PyPI, log if
# the framework extension is not found from PyPI.
# Note: Should we error? This is a rare use case.
if
_is_pip_package_installed
(
"transformer-engine-cu12"
):
if
not
_is_pip_package_installed
(
module_name
):
_logger
.
info
(
"Could not find package %s. Install transformer-engine using "
f
"'pip3 install transformer-engine[
{
extra_dep_name
}
]==VERSION'"
,
module_name
,
)
# After all checks are completed, load the shared object file.
# After all checks are completed, load the shared object file.
spec
=
importlib
.
util
.
spec_from_file_location
(
module_name
,
_get_shared_object_file
(
framework
))
spec
=
importlib
.
util
.
spec_from_file_location
(
module_name
,
_get_shared_object_file
(
framework
))
solib
=
importlib
.
util
.
module_from_spec
(
spec
)
solib
=
importlib
.
util
.
module_from_spec
(
spec
)
...
@@ -170,6 +192,35 @@ def load_framework_extension(framework: str) -> None:
...
@@ -170,6 +192,35 @@ def load_framework_extension(framework: str) -> None:
spec
.
loader
.
exec_module
(
solib
)
spec
.
loader
.
exec_module
(
solib
)
def
sanity_checks_for_pypi_installation
()
->
None
:
"""Ensure that package is installed correctly if using PyPI."""
te_core_installed
,
te_core_package_name
,
te_core_version
=
get_te_core_package_info
()
te_installed
=
_is_package_installed
(
"transformer_engine"
)
te_installed_via_pypi
=
_is_package_installed_from_wheel
(
"transformer_engine"
)
assert
te_installed
,
"Could not find `transformer-engine`."
# If the core package is installed via PyPI.
if
te_core_installed
:
assert
te_installed_via_pypi
,
"Could not find `transformer-engine` PyPI package."
assert
version
(
"transformer-engine"
)
==
te_core_version
,
(
"Transformer Engine package version mismatch. Found "
f
"transformer-engine v
{
version
(
'transformer-engine'
)
}
"
f
"and
{
te_core_package_name
}
v
{
te_core_version
}
."
)
# Only the metapackage is found, invalid usecase.
elif
te_installed_via_pypi
:
raise
RuntimeError
(
"Found empty `transformer-engine` meta package installed. "
"Install `transformer-engine` with framework extensions via"
"'pip3 install --no-build-isolation transformer-engine[pytorch,jax]==VERSION'"
" or 'pip3 install transformer-engine[core]` for the TE core lib only. The `core_cu12`"
" or `core_cu13` extra deps can be used to specify CUDA version for the TE core lib."
)
@
functools
.
lru_cache
(
maxsize
=
None
)
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_get_sys_extension
()
->
str
:
def
_get_sys_extension
()
->
str
:
"""File extension for shared objects."""
"""File extension for shared objects."""
...
@@ -339,16 +390,14 @@ def _load_core_library():
...
@@ -339,16 +390,14 @@ def _load_core_library():
if
"NVTE_PROJECT_BUILDING"
not
in
os
.
environ
or
bool
(
int
(
os
.
getenv
(
"NVTE_RELEASE_BUILD"
,
"0"
))):
if
"NVTE_PROJECT_BUILDING"
not
in
os
.
environ
or
bool
(
int
(
os
.
getenv
(
"NVTE_RELEASE_BUILD"
,
"0"
))):
try
:
sanity_checks_for_pypi_installation
()
_CUDNN_LIB_CTYPES
=
_load_cudnn
()
_CUDNN_LIB_CTYPES
=
_load_cudnn
()
_NVRTC_LIB_CTYPES
=
_load_nvrtc
()
_NVRTC_LIB_CTYPES
=
_load_nvrtc
()
_CURAND_LIB_CTYPES
=
_load_curand
()
_CURAND_LIB_CTYPES
=
_load_curand
()
_CUBLAS_LIB_CTYPES
=
_load_nvidia_cuda_library
(
"cublas"
)
_CUBLAS_LIB_CTYPES
=
_load_nvidia_cuda_library
(
"cublas"
)
_CUDART_LIB_CTYPES
=
_load_nvidia_cuda_library
(
"cuda_runtime"
)
_CUDART_LIB_CTYPES
=
_load_nvidia_cuda_library
(
"cuda_runtime"
)
# Needed to find the correct headers for NVRTC kernels.
if
not
os
.
getenv
(
"NVTE_CUDA_INCLUDE_DIR"
)
and
_nvidia_cudart_include_dir
():
os
.
environ
[
"NVTE_CUDA_INCLUDE_DIR"
]
=
_nvidia_cudart_include_dir
()
except
OSError
:
pass
_TE_LIB_CTYPES
=
_load_core_library
()
_TE_LIB_CTYPES
=
_load_core_library
()
# Needed to find the correct headers for NVRTC kernels.
if
not
os
.
getenv
(
"NVTE_CUDA_INCLUDE_DIR"
)
and
_nvidia_cudart_include_dir
():
os
.
environ
[
"NVTE_CUDA_INCLUDE_DIR"
]
=
_nvidia_cudart_include_dir
()
transformer_engine/common/fused_attn/fused_attn.cpp
View file @
0a5016b1
...
@@ -138,7 +138,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
...
@@ -138,7 +138,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
float
dropout
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
float
dropout
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
int64_t
window_size_left
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
int64_t
window_size_left
,
int64_t
window_size_right
)
{
int64_t
window_size_right
,
bool
return_max_logit
)
{
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
NVTE_Fused_Attn_Backend
backend
=
NVTE_Fused_Attn_Backend
::
NVTE_No_Backend
;
NVTE_Fused_Attn_Backend
backend
=
NVTE_Fused_Attn_Backend
::
NVTE_No_Backend
;
const
int
device_id
=
cuda
::
current_device
();
const
int
device_id
=
cuda
::
current_device
();
...
@@ -187,7 +187,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
...
@@ -187,7 +187,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_BSHD
||
qkv_format
==
NVTE_QKV_Format
::
NVTE_SBHD
)
&&
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_BSHD
||
qkv_format
==
NVTE_QKV_Format
::
NVTE_SBHD
)
&&
!
requires_64bit_ragged_offset
&&
(
softmax_type
==
NVTE_Softmax_Type
::
NVTE_VANILLA_SOFTMAX
)
&&
!
requires_64bit_ragged_offset
&&
(
softmax_type
==
NVTE_Softmax_Type
::
NVTE_VANILLA_SOFTMAX
)
&&
// 9.10.0: known bugs with SDPA FP8
// 9.10.0: known bugs with SDPA FP8
(
cudnn_runtime_version
!=
91000
))
{
(
cudnn_runtime_version
!=
91000
)
&&
!
return_max_logit
)
{
if
(
cudnn_runtime_version
>=
8900
)
{
if
(
cudnn_runtime_version
>=
8900
)
{
backend
=
NVTE_Fused_Attn_Backend
::
NVTE_FP8
;
backend
=
NVTE_Fused_Attn_Backend
::
NVTE_FP8
;
}
else
{
}
else
{
...
@@ -216,7 +216,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
...
@@ -216,7 +216,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(
qkv_layout
==
NVTE_QKV_Layout
::
NVTE_BSHD_BSHD_BSHD
))
&&
(
qkv_layout
==
NVTE_QKV_Layout
::
NVTE_BSHD_BSHD_BSHD
))
&&
((
window_size_left
==
-
1
)
&&
(
window_size_right
==
-
1
||
window_size_right
==
0
))
&&
((
window_size_left
==
-
1
)
&&
(
window_size_right
==
-
1
||
window_size_right
==
0
))
&&
!
requires_64bit_ragged_offset
&&
!
requires_64bit_ragged_offset
&&
(
softmax_type
==
NVTE_Softmax_Type
::
NVTE_VANILLA_SOFTMAX
))
{
(
softmax_type
==
NVTE_Softmax_Type
::
NVTE_VANILLA_SOFTMAX
)
&&
!
return_max_logit
)
{
flag_m512
=
true
;
flag_m512
=
true
;
}
}
if
(
if
(
...
@@ -418,8 +418,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
...
@@ -418,8 +418,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
NVTETensor
O
,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens
,
const
NVTETensor
cu_seqlens_padded
,
const
NVTETensor
rng_state
,
const
NVTETensor
cu_seqlens_padded
,
const
NVTETensor
rng_state
,
size_t
max_seqlen
,
bool
is_training
,
float
attn_scale
,
size_t
max_seqlen
,
bool
is_training
,
bool
return_max_logit
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
int64_t
window_size_right
,
NVTETensor
workspace
,
...
@@ -460,7 +460,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
...
@@ -460,7 +460,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
is_training
,
QKV_type
,
QKV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
is_training
,
QKV_type
,
QKV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
h
,
h
,
max_seqlen
,
max_seqlen
,
d
,
d
,
window_size_left
,
window_size_right
);
h
,
h
,
max_seqlen
,
max_seqlen
,
d
,
d
,
window_size_left
,
window_size_right
,
return_max_logit
);
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
#if (CUDNN_VERSION >= 8901)
#if (CUDNN_VERSION >= 8901)
...
@@ -474,10 +474,10 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
...
@@ -474,10 +474,10 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
}
else
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_arbitrary_seqlen
)
{
}
else
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_arbitrary_seqlen
)
{
#if (CUDNN_VERSION >= 8900)
#if (CUDNN_VERSION >= 8900)
fused_attn_arbitrary_seqlen_fwd_qkvpacked
(
fused_attn_arbitrary_seqlen_fwd_qkvpacked
(
b
,
h
,
max_seqlen
,
d
,
t
,
is_training
,
attn_scale
,
dropout
,
qkv_layout
,
bias_type
,
b
,
h
,
max_seqlen
,
d
,
t
,
is_training
,
return_max_logit
,
attn_scale
,
dropout
,
qkv_layout
,
attn_mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
input_QKV
,
input_Bias
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
input_QKV
,
input_SoftmaxOffset
,
output_O
,
Aux_CTX_Tensors
,
input_cu_seqlens
,
input_cu_seqlens_padded
,
input_Bias
,
input_SoftmaxOffset
,
output_O
,
Aux_CTX_Tensors
,
input_cu_seqlens
,
input_rng_state
,
wkspace
,
stream
,
handle
);
input_cu_seqlens_padded
,
input_rng_state
,
wkspace
,
stream
,
handle
);
#else
#else
NVTE_ERROR
(
NVTE_ERROR
(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length.
\n
"
);
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length.
\n
"
);
...
@@ -544,7 +544,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
...
@@ -544,7 +544,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
true
,
QKV_type
,
QKV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
h
,
h
,
true
,
QKV_type
,
QKV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
h
,
h
,
max_seqlen
,
max_seqlen
,
d
,
d
,
window_size_left
,
window_size_right
);
max_seqlen
,
max_seqlen
,
d
,
d
,
window_size_left
,
window_size_right
,
false
);
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
#if (CUDNN_VERSION >= 8901)
#if (CUDNN_VERSION >= 8901)
...
@@ -602,7 +602,7 @@ void nvte_fused_attn_fwd_kvpacked(
...
@@ -602,7 +602,7 @@ void nvte_fused_attn_fwd_kvpacked(
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
page_table_k
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
page_table_k
,
const
NVTETensor
page_table_v
,
const
NVTETensor
rng_state
,
size_t
max_seqlen_q
,
const
NVTETensor
page_table_v
,
const
NVTETensor
rng_state
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
float
attn_scale
,
float
dropout
,
size_t
max_seqlen_kv
,
bool
is_training
,
bool
return_max_logit
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTETensor
workspace
,
cudaStream_t
stream
)
{
...
@@ -680,7 +680,8 @@ void nvte_fused_attn_fwd_kvpacked(
...
@@ -680,7 +680,8 @@ void nvte_fused_attn_fwd_kvpacked(
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
is_training
,
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
is_training
,
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d
,
d
,
window_size_left
,
window_size_right
);
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d
,
d
,
window_size_left
,
window_size_right
,
return_max_logit
);
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
#if (CUDNN_VERSION >= 8901)
#if (CUDNN_VERSION >= 8901)
...
@@ -695,12 +696,12 @@ void nvte_fused_attn_fwd_kvpacked(
...
@@ -695,12 +696,12 @@ void nvte_fused_attn_fwd_kvpacked(
#if (CUDNN_VERSION >= 8903)
#if (CUDNN_VERSION >= 8903)
fused_attn_arbitrary_seqlen_fwd_kvpacked
(
fused_attn_arbitrary_seqlen_fwd_kvpacked
(
b
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d
,
t_q
,
t_kv
,
num_pages_k
,
num_pages_v
,
b
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d
,
t_q
,
t_kv
,
num_pages_k
,
num_pages_v
,
page_size_k
,
page_size_v
,
max_pages_per_seq_k
,
max_pages_per_seq_v
,
is_training
,
attn_scale
,
page_size_k
,
page_size_v
,
max_pages_per_seq_k
,
max_pages_per_seq_v
,
is_training
,
dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size_left
,
return_max_logit
,
attn_scale
,
dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size_right
,
input_Q
,
input_KV
,
input_Bias
,
input_SoftmaxOffset
,
output_O
,
window_size_left
,
window_size_right
,
input_Q
,
input_KV
,
input_Bias
,
input_SoftmaxOffset
,
Aux_CTX_Tensors
,
input_cu_seqlens_q
,
input_cu_seqlens_kv
,
input_cu_seqlens_q_padded
,
output_O
,
Aux_CTX_Tensors
,
input_cu_seqlens_q
,
input_cu_seqlens_kv
,
input_cu_seqlens_
kv
_padded
,
input_
page_table_k
,
input_page_table_
v
,
input_rng_state
,
input_cu_seqlens_
q
_padded
,
input_
cu_seqlens_kv_padded
,
input_page_table_
k
,
wkspace
,
stream
,
handle
);
input_page_table_v
,
input_rng_state
,
wkspace
,
stream
,
handle
);
#else
#else
NVTE_ERROR
(
NVTE_ERROR
(
"cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length.
\n
"
);
"cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length.
\n
"
);
...
@@ -777,7 +778,7 @@ void nvte_fused_attn_bwd_kvpacked(
...
@@ -777,7 +778,7 @@ void nvte_fused_attn_bwd_kvpacked(
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
true
,
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
h_q
,
true
,
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d
,
d
,
window_size_left
,
window_size_right
);
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d
,
d
,
window_size_left
,
window_size_right
,
false
);
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
#if (CUDNN_VERSION >= 8901)
#if (CUDNN_VERSION >= 8901)
...
@@ -832,18 +833,16 @@ void nvte_fused_attn_bwd_kvpacked(
...
@@ -832,18 +833,16 @@ void nvte_fused_attn_bwd_kvpacked(
}
}
}
}
// NVTE fused attention FWD with separate Q, K and V
// NVTE fused attention FWD with separate Q, K and V
void
nvte_fused_attn_fwd
(
const
NVTETensor
Q
,
const
NVTETensor
K
,
const
NVTETensor
V
,
void
nvte_fused_attn_fwd
(
const
NVTETensor
Bias
,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
const
NVTETensor
Q
,
const
NVTETensor
K
,
const
NVTETensor
V
,
const
NVTETensor
Bias
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
page_table_k
,
const
NVTETensor
page_table_k
,
const
NVTETensor
page_table_v
,
const
NVTETensor
rng_state
,
const
NVTETensor
page_table_v
,
const
NVTETensor
rng_state
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
bool
return_max_logit
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_flash_attn_fwd
);
NVTE_API_CALL
(
nvte_flash_attn_fwd
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
const
Tensor
*
input_cu_seqlens_q
=
convertNVTETensorCheck
(
cu_seqlens_q
);
const
Tensor
*
input_cu_seqlens_q
=
convertNVTETensorCheck
(
cu_seqlens_q
);
...
@@ -913,7 +912,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
...
@@ -913,7 +912,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
is_training
,
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
is_training
,
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d_qk
,
d_v
,
window_size_left
,
window_size_right
);
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d_qk
,
d_v
,
window_size_left
,
window_size_right
,
return_max_logit
);
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
#if (CUDNN_VERSION >= 8901)
#if (CUDNN_VERSION >= 8901)
...
@@ -928,12 +928,12 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
...
@@ -928,12 +928,12 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
#if (CUDNN_VERSION >= 8900)
#if (CUDNN_VERSION >= 8900)
fused_attn_arbitrary_seqlen_fwd
(
fused_attn_arbitrary_seqlen_fwd
(
b
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d_qk
,
d_v
,
t_q
,
t_kv
,
num_pages_k
,
num_pages_v
,
b
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d_qk
,
d_v
,
t_q
,
t_kv
,
num_pages_k
,
num_pages_v
,
page_size_k
,
page_size_v
,
max_pages_per_seq_k
,
max_pages_per_seq_v
,
is_training
,
attn_scale
,
page_size_k
,
page_size_v
,
max_pages_per_seq_k
,
max_pages_per_seq_v
,
is_training
,
dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size_left
,
return_max_logit
,
attn_scale
,
dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size_right
,
input_Q
,
input_K
,
input_V
,
input_Bias
,
input_SoftmaxOffset
,
output_O
,
window_size_left
,
window_size_right
,
input_Q
,
input_K
,
input_V
,
input_Bias
,
Aux_CTX_Tensors
,
input_cu_seqlens_q
,
input_cu_seqlens_kv
,
input_cu_seqlens_q_padded
,
input_SoftmaxOffset
,
output_O
,
Aux_CTX_Tensors
,
input_cu_seqlens_q
,
input_cu_seqlens_kv
,
input_cu_seqlens_
kv
_padded
,
input_
page_table_k
,
input_page_table_
v
,
input_rng_state
,
input_cu_seqlens_
q
_padded
,
input_
cu_seqlens_kv_padded
,
input_page_table_
k
,
wkspace
,
stream
,
handle
);
input_page_table_v
,
input_rng_state
,
wkspace
,
stream
,
handle
);
#else
#else
NVTE_ERROR
(
NVTE_ERROR
(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length.
\n
"
);
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length.
\n
"
);
...
@@ -1008,7 +1008,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
...
@@ -1008,7 +1008,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
true
,
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
h_q
,
true
,
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d_qk
,
d_v
,
window_size_left
,
window_size_right
);
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d_qk
,
d_v
,
window_size_left
,
window_size_right
,
false
);
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
#if (CUDNN_VERSION >= 8901)
#if (CUDNN_VERSION >= 8901)
...
...
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu
View file @
0a5016b1
This diff is collapsed.
Click to expand it.
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h
View file @
0a5016b1
...
@@ -20,12 +20,13 @@ namespace transformer_engine {
...
@@ -20,12 +20,13 @@ namespace transformer_engine {
#if (CUDNN_VERSION >= 8900)
#if (CUDNN_VERSION >= 8900)
void
fused_attn_arbitrary_seqlen_fwd_qkvpacked
(
void
fused_attn_arbitrary_seqlen_fwd_qkvpacked
(
size_t
batch
,
size_t
num_attn_heads
,
size_t
max_seqlen
,
size_t
head_dim
,
size_t
num_tokens
,
size_t
batch
,
size_t
num_attn_heads
,
size_t
max_seqlen
,
size_t
head_dim
,
size_t
num_tokens
,
bool
is_training
,
float
attn_scale
,
float
p_dropout
,
NVTE_QKV_Layout
qkv_layout
,
bool
is_training
,
bool
return_max_logit
,
float
attn_scale
,
float
p_dropout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
const
Tensor
*
input_QKV
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
const
Tensor
*
input_Bias
,
const
Tensor
*
input_SoftmaxOffset
,
Tensor
*
output_O
,
const
Tensor
*
input_QKV
,
const
Tensor
*
input_Bias
,
const
Tensor
*
input_SoftmaxOffset
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
Tensor
*
cu_seqlens
,
const
Tensor
*
cu_seqlens_padded
,
Tensor
*
output_O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
Tensor
*
cu_seqlens
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
);
const
Tensor
*
cu_seqlens_padded
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
);
void
fused_attn_arbitrary_seqlen_bwd_qkvpacked
(
void
fused_attn_arbitrary_seqlen_bwd_qkvpacked
(
size_t
batch
,
size_t
num_attn_heads
,
size_t
max_seqlen
,
size_t
head_dim
,
size_t
num_tokens
,
size_t
batch
,
size_t
num_attn_heads
,
size_t
max_seqlen
,
size_t
head_dim
,
size_t
num_tokens
,
...
@@ -41,14 +42,15 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
...
@@ -41,14 +42,15 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t
batch
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
batch
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim
,
size_t
num_tokens_q
,
size_t
num_tokens_kv
,
size_t
max_seqlen_kv
,
size_t
head_dim
,
size_t
num_tokens_q
,
size_t
num_tokens_kv
,
size_t
num_pages_k
,
size_t
num_pages_v
,
size_t
page_size_k
,
size_t
page_size_v
,
size_t
num_pages_k
,
size_t
num_pages_v
,
size_t
page_size_k
,
size_t
page_size_v
,
size_t
max_pages_per_seq_k
,
size_t
max_pages_per_seq_v
,
bool
is_training
,
float
attn_scale
,
size_t
max_pages_per_seq_k
,
size_t
max_pages_per_seq_v
,
bool
is_training
,
bool
return_max_logit
,
float
p_dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
float
attn_scale
,
float
p_dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
const
Tensor
*
input_Q
,
const
Tensor
*
input_KV
,
const
Tensor
*
input_Bias
,
int64_t
window_size_right
,
const
Tensor
*
input_Q
,
const
Tensor
*
input_KV
,
const
Tensor
*
input_SoftmaxOffset
,
Tensor
*
output_O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
Tensor
*
input_Bias
,
const
Tensor
*
input_SoftmaxOffset
,
Tensor
*
output_O
,
const
Tensor
*
cu_seqlens_q
,
const
Tensor
*
cu_seqlens_kv
,
const
Tensor
*
cu_seqlens_q_padded
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
Tensor
*
cu_seqlens_q
,
const
Tensor
*
cu_seqlens_kv
,
const
Tensor
*
cu_seqlens_kv_padded
,
const
Tensor
*
page_table_k
,
const
Tensor
*
page_table_v
,
const
Tensor
*
cu_seqlens_q_padded
,
const
Tensor
*
cu_seqlens_kv_padded
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
);
const
Tensor
*
page_table_k
,
const
Tensor
*
page_table_v
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
);
void
fused_attn_arbitrary_seqlen_bwd_kvpacked
(
void
fused_attn_arbitrary_seqlen_bwd_kvpacked
(
size_t
batch
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
batch
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
...
@@ -68,14 +70,14 @@ void fused_attn_arbitrary_seqlen_fwd(
...
@@ -68,14 +70,14 @@ void fused_attn_arbitrary_seqlen_fwd(
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
size_t
num_tokens_q
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
size_t
num_tokens_q
,
size_t
num_tokens_kv
,
size_t
num_pages_k
,
size_t
num_pages_v
,
size_t
page_size_k
,
size_t
num_tokens_kv
,
size_t
num_pages_k
,
size_t
num_pages_v
,
size_t
page_size_k
,
size_t
page_size_v
,
size_t
max_pages_per_seq_k
,
size_t
max_pages_per_seq_v
,
bool
is_training
,
size_t
page_size_v
,
size_t
max_pages_per_seq_k
,
size_t
max_pages_per_seq_v
,
bool
is_training
,
float
attn_scale
,
float
p_dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
bool
return_max_logit
,
float
attn_scale
,
float
p_dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_
right
,
const
Tensor
*
input_Q
,
const
Tensor
*
input_K
,
const
Tensor
*
input_
V
,
int64_t
window_size_
left
,
int64_t
window_size_right
,
const
Tensor
*
input_
Q
,
const
Tensor
*
input_
Bias
,
const
Tensor
*
input_
SoftmaxOffset
,
Tensor
*
out
put_
O
,
const
Tensor
*
input_
K
,
const
Tensor
*
input_
V
,
const
Tensor
*
in
put_
Bias
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
Tensor
*
cu_seqlens_q
,
const
Tensor
*
cu_seqlens_kv
,
const
Tensor
*
input_SoftmaxOffset
,
Tensor
*
output_O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
Tensor
*
cu_seqlens_q
_padded
,
const
Tensor
*
cu_seqlens_kv_padded
,
const
Tensor
*
cu_seqlens_q
,
const
Tensor
*
cu_seqlens_kv
,
const
Tensor
*
cu_seqlens_q
_padded
,
const
Tensor
*
page_table_k
,
const
Tensor
*
page_table_v
,
const
Tensor
*
rng_state
,
const
Tensor
*
cu_seqlens_kv_padded
,
const
Tensor
*
page_table_k
,
const
Tensor
*
page_table_v
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
);
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
);
void
fused_attn_arbitrary_seqlen_bwd
(
void
fused_attn_arbitrary_seqlen_bwd
(
size_t
batch
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
batch
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
...
...
transformer_engine/common/fused_attn/fused_attn_fp8.cu
View file @
0a5016b1
...
@@ -1710,7 +1710,8 @@ void fused_attn_fp8_fwd_impl_v1(
...
@@ -1710,7 +1710,8 @@ void fused_attn_fp8_fwd_impl_v1(
qkv_tensor_type
,
qkv_tensor_type
,
o_tensor_type
,
o_tensor_type
,
cudnn_frontend
::
DataType_t
::
NOT_SET
,
cudnn_frontend
::
DataType_t
::
NOT_SET
,
cudnn_frontend
::
DataType_t
::
NOT_SET
};
cudnn_frontend
::
DataType_t
::
NOT_SET
,
false
};
namespace
fe
=
cudnn_frontend
;
namespace
fe
=
cudnn_frontend
;
using
graph_and_tensors
=
using
graph_and_tensors
=
...
@@ -2038,7 +2039,8 @@ void fused_attn_fp8_bwd_impl_v1(
...
@@ -2038,7 +2039,8 @@ void fused_attn_fp8_bwd_impl_v1(
qkv_tensor_type
,
qkv_tensor_type
,
o_tensor_type
,
o_tensor_type
,
do_tensor_type
,
do_tensor_type
,
dqkv_tensor_type
};
dqkv_tensor_type
,
false
};
namespace
fe
=
cudnn_frontend
;
namespace
fe
=
cudnn_frontend
;
using
graph_and_tensors
=
using
graph_and_tensors
=
...
...
transformer_engine/common/fused_attn/utils.h
View file @
0a5016b1
...
@@ -115,20 +115,21 @@ struct FADescriptor_v1 {
...
@@ -115,20 +115,21 @@ struct FADescriptor_v1 {
cudnn_frontend
::
DataType_t
o_tensor_type
;
cudnn_frontend
::
DataType_t
o_tensor_type
;
cudnn_frontend
::
DataType_t
do_tensor_type
;
cudnn_frontend
::
DataType_t
do_tensor_type
;
cudnn_frontend
::
DataType_t
dqkv_tensor_type
;
cudnn_frontend
::
DataType_t
dqkv_tensor_type
;
bool
generate_max_sum_exp
;
bool
operator
<
(
const
FADescriptor_v1
&
rhs
)
const
{
bool
operator
<
(
const
FADescriptor_v1
&
rhs
)
const
{
return
std
::
tie
(
b
,
h
,
hg
,
s_q
,
s_kv
,
d_qk
,
d_v
,
num_pages_k
,
num_pages_v
,
page_size_k
,
return
std
::
tie
(
b
,
h
,
hg
,
s_q
,
s_kv
,
d_qk
,
d_v
,
num_pages_k
,
num_pages_v
,
page_size_k
,
page_size_v
,
max_pages_per_seq_k
,
max_pages_per_seq_v
,
bias_b
,
bias_h
,
page_size_v
,
max_pages_per_seq_k
,
max_pages_per_seq_v
,
bias_b
,
bias_h
,
attnScale
,
isTraining
,
dropoutProbability
,
layout
,
mask_type
,
softmax_type
,
attnScale
,
isTraining
,
dropoutProbability
,
layout
,
mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
deterministic
,
bias_type
,
qkv_tensor_type
,
window_size_left
,
window_size_right
,
deterministic
,
bias_type
,
qkv_tensor_type
,
o_tensor_type
,
do_tensor_type
,
dqkv_tensor_type
)
<
o_tensor_type
,
do_tensor_type
,
dqkv_tensor_type
,
generate_max_sum_exp
)
<
std
::
tie
(
rhs
.
b
,
rhs
.
h
,
rhs
.
hg
,
rhs
.
s_q
,
rhs
.
s_kv
,
rhs
.
d_qk
,
rhs
.
d_v
,
rhs
.
num_pages_k
,
std
::
tie
(
rhs
.
b
,
rhs
.
h
,
rhs
.
hg
,
rhs
.
s_q
,
rhs
.
s_kv
,
rhs
.
d_qk
,
rhs
.
d_v
,
rhs
.
num_pages_k
,
rhs
.
num_pages_v
,
rhs
.
page_size_k
,
rhs
.
page_size_v
,
rhs
.
max_pages_per_seq_k
,
rhs
.
num_pages_v
,
rhs
.
page_size_k
,
rhs
.
page_size_v
,
rhs
.
max_pages_per_seq_k
,
rhs
.
max_pages_per_seq_v
,
rhs
.
bias_b
,
rhs
.
bias_h
,
rhs
.
attnScale
,
rhs
.
isTraining
,
rhs
.
max_pages_per_seq_v
,
rhs
.
bias_b
,
rhs
.
bias_h
,
rhs
.
attnScale
,
rhs
.
isTraining
,
rhs
.
dropoutProbability
,
rhs
.
layout
,
rhs
.
mask_type
,
rhs
.
softmax_type
,
rhs
.
dropoutProbability
,
rhs
.
layout
,
rhs
.
mask_type
,
rhs
.
softmax_type
,
rhs
.
window_size_left
,
rhs
.
window_size_right
,
rhs
.
deterministic
,
rhs
.
bias_type
,
rhs
.
window_size_left
,
rhs
.
window_size_right
,
rhs
.
deterministic
,
rhs
.
bias_type
,
rhs
.
qkv_tensor_type
,
rhs
.
o_tensor_type
,
rhs
.
do_tensor_type
,
rhs
.
qkv_tensor_type
,
rhs
.
o_tensor_type
,
rhs
.
do_tensor_type
,
rhs
.
dqkv_tensor_type
);
rhs
.
dqkv_tensor_type
,
rhs
.
generate_max_sum_exp
);
}
}
};
};
...
...
transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu
View file @
0a5016b1
...
@@ -97,22 +97,23 @@ cutlass::Array<cutlass::float_e2m1_t, 8>
...
@@ -97,22 +97,23 @@ cutlass::Array<cutlass::float_e2m1_t, 8>
StochasticNumericConverterBase
(
cutlass
::
Array
<
float
,
8
>
const
&
input
,
cutlass
::
Array
<
uint32_t
,
2
>
const
&
rbits
)
{
StochasticNumericConverterBase
(
cutlass
::
Array
<
float
,
8
>
const
&
input
,
cutlass
::
Array
<
uint32_t
,
2
>
const
&
rbits
)
{
using
result_type
=
cutlass
::
Array
<
cutlass
::
float_e2m1_t
,
8
>
;
using
result_type
=
cutlass
::
Array
<
cutlass
::
float_e2m1_t
,
8
>
;
result_type
output
;
result_type
output
;
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
constexpr
bool
has_rs
=
ARCH_HAS_STOCHASTIC_ROUNDING
;
auto
output_ptr
=
reinterpret_cast
<
uint16_t
*>
(
&
output
);
if
constexpr
(
has_rs
)
{
asm
volatile
(
\
auto
output_ptr
=
reinterpret_cast
<
uint16_t
*>
(
&
output
);
"{
\n
"
\
asm
volatile
(
\
"cvt.rs.satfinite.e2m1x4.f32 %0, {%5, %4, %3, %2}, %10;
\n
"
\
"{
\n
"
\
"cvt.rs.satfinite.e2m1x4.f32 %1, {%9, %8, %7, %6}, %11;
\n
"
\
"cvt.rs.satfinite.e2m1x4.f32 %0, {%5, %4, %3, %2}, %10;
\n
"
\
"}"
\
"cvt.rs.satfinite.e2m1x4.f32 %1, {%9, %8, %7, %6}, %11;
\n
"
\
:
"=h"
(
output_ptr
[
0
]),
"}"
\
:
"=h"
(
output_ptr
[
0
]),
"=h"
(
output_ptr
[
1
])
"=h"
(
output_ptr
[
1
])
:
"f"
(
input
[
0
]),
"f"
(
input
[
1
]),
"f"
(
input
[
2
]),
"f"
(
input
[
3
]),
:
"f"
(
input
[
0
]),
"f"
(
input
[
1
]),
"f"
(
input
[
2
]),
"f"
(
input
[
3
]),
"f"
(
input
[
4
]),
"f"
(
input
[
5
]),
"f"
(
input
[
6
]),
"f"
(
input
[
7
]),
"f"
(
input
[
4
]),
"f"
(
input
[
5
]),
"f"
(
input
[
6
]),
"f"
(
input
[
7
]),
"r"
(
rbits
[
0
]),
"r"
(
rbits
[
1
]));
"r"
(
rbits
[
0
]),
"r"
(
rbits
[
1
]));
#
else
}
else
{
NVTE_DEVICE_ERROR
(
"FP4 cvt PTX instructions are architecture-specific. "
NVTE_DEVICE_ERROR
(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
"Try recompiling with sm_XXXa instead of sm_XXX."
);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
}
return
output
;
return
output
;
}
}
...
...
transformer_engine/common/include/transformer_engine/fused_attn.h
View file @
0a5016b1
...
@@ -190,29 +190,30 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
...
@@ -190,29 +190,30 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
/*! \brief Get fused attention backend based on input parameters.
/*! \brief Get fused attention backend based on input parameters.
*
*
* \param[in] is_training Whether the model is in training mode.
* \param[in] is_training Whether the model is in training mode.
* \param[in] q_dtype The data type of Tensor Q.
* \param[in] q_dtype The data type of Tensor Q.
* \param[in] kv_dtype The data type of Tensors K, V.
* \param[in] kv_dtype The data type of Tensors K, V.
* \param[in] qkv_layout The layout of Tensors Q, K, V.
* \param[in] qkv_layout The layout of Tensors Q, K, V.
* \param[in] bias_type The attention bias type.
* \param[in] bias_type The attention bias type.
* \param[in] attn_mask_type The attention mask type.
* \param[in] attn_mask_type The attention mask type.
* \param[in] softmax_type The attention softmax type.
* \param[in] softmax_type The attention softmax type.
* \param[in] dropout The dropout probability.
* \param[in] dropout The dropout probability.
* \param[in] num_attn_heads The number of heads in Q.
* \param[in] num_attn_heads The number of heads in Q.
* \param[in] num_gqa_groups The number of heads in K, V.
* \param[in] num_gqa_groups The number of heads in K, V.
* \param[in] max_seqlen_q The sequence length of Q.
* \param[in] max_seqlen_q The sequence length of Q.
* \param[in] max_seqlen_kv The sequence length of K, V.
* \param[in] max_seqlen_kv The sequence length of K, V.
* \param[in] head_dim_qk The head dimension of Q, K.
* \param[in] head_dim_qk The head dimension of Q, K.
* \param[in] head_dim_v The head dimension of V.
* \param[in] head_dim_v The head dimension of V.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
*/
*/
NVTE_Fused_Attn_Backend
nvte_get_fused_attn_backend
(
NVTE_Fused_Attn_Backend
nvte_get_fused_attn_backend
(
bool
is_training
,
NVTEDType
q_dtype
,
NVTEDType
kv_dtype
,
NVTE_QKV_Layout
qkv_layout
,
bool
is_training
,
NVTEDType
q_dtype
,
NVTEDType
kv_dtype
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
float
dropout
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
float
dropout
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
int64_t
window_size_left
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
int64_t
window_size_left
,
int64_t
window_size_right
);
int64_t
window_size_right
,
bool
return_max_logit
);
/*! \brief Compute dot product attention with packed QKV input.
/*! \brief Compute dot product attention with packed QKV input.
*
*
...
@@ -255,6 +256,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
...
@@ -255,6 +256,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] max_seqlen Max sequence length used for computing,
* \param[in] max_seqlen Max sequence length used for computing,
* it may be >= max(seqlen_i) for i=0,...batch_size-1.
* it may be >= max(seqlen_i) for i=0,...batch_size-1.
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] qkv_layout QKV tensor's layout.
...
@@ -266,13 +268,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
...
@@ -266,13 +268,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] workspace Workspace tensor.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
*/
void
nvte_fused_attn_fwd_qkvpacked
(
void
nvte_fused_attn_fwd_qkvpacked
(
const
NVTETensor
QKV
,
const
NVTETensor
Bias
,
const
NVTETensor
QKV
,
const
NVTETensor
Bias
,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens
,
const
NVTETensor
cu_seqlens_padded
,
const
NVTETensor
rng_state
,
size_t
max_seqlen
,
const
NVTETensor
cu_seqlens_padded
,
const
NVTETensor
rng_state
,
bool
is_training
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
size_t
max_seqlen
,
bool
is_training
,
bool
return_max_logit
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
);
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Compute the backward of the dot product attention with packed QKV input.
/*! \brief Compute the backward of the dot product attention with packed QKV input.
*
*
...
@@ -381,6 +386,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
...
@@ -381,6 +386,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] max_seqlen_kv Max sequence length used for computing for KV.
* \param[in] max_seqlen_kv Max sequence length used for computing for KV.
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] qkv_layout QKV tensor's layout.
...
@@ -399,7 +405,7 @@ void nvte_fused_attn_fwd_kvpacked(
...
@@ -399,7 +405,7 @@ void nvte_fused_attn_fwd_kvpacked(
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
page_table_k
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
page_table_k
,
const
NVTETensor
page_table_v
,
const
NVTETensor
rng_state
,
size_t
max_seqlen_q
,
const
NVTETensor
page_table_v
,
const
NVTETensor
rng_state
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
float
attn_scale
,
float
dropout
,
size_t
max_seqlen_kv
,
bool
is_training
,
bool
return_max_logit
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
);
NVTETensor
workspace
,
cudaStream_t
stream
);
...
@@ -520,6 +526,7 @@ void nvte_fused_attn_bwd_kvpacked(
...
@@ -520,6 +526,7 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] max_seqlen_kv Max sequence length used for computing for K and V.
* \param[in] max_seqlen_kv Max sequence length used for computing for K and V.
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensors' layout.
* \param[in] qkv_layout QKV tensors' layout.
...
@@ -531,18 +538,16 @@ void nvte_fused_attn_bwd_kvpacked(
...
@@ -531,18 +538,16 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] workspace Workspace tensor.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
*/
void
nvte_fused_attn_fwd
(
const
NVTETensor
Q
,
const
NVTETensor
K
,
const
NVTETensor
V
,
void
nvte_fused_attn_fwd
(
const
NVTETensor
Bias
,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
const
NVTETensor
Q
,
const
NVTETensor
K
,
const
NVTETensor
V
,
const
NVTETensor
Bias
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
page_table_k
,
const
NVTETensor
page_table_k
,
const
NVTETensor
page_table_v
,
const
NVTETensor
rng_state
,
const
NVTETensor
page_table_v
,
const
NVTETensor
rng_state
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
bool
return_max_logit
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
);
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Compute the backward of the dot product attention with separate Q, K and V.
/*! \brief Compute the backward of the dot product attention with separate Q, K and V.
*
*
...
...
transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu
View file @
0a5016b1
...
@@ -264,48 +264,50 @@ __device__ __forceinline__ size_t scale_factor_swizzled_offset(size_t row_idx, s
...
@@ -264,48 +264,50 @@ __device__ __forceinline__ size_t scale_factor_swizzled_offset(size_t row_idx, s
__device__
__forceinline__
__nv_fp4x4_e2m1
cvt_fp32_to_fp4_4x_with_stochastic_rounding
(
__device__
__forceinline__
__nv_fp4x4_e2m1
cvt_fp32_to_fp4_4x_with_stochastic_rounding
(
const
float2
in01
,
const
float2
in23
,
const
uint32_t
rbits
)
{
const
float2
in01
,
const
float2
in23
,
const
uint32_t
rbits
)
{
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
constexpr
bool
has_rs
=
ARCH_HAS_STOCHASTIC_ROUNDING
;
uint16_t
out_4x
;
if
constexpr
(
has_rs
)
{
asm
volatile
(
uint16_t
out_4x
;
"{
\n
"
asm
volatile
(
"cvt.rs.satfinite.e2m1x4.f32 %0, {%3, %4, %1, %2}, %5;
\n\t
"
"{
\n
"
"}"
"cvt.rs.satfinite.e2m1x4.f32 %0, {%3, %4, %1, %2}, %5;
\n\t
"
:
"=h"
(
out_4x
)
"}"
:
"f"
(
in01
.
y
),
"f"
(
in01
.
x
),
"f"
(
in23
.
y
),
"f"
(
in23
.
x
),
"r"
(
rbits
));
:
"=h"
(
out_4x
)
return
*
reinterpret_cast
<
__nv_fp4x4_e2m1
*>
(
&
out_4x
);
:
"f"
(
in01
.
y
),
"f"
(
in01
.
x
),
"f"
(
in23
.
y
),
"f"
(
in23
.
x
),
"r"
(
rbits
));
#else
return
*
reinterpret_cast
<
__nv_fp4x4_e2m1
*>
(
&
out_4x
);
NVTE_DEVICE_ERROR
(
}
else
{
"FP4 cvt PTX instructions are architecture-specific. "
NVTE_DEVICE_ERROR
(
"Try recompiling with sm_XXXa instead of sm_XXX."
);
"FP4 cvt.rs PTX instructions are architecture-specific. "
uint16_t
dummy
=
0
;
"Try recompiling with sm_XXXa instead of sm_XXX."
);
return
*
reinterpret_cast
<
__nv_fp4x4_e2m1
*>
(
&
dummy
);
uint16_t
dummy
=
0
;
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
return
*
reinterpret_cast
<
__nv_fp4x4_e2m1
*>
(
&
dummy
);
}
}
}
__device__
__forceinline__
__nv_fp4x4_e2m1
cvt_fp32_to_fp4_4x_with_rn
(
const
float2
in01
,
__device__
__forceinline__
__nv_fp4x4_e2m1
cvt_fp32_to_fp4_4x_with_rn
(
const
float2
in01
,
const
float2
in23
,
const
float2
in23
,
const
uint32_t
rbits
)
{
const
uint32_t
rbits
)
{
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
constexpr
bool
has_fp4
=
ARCH_BLACKWELL_FAMILY
;
// NOTE: rbits unused for rn.
if
constexpr
(
has_fp4
)
{
uint32_t
out_4x
;
// Only need 16 bit. Using 32 bit container for packing.
// NOTE: rbits unused for rn.
asm
volatile
(
uint32_t
out_4x
;
// Only need 16 bit. Using 32 bit container for packing.
"{
\n
"
asm
volatile
(
".reg.b8 f0;
\n\t
"
"{
\n
"
".reg.b8 f1;
\n\t
"
".reg.b8 f0;
\n\t
"
"cvt.rn.satfinite.e2m1x2.f32 f0, %1, %2;
\n\t
"
".reg.b8 f1;
\n\t
"
"cvt.rn.satfinite.e2m1x2.f32 f1, %3, %4;
\n\t
"
"cvt.rn.satfinite.e2m1x2.f32 f0, %1, %2;
\n\t
"
"mov.b32 %0, {f0, f1, f0, f1};
\n\t
"
"cvt.rn.satfinite.e2m1x2.f32 f1, %3, %4;
\n\t
"
"}"
"mov.b32 %0, {f0, f1, f0, f1};
\n\t
"
:
"=r"
(
out_4x
)
"}"
:
"f"
(
in01
.
y
),
"f"
(
in01
.
x
),
"f"
(
in23
.
y
),
"f"
(
in23
.
x
));
:
"=r"
(
out_4x
)
return
reinterpret_cast
<
__nv_fp4x4_e2m1
*>
(
&
out_4x
)[
0
];
:
"f"
(
in01
.
y
),
"f"
(
in01
.
x
),
"f"
(
in23
.
y
),
"f"
(
in23
.
x
));
#else
return
reinterpret_cast
<
__nv_fp4x4_e2m1
*>
(
&
out_4x
)[
0
];
NVTE_DEVICE_ERROR
(
}
else
{
"FP4 cvt PTX instructions are architecture-specific. "
NVTE_DEVICE_ERROR
(
"Try recompiling with sm_XXXa instead of sm_XXX."
);
"FP4 cvt PTX instructions are architecture-specific. "
uint16_t
dummy
=
0
;
"Try recompiling with sm_XXXa instead of sm_XXX."
);
return
*
reinterpret_cast
<
__nv_fp4x4_e2m1
*>
(
&
dummy
);
uint16_t
dummy
=
0
;
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
return
*
reinterpret_cast
<
__nv_fp4x4_e2m1
*>
(
&
dummy
);
}
}
}
template
<
bool
kApplyStochasticRounding
>
template
<
bool
kApplyStochasticRounding
>
...
...
transformer_engine/common/util/nvfp4_transpose.cuh
View file @
0a5016b1
...
@@ -15,10 +15,9 @@
...
@@ -15,10 +15,9 @@
#include <cudaTypedefs.h>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#if
CUDA_VERSION > 12080
#if
FP4_TYPE_SUPPORTED
#include <cuda_fp4.h>
#include <cuda_fp4.h>
#endif // CUDA_VERSION > 12080
#endif // FP4_TYPE_SUPPORTED
#include <cfloat>
#include <cfloat>
#include "../common.h"
#include "../common.h"
...
@@ -30,7 +29,7 @@
...
@@ -30,7 +29,7 @@
namespace
transformer_engine
{
namespace
transformer_engine
{
#if
CUDA_VERSION > 12080
#if
FP4_TYPE_SUPPORTED
namespace
nvfp4_transpose
{
namespace
nvfp4_transpose
{
using
RNG
=
decltype
(
curanddx
::
Generator
<
curanddx
::
philox4_32
>
()
+
curanddx
::
PhiloxRounds
<
10
>
()
+
using
RNG
=
decltype
(
curanddx
::
Generator
<
curanddx
::
philox4_32
>
()
+
curanddx
::
PhiloxRounds
<
10
>
()
+
...
@@ -152,89 +151,89 @@ __device__ __forceinline__ uint32_t get_rbits(RNG &rng, uint4 &random_uint4, int
...
@@ -152,89 +151,89 @@ __device__ __forceinline__ uint32_t get_rbits(RNG &rng, uint4 &random_uint4, int
return
rbits
;
return
rbits
;
}
}
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__device__
__forceinline__
fp4e2m1x4
mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding
(
__device__
__forceinline__
fp4e2m1x4
mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding
(
const
uint64_t
in_4x
,
const
float2
scale
,
const
uint32_t
rbits
)
{
const
uint64_t
in_4x
,
const
float2
scale
,
const
uint32_t
rbits
)
{
uint16_t
out_4x
=
0
;
uint16_t
out_4x
=
0
;
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
constexpr
bool
has_rs
=
ARCH_HAS_STOCHASTIC_ROUNDING
;
asm
volatile
(
if
constexpr
(
has_rs
)
{
"{
\n
"
asm
volatile
(
".reg.b64 v01;
\n\t
"
"{
\n
"
".reg.b64 v23;
\n\t
"
".reg.b64 v01;
\n\t
"
".reg.b16 v0_bf16;
\n\t
"
".reg.b64 v23;
\n\t
"
".reg.b16 v1_bf16;
\n\t
"
".reg.b16 v0_bf16;
\n\t
"
".reg.b16 v2_bf16;
\n\t
"
".reg.b16 v1_bf16;
\n\t
"
".reg.b16 v3_bf16;
\n\t
"
".reg.b16 v2_bf16;
\n\t
"
".reg.b32 v0;
\n\t
"
".reg.b16 v3_bf16;
\n\t
"
".reg.b32 v1;
\n\t
"
".reg.b32 v0;
\n\t
"
".reg.b32 v2;
\n\t
"
".reg.b32 v1;
\n\t
"
".reg.b32 v3;
\n\t
"
".reg.b32 v2;
\n\t
"
"mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1;
\n\t
"
".reg.b32 v3;
\n\t
"
"cvt.f32.bf16 v0, v0_bf16;
\n\t
"
"mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1;
\n\t
"
"cvt.f32.bf16 v1, v1_bf16;
\n\t
"
"cvt.f32.bf16 v0, v0_bf16;
\n\t
"
"cvt.f32.bf16 v2, v2_bf16;
\n\t
"
"cvt.f32.bf16 v1, v1_bf16;
\n\t
"
"cvt.f32.bf16 v3, v3_bf16;
\n\t
"
"cvt.f32.bf16 v2, v2_bf16;
\n\t
"
"mov.b64 v01, {v0, v1};
\n\t
"
"cvt.f32.bf16 v3, v3_bf16;
\n\t
"
"mov.b64 v23, {v2, v3};
\n\t
"
"mov.b64 v01, {v0, v1};
\n\t
"
"mul.f32x2 v01, v01, %2;
\n\t
"
// mind the shuffled elements order
"mov.b64 v23, {v2, v3};
\n\t
"
"mul.f32x2 v23, v23, %2;
\n\t
"
// mind the shuffled elements order
"mul.f32x2 v01, v01, %2;
\n\t
"
// mind the shuffled elements order
"mov.b64 {v1, v0}, v01;
\n\t
"
"mul.f32x2 v23, v23, %2;
\n\t
"
// mind the shuffled elements order
"mov.b64 {v3, v2}, v23;
\n\t
"
"mov.b64 {v1, v0}, v01;
\n\t
"
"cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %3;
\n\t
"
// mind the shuffled elements order
"mov.b64 {v3, v2}, v23;
\n\t
"
"}"
"cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %3;
\n\t
"
// mind the shuffled elements order
:
"=h"
(
out_4x
)
"}"
:
"l"
(
in_4x
),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)),
"r"
(
rbits
));
:
"=h"
(
out_4x
)
#else
:
"l"
(
in_4x
),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)),
"r"
(
rbits
));
NVTE_DEVICE_ERROR
(
}
else
{
"FP4 cvt PTX instructions are architecture-specific. "
NVTE_DEVICE_ERROR
(
"Try recompiling with sm_XXXa instead of sm_XXX."
);
"FP4 cvt PTX instructions are architecture-specific. "
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
"Try recompiling with sm_XXXa instead of sm_XXX."
);
}
return
*
reinterpret_cast
<
fp4e2m1x4
*>
(
&
out_4x
);
return
*
reinterpret_cast
<
fp4e2m1x4
*>
(
&
out_4x
);
}
}
__device__
__forceinline__
fp4e2m1x4
mul_cvt_bf16_to_fp4_4x_with_rn
(
const
uint64_t
in_4x
,
__device__
__forceinline__
fp4e2m1x4
mul_cvt_bf16_to_fp4_4x_with_rn
(
const
uint64_t
in_4x
,
const
float2
scale
,
const
float2
scale
,
const
uint32_t
rbits
)
{
const
uint32_t
rbits
)
{
// NOTE: rbits unused for rn.
constexpr
bool
is_blackwell
=
ARCH_BLACKWELL_FAMILY
;
uint32_t
out_4x
=
0
;
// Only need 16 bit. Using 32 bit container for packing.
uint32_t
out_4x
=
0
;
// Only need 16 bit. Using 32 bit container for packing.
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
if
constexpr
(
is_blackwell
)
{
asm
volatile
(
// NOTE: rbits unused for rn.
"{
\n
"
asm
volatile
(
".reg.b64 v01;
\n\t
"
"{
\n
"
".reg.b64 v23;
\n\t
"
".reg.b64 v01;
\n\t
"
".reg.b16 v0_bf16;
\n\t
"
".reg.b64 v23;
\n\t
"
".reg.b16 v1_bf16;
\n\t
"
".reg.b16 v0_bf16;
\n\t
"
".reg.b16 v2_bf16;
\n\t
"
".reg.b16 v1_bf16;
\n\t
"
".reg.b16 v3_bf16;
\n\t
"
".reg.b16 v2_bf16;
\n\t
"
".reg.b32 v0;
\n\t
"
".reg.b16 v3_bf16;
\n\t
"
".reg.b32 v1;
\n\t
"
".reg.b32 v0;
\n\t
"
".reg.b32 v2;
\n\t
"
".reg.b32 v1;
\n\t
"
".reg.b32 v3;
\n\t
"
".reg.b32 v2;
\n\t
"
".reg.b8 f0;
\n\t
"
".reg.b32 v3;
\n\t
"
".reg.b8 f1;
\n\t
"
".reg.b8 f0;
\n\t
"
"mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1;
\n\t
"
".reg.b8 f1;
\n\t
"
"cvt.f32.bf16 v0, v0_bf16;
\n\t
"
"mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1;
\n\t
"
"cvt.f32.bf16 v1, v1_bf16;
\n\t
"
"cvt.f32.bf16 v0, v0_bf16;
\n\t
"
"cvt.f32.bf16 v2, v2_bf16;
\n\t
"
"cvt.f32.bf16 v1, v1_bf16;
\n\t
"
"cvt.f32.bf16 v3, v3_bf16;
\n\t
"
"cvt.f32.bf16 v2, v2_bf16;
\n\t
"
"mov.b64 v01, {v0, v1};
\n\t
"
"cvt.f32.bf16 v3, v3_bf16;
\n\t
"
"mov.b64 v23, {v2, v3};
\n\t
"
"mov.b64 v01, {v0, v1};
\n\t
"
"mul.f32x2 v01, v01, %2;
\n\t
"
// mind the shuffled elements order
"mov.b64 v23, {v2, v3};
\n\t
"
"mul.f32x2 v23, v23, %2;
\n\t
"
// mind the shuffled elements order
"mul.f32x2 v01, v01, %2;
\n\t
"
// mind the shuffled elements order
"mov.b64 {v1, v0}, v01;
\n\t
"
"mul.f32x2 v23, v23, %2;
\n\t
"
// mind the shuffled elements order
"mov.b64 {v3, v2}, v23;
\n\t
"
"mov.b64 {v1, v0}, v01;
\n\t
"
"cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;
\n\t
"
"mov.b64 {v3, v2}, v23;
\n\t
"
"cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;
\n\t
"
"cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;
\n\t
"
"mov.b32 %0, {f0, f1, f0, f1};
\n\t
"
"cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;
\n\t
"
"}"
"mov.b32 %0, {f0, f1, f0, f1};
\n\t
"
:
"=r"
(
out_4x
)
"}"
:
"l"
(
in_4x
),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)));
:
"=r"
(
out_4x
)
#else
:
"l"
(
in_4x
),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)));
NVTE_DEVICE_ERROR
(
}
else
{
"FP4 cvt PTX instructions are architecture-specific. "
NVTE_DEVICE_ERROR
(
"Try recompiling with sm_XXXa instead of sm_XXX."
);
"FP4 cvt PTX instructions are architecture-specific. "
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
"Try recompiling with sm_XXXa instead of sm_XXX."
);
}
return
reinterpret_cast
<
fp4e2m1x4
*>
(
&
out_4x
)[
0
];
return
reinterpret_cast
<
fp4e2m1x4
*>
(
&
out_4x
)[
0
];
}
}
...
@@ -252,34 +251,35 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x(const uint64_t in_4x
...
@@ -252,34 +251,35 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x(const uint64_t in_4x
__device__
__forceinline__
fp4e2m1x4
mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding
(
__device__
__forceinline__
fp4e2m1x4
mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding
(
const
float2
in01
,
const
float2
in23
,
const
float2
scale
,
const
uint32_t
rbits
)
{
const
float2
in01
,
const
float2
in23
,
const
float2
scale
,
const
uint32_t
rbits
)
{
uint16_t
out_4x
=
0
;
uint16_t
out_4x
=
0
;
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
constexpr
bool
has_rs
=
ARCH_HAS_STOCHASTIC_ROUNDING
;
asm
volatile
(
if
constexpr
(
has_rs
)
{
"{
\n
"
asm
volatile
(
".reg.b64 v01;
\n\t
"
"{
\n
"
".reg.b64 v23;
\n\t
"
".reg.b64 v01;
\n\t
"
".reg.b32 v0;
\n\t
"
".reg.b64 v23;
\n\t
"
".reg.b32 v1;
\n\t
"
".reg.b32 v0;
\n\t
"
".reg.b32 v2;
\n\t
"
".reg.b32 v1;
\n\t
"
".reg.b32 v3;
\n\t
"
".reg.b32 v2;
\n\t
"
"mov.b64 {v0, v1} , %1;
\n\t
"
".reg.b32 v3;
\n\t
"
"mov.b64 {v2, v3} , %2;
\n\t
"
"mov.b64 {v0, v1} , %1;
\n\t
"
"mov.b64 v01, {v0, v1};
\n\t
"
"mov.b64 {v2, v3} , %2;
\n\t
"
"mov.b64 v23, {v2, v3};
\n\t
"
"mov.b64 v01, {v0, v1};
\n\t
"
"mul.f32x2 v01, v01, %3;
\n\t
"
// mind the shuffled elements order
"mov.b64 v23, {v2, v3};
\n\t
"
"mul.f32x2 v23, v23, %3;
\n\t
"
// mind the shuffled elements order
"mul.f32x2 v01, v01, %3;
\n\t
"
// mind the shuffled elements order
"mov.b64 {v1, v0}, v01;
\n\t
"
"mul.f32x2 v23, v23, %3;
\n\t
"
// mind the shuffled elements order
"mov.b64 {v3, v2}, v23;
\n\t
"
"mov.b64 {v1, v0}, v01;
\n\t
"
"cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %4;
\n\t
"
// mind the shuffled elements order
"mov.b64 {v3, v2}, v23;
\n\t
"
"}"
"cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %4;
\n\t
"
// mind the shuffled elements order
:
"=h"
(
out_4x
)
"}"
:
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
in01
)),
:
"=h"
(
out_4x
)
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
in23
)),
:
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
in01
)),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)),
"r"
(
rbits
));
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
in23
)),
#else
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)),
"r"
(
rbits
));
NVTE_DEVICE_ERROR
(
}
else
{
"FP4 cvt PTX instructions are architecture-specific. "
NVTE_DEVICE_ERROR
(
"Try recompiling with sm_XXXa instead of sm_XXX."
);
"FP4 cvt PTX instructions are architecture-specific. "
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
"Try recompiling with sm_XXXa instead of sm_XXX."
);
}
return
*
reinterpret_cast
<
fp4e2m1x4
*>
(
&
out_4x
);
return
*
reinterpret_cast
<
fp4e2m1x4
*>
(
&
out_4x
);
}
}
...
@@ -287,40 +287,41 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_rn(const float2
...
@@ -287,40 +287,41 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_rn(const float2
const
float2
in23
,
const
float2
in23
,
const
float2
scale
,
const
float2
scale
,
const
uint32_t
rbits
)
{
const
uint32_t
rbits
)
{
// NOTE: rbits unused for rn.
constexpr
bool
is_blackwell
=
ARCH_BLACKWELL_FAMILY
;
uint32_t
out_4x
=
0
;
// Only need 16 bit. Using 32 bit container for packing.
uint32_t
out_4x
=
0
;
// Only need 16 bit. Using 32 bit container for packing.
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
if
constexpr
(
is_blackwell
)
{
asm
volatile
(
// NOTE: rbits unused for rn.
"{
\n
"
asm
volatile
(
".reg.b64 v01;
\n\t
"
"{
\n
"
".reg.b64 v23;
\n\t
"
".reg.b64 v01;
\n\t
"
".reg.b32 v0;
\n\t
"
".reg.b64 v23;
\n\t
"
".reg.b32 v1;
\n\t
"
".reg.b32 v0;
\n\t
"
".reg.b32 v2;
\n\t
"
".reg.b32 v1;
\n\t
"
".reg.b32 v3;
\n\t
"
".reg.b32 v2;
\n\t
"
".reg.b8 f0;
\n\t
"
".reg.b32 v3;
\n\t
"
".reg.b8 f1;
\n\t
"
".reg.b8 f0;
\n\t
"
"mov.b64 {v0, v1} , %1;
\n\t
"
".reg.b8 f1;
\n\t
"
"mov.b64 {v2, v3} , %2;
\n\t
"
"mov.b64 {v0, v1} , %1;
\n\t
"
"mov.b64 v01, {v0, v1};
\n\t
"
"mov.b64 {v2, v3} , %2;
\n\t
"
"mov.b64 v23, {v2, v3};
\n\t
"
"mov.b64 v01, {v0, v1};
\n\t
"
"mul.f32x2 v01, v01, %3;
\n\t
"
// mind the shuffled elements order
"mov.b64 v23, {v2, v3};
\n\t
"
"mul.f32x2 v23, v23, %3;
\n\t
"
// mind the shuffled elements order
"mul.f32x2 v01, v01, %3;
\n\t
"
// mind the shuffled elements order
"mov.b64 {v1, v0}, v01;
\n\t
"
"mul.f32x2 v23, v23, %3;
\n\t
"
// mind the shuffled elements order
"mov.b64 {v3, v2}, v23;
\n\t
"
"mov.b64 {v1, v0}, v01;
\n\t
"
"cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;
\n\t
"
"mov.b64 {v3, v2}, v23;
\n\t
"
"cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;
\n\t
"
"cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;
\n\t
"
"mov.b32 %0, {f0, f1, f0, f1};
\n\t
"
"cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;
\n\t
"
"}"
"mov.b32 %0, {f0, f1, f0, f1};
\n\t
"
:
"=r"
(
out_4x
)
"}"
:
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
in01
)),
:
"=r"
(
out_4x
)
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
in23
)),
:
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
in01
)),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)));
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
in23
)),
#else
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)));
NVTE_DEVICE_ERROR
(
}
else
{
"FP4 cvt PTX instructions are architecture-specific. "
NVTE_DEVICE_ERROR
(
"Try recompiling with sm_XXXa instead of sm_XXX."
);
"FP4 cvt PTX instructions are architecture-specific. "
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
"Try recompiling with sm_XXXa instead of sm_XXX."
);
}
return
reinterpret_cast
<
fp4e2m1x4
*>
(
&
out_4x
)[
0
];
return
reinterpret_cast
<
fp4e2m1x4
*>
(
&
out_4x
)[
0
];
}
}
...
@@ -335,8 +336,6 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, c
...
@@ -335,8 +336,6 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, c
}
}
}
}
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
template
<
bool
COMPUTE_ACTIVATIONS
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
),
template
<
bool
COMPUTE_ACTIVATIONS
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
),
typename
IType
,
bool
USE_STOCHASTIC_ROUNDING
,
bool
RETURN_TRANSPOSE
>
typename
IType
,
bool
USE_STOCHASTIC_ROUNDING
,
bool
RETURN_TRANSPOSE
>
__global__
void
__launch_bounds__
(
THREADS_NUM
)
__global__
void
__launch_bounds__
(
THREADS_NUM
)
...
@@ -1380,18 +1379,13 @@ __global__ void __launch_bounds__(THREADS_NUM)
...
@@ -1380,18 +1379,13 @@ __global__ void __launch_bounds__(THREADS_NUM)
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
}
}
// namespace nvfp4_transpose
}
// namespace nvfp4_transpose
#endif // CUDA_VERSION > 12080
#endif // FP4_TYPE_SUPPORTED
// Compile-time flag to choose kernel variant
#ifndef USE_2D_NVFP4_KERNEL
#define USE_2D_NVFP4_KERNEL 0
#endif
template
<
bool
COMPUTE_ACTIVATIONS
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
),
template
<
bool
COMPUTE_ACTIVATIONS
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
),
bool
use_2d_quantization
>
bool
use_2d_quantization
>
void
nvfp4_quantize_transpose
(
const
Tensor
&
input
,
const
Tensor
*
noop
,
Tensor
*
output
,
void
nvfp4_quantize_transpose
(
const
Tensor
&
input
,
const
Tensor
*
noop
,
Tensor
*
output
,
const
QuantizationConfig
*
quant_config
,
cudaStream_t
stream
)
{
const
QuantizationConfig
*
quant_config
,
cudaStream_t
stream
)
{
#if
CUDA_VERSION > 12080
#if
FP4_TYPE_SUPPORTED
bool
use_stochastic_rounding
=
quant_config
?
quant_config
->
stochastic_rounding
:
false
;
bool
use_stochastic_rounding
=
quant_config
?
quant_config
->
stochastic_rounding
:
false
;
// If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to
// If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to
...
@@ -1509,7 +1503,7 @@ void nvfp4_quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *o
...
@@ -1509,7 +1503,7 @@ void nvfp4_quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *o
}););
}););
#else
#else
NVTE_ERROR
(
"FP4 support requires CUDA 12.8+, but compile-time CUDA version is "
,
CUDA_VERSION
);
NVTE_ERROR
(
"FP4 support requires CUDA 12.8+, but compile-time CUDA version is "
,
CUDA_VERSION
);
#endif //
CUDA_VERSION > 12080
#endif //
FP4_TYPE_SUPPORTED
}
}
}
// namespace transformer_engine
}
// namespace transformer_engine
...
...
transformer_engine/common/util/ptx.cuh
View file @
0a5016b1
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment