Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
0a5016b1
Commit
0a5016b1
authored
Dec 03, 2025
by
wenjh
Browse files
Merge nv release_v2.9
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parents
063ef88d
70f53666
Changes
61
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1391 additions
and
767 deletions
+1391
-767
tests/jax/test_distributed_softmax.py
tests/jax/test_distributed_softmax.py
+6
-4
tests/jax/test_fused_attn.py
tests/jax/test_fused_attn.py
+3
-3
tests/jax/test_helper.py
tests/jax/test_helper.py
+81
-1
tests/pytorch/attention/run_attention_with_cp.py
tests/pytorch/attention/run_attention_with_cp.py
+11
-4
tests/pytorch/attention/test_attention.py
tests/pytorch/attention/test_attention.py
+70
-20
tests/pytorch/attention/test_attention_with_cp.py
tests/pytorch/attention/test_attention_with_cp.py
+3
-3
tests/pytorch/test_numerics.py
tests/pytorch/test_numerics.py
+1
-29
tests/pytorch/utils.py
tests/pytorch/utils.py
+3
-0
transformer_engine/common/CMakeLists.txt
transformer_engine/common/CMakeLists.txt
+282
-162
transformer_engine/common/__init__.py
transformer_engine/common/__init__.py
+96
-47
transformer_engine/common/fused_attn/fused_attn.cpp
transformer_engine/common/fused_attn/fused_attn.cpp
+40
-40
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu
...gine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu
+268
-142
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h
...ngine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h
+24
-22
transformer_engine/common/fused_attn/fused_attn_fp8.cu
transformer_engine/common/fused_attn/fused_attn_fp8.cu
+4
-2
transformer_engine/common/fused_attn/utils.h
transformer_engine/common/fused_attn/utils.h
+3
-2
transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu
...mmon/hadamard_transform/hadamard_transform_cast_fusion.cu
+14
-13
transformer_engine/common/include/transformer_engine/fused_attn.h
...mer_engine/common/include/transformer_engine/fused_attn.h
+42
-37
transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu
...mmon/transpose/quantize_transpose_vector_blockwise_fp4.cu
+39
-37
transformer_engine/common/util/nvfp4_transpose.cuh
transformer_engine/common/util/nvfp4_transpose.cuh
+142
-148
transformer_engine/common/util/ptx.cuh
transformer_engine/common/util/ptx.cuh
+259
-51
No files found.
tests/jax/test_distributed_softmax.py
View file @
0a5016b1
...
@@ -103,8 +103,10 @@ class TestDistributedSoftmax:
...
@@ -103,8 +103,10 @@ class TestDistributedSoftmax:
devices
=
np
.
asarray
(
jax
.
devices
()[:
device_count
]).
reshape
(
*
mesh_shape
)
devices
=
np
.
asarray
(
jax
.
devices
()[:
device_count
]).
reshape
(
*
mesh_shape
)
mesh
=
Mesh
(
devices
,
mesh_axes
)
mesh
=
Mesh
(
devices
,
mesh_axes
)
with
mesh
,
autocast
(
mesh_resource
=
mesh_resource
):
with
mesh
,
autocast
(
mesh_resource
=
mesh_resource
):
x_
=
jax
.
device_put
(
x
,
NamedSharding
(
mesh
,
x_pspec
))
x_named_sharding
=
NamedSharding
(
mesh
,
x_pspec
)
mask_
=
jax
.
device_put
(
mask
,
NamedSharding
(
mesh
,
mask_pspec
))
mask_named_sharding
=
NamedSharding
(
mesh
,
mask_pspec
)
x_
=
jax
.
device_put
(
x
,
x_named_sharding
)
mask_
=
jax
.
device_put
(
mask
,
mask_named_sharding
)
with
warnings
.
catch_warnings
(
record
=
True
)
as
warns
:
with
warnings
.
catch_warnings
(
record
=
True
)
as
warns
:
try
:
try
:
...
@@ -116,8 +118,8 @@ class TestDistributedSoftmax:
...
@@ -116,8 +118,8 @@ class TestDistributedSoftmax:
grad_args
=
(
0
,),
grad_args
=
(
0
,),
metric_fwd_dtype
=
dtype
,
metric_fwd_dtype
=
dtype
,
metric_bwd_dtype
=
dtype
,
metric_bwd_dtype
=
dtype
,
in_shardings
=
(
x_
pspec
,
mask_pspec
),
in_shardings
=
(
x_
named_sharding
,
mask_named_sharding
),
out_shardings
=
(
None
,
(
x_
pspec
,)
),
out_shardings
=
(
None
,
x_
named_sharding
),
)
)
except
AssertionError
as
err
:
except
AssertionError
as
err
:
# Softmax should still produce the correct numerical result with
# Softmax should still produce the correct numerical result with
...
...
tests/jax/test_fused_attn.py
View file @
0a5016b1
...
@@ -378,14 +378,14 @@ class FusedAttnRunner:
...
@@ -378,14 +378,14 @@ class FusedAttnRunner:
pytest
.
skip
(
pytest
.
skip
(
"seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
"seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
)
)
# TODO(KshitijLakhani): Set the upper limit for skipping this test when cuDNN adds support
if
(
if
(
get_device_compute_capability
(
0
)
=
=
100
get_device_compute_capability
(
0
)
>
=
100
and
self
.
dropout_prob
==
0.1
and
self
.
dropout_prob
==
0.1
and
self
.
attn_bias_type
is
not
AttnBiasType
.
NO_BIAS
and
self
.
attn_bias_type
is
not
AttnBiasType
.
NO_BIAS
):
):
pytest
.
skip
(
pytest
.
skip
(
"For sm100, bprop kernel support for dropout + determinism (bias) is not supported"
"For sm100
+
, bprop kernel support for dropout + determinism (bias) is not supported"
)
)
# Test the MLA case where head dims for qk differ from head dims for v, only if the tensors
# Test the MLA case where head dims for qk differ from head dims for v, only if the tensors
# are provided in BSHD_BSHD_BSHD or THD_THD_THD formats
# are provided in BSHD_BSHD_BSHD or THD_THD_THD formats
...
...
tests/jax/test_helper.py
View file @
0a5016b1
...
@@ -3,11 +3,13 @@
...
@@ -3,11 +3,13 @@
# See LICENSE for license information.
# See LICENSE for license information.
import
unittest
import
unittest
from
functools
import
partial
import
flax
import
flax
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
import
numpy
as
np
import
numpy
as
np
from
flax
import
linen
as
nn
from
utils
import
assert_allclose
from
utils
import
assert_allclose
from
transformer_engine.common.recipe
import
(
from
transformer_engine.common.recipe
import
(
...
@@ -24,15 +26,51 @@ from transformer_engine.jax.quantize import (
...
@@ -24,15 +26,51 @@ from transformer_engine.jax.quantize import (
ScalingMode
,
ScalingMode
,
update_collections
,
update_collections
,
TensorSource
,
TensorSource
,
QuantizerFactory
,
QuantizeLayout
,
)
)
from
transformer_engine.jax.quantize.helper
import
_format2dtypes
from
transformer_engine.jax.quantize.helper
import
_format2dtypes
from
transformer_engine.jax.sharding
import
MeshResource
,
global_mesh_resource
from
transformer_engine.jax.sharding
import
MeshResource
,
global_mesh_resource
from
transformer_engine.jax.flax.module
import
TransformerEngineBase
is_fp8_supported
,
reason
=
is_scaling_mode_supported
(
ScalingMode
.
DELAYED_TENSOR_SCALING
)
is_fp8_supported
,
reason
=
is_scaling_mode_supported
(
ScalingMode
.
DELAYED_TENSOR_SCALING
)
is_mxfp8_supported
,
mxfp8_reason
=
is_scaling_mode_supported
(
ScalingMode
.
MXFP8_1D_SCALING
)
is_mxfp8_supported
,
mxfp8_reason
=
is_scaling_mode_supported
(
ScalingMode
.
MXFP8_1D_SCALING
)
is_nvfp4_supported
,
nvfp4_reason
=
is_scaling_mode_supported
(
ScalingMode
.
NVFP4_1D_SCALING
)
is_nvfp4_supported
,
nvfp4_reason
=
is_scaling_mode_supported
(
ScalingMode
.
NVFP4_1D_SCALING
)
def
quantizer_check_vjp
(
outer_quantizer_set
,
assertion_func
,
x
):
"""Check that the quantizers in the quantizer set are as expected and reconstructed correctly from flattened pytree representations across VJP boundaries."""
# Define a function with a custom VJP (vector-Jacobian product)
@
partial
(
jax
.
custom_vjp
,
nondiff_argnums
=
(
1
,))
def
quantizer_check
(
inner_quantizer_set
,
assertion_func
,
x
):
return
quantizer_check_fwd
(
inner_quantizer_set
,
assertion_func
,
x
)
def
quantizer_check_fwd
(
inner_quantizer_set
,
assertion_func
,
x
):
assertion_func
(
inner_quantizer_set
.
x
,
TensorSource
.
X
)
assertion_func
(
inner_quantizer_set
.
kernel
,
TensorSource
.
KERNEL
)
assertion_func
(
inner_quantizer_set
.
dgrad
,
TensorSource
.
DGRAD
)
return
x
def
quantizer_check_bwd
(
ctx
,
g
):
return
(
g
,)
quantizer_check
.
defvjp
(
quantizer_check_fwd
,
quantizer_check_bwd
)
return
quantizer_check
(
outer_quantizer_set
,
assertion_func
,
x
)
class
TestModule
(
TransformerEngineBase
):
"""A simple module to test quantizer creation and reconstruction across VJP boundaries."""
# Signature: (quantizer: Quantizer, tensor_source: TensorSource) -> None
assertion_func
:
callable
@
nn
.
compact
def
__call__
(
self
,
x
):
quantizer_set
=
self
.
generate_quantizer_set
()
return
quantizer_check_vjp
(
quantizer_set
,
self
.
assertion_func
,
x
)
class
TestHelper
(
unittest
.
TestCase
):
class
TestHelper
(
unittest
.
TestCase
):
@
unittest
.
skipIf
(
not
is_fp8_supported
,
reason
=
reason
)
@
unittest
.
skipIf
(
not
is_fp8_supported
,
reason
=
reason
)
...
@@ -89,12 +127,43 @@ class TestFP8Functions(unittest.TestCase):
...
@@ -89,12 +127,43 @@ class TestFP8Functions(unittest.TestCase):
for
tensor_source
in
TensorSource
:
for
tensor_source
in
TensorSource
:
target_scaling_mode
=
(
target_scaling_mode
=
(
ScalingMode
.
NVFP4_2D_SCALING
ScalingMode
.
NVFP4_2D_SCALING
if
tensor_source
==
TensorSource
.
KERNEL
if
(
not
test
.
disable_2d_quantization
)
and
tensor_source
==
TensorSource
.
KERNEL
else
ScalingMode
.
NVFP4_1D_SCALING
else
ScalingMode
.
NVFP4_1D_SCALING
)
)
self
.
assertEqual
(
self
.
assertEqual
(
get_quantize_config
().
get_scaling_mode
(
tensor_source
),
target_scaling_mode
get_quantize_config
().
get_scaling_mode
(
tensor_source
),
target_scaling_mode
)
)
self
.
assertEqual
(
get_quantize_config
().
DISABLE_STOCHASTIC_ROUNDING
,
test
.
disable_stochastic_rounding
)
self
.
assertEqual
(
get_quantize_config
().
DISABLE_RHT
,
test
.
disable_rht
)
self
.
assertEqual
(
get_quantize_config
().
DISABLE_2D_QUANTIZATION
,
test
.
disable_2d_quantization
)
def
_compare_nvfp4_scaling_quantizers
(
self
,
test
):
"""Check that the quantizers created have the expected stochastic rounding state and the state is preserved across VJP boundaries."""
def
assertion_func
(
quantizer
,
tensor_source
):
if
test
.
disable_stochastic_rounding
or
tensor_source
!=
TensorSource
.
DGRAD
:
self
.
assertIsNone
(
quantizer
.
stochastic_rounding_rng_state
)
else
:
self
.
assertIsNotNone
(
quantizer
.
stochastic_rounding_rng_state
)
expected_rht
=
(
quantizer
.
scaling_mode
==
ScalingMode
.
NVFP4_1D_SCALING
and
quantizer
.
q_layout
in
{
QuantizeLayout
.
ROWWISE_COLWISE
,
QuantizeLayout
.
COLWISE
}
and
not
test
.
disable_rht
)
self
.
assertEqual
(
quantizer
.
use_rht
,
expected_rht
)
x
=
jnp
.
ones
((),
dtype
=
jnp
.
float32
)
test_module
=
TestModule
(
assertion_func
=
assertion_func
)
param_key
,
sr_key
=
jax
.
random
.
split
(
jax
.
random
.
PRNGKey
(
0
))
rngs
=
{
"params"
:
param_key
,
"sr_rng"
:
sr_key
}
variables
=
test_module
.
init
(
rngs
,
x
)
jax
.
jit
(
jax
.
value_and_grad
(
test_module
.
apply
),
static_argnums
=
(
2
,))(
variables
,
x
,
rngs
=
rngs
)
@
unittest
.
skipIf
(
not
is_fp8_supported
,
reason
=
reason
)
@
unittest
.
skipIf
(
not
is_fp8_supported
,
reason
=
reason
)
def
test_autocast_delayed_scaling
(
self
):
def
test_autocast_delayed_scaling
(
self
):
...
@@ -171,5 +240,16 @@ class TestFP8Functions(unittest.TestCase):
...
@@ -171,5 +240,16 @@ class TestFP8Functions(unittest.TestCase):
with
autocast
(
enabled
=
True
,
recipe
=
bs
,
mesh_resource
=
MeshResource
()):
with
autocast
(
enabled
=
True
,
recipe
=
bs
,
mesh_resource
=
MeshResource
()):
self
.
assertTrue
(
get_quantize_config
().
is_fp8_enabled
())
self
.
assertTrue
(
get_quantize_config
().
is_fp8_enabled
())
self
.
_compare_nvfp4_scaling
(
bs
)
self
.
_compare_nvfp4_scaling
(
bs
)
self
.
_compare_nvfp4_scaling_quantizers
(
bs
)
bs
=
NVFP4BlockScaling
(
disable_stochastic_rounding
=
True
,
disable_rht
=
True
,
disable_2d_quantization
=
True
,
)
with
autocast
(
enabled
=
True
,
recipe
=
bs
,
mesh_resource
=
MeshResource
()):
self
.
assertTrue
(
get_quantize_config
().
is_fp8_enabled
())
self
.
_compare_nvfp4_scaling
(
bs
)
self
.
_compare_nvfp4_scaling_quantizers
(
bs
)
self
.
_check_default_state
()
self
.
_check_default_state
()
tests/pytorch/attention/run_attention_with_cp.py
View file @
0a5016b1
...
@@ -248,6 +248,7 @@ def run_dpa_with_cp(
...
@@ -248,6 +248,7 @@ def run_dpa_with_cp(
attn_mask_type
=
config
.
attn_mask_type
,
attn_mask_type
=
config
.
attn_mask_type
,
window_size
=
config
.
window_size
,
window_size
=
config
.
window_size
,
softmax_type
=
config
.
softmax_type
,
softmax_type
=
config
.
softmax_type
,
return_max_logit
=
config
.
return_max_logit
,
).
cuda
()
).
cuda
()
if
config
.
softmax_type
!=
"vanilla"
:
if
config
.
softmax_type
!=
"vanilla"
:
core_attn
.
softmax_offset
.
requires_grad
=
True
core_attn
.
softmax_offset
.
requires_grad
=
True
...
@@ -308,6 +309,7 @@ def run_dpa_with_cp(
...
@@ -308,6 +309,7 @@ def run_dpa_with_cp(
fp8_context
=
autocast
(
enabled
=
True
,
recipe
=
fp8_recipe
,
amax_reduction_group
=
cp_comm_group
)
fp8_context
=
autocast
(
enabled
=
True
,
recipe
=
fp8_recipe
,
amax_reduction_group
=
cp_comm_group
)
else
:
else
:
fp8_context
=
nullcontext
()
fp8_context
=
nullcontext
()
max_logit
=
None
with
fp8_context
:
with
fp8_context
:
# q, k, v, out in FP8; dout in F16
# q, k, v, out in FP8; dout in F16
out
=
core_attn
(
out
=
core_attn
(
...
@@ -322,6 +324,8 @@ def run_dpa_with_cp(
...
@@ -322,6 +324,8 @@ def run_dpa_with_cp(
cu_seqlens_kv_padded
=
cu_seqlens_kv_padded
,
cu_seqlens_kv_padded
=
cu_seqlens_kv_padded
,
fp8_output
=
fp8_mha
,
fp8_output
=
fp8_mha
,
)
)
if
config
.
return_max_logit
:
out
,
max_logit
=
out
if
fp8_bwd
and
fp8_mha
:
if
fp8_bwd
and
fp8_mha
:
dout_fp8
=
dout_quantizer
(
dout
)
dout_fp8
=
dout_quantizer
(
dout
)
out
.
backward
(
dout_fp8
)
out
.
backward
(
dout_fp8
)
...
@@ -400,6 +404,7 @@ def run_dpa_with_cp(
...
@@ -400,6 +404,7 @@ def run_dpa_with_cp(
fp8_context
=
nullcontext
()
fp8_context
=
nullcontext
()
# run attention
# run attention
max_logit_
=
None
with
fp8_context
:
with
fp8_context
:
# q, k, v, out in FP8; dout in F16
# q, k, v, out in FP8; dout in F16
out_
=
core_attn
(
out_
=
core_attn
(
...
@@ -414,6 +419,8 @@ def run_dpa_with_cp(
...
@@ -414,6 +419,8 @@ def run_dpa_with_cp(
cu_seqlens_kv_padded
=
cu_seqlens_kv_padded
,
cu_seqlens_kv_padded
=
cu_seqlens_kv_padded
,
fp8_output
=
fp8_mha
,
fp8_output
=
fp8_mha
,
)
)
if
config
.
return_max_logit
:
out_
,
max_logit_
=
out_
if
fp8_bwd
and
fp8_mha
:
if
fp8_bwd
and
fp8_mha
:
dout_fp8_
=
dout_quantizer
(
dout_
)
dout_fp8_
=
dout_quantizer
(
dout_
)
out_
.
backward
(
dout_fp8_
)
out_
.
backward
(
dout_fp8_
)
...
@@ -495,15 +502,15 @@ def run_dpa_with_cp(
...
@@ -495,15 +502,15 @@ def run_dpa_with_cp(
)
)
atol
,
rtol
,
rmse_tol
=
get_tols
(
config
,
dtype
)
atol
,
rtol
,
rmse_tol
=
get_tols
(
config
,
dtype
)
tensors_cp
=
[
out_
,
dq_
,
dk_
,
dv_
,
d_softmax_offset_
]
tensors_cp
=
[
out_
,
dq_
,
dk_
,
dv_
,
d_softmax_offset_
,
max_logit_
]
tensors_no_cp
=
[
out
,
dq
,
dk
,
dv
,
d_softmax_offset
]
tensors_no_cp
=
[
out
,
dq
,
dk
,
dv
,
d_softmax_offset
,
max_logit
]
names
=
[
"out"
,
"dq"
,
"dk"
,
"dv"
,
"d_softmax_offset"
]
names
=
[
"out"
,
"dq"
,
"dk"
,
"dv"
,
"d_softmax_offset"
,
"max_logit"
]
names_cp
=
[
x
+
"_cp"
for
x
in
names
]
names_cp
=
[
x
+
"_cp"
for
x
in
names
]
names_no_cp
=
[
x
+
"_no_cp"
for
x
in
names
]
names_no_cp
=
[
x
+
"_no_cp"
for
x
in
names
]
is_fp8
=
dtype
==
"fp8"
is_fp8
=
dtype
==
"fp8"
for
i
,
t
in
enumerate
(
tensors_no_cp
):
for
i
,
t
in
enumerate
(
tensors_no_cp
):
if
t
is
not
None
:
if
t
is
not
None
:
if
"softmax_offset"
not
in
names
[
i
]:
if
"softmax_offset"
not
in
names
[
i
]
and
"max_logit"
not
in
names
[
i
]:
if
qkv_format
==
"bshd"
:
if
qkv_format
==
"bshd"
:
compare_and_assert
(
compare_and_assert
(
t
[:,
0
],
t
[:,
0
],
...
...
tests/pytorch/attention/test_attention.py
View file @
0a5016b1
...
@@ -60,8 +60,16 @@ from utils import (
...
@@ -60,8 +60,16 @@ from utils import (
get_available_attention_backends
,
get_available_attention_backends
,
)
)
# Check if hardware supports FP8
# Check if hardware supports FP8
attention.
fp8_available
,
reason_for_no_fp8
=
is_fp8_available
(
return_reason
=
True
)
fp8_available
,
reason_for_no_fp8
=
is_fp8_available
(
return_reason
=
True
)
fp8_attn_available
,
reason_for_no_fp8_attn
=
fp8_available
,
reason_for_no_fp8
device_compute_capability
=
get_device_compute_capability
()
if
fp8_available
and
(
device_compute_capability
<
(
9
,
0
)
or
device_compute_capability
>=
(
12
,
0
)):
fp8_attn_available
=
False
reason_for_no_fp8_attn
=
(
"FP8 attention is not supported for compute capability ="
f
" sm
{
device_compute_capability
[
0
]
*
10
+
device_compute_capability
[
1
]
}
"
)
# Reset RNG seed and states
# Reset RNG seed and states
seed
=
1234
seed
=
1234
...
@@ -130,6 +138,11 @@ def test_dot_product_attention(
...
@@ -130,6 +138,11 @@ def test_dot_product_attention(
if
config
.
window_size
==
(
-
1
,
-
1
)
and
swa
:
if
config
.
window_size
==
(
-
1
,
-
1
)
and
swa
:
config
.
window_size
=
[
2
,
2
]
config
.
window_size
=
[
2
,
2
]
config
.
window_size
=
check_set_window_size
(
config
.
attn_mask_type
,
config
.
window_size
)
config
.
window_size
=
check_set_window_size
(
config
.
attn_mask_type
,
config
.
window_size
)
qkv_format
=
qkv_layout
.
replace
(
"3"
,
""
).
replace
(
"2"
,
""
).
split
(
"_"
)[
0
]
if
qkv_format
==
"thd"
and
"padding"
not
in
config
.
attn_mask_type
:
config
.
attn_mask_type
=
(
"padding_"
+
config
.
attn_mask_type
if
config
.
attn_mask_type
!=
"no_mask"
else
"padding"
)
# Get backends
# Get backends
is_training
=
True
is_training
=
True
...
@@ -171,7 +184,7 @@ def test_dot_product_attention(
...
@@ -171,7 +184,7 @@ def test_dot_product_attention(
# UnfusedDotProductAttention backend
# UnfusedDotProductAttention backend
if
unfused_attn_supported
:
if
unfused_attn_supported
:
unfused_attn_fwd
,
unfused_attn_bwd
=
_run_dot_product_attention
(
unfused_attn_fwd
,
unfused_max_logit
,
unfused_attn_bwd
=
_run_dot_product_attention
(
dtype
,
dtype
,
config
,
config
,
"UnfusedDotProductAttention"
,
"UnfusedDotProductAttention"
,
...
@@ -185,7 +198,7 @@ def test_dot_product_attention(
...
@@ -185,7 +198,7 @@ def test_dot_product_attention(
# FusedAttention backend
# FusedAttention backend
if
fused_attn_supported
:
if
fused_attn_supported
:
if
len
(
fused_attn_backends
)
==
1
:
if
len
(
fused_attn_backends
)
==
1
:
fused_attn_fwd
,
fused_attn_bwd
=
_run_dot_product_attention
(
fused_attn_fwd
,
fused_max_logit
,
fused_attn_bwd
=
_run_dot_product_attention
(
dtype
,
dtype
,
config
,
config
,
"FusedAttention"
,
"FusedAttention"
,
...
@@ -197,7 +210,7 @@ def test_dot_product_attention(
...
@@ -197,7 +210,7 @@ def test_dot_product_attention(
)
)
if
len
(
fused_attn_backends
)
==
2
:
if
len
(
fused_attn_backends
)
==
2
:
os
.
environ
[
"NVTE_FUSED_ATTN_BACKEND"
]
=
"0"
os
.
environ
[
"NVTE_FUSED_ATTN_BACKEND"
]
=
"0"
fused_attn_fwd
,
fused_attn_bwd
=
_run_dot_product_attention
(
fused_attn_fwd
,
_
,
fused_attn_bwd
=
_run_dot_product_attention
(
dtype
,
dtype
,
config
,
config
,
"FusedAttention"
,
"FusedAttention"
,
...
@@ -208,7 +221,7 @@ def test_dot_product_attention(
...
@@ -208,7 +221,7 @@ def test_dot_product_attention(
is_training
,
is_training
,
)
)
os
.
environ
[
"NVTE_FUSED_ATTN_BACKEND"
]
=
"1"
os
.
environ
[
"NVTE_FUSED_ATTN_BACKEND"
]
=
"1"
fused_attn_fwd_1
,
fused_attn_bwd_1
=
_run_dot_product_attention
(
fused_attn_fwd_1
,
_
,
fused_attn_bwd_1
=
_run_dot_product_attention
(
dtype
,
dtype
,
config
,
config
,
"FusedAttention"
,
"FusedAttention"
,
...
@@ -221,7 +234,7 @@ def test_dot_product_attention(
...
@@ -221,7 +234,7 @@ def test_dot_product_attention(
# FlashAttention backend
# FlashAttention backend
if
flash_attn_supported
:
if
flash_attn_supported
:
flash_attn_fwd
,
flash_attn_bwd
=
_run_dot_product_attention
(
flash_attn_fwd
,
_
,
flash_attn_bwd
=
_run_dot_product_attention
(
dtype
,
dtype
,
config
,
config
,
"FlashAttention"
,
"FlashAttention"
,
...
@@ -242,6 +255,8 @@ def test_dot_product_attention(
...
@@ -242,6 +255,8 @@ def test_dot_product_attention(
if
unfused_attn_supported
and
fused_attn_supported
:
if
unfused_attn_supported
and
fused_attn_supported
:
logging
.
info
(
"[test_dot_product_attention]: unfused attn vs fused attn"
)
logging
.
info
(
"[test_dot_product_attention]: unfused attn vs fused attn"
)
torch
.
testing
.
assert_close
(
fused_attn_fwd
,
unfused_attn_fwd
,
**
tols
)
torch
.
testing
.
assert_close
(
fused_attn_fwd
,
unfused_attn_fwd
,
**
tols
)
if
config
.
return_max_logit
:
torch
.
testing
.
assert_close
(
fused_max_logit
,
unfused_max_logit
,
**
tols
)
for
i
,
_
in
enumerate
(
unfused_attn_bwd
):
for
i
,
_
in
enumerate
(
unfused_attn_bwd
):
torch
.
testing
.
assert_close
(
fused_attn_bwd
[
i
],
unfused_attn_bwd
[
i
],
**
tols
)
torch
.
testing
.
assert_close
(
fused_attn_bwd
[
i
],
unfused_attn_bwd
[
i
],
**
tols
)
if
fused_attn_supported
and
flash_attn_supported
:
if
fused_attn_supported
and
flash_attn_supported
:
...
@@ -265,6 +280,33 @@ def test_dpa_checkpoint(dtype, model_configs, model):
...
@@ -265,6 +280,33 @@ def test_dpa_checkpoint(dtype, model_configs, model):
test_dot_product_attention
(
dtype
,
model_configs
,
model
,
True
,
True
,
None
,
False
,
False
)
test_dot_product_attention
(
dtype
,
model_configs
,
model
,
True
,
True
,
None
,
False
,
False
)
model_configs_max_logit
=
{
# test: ModelConfig(b, sq, hq, dqk)
"max_logit_1"
:
ModelConfig
(
1
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
),
"max_logit_2"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
attn_mask_type
=
"causal"
),
"max_logit_3"
:
ModelConfig
(
2
,
1
,
16
,
128
,
max_seqlen_kv
=
2048
,
attn_mask_type
=
"padding_causal"
),
"max_logit_4"
:
ModelConfig
(
8
,
128
,
16
,
192
,
max_seqlen_kv
=
2048
,
attn_bias_type
=
"post_scale_bias"
),
"max_logit_5"
:
ModelConfig
(
8
,
128
,
16
,
512
,
max_seqlen_kv
=
2048
,
attn_mask_type
=
"causal"
,
window_size
=
(
20
,
0
)
),
"max_logit_6"
:
ModelConfig
(
8
,
1
,
16
,
1024
,
max_seqlen_kv
=
2048
),
}
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
8
,
9
,
1
),
reason
=
"cuDNN 8.9.1+ is required."
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"model_configs"
,
[
model_configs_max_logit
])
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_max_logit
.
keys
())
@
pytest
.
mark
.
parametrize
(
"qkv_layout"
,
[
"sbhd_sbhd_sbhd"
,
"thd_thd_thd"
])
def
test_dpa_max_logit
(
dtype
,
model_configs
,
model
,
qkv_layout
):
"""Test DotProductAttention module with checkpointing"""
config
=
model_configs
[
model
]
config
.
return_max_logit
=
True
test_dot_product_attention
(
dtype
,
model_configs
,
model
,
False
,
True
,
qkv_layout
,
False
,
False
)
model_configs_softmax
=
{
model_configs_softmax
=
{
# test: ModelConfig(b, sq, hq, dqk)
# test: ModelConfig(b, sq, hq, dqk)
"softmax_1_0"
:
ModelConfig
(
2
,
2048
,
64
,
64
,
num_gqa_groups
=
8
),
"softmax_1_0"
:
ModelConfig
(
2
,
2048
,
64
,
64
,
num_gqa_groups
=
8
),
...
@@ -962,6 +1004,8 @@ def _run_dot_product_attention(
...
@@ -962,6 +1004,8 @@ def _run_dot_product_attention(
layout
=
layout
.
replace
(
"d"
,
"dqk"
)
layout
=
layout
.
replace
(
"d"
,
"dqk"
)
tensor_shape
=
[
dim_to_num
[
j
]
for
j
in
layout
.
split
(
"_"
)]
tensor_shape
=
[
dim_to_num
[
j
]
for
j
in
layout
.
split
(
"_"
)]
tensor
=
0.1
*
torch
.
randn
(
tensor_shape
,
dtype
=
dtype
,
device
=
"cuda"
)
tensor
=
0.1
*
torch
.
randn
(
tensor_shape
,
dtype
=
dtype
,
device
=
"cuda"
)
# tensor: with padding tokens
# tensor_orig: without padding tokens
tensor_orig
=
tensor
tensor_orig
=
tensor
if
qkv_format
==
"thd"
and
pad_between_seqs
:
if
qkv_format
==
"thd"
and
pad_between_seqs
:
tensor_orig
=
torch
.
Tensor
([]).
to
(
device
=
"cuda"
,
dtype
=
dtype
)
tensor_orig
=
torch
.
Tensor
([]).
to
(
device
=
"cuda"
,
dtype
=
dtype
)
...
@@ -1071,6 +1115,7 @@ def _run_dot_product_attention(
...
@@ -1071,6 +1115,7 @@ def _run_dot_product_attention(
layer_number
=
1
,
layer_number
=
1
,
attention_type
=
config
.
attn_type
,
attention_type
=
config
.
attn_type
,
softmax_type
=
config
.
softmax_type
,
softmax_type
=
config
.
softmax_type
,
return_max_logit
=
config
.
return_max_logit
,
).
to
(
dtype
=
dtype
,
device
=
"cuda"
)
).
to
(
dtype
=
dtype
,
device
=
"cuda"
)
if
not
is_training
:
if
not
is_training
:
block
=
block
.
eval
()
block
=
block
.
eval
()
...
@@ -1108,16 +1153,21 @@ def _run_dot_product_attention(
...
@@ -1108,16 +1153,21 @@ def _run_dot_product_attention(
alibi_slopes
=
alibi_slopes
,
alibi_slopes
=
alibi_slopes
,
fast_zero_fill
=
True
,
fast_zero_fill
=
True
,
)
)
max_logit
=
None
if
config
.
return_max_logit
:
out
,
max_logit
=
out
if
is_training
:
if
is_training
:
out
.
backward
(
d_out
)
out
.
backward
(
d_out
)
d_softmax_offset
=
None
d_softmax_offset
=
None
if
is_training
and
config
.
softmax_type
!=
"vanilla"
:
if
is_training
and
config
.
softmax_type
!=
"vanilla"
:
d_softmax_offset
=
block
.
softmax_offset
.
grad
d_softmax_offset
=
block
.
softmax_offset
.
grad
if
backend
in
[
"FlashAttention"
,
"UnfusedDotProductAttention"
]:
if
backend
in
[
"FlashAttention"
,
"UnfusedDotProductAttention"
]:
if
is_training
:
if
is_training
:
return
out
,
(
q
.
grad
,
k
.
grad
,
v
.
grad
,
d_softmax_offset
)
return
out
,
max_logit
,
(
q
.
grad
,
k
.
grad
,
v
.
grad
,
d_softmax_offset
)
else
:
else
:
return
out
,
(
None
,
None
,
None
,
d_softmax_offset
)
return
out
,
max_logit
,
(
None
,
None
,
None
,
d_softmax_offset
)
if
backend
==
"FusedAttention"
:
if
backend
==
"FusedAttention"
:
if
qkv_format
==
"thd"
and
pad_between_seqs
:
if
qkv_format
==
"thd"
and
pad_between_seqs
:
out_orig
=
torch
.
Tensor
([]).
to
(
device
=
"cuda"
,
dtype
=
dtype
)
out_orig
=
torch
.
Tensor
([]).
to
(
device
=
"cuda"
,
dtype
=
dtype
)
...
@@ -1146,14 +1196,18 @@ def _run_dot_product_attention(
...
@@ -1146,14 +1196,18 @@ def _run_dot_product_attention(
[
v_grad_orig
,
v
.
grad
[
valid_range_kv
[
0
]
:
valid_range_kv
[
1
]]],
dim
=
0
[
v_grad_orig
,
v
.
grad
[
valid_range_kv
[
0
]
:
valid_range_kv
[
1
]]],
dim
=
0
)
)
if
is_training
:
if
is_training
:
return
out_orig
,
(
q_grad_orig
,
k_grad_orig
,
v_grad_orig
,
d_softmax_offset
)
return
(
out_orig
,
max_logit
,
(
q_grad_orig
,
k_grad_orig
,
v_grad_orig
,
d_softmax_offset
),
)
else
:
else
:
return
out_orig
,
(
None
,
None
,
None
,
d_softmax_offset
)
return
out_orig
,
max_logit
,
(
None
,
None
,
None
,
d_softmax_offset
)
else
:
else
:
if
is_training
:
if
is_training
:
return
out
,
(
q
.
grad
,
k
.
grad
,
v
.
grad
,
d_softmax_offset
)
return
out
,
max_logit
,
(
q
.
grad
,
k
.
grad
,
v
.
grad
,
d_softmax_offset
)
else
:
else
:
return
out
,
(
None
,
None
,
None
,
d_softmax_offset
)
return
out
,
max_logit
,
(
None
,
None
,
None
,
d_softmax_offset
)
model_configs_te_layer
=
{
model_configs_te_layer
=
{
...
@@ -1527,8 +1581,7 @@ model_configs_fp8_extra_state = {
...
@@ -1527,8 +1581,7 @@ model_configs_fp8_extra_state = {
}
}
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
skipif
(
not
fp8_attn_available
,
reason
=
reason_for_no_fp8_attn
)
@
pytest
.
mark
.
skipif
(
get_device_compute_capability
()
<
(
9
,
0
),
reason
=
"FP8 tests require Hopper."
)
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
9
,
3
,
0
),
reason
=
"cuDNN 9.3.0+ is required."
)
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
9
,
3
,
0
),
reason
=
"cuDNN 9.3.0+ is required."
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"large"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"large"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
...
@@ -1690,8 +1743,7 @@ qkv_format_fp8_vs_f16 = ["bshd", "sbhd"]
...
@@ -1690,8 +1743,7 @@ qkv_format_fp8_vs_f16 = ["bshd", "sbhd"]
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
9
,
2
,
1
),
reason
=
"cuDNN 9.2.1+ is required."
)
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
9
,
2
,
1
),
reason
=
"cuDNN 9.2.1+ is required."
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
skipif
(
not
fp8_attn_available
,
reason
=
reason_for_no_fp8_attn
)
@
pytest
.
mark
.
skipif
(
get_device_compute_capability
()
<
(
9
,
0
),
reason
=
"FP8 tests require Hopper+."
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types_fp8_vs_f16
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types_fp8_vs_f16
)
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_fp8_vs_f16
.
keys
())
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_fp8_vs_f16
.
keys
())
@
pytest
.
mark
.
parametrize
(
"qkv_format"
,
qkv_format_fp8_vs_f16
)
@
pytest
.
mark
.
parametrize
(
"qkv_format"
,
qkv_format_fp8_vs_f16
)
...
@@ -1927,8 +1979,7 @@ def _run_mha_fp8_vs_f16(
...
@@ -1927,8 +1979,7 @@ def _run_mha_fp8_vs_f16(
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
9
,
2
,
1
),
reason
=
"cuDNN 9.2.1+ is required."
)
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
9
,
2
,
1
),
reason
=
"cuDNN 9.2.1+ is required."
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
skipif
(
not
fp8_attn_available
,
reason
=
reason_for_no_fp8_attn
)
@
pytest
.
mark
.
skipif
(
get_device_compute_capability
()
<
(
9
,
0
),
reason
=
"FP8 tests require Hopper+."
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types_fp8_vs_f16
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types_fp8_vs_f16
)
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_fp8_vs_f16
.
keys
())
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_fp8_vs_f16
.
keys
())
@
pytest
.
mark
.
parametrize
(
"qkv_layout"
,
qkv_layout_fp8_vs_f16
)
@
pytest
.
mark
.
parametrize
(
"qkv_layout"
,
qkv_layout_fp8_vs_f16
)
...
@@ -2256,8 +2307,7 @@ models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"]
...
@@ -2256,8 +2307,7 @@ models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"]
),
),
reason
=
f
"""cuDNN
{
"8.9.3"
if
cudnn_frontend_version
==
0
else
"9.2.1"
}
+ is required."""
,
reason
=
f
"""cuDNN
{
"8.9.3"
if
cudnn_frontend_version
==
0
else
"9.2.1"
}
+ is required."""
,
)
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
skipif
(
not
fp8_attn_available
,
reason
=
reason_for_no_fp8_attn
)
@
pytest
.
mark
.
skipif
(
get_device_compute_capability
()
<
(
9
,
0
),
reason
=
"FP8 tests require Hopper+."
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types_fp8
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types_fp8
)
@
pytest
.
mark
.
parametrize
(
"model"
,
models_v1
if
cudnn_frontend_version
==
1
else
models_v0
)
@
pytest
.
mark
.
parametrize
(
"model"
,
models_v1
if
cudnn_frontend_version
==
1
else
models_v0
)
def
test_custom_mha_fp8_vs_f16
(
dtype
,
model
):
def
test_custom_mha_fp8_vs_f16
(
dtype
,
model
):
...
...
tests/pytorch/attention/test_attention_with_cp.py
View file @
0a5016b1
...
@@ -138,8 +138,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
...
@@ -138,8 +138,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
model_configs_fused_attn
=
{
model_configs_fused_attn
=
{
# test: ModelConfig(b, sq, hq, dqk)
# test: ModelConfig(b, sq, hq, dqk)
"cp_1_0"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
attn_mask_type
=
"causal"
),
# MHA
"cp_1_0"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
attn_mask_type
=
"causal"
,
return_max_logit
=
True
),
# MHA
"cp_1_1"
:
ModelConfig
(
2
,
4096
,
12
,
128
),
# MHA
"cp_1_1"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
return_max_logit
=
True
),
# MHA
"cp_1_2"
:
ModelConfig
(
"cp_1_2"
:
ModelConfig
(
2
,
4096
,
12
,
128
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"post_scale_bias"
2
,
4096
,
12
,
128
,
attn_mask_type
=
"causal"
,
attn_bias_type
=
"post_scale_bias"
),
# MHA
),
# MHA
...
@@ -184,7 +184,7 @@ dtypes = ["bf16", "fp16", "fp8"]
...
@@ -184,7 +184,7 @@ dtypes = ["bf16", "fp16", "fp8"]
qkv_formats
=
[
"bshd"
,
"sbhd"
,
"thd"
]
qkv_formats
=
[
"bshd"
,
"sbhd"
,
"thd"
]
cp_comm_types
=
[
"p2p"
,
"all_gather"
,
"a2a"
,
"a2a+p2p"
]
cp_comm_types
=
[
"p2p"
,
"all_gather"
,
"a2a"
,
"a2a+p2p"
]
if
test_essential
:
if
test_essential
:
configs
=
[
"cp_1_0"
,
"cp_2_0"
,
"cp_2_2"
,
"cp_3_2"
,
"cp_4_2"
]
configs
=
[
"cp_1_0"
,
"cp_1_1"
,
"cp_2_0"
,
"cp_2_2"
,
"cp_3_2"
,
"cp_4_2"
]
model_configs_fused_attn
=
{
k
:
model_configs_fused_attn
[
k
]
for
k
in
configs
}
model_configs_fused_attn
=
{
k
:
model_configs_fused_attn
[
k
]
for
k
in
configs
}
dtypes
=
[
"bf16"
,
"fp8"
]
dtypes
=
[
"bf16"
,
"fp8"
]
qkv_formats
=
[
"sbhd"
,
"thd"
]
qkv_formats
=
[
"sbhd"
,
"thd"
]
...
...
tests/pytorch/test_numerics.py
View file @
0a5016b1
...
@@ -45,11 +45,10 @@ from transformer_engine.pytorch import (
...
@@ -45,11 +45,10 @@ from transformer_engine.pytorch import (
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch
import
checkpoint
as
te_checkpoint
from
transformer_engine.pytorch
import
checkpoint
as
te_checkpoint
from
transformer_engine.pytorch.cpp_extensions
import
general_gemm
,
general_grouped_gemm
from
transformer_engine.pytorch.cpp_extensions
import
general_gemm
,
general_grouped_gemm
from
transformer_engine.pytorch.cpp_extensions.fused_attn
import
FusedAttnBackend
from
transformer_engine.pytorch.module.base
import
get_multi_stream_cublas_workspace
,
get_workspace
from
transformer_engine.pytorch.module.base
import
get_multi_stream_cublas_workspace
,
get_workspace
from
transformer_engine.common
import
recipe
from
transformer_engine.common
import
recipe
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
utils
import
ModelConfig
,
reset_rng_states
,
get_available_attention_backends
from
utils
import
ModelConfig
,
reset_rng_states
# Only run FP8 tests on supported devices.
# Only run FP8 tests on supported devices.
...
@@ -135,23 +134,6 @@ if torch.cuda.get_device_capability() == (9, 0):
...
@@ -135,23 +134,6 @@ if torch.cuda.get_device_capability() == (9, 0):
use_cutlass_grouped_gemm
.
append
(
True
)
use_cutlass_grouped_gemm
.
append
(
True
)
def
is_fused_attn_available
(
config
:
ModelConfig
,
dtype
:
torch
.
dtype
,
qkv_layout
=
"bshd_bshd_bshd"
,
is_training
=
True
,
deterministic
=
False
,
):
_
,
_
,
fused_attn_backends
=
get_available_attention_backends
(
config
,
qkv_dtype
=
dtype
,
qkv_layout
=
qkv_layout
,
is_training
=
is_training
,
deterministic
=
deterministic
,
)
return
FusedAttnBackend
[
"F16_arbitrary_seqlen"
]
in
fused_attn_backends
def
get_causal_attn_mask
(
sq
:
int
)
->
torch
.
Tensor
:
def
get_causal_attn_mask
(
sq
:
int
)
->
torch
.
Tensor
:
return
torch
.
triu
(
torch
.
ones
(
sq
,
sq
,
device
=
"cuda"
),
diagonal
=
1
).
bool
()
return
torch
.
triu
(
torch
.
ones
(
sq
,
sq
,
device
=
"cuda"
),
diagonal
=
1
).
bool
()
...
@@ -872,8 +854,6 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
...
@@ -872,8 +854,6 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"126m"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"126m"
])
def
test_gpt_checkpointing
(
dtype
,
bs
,
model
):
def
test_gpt_checkpointing
(
dtype
,
bs
,
model
):
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
not
is_fused_attn_available
(
config
,
dtype
,
deterministic
=
True
):
pytest
.
skip
(
"No attention backend available."
)
outputs
=
_test_e2e_checkpointing
(
bs
,
dtype
,
config
,
checkpoint
=
False
)
outputs
=
_test_e2e_checkpointing
(
bs
,
dtype
,
config
,
checkpoint
=
False
)
outputs_checkpoint
=
_test_e2e_checkpointing
(
bs
,
dtype
,
config
,
checkpoint
=
True
)
outputs_checkpoint
=
_test_e2e_checkpointing
(
bs
,
dtype
,
config
,
checkpoint
=
True
)
...
@@ -920,10 +900,6 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):
...
@@ -920,10 +900,6 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):
@
pytest
.
mark
.
parametrize
(
"parallel_attention_mlp"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"parallel_attention_mlp"
,
all_boolean
)
def
test_gpt_accuracy
(
dtype
,
bs
,
model
,
parallel_attention_mlp
):
def
test_gpt_accuracy
(
dtype
,
bs
,
model
,
parallel_attention_mlp
):
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
not
is_fused_attn_available
(
config
,
dtype
,
qkv_layout
=
"sb3hd"
,
is_training
=
True
,
deterministic
=
True
):
pytest
.
skip
(
"No attention backend available."
)
te_gpt
=
TransformerLayer
(
te_gpt
=
TransformerLayer
(
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
...
@@ -1035,10 +1011,6 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
...
@@ -1035,10 +1011,6 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
@
pytest
.
mark
.
parametrize
(
"mask_type"
,
mask_types
)
@
pytest
.
mark
.
parametrize
(
"mask_type"
,
mask_types
)
def
test_mha_accuracy
(
dtype
,
bs
,
model
,
mask_type
):
def
test_mha_accuracy
(
dtype
,
bs
,
model
,
mask_type
):
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
not
is_fused_attn_available
(
config
,
dtype
,
qkv_layout
=
"sb3hd"
,
is_training
=
True
,
deterministic
=
True
):
pytest
.
skip
(
"No attention backend available."
)
te_mha
=
MultiheadAttention
(
te_mha
=
MultiheadAttention
(
config
.
hidden_size
,
config
.
hidden_size
,
...
...
tests/pytorch/utils.py
View file @
0a5016b1
...
@@ -205,6 +205,7 @@ class ModelConfig:
...
@@ -205,6 +205,7 @@ class ModelConfig:
window_size
:
Tuple
[
int
,
int
]
=
(
-
1
,
-
1
),
window_size
:
Tuple
[
int
,
int
]
=
(
-
1
,
-
1
),
context_parallel
:
bool
=
False
,
context_parallel
:
bool
=
False
,
cp_comm_type
:
str
=
"p2p"
,
cp_comm_type
:
str
=
"p2p"
,
return_max_logit
=
False
,
total_requests
:
int
=
None
,
total_requests
:
int
=
None
,
max_ctx_len
:
int
=
None
,
max_ctx_len
:
int
=
None
,
num_layers
:
int
=
1
,
num_layers
:
int
=
1
,
...
@@ -233,6 +234,7 @@ class ModelConfig:
...
@@ -233,6 +234,7 @@ class ModelConfig:
self
.
window_size
=
check_set_window_size
(
self
.
attn_mask_type
,
window_size
)
self
.
window_size
=
check_set_window_size
(
self
.
attn_mask_type
,
window_size
)
self
.
context_parallel
=
context_parallel
self
.
context_parallel
=
context_parallel
self
.
cp_comm_type
=
cp_comm_type
self
.
cp_comm_type
=
cp_comm_type
self
.
return_max_logit
=
return_max_logit
self
.
total_requests
=
total_requests
self
.
total_requests
=
total_requests
self
.
max_ctx_len
=
max_ctx_len
self
.
max_ctx_len
=
max_ctx_len
self
.
num_layers
=
num_layers
self
.
num_layers
=
num_layers
...
@@ -318,6 +320,7 @@ def get_available_attention_backends(
...
@@ -318,6 +320,7 @@ def get_available_attention_backends(
is_training
=
is_training
,
is_training
=
is_training
,
inference_params
=
inference_params
,
inference_params
=
inference_params
,
softmax_type
=
config
.
softmax_type
,
softmax_type
=
config
.
softmax_type
,
return_max_logit
=
config
.
return_max_logit
,
)
)
(
(
use_flash_attention
,
use_flash_attention
,
...
...
transformer_engine/common/CMakeLists.txt
View file @
0a5016b1
...
@@ -29,15 +29,6 @@ endif()
...
@@ -29,15 +29,6 @@ endif()
# Language options
# Language options
if
(
USE_CUDA
)
if
(
USE_CUDA
)
if
(
NOT DEFINED CMAKE_CUDA_ARCHITECTURES
)
if
(
CUDAToolkit_VERSION VERSION_GREATER_EQUAL 13.0
)
set
(
CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120
)
elseif
(
CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8
)
set
(
CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120
)
else
()
set
(
CMAKE_CUDA_ARCHITECTURES 70 80 89 90
)
endif
()
endif
()
set
(
CMAKE_CXX_STANDARD 17
)
set
(
CMAKE_CXX_STANDARD 17
)
set
(
CMAKE_CUDA_STANDARD 17
)
set
(
CMAKE_CUDA_STANDARD 17
)
set
(
CMAKE_CUDA_STANDARD_REQUIRED ON
)
set
(
CMAKE_CUDA_STANDARD_REQUIRED ON
)
...
@@ -54,8 +45,62 @@ if(USE_CUDA)
...
@@ -54,8 +45,62 @@ if(USE_CUDA)
# CUDA Toolkit
# CUDA Toolkit
find_package
(
CUDAToolkit REQUIRED
)
find_package
(
CUDAToolkit REQUIRED
)
if
(
CUDAToolkit_VERSION VERSION_LESS 12.0
)
if
(
CUDAToolkit_VERSION VERSION_LESS 12.1
)
message
(
FATAL_ERROR
"CUDA 12.0+ is required, but found CUDA
${
CUDAToolkit_VERSION
}
"
)
message
(
FATAL_ERROR
"CUDA 12.1+ is required, but found CUDA
${
CUDAToolkit_VERSION
}
"
)
endif
()
# Process GPU architectures
if
(
NOT DEFINED CMAKE_CUDA_ARCHITECTURES
)
if
(
CUDAToolkit_VERSION VERSION_GREATER_EQUAL 13.0
)
set
(
CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120
)
elseif
(
CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8
)
set
(
CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120
)
else
()
set
(
CMAKE_CUDA_ARCHITECTURES 70 80 89 90
)
endif
()
endif
()
# Process CMAKE_CUDA_ARCHITECTURES to separate generic and specific architectures
set
(
NVTE_GENERIC_ARCHS
)
set
(
NVTE_SPECIFIC_ARCHS
)
# Check for architecture 100
list
(
FIND CMAKE_CUDA_ARCHITECTURES
"100"
arch_100_index
)
if
(
NOT arch_100_index EQUAL -1
)
list
(
REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES
"100"
)
list
(
APPEND NVTE_GENERIC_ARCHS
"100"
)
list
(
APPEND NVTE_SPECIFIC_ARCHS
"100a"
)
if
(
CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.9
)
list
(
APPEND NVTE_SPECIFIC_ARCHS
"103a"
)
endif
()
endif
()
# Check for architecture 101 (if we see this we are in toolkit <= 12.9)
list
(
FIND CMAKE_CUDA_ARCHITECTURES
"101"
arch_101_index
)
if
(
NOT arch_101_index EQUAL -1
)
list
(
REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES
"101"
)
list
(
APPEND NVTE_GENERIC_ARCHS
"101"
)
list
(
APPEND NVTE_SPECIFIC_ARCHS
"101a"
)
endif
()
# Check for architecture 110 (if we see this we are in toolkit >= 13.0)
list
(
FIND CMAKE_CUDA_ARCHITECTURES
"110"
arch_110_index
)
if
(
NOT arch_110_index EQUAL -1
)
list
(
REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES
"110"
)
list
(
APPEND NVTE_GENERIC_ARCHS
"110"
)
list
(
APPEND NVTE_SPECIFIC_ARCHS
"110f"
)
endif
()
# Check for architecture 120
list
(
FIND CMAKE_CUDA_ARCHITECTURES
"120"
arch_120_index
)
if
(
NOT arch_120_index EQUAL -1
)
list
(
REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES
"120"
)
list
(
APPEND NVTE_GENERIC_ARCHS
"120"
)
if
(
CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.9
)
list
(
APPEND NVTE_SPECIFIC_ARCHS
"120f"
)
else
()
list
(
APPEND NVTE_SPECIFIC_ARCHS
"120a"
)
endif
()
endif
()
endif
()
# cuDNN frontend API
# cuDNN frontend API
...
@@ -135,11 +180,29 @@ endif()
...
@@ -135,11 +180,29 @@ endif()
# Configure Transformer Engine library
# Configure Transformer Engine library
include_directories
(
${
PROJECT_SOURCE_DIR
}
/..
)
include_directories
(
${
PROJECT_SOURCE_DIR
}
/..
)
set
(
transformer_engine_SOURCES
)
set
(
transformer_engine_SOURCES
)
set
(
transformer_engine_cpp_sources
)
set
(
transformer_engine_cuda_sources
)
set
(
transformer_engine_cuda_arch_specific_sources
)
if
(
USE_CUDA
)
if
(
USE_CUDA
)
list
(
APPEND transformer_engine_
SOURCES
list
(
APPEND transformer_engine_
cpp_sources
cudnn_utils.cpp
cudnn_utils.cpp
transformer_engine.cpp
transformer_engine.cpp
fused_attn/fused_attn.cpp
gemm/config.cpp
normalization/common.cpp
normalization/layernorm/ln_api.cpp
normalization/rmsnorm/rmsnorm_api.cpp
util/cuda_driver.cpp
util/cuda_nvml.cpp
util/cuda_runtime.cpp
util/multi_stream.cpp
util/rtc.cpp
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/comm_gemm_overlap.cpp
)
list
(
APPEND transformer_engine_cuda_sources
common.cu
common.cu
multi_tensor/adam.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/compute_scale.cu
...
@@ -151,40 +214,23 @@ if(USE_CUDA)
...
@@ -151,40 +214,23 @@ if(USE_CUDA)
transpose/cast_transpose_fusion.cu
transpose/cast_transpose_fusion.cu
transpose/transpose_fusion.cu
transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu
transpose/multi_cast_transpose.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu
transpose/swap_first_dims.cu
transpose/swap_first_dims.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu
activation/gelu.cu
dropout/dropout.cu
dropout/dropout.cu
fused_attn/flash_attn.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
fused_attn/kv_cache.cu
fused_attn/fused_attn_f16_max512_seqlen.cu
fused_attn/fused_attn_f16_max512_seqlen.cu
fused_attn/fused_attn_f16_arbitrary_seqlen.cu
fused_attn/fused_attn_f16_arbitrary_seqlen.cu
activation/relu.cu
activation/swiglu.cu
fused_attn/fused_attn_fp8.cu
fused_attn/fused_attn_fp8.cu
fused_attn/fused_attn.cpp
fused_attn/utils.cu
fused_attn/utils.cu
gemm/config.cpp
gemm/cublaslt_gemm.cu
gemm/cublaslt_gemm.cu
gemm/cutlass_grouped_gemm.cu
normalization/common.cpp
normalization/layernorm/ln_api.cpp
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_api.cpp
normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
permutation/permutation.cu
permutation/permutation.cu
util/cast.cu
util/padding.cu
util/padding.cu
util/cuda_driver.cpp
util/cuda_nvml.cpp
util/cuda_runtime.cpp
util/multi_stream.cpp
util/rtc.cpp
swizzle/swizzle.cu
swizzle/swizzle.cu
swizzle/swizzle_block_scaling.cu
swizzle/swizzle_block_scaling.cu
fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_masked_softmax.cu
...
@@ -198,25 +244,84 @@ if(USE_CUDA)
...
@@ -198,25 +244,84 @@ if(USE_CUDA)
recipe/delayed_scaling.cu
recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu
recipe/fp8_block_scaling.cu
recipe/nvfp4.cu
recipe/nvfp4.cu
comm_gemm_overlap/userbuffers/userbuffers.cu
)
list
(
APPEND transformer_engine_cuda_arch_specific_sources
gemm/cutlass_grouped_gemm.cu
util/cast.cu
activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu
hadamard_transform/hadamard_transform.cu
hadamard_transform/hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu
hadamard_transform/hadamard_transform_cast_fusion.cu
)
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
# Compiling the files with the worst compilation time first to hopefully overlap
comm_gemm_overlap/userbuffers/userbuffers.cu
# better with the faster-compiling cpp files
comm_gemm_overlap/comm_gemm_overlap.cpp
)
list
(
APPEND transformer_engine_SOURCES
${
transformer_engine_cuda_arch_specific_sources
}
${
transformer_engine_cuda_sources
}
${
transformer_engine_cpp_sources
}
)
# Set compile options for CUDA sources with generic architectures
foreach
(
cuda_source IN LISTS transformer_engine_cuda_sources
)
set
(
arch_compile_options
)
foreach
(
arch IN LISTS NVTE_GENERIC_ARCHS
)
list
(
APPEND arch_compile_options
"--generate-code=arch=compute_
${
arch
}
,code=sm_
${
arch
}
"
)
endforeach
()
if
(
arch_compile_options
)
set_property
(
SOURCE
${
cuda_source
}
APPEND
PROPERTY
COMPILE_OPTIONS
${
arch_compile_options
}
)
endif
()
endforeach
()
# Set compile options for CUDA sources with specific architectures
foreach
(
cuda_source IN LISTS transformer_engine_cuda_arch_specific_sources
)
set
(
arch_compile_options
)
foreach
(
arch IN LISTS NVTE_SPECIFIC_ARCHS
)
list
(
APPEND arch_compile_options
"--generate-code=arch=compute_
${
arch
}
,code=sm_
${
arch
}
"
)
endforeach
()
if
(
arch_compile_options
)
set_property
(
SOURCE
${
cuda_source
}
APPEND
PROPERTY
COMPILE_OPTIONS
${
arch_compile_options
}
)
endif
()
endforeach
()
if
(
NVTE_WITH_CUBLASMP
)
if
(
NVTE_WITH_CUBLASMP
)
list
(
APPEND transformer_engine_SOURCES
list
(
APPEND transformer_engine_SOURCES
comm_gemm/comm_gemm.cpp
)
comm_gemm/comm_gemm.cpp
)
endif
()
endif
()
add_library
(
transformer_engine SHARED
${
transformer_engine_SOURCES
}
)
add_library
(
transformer_engine SHARED
${
transformer_engine_SOURCES
}
)
else
()
else
()
list
(
APPEND transformer_engine_
SOURCES
list
(
APPEND transformer_engine_
cpp_sources
cudnn_utils.cpp
cudnn_utils.cpp
transformer_engine.cpp
transformer_engine.cpp
gemm/config.cpp
normalization/common.cpp
normalization/layernorm/ln_api.cpp
normalization/rmsnorm/rmsnorm_api.cpp
util/cuda_driver.cpp
util/cuda_nvml.cpp
util/cuda_runtime.cpp
util/multi_stream.cpp
util/rtc.cpp
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/comm_gemm_overlap.cpp
)
list
(
APPEND transformer_engine_cuda_sources
common.cu
common.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
multi_tensor/adam.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu
multi_tensor/l2norm.cu
...
@@ -227,31 +332,23 @@ else()
...
@@ -227,31 +332,23 @@ else()
transpose/cast_transpose_fusion.cu
transpose/cast_transpose_fusion.cu
transpose/transpose_fusion.cu
transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu
transpose/multi_cast_transpose.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu
transpose/swap_first_dims.cu
transpose/swap_first_dims.cu
activation/gelu.cu
dropout/dropout.cu
dropout/dropout.cu
activation/relu.cu
fused_attn/flash_attn.cu
activation/swiglu.cu
fused_attn/context_parallel.cu
gemm/config.cpp
fused_attn/kv_cache.cu
fused_attn/fused_attn_f16_max512_seqlen.cu
fused_attn/fused_attn_f16_arbitrary_seqlen.cu
fused_attn/fused_attn_fp8.cu
fused_attn/utils.cu
gemm/cublaslt_gemm.cu
gemm/cublaslt_gemm.cu
gemm/hipblas_gemm.cu
normalization/common.cpp
normalization/layernorm/ln_api.cpp
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu
normalization/layernorm/ln_fwd_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_api.cpp
normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
permutation/permutation.cu
permutation/permutation.cu
util/cast.cu
util/padding.cu
util/padding.cu
util/cuda_driver.cpp
util/cuda_nvml.cpp
util/cuda_runtime.cpp
util/multi_stream.cpp
util/rtc.cpp
swizzle/swizzle.cu
swizzle/swizzle.cu
swizzle/swizzle_block_scaling.cu
swizzle/swizzle_block_scaling.cu
fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_masked_softmax.cu
...
@@ -264,10 +361,25 @@ else()
...
@@ -264,10 +361,25 @@ else()
recipe/current_scaling.cu
recipe/current_scaling.cu
recipe/delayed_scaling.cu
recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu
recipe/fp8_block_scaling.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc
recipe/nvfp4.cu
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
)
comm_gemm_overlap/userbuffers/userbuffers.cu
comm_gemm_overlap/comm_gemm_overlap.cpp
)
list
(
APPEND transformer_engine_cuda_arch_specific_sources
gemm/cutlass_grouped_gemm.cu
util/cast.cu
activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu
hadamard_transform/hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu
)
# Compiling the files with the worst compilation time first to hopefully overlap
# better with the faster-compiling cpp files
list
(
APPEND transformer_engine_SOURCES
${
transformer_engine_cuda_arch_specific_sources
}
${
transformer_engine_cuda_sources
}
${
transformer_engine_cpp_sources
}
)
if
(
NVTE_WITH_CUBLASMP
)
if
(
NVTE_WITH_CUBLASMP
)
list
(
APPEND transformer_engine_SOURCES
list
(
APPEND transformer_engine_SOURCES
comm_gemm/comm_gemm.cpp
)
comm_gemm/comm_gemm.cpp
)
...
@@ -316,8 +428,10 @@ if (USE_CUDA)
...
@@ -316,8 +428,10 @@ if (USE_CUDA)
CUDA::cublas
CUDA::cublas
CUDA::cudart
CUDA::cudart
CUDNN::cudnn_all
)
CUDNN::cudnn_all
)
target_include_directories
(
transformer_engine PRIVATE
target_include_directories
(
transformer_engine PRIVATE
${
CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES
}
)
${
CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES
}
)
target_include_directories
(
transformer_engine PRIVATE
${
MATHDX_INCLUDE_DIR
}
)
target_include_directories
(
transformer_engine SYSTEM PRIVATE
target_include_directories
(
transformer_engine SYSTEM PRIVATE
${
CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES
}
/cccl
)
${
CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES
}
/cccl
)
target_include_directories
(
transformer_engine PRIVATE
"
${
CUDNN_FRONTEND_INCLUDE_DIR
}
"
)
target_include_directories
(
transformer_engine PRIVATE
"
${
CUDNN_FRONTEND_INCLUDE_DIR
}
"
)
...
@@ -436,7 +550,8 @@ target_include_directories(transformer_engine PRIVATE
...
@@ -436,7 +550,8 @@ target_include_directories(transformer_engine PRIVATE
"
${
CMAKE_CURRENT_BINARY_DIR
}
/string_headers"
)
"
${
CMAKE_CURRENT_BINARY_DIR
}
/string_headers"
)
# Compiler options
# Compiler options
set_source_files_properties
(
fused_softmax/scaled_masked_softmax.cu
set
(
nvte_sources_with_fast_math
)
list
(
APPEND nvte_sources_with_fast_math fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
multi_tensor/adam.cu
multi_tensor/adam.cu
...
@@ -446,20 +561,25 @@ set_source_files_properties(fused_softmax/scaled_masked_softmax.cu
...
@@ -446,20 +561,25 @@ set_source_files_properties(fused_softmax/scaled_masked_softmax.cu
multi_tensor/sgd.cu
multi_tensor/sgd.cu
fused_attn/flash_attn.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
fused_attn/kv_cache.cu
)
PROPERTIES
COMPILE_OPTIONS
"--use_fast_math"
)
option
(
NVTE_BUILD_ACTIVATION_WITH_FAST_MATH
"Compile activation kernels with --use_fast_math option"
OFF
)
option
(
NVTE_BUILD_ACTIVATION_WITH_FAST_MATH
"Compile activation kernels with --use_fast_math option"
OFF
)
if
(
NVTE_BUILD_ACTIVATION_WITH_FAST_MATH
)
if
(
NVTE_BUILD_ACTIVATION_WITH_FAST_MATH
)
set_source_files_properties
(
activation/gelu.cu
list
(
APPEND nvte_sources_with_fast_math
activation/gelu.cu
activation/relu.cu
activation/relu.cu
activation/swiglu.cu
activation/swiglu.cu
util/cast.cu
util/cast.cu
)
PROPERTIES
COMPILE_OPTIONS
"--use_fast_math"
)
endif
()
endif
()
if
(
USE_CUDA
)
if
(
USE_CUDA
)
foreach
(
cuda_source IN LISTS nvte_sources_with_fast_math
)
set_property
(
SOURCE
${
cuda_source
}
APPEND
PROPERTY
COMPILE_OPTIONS
"--use_fast_math"
)
endforeach
()
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
--expt-relaxed-constexpr"
)
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
--expt-relaxed-constexpr"
)
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
-O3"
)
set
(
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
-O3"
)
else
()
else
()
...
...
transformer_engine/common/__init__.py
View file @
0a5016b1
...
@@ -8,22 +8,18 @@ import ctypes
...
@@ -8,22 +8,18 @@ import ctypes
import
functools
import
functools
import
glob
import
glob
import
importlib
import
importlib
from
importlib.metadata
import
version
,
metadata
,
PackageNotFoundError
from
importlib.metadata
import
version
,
distribution
,
PackageNotFoundError
import
logging
import
os
import
os
from
pathlib
import
Path
from
pathlib
import
Path
import
platform
import
platform
import
subprocess
import
subprocess
import
sys
import
sys
import
sysconfig
import
sysconfig
from
typing
import
Optional
from
typing
import
Optional
,
Tuple
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
_logger
=
logging
.
getLogger
(
__name__
)
@
functools
.
lru_cache
(
maxsize
=
None
)
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_is_
pip_
package_installed
(
package
)
->
bool
:
def
_is_package_installed
(
package
)
->
bool
:
"""Check if the given package is installed via pip."""
"""Check if the given package is installed via pip."""
# This is needed because we only want to return true
# This is needed because we only want to return true
...
@@ -31,12 +27,34 @@ def _is_pip_package_installed(package) -> bool:
...
@@ -31,12 +27,34 @@ def _is_pip_package_installed(package) -> bool:
# if it's importable in the current directory due to
# if it's importable in the current directory due to
# the presence of the shared library module.
# the presence of the shared library module.
try
:
try
:
metadata
(
package
)
distribution
(
package
)
except
PackageNotFoundError
:
except
PackageNotFoundError
:
return
False
return
False
return
True
return
True
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_is_package_installed_from_wheel
(
package
)
->
bool
:
"""Check if the given package is installed via PyPI."""
if
not
_is_package_installed
(
package
):
return
False
te_dist
=
distribution
(
package
)
te_wheel_file
=
""
for
file_path
in
te_dist
.
files
:
if
file_path
.
name
==
"WHEEL"
:
te_wheel_file
=
te_dist
.
locate_file
(
""
)
/
file_path
if
not
te_wheel_file
:
return
False
with
te_wheel_file
.
open
(
"r"
)
as
f
:
for
line
in
f
:
if
line
.
startswith
(
"Root-Is-Purelib:"
):
return
line
.
strip
().
split
(
":"
)[
1
].
strip
().
lower
()
==
"true"
return
False
@
functools
.
lru_cache
(
maxsize
=
None
)
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_find_shared_object_in_te_dir
(
te_path
:
Path
,
prefix
:
str
)
->
Optional
[
Path
]:
def
_find_shared_object_in_te_dir
(
te_path
:
Path
,
prefix
:
str
)
->
Optional
[
Path
]:
"""
"""
...
@@ -112,6 +130,19 @@ def _get_shared_object_file(library: str) -> Path:
...
@@ -112,6 +130,19 @@ def _get_shared_object_file(library: str) -> Path:
)
)
def
get_te_core_package_info
()
->
Tuple
[
bool
,
str
,
str
]:
"""
Check if Tranformer Engine core package is installed.
Returns the module name and version if found.
"""
te_core_packages
=
(
"transformer-engine-cu12"
,
"transformer-engine-cu13"
)
for
package
in
te_core_packages
:
if
_is_package_installed
(
package
):
return
True
,
package
,
version
(
package
)
return
False
,
""
,
""
@
functools
.
lru_cache
(
maxsize
=
None
)
@
functools
.
lru_cache
(
maxsize
=
None
)
def
load_framework_extension
(
framework
:
str
)
->
None
:
def
load_framework_extension
(
framework
:
str
)
->
None
:
"""
"""
...
@@ -130,37 +161,28 @@ def load_framework_extension(framework: str) -> None:
...
@@ -130,37 +161,28 @@ def load_framework_extension(framework: str) -> None:
if
framework
==
"torch"
:
if
framework
==
"torch"
:
extra_dep_name
=
"pytorch"
extra_dep_name
=
"pytorch"
# Find the TE packages. The core and framework packages can only be installed via PyPI.
# For the `transformer-engine` package, we need to check explicity.
te_core_installed
,
te_core_package_name
,
te_core_version
=
get_te_core_package_info
()
te_framework_installed
=
_is_package_installed
(
module_name
)
te_installed
=
_is_package_installed
(
"transformer_engine"
)
te_installed_via_pypi
=
_is_package_installed_from_wheel
(
"transformer_engine"
)
assert
te_installed
,
"Could not find `transformer_engine`."
# If the framework extension pip package is installed, it means that TE is installed via
# If the framework extension pip package is installed, it means that TE is installed via
# PyPI. For this case we need to make sure that the metapackage, the core lib, and framework
# PyPI. For this case we need to make sure that the metapackage, the core lib, and framework
# extension are all installed via PyPI and have matching version.
# extension are all installed via PyPI and have matching versions.
if
_is_pip_package_installed
(
module_name
):
if
te_framework_installed
:
assert
_is_pip_package_installed
(
assert
te_installed_via_pypi
,
"Could not find `transformer-engine` PyPI package."
"transformer_engine"
assert
te_core_installed
,
"Could not find TE core package `transformer-engine-cu*`."
),
"Could not find `transformer-engine`."
assert
_is_pip_package_installed
(
"transformer_engine_cu12"
),
"Could not find `transformer-engine-cu12`."
assert
(
version
(
module_name
)
==
version
(
"transformer-engine"
)
==
version
(
"transformer-engine-cu12"
)
),
(
"TransformerEngine package version mismatch. Found"
f
"
{
module_name
}
v
{
version
(
module_name
)
}
, transformer-engine"
f
" v
{
version
(
'transformer-engine'
)
}
, and transformer-engine-cu12"
f
" v
{
version
(
'transformer-engine-cu12'
)
}
. Install transformer-engine using "
f
"'pip3 install transformer-engine[
{
extra_dep_name
}
]==VERSION'"
)
# If the core package is installed via PyPI, log if
assert
version
(
module_name
)
==
version
(
"transformer-engine"
)
==
te_core_version
,
(
# the framework extension is not found from PyPI.
"Transformer Engine package version mismatch. Found"
# Note: Should we error? This is a rare use case.
f
"
{
module_name
}
v
{
version
(
module_name
)
}
, transformer-engine"
if
_is_pip_package_installed
(
"transformer-engine-cu12"
):
f
" v
{
version
(
'transformer-engine'
)
}
, and
{
te_core_package_name
}
"
if
not
_is_pip_package_installed
(
module_name
):
f
" v
{
te_core_version
}
. Install transformer-engine using "
_logger
.
info
(
f
"'pip3 install --no-build-isolation transformer-engine[
{
extra_dep_name
}
]==VERSION'"
"Could not find package %s. Install transformer-engine using "
f
"'pip3 install transformer-engine[
{
extra_dep_name
}
]==VERSION'"
,
module_name
,
)
)
# After all checks are completed, load the shared object file.
# After all checks are completed, load the shared object file.
...
@@ -170,6 +192,35 @@ def load_framework_extension(framework: str) -> None:
...
@@ -170,6 +192,35 @@ def load_framework_extension(framework: str) -> None:
spec
.
loader
.
exec_module
(
solib
)
spec
.
loader
.
exec_module
(
solib
)
def
sanity_checks_for_pypi_installation
()
->
None
:
"""Ensure that package is installed correctly if using PyPI."""
te_core_installed
,
te_core_package_name
,
te_core_version
=
get_te_core_package_info
()
te_installed
=
_is_package_installed
(
"transformer_engine"
)
te_installed_via_pypi
=
_is_package_installed_from_wheel
(
"transformer_engine"
)
assert
te_installed
,
"Could not find `transformer-engine`."
# If the core package is installed via PyPI.
if
te_core_installed
:
assert
te_installed_via_pypi
,
"Could not find `transformer-engine` PyPI package."
assert
version
(
"transformer-engine"
)
==
te_core_version
,
(
"Transformer Engine package version mismatch. Found "
f
"transformer-engine v
{
version
(
'transformer-engine'
)
}
"
f
"and
{
te_core_package_name
}
v
{
te_core_version
}
."
)
# Only the metapackage is found, invalid usecase.
elif
te_installed_via_pypi
:
raise
RuntimeError
(
"Found empty `transformer-engine` meta package installed. "
"Install `transformer-engine` with framework extensions via"
"'pip3 install --no-build-isolation transformer-engine[pytorch,jax]==VERSION'"
" or 'pip3 install transformer-engine[core]` for the TE core lib only. The `core_cu12`"
" or `core_cu13` extra deps can be used to specify CUDA version for the TE core lib."
)
@
functools
.
lru_cache
(
maxsize
=
None
)
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_get_sys_extension
()
->
str
:
def
_get_sys_extension
()
->
str
:
"""File extension for shared objects."""
"""File extension for shared objects."""
...
@@ -339,16 +390,14 @@ def _load_core_library():
...
@@ -339,16 +390,14 @@ def _load_core_library():
if
"NVTE_PROJECT_BUILDING"
not
in
os
.
environ
or
bool
(
int
(
os
.
getenv
(
"NVTE_RELEASE_BUILD"
,
"0"
))):
if
"NVTE_PROJECT_BUILDING"
not
in
os
.
environ
or
bool
(
int
(
os
.
getenv
(
"NVTE_RELEASE_BUILD"
,
"0"
))):
try
:
sanity_checks_for_pypi_installation
()
_CUDNN_LIB_CTYPES
=
_load_cudnn
()
_CUDNN_LIB_CTYPES
=
_load_cudnn
()
_NVRTC_LIB_CTYPES
=
_load_nvrtc
()
_NVRTC_LIB_CTYPES
=
_load_nvrtc
()
_CURAND_LIB_CTYPES
=
_load_curand
()
_CURAND_LIB_CTYPES
=
_load_curand
()
_CUBLAS_LIB_CTYPES
=
_load_nvidia_cuda_library
(
"cublas"
)
_CUBLAS_LIB_CTYPES
=
_load_nvidia_cuda_library
(
"cublas"
)
_CUDART_LIB_CTYPES
=
_load_nvidia_cuda_library
(
"cuda_runtime"
)
_CUDART_LIB_CTYPES
=
_load_nvidia_cuda_library
(
"cuda_runtime"
)
_TE_LIB_CTYPES
=
_load_core_library
()
# Needed to find the correct headers for NVRTC kernels.
# Needed to find the correct headers for NVRTC kernels.
if
not
os
.
getenv
(
"NVTE_CUDA_INCLUDE_DIR"
)
and
_nvidia_cudart_include_dir
():
if
not
os
.
getenv
(
"NVTE_CUDA_INCLUDE_DIR"
)
and
_nvidia_cudart_include_dir
():
os
.
environ
[
"NVTE_CUDA_INCLUDE_DIR"
]
=
_nvidia_cudart_include_dir
()
os
.
environ
[
"NVTE_CUDA_INCLUDE_DIR"
]
=
_nvidia_cudart_include_dir
()
except
OSError
:
pass
_TE_LIB_CTYPES
=
_load_core_library
()
transformer_engine/common/fused_attn/fused_attn.cpp
View file @
0a5016b1
...
@@ -138,7 +138,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
...
@@ -138,7 +138,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
float
dropout
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
float
dropout
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
int64_t
window_size_left
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
int64_t
window_size_left
,
int64_t
window_size_right
)
{
int64_t
window_size_right
,
bool
return_max_logit
)
{
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
NVTE_Fused_Attn_Backend
backend
=
NVTE_Fused_Attn_Backend
::
NVTE_No_Backend
;
NVTE_Fused_Attn_Backend
backend
=
NVTE_Fused_Attn_Backend
::
NVTE_No_Backend
;
const
int
device_id
=
cuda
::
current_device
();
const
int
device_id
=
cuda
::
current_device
();
...
@@ -187,7 +187,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
...
@@ -187,7 +187,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_BSHD
||
qkv_format
==
NVTE_QKV_Format
::
NVTE_SBHD
)
&&
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_BSHD
||
qkv_format
==
NVTE_QKV_Format
::
NVTE_SBHD
)
&&
!
requires_64bit_ragged_offset
&&
(
softmax_type
==
NVTE_Softmax_Type
::
NVTE_VANILLA_SOFTMAX
)
&&
!
requires_64bit_ragged_offset
&&
(
softmax_type
==
NVTE_Softmax_Type
::
NVTE_VANILLA_SOFTMAX
)
&&
// 9.10.0: known bugs with SDPA FP8
// 9.10.0: known bugs with SDPA FP8
(
cudnn_runtime_version
!=
91000
))
{
(
cudnn_runtime_version
!=
91000
)
&&
!
return_max_logit
)
{
if
(
cudnn_runtime_version
>=
8900
)
{
if
(
cudnn_runtime_version
>=
8900
)
{
backend
=
NVTE_Fused_Attn_Backend
::
NVTE_FP8
;
backend
=
NVTE_Fused_Attn_Backend
::
NVTE_FP8
;
}
else
{
}
else
{
...
@@ -216,7 +216,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
...
@@ -216,7 +216,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(
qkv_layout
==
NVTE_QKV_Layout
::
NVTE_BSHD_BSHD_BSHD
))
&&
(
qkv_layout
==
NVTE_QKV_Layout
::
NVTE_BSHD_BSHD_BSHD
))
&&
((
window_size_left
==
-
1
)
&&
(
window_size_right
==
-
1
||
window_size_right
==
0
))
&&
((
window_size_left
==
-
1
)
&&
(
window_size_right
==
-
1
||
window_size_right
==
0
))
&&
!
requires_64bit_ragged_offset
&&
!
requires_64bit_ragged_offset
&&
(
softmax_type
==
NVTE_Softmax_Type
::
NVTE_VANILLA_SOFTMAX
))
{
(
softmax_type
==
NVTE_Softmax_Type
::
NVTE_VANILLA_SOFTMAX
)
&&
!
return_max_logit
)
{
flag_m512
=
true
;
flag_m512
=
true
;
}
}
if
(
if
(
...
@@ -418,8 +418,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
...
@@ -418,8 +418,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
NVTETensor
O
,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens
,
const
NVTETensor
cu_seqlens_padded
,
const
NVTETensor
rng_state
,
const
NVTETensor
cu_seqlens_padded
,
const
NVTETensor
rng_state
,
size_t
max_seqlen
,
bool
is_training
,
float
attn_scale
,
size_t
max_seqlen
,
bool
is_training
,
bool
return_max_logit
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
int64_t
window_size_right
,
NVTETensor
workspace
,
...
@@ -460,7 +460,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
...
@@ -460,7 +460,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
is_training
,
QKV_type
,
QKV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
is_training
,
QKV_type
,
QKV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
h
,
h
,
max_seqlen
,
max_seqlen
,
d
,
d
,
window_size_left
,
window_size_right
);
h
,
h
,
max_seqlen
,
max_seqlen
,
d
,
d
,
window_size_left
,
window_size_right
,
return_max_logit
);
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
#if (CUDNN_VERSION >= 8901)
#if (CUDNN_VERSION >= 8901)
...
@@ -474,10 +474,10 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
...
@@ -474,10 +474,10 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
}
else
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_arbitrary_seqlen
)
{
}
else
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_arbitrary_seqlen
)
{
#if (CUDNN_VERSION >= 8900)
#if (CUDNN_VERSION >= 8900)
fused_attn_arbitrary_seqlen_fwd_qkvpacked
(
fused_attn_arbitrary_seqlen_fwd_qkvpacked
(
b
,
h
,
max_seqlen
,
d
,
t
,
is_training
,
attn_scale
,
dropout
,
qkv_layout
,
bias_type
,
b
,
h
,
max_seqlen
,
d
,
t
,
is_training
,
return_max_logit
,
attn_scale
,
dropout
,
qkv_layout
,
attn_mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
input_QKV
,
input_Bias
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
input_QKV
,
input_SoftmaxOffset
,
output_O
,
Aux_CTX_Tensors
,
input_cu_seqlens
,
input_cu_seqlens_padded
,
input_Bias
,
input_SoftmaxOffset
,
output_O
,
Aux_CTX_Tensors
,
input_cu_seqlens
,
input_rng_state
,
wkspace
,
stream
,
handle
);
input_cu_seqlens_padded
,
input_rng_state
,
wkspace
,
stream
,
handle
);
#else
#else
NVTE_ERROR
(
NVTE_ERROR
(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length.
\n
"
);
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length.
\n
"
);
...
@@ -544,7 +544,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
...
@@ -544,7 +544,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
true
,
QKV_type
,
QKV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
h
,
h
,
true
,
QKV_type
,
QKV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
h
,
h
,
max_seqlen
,
max_seqlen
,
d
,
d
,
window_size_left
,
window_size_right
);
max_seqlen
,
max_seqlen
,
d
,
d
,
window_size_left
,
window_size_right
,
false
);
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
#if (CUDNN_VERSION >= 8901)
#if (CUDNN_VERSION >= 8901)
...
@@ -602,7 +602,7 @@ void nvte_fused_attn_fwd_kvpacked(
...
@@ -602,7 +602,7 @@ void nvte_fused_attn_fwd_kvpacked(
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
page_table_k
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
page_table_k
,
const
NVTETensor
page_table_v
,
const
NVTETensor
rng_state
,
size_t
max_seqlen_q
,
const
NVTETensor
page_table_v
,
const
NVTETensor
rng_state
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
float
attn_scale
,
float
dropout
,
size_t
max_seqlen_kv
,
bool
is_training
,
bool
return_max_logit
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTETensor
workspace
,
cudaStream_t
stream
)
{
...
@@ -680,7 +680,8 @@ void nvte_fused_attn_fwd_kvpacked(
...
@@ -680,7 +680,8 @@ void nvte_fused_attn_fwd_kvpacked(
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
is_training
,
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
is_training
,
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d
,
d
,
window_size_left
,
window_size_right
);
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d
,
d
,
window_size_left
,
window_size_right
,
return_max_logit
);
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
#if (CUDNN_VERSION >= 8901)
#if (CUDNN_VERSION >= 8901)
...
@@ -695,12 +696,12 @@ void nvte_fused_attn_fwd_kvpacked(
...
@@ -695,12 +696,12 @@ void nvte_fused_attn_fwd_kvpacked(
#if (CUDNN_VERSION >= 8903)
#if (CUDNN_VERSION >= 8903)
fused_attn_arbitrary_seqlen_fwd_kvpacked
(
fused_attn_arbitrary_seqlen_fwd_kvpacked
(
b
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d
,
t_q
,
t_kv
,
num_pages_k
,
num_pages_v
,
b
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d
,
t_q
,
t_kv
,
num_pages_k
,
num_pages_v
,
page_size_k
,
page_size_v
,
max_pages_per_seq_k
,
max_pages_per_seq_v
,
is_training
,
attn_scale
,
page_size_k
,
page_size_v
,
max_pages_per_seq_k
,
max_pages_per_seq_v
,
is_training
,
dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size_left
,
return_max_logit
,
attn_scale
,
dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size_right
,
input_Q
,
input_KV
,
input_Bias
,
input_SoftmaxOffset
,
output_O
,
window_size_left
,
window_size_right
,
input_Q
,
input_KV
,
input_Bias
,
input_SoftmaxOffset
,
Aux_CTX_Tensors
,
input_cu_seqlens_q
,
input_cu_seqlens_kv
,
input_cu_seqlens_q_padded
,
output_O
,
Aux_CTX_Tensors
,
input_cu_seqlens_q
,
input_cu_seqlens_kv
,
input_cu_seqlens_
kv
_padded
,
input_
page_table_k
,
input_page_table_
v
,
input_rng_state
,
input_cu_seqlens_
q
_padded
,
input_
cu_seqlens_kv_padded
,
input_page_table_
k
,
wkspace
,
stream
,
handle
);
input_page_table_v
,
input_rng_state
,
wkspace
,
stream
,
handle
);
#else
#else
NVTE_ERROR
(
NVTE_ERROR
(
"cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length.
\n
"
);
"cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length.
\n
"
);
...
@@ -777,7 +778,7 @@ void nvte_fused_attn_bwd_kvpacked(
...
@@ -777,7 +778,7 @@ void nvte_fused_attn_bwd_kvpacked(
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
true
,
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
h_q
,
true
,
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d
,
d
,
window_size_left
,
window_size_right
);
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d
,
d
,
window_size_left
,
window_size_right
,
false
);
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
#if (CUDNN_VERSION >= 8901)
#if (CUDNN_VERSION >= 8901)
...
@@ -832,17 +833,15 @@ void nvte_fused_attn_bwd_kvpacked(
...
@@ -832,17 +833,15 @@ void nvte_fused_attn_bwd_kvpacked(
}
}
}
}
// NVTE fused attention FWD with separate Q, K and V
// NVTE fused attention FWD with separate Q, K and V
void
nvte_fused_attn_fwd
(
const
NVTETensor
Q
,
const
NVTETensor
K
,
const
NVTETensor
V
,
void
nvte_fused_attn_fwd
(
const
NVTETensor
Bias
,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
const
NVTETensor
Q
,
const
NVTETensor
K
,
const
NVTETensor
V
,
const
NVTETensor
Bias
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
page_table_k
,
const
NVTETensor
page_table_k
,
const
NVTETensor
page_table_v
,
const
NVTETensor
rng_state
,
const
NVTETensor
page_table_v
,
const
NVTETensor
rng_state
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
bool
return_max_logit
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_flash_attn_fwd
);
NVTE_API_CALL
(
nvte_flash_attn_fwd
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
...
@@ -913,7 +912,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
...
@@ -913,7 +912,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
is_training
,
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
is_training
,
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d_qk
,
d_v
,
window_size_left
,
window_size_right
);
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d_qk
,
d_v
,
window_size_left
,
window_size_right
,
return_max_logit
);
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
#if (CUDNN_VERSION >= 8901)
#if (CUDNN_VERSION >= 8901)
...
@@ -928,12 +928,12 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
...
@@ -928,12 +928,12 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
#if (CUDNN_VERSION >= 8900)
#if (CUDNN_VERSION >= 8900)
fused_attn_arbitrary_seqlen_fwd
(
fused_attn_arbitrary_seqlen_fwd
(
b
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d_qk
,
d_v
,
t_q
,
t_kv
,
num_pages_k
,
num_pages_v
,
b
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d_qk
,
d_v
,
t_q
,
t_kv
,
num_pages_k
,
num_pages_v
,
page_size_k
,
page_size_v
,
max_pages_per_seq_k
,
max_pages_per_seq_v
,
is_training
,
attn_scale
,
page_size_k
,
page_size_v
,
max_pages_per_seq_k
,
max_pages_per_seq_v
,
is_training
,
dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size_left
,
return_max_logit
,
attn_scale
,
dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size_right
,
input_Q
,
input_K
,
input_V
,
input_Bias
,
input_SoftmaxOffset
,
output_O
,
window_size_left
,
window_size_right
,
input_Q
,
input_K
,
input_V
,
input_Bias
,
Aux_CTX_Tensors
,
input_cu_seqlens_q
,
input_cu_seqlens_kv
,
input_cu_seqlens_q_padded
,
input_SoftmaxOffset
,
output_O
,
Aux_CTX_Tensors
,
input_cu_seqlens_q
,
input_cu_seqlens_kv
,
input_cu_seqlens_
kv
_padded
,
input_
page_table_k
,
input_page_table_
v
,
input_rng_state
,
input_cu_seqlens_
q
_padded
,
input_
cu_seqlens_kv_padded
,
input_page_table_
k
,
wkspace
,
stream
,
handle
);
input_page_table_v
,
input_rng_state
,
wkspace
,
stream
,
handle
);
#else
#else
NVTE_ERROR
(
NVTE_ERROR
(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length.
\n
"
);
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length.
\n
"
);
...
@@ -1008,7 +1008,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
...
@@ -1008,7 +1008,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
true
,
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
h_q
,
true
,
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d_qk
,
d_v
,
window_size_left
,
window_size_right
);
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d_qk
,
d_v
,
window_size_left
,
window_size_right
,
false
);
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
#if (CUDNN_VERSION >= 8901)
#if (CUDNN_VERSION >= 8901)
...
...
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu
View file @
0a5016b1
...
@@ -53,10 +53,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
...
@@ -53,10 +53,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
int64_t
max_b
,
int64_t
max_t_q
,
int64_t
max_t_kv
,
int64_t
num_pages_k
,
int64_t
num_pages_v
,
int64_t
max_b
,
int64_t
max_t_q
,
int64_t
max_t_kv
,
int64_t
num_pages_k
,
int64_t
num_pages_v
,
int64_t
page_size_k
,
int64_t
page_size_v
,
int64_t
max_pages_per_seq_k
,
int64_t
page_size_k
,
int64_t
page_size_v
,
int64_t
max_pages_per_seq_k
,
int64_t
max_pages_per_seq_v
,
int64_t
bias_b
,
int64_t
bias_h
,
bool
is_training
,
int64_t
max_pages_per_seq_v
,
int64_t
bias_b
,
int64_t
bias_h
,
bool
is_training
,
float
scaling_factor
,
float
dropout_probability
,
NVTE_QKV_Layout
layout
,
bool
return_max_logit
,
float
scaling_factor
,
float
dropout_probability
,
NVTE_QKV_Layout
layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
void
*
devPtrQ
,
void
*
devPtrK
,
int64_t
window_size_left
,
int64_t
window_size_right
,
void
*
devPtrQ
,
void
*
devPtrK
,
void
*
devPtrV
,
void
*
devPtrBias
,
void
*
devPtrSoftmaxOffset
,
void
*
devPtrS
oftmaxStats
,
void
*
devPtrV
,
void
*
devPtrBias
,
void
*
devPtrSoftmaxOffset
,
void
*
devPtrS
1
,
void
*
devPtrS2
,
void
*
devPtrO
,
void
*
devPtrDropoutSeed
,
void
*
devPtrDropoutOffset
,
void
*
devPtrCuSeqlensQ
,
void
*
devPtrO
,
void
*
devPtrDropoutSeed
,
void
*
devPtrDropoutOffset
,
void
*
devPtrCuSeqlensQ
,
void
*
devPtrCuSeqlensKV
,
void
*
devPtrPageTableK
,
void
*
devPtrPageTableV
,
void
*
devPtrCuSeqlensKV
,
void
*
devPtrPageTableK
,
void
*
devPtrPageTableV
,
void
*
devPtrSeqOffsetsQ
,
void
*
devPtrSeqOffsetsKV
,
cudnn_frontend
::
DataType_t
tensorType
,
void
*
devPtrSeqOffsetsQ
,
void
*
devPtrSeqOffsetsKV
,
cudnn_frontend
::
DataType_t
tensorType
,
...
@@ -102,8 +102,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
...
@@ -102,8 +102,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
}
}
const
DType
ragged_offset_type
=
cudnn_runtime_version
>=
90500
?
DType
::
kInt64
:
DType
::
kInt32
;
const
DType
ragged_offset_type
=
cudnn_runtime_version
>=
90500
?
DType
::
kInt64
:
DType
::
kInt32
;
bool
generate_stats
=
!
return_max_logit
;
try
{
try
{
FADescriptor_v1
descriptor
{
b
,
FADescriptor_v1
descriptor
{
b
,
h
,
h
,
hg
,
hg
,
s_q
,
s_q
,
...
@@ -131,7 +133,9 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
...
@@ -131,7 +133,9 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
tensorType
,
tensorType
,
cudnn_frontend
::
DataType_t
::
NOT_SET
,
cudnn_frontend
::
DataType_t
::
NOT_SET
,
cudnn_frontend
::
DataType_t
::
NOT_SET
,
cudnn_frontend
::
DataType_t
::
NOT_SET
,
cudnn_frontend
::
DataType_t
::
NOT_SET
};
cudnn_frontend
::
DataType_t
::
NOT_SET
,
return_max_logit
,
};
namespace
fe
=
cudnn_frontend
;
namespace
fe
=
cudnn_frontend
;
using
graph_and_tensors
=
using
graph_and_tensors
=
...
@@ -141,7 +145,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
...
@@ -141,7 +145,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
// V
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
// V
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
// attn_scale
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
// attn_scale
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
// O
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
// O
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
// Stats
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
// S1
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
// S2
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
// bias
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
// bias
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
// softmax_offset
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
// softmax_offset
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
// seq_q
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
// seq_q
...
@@ -244,6 +249,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
...
@@ -244,6 +249,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
sdpa_options
=
fe
::
graph
::
SDPA_attributes
()
sdpa_options
=
fe
::
graph
::
SDPA_attributes
()
.
set_name
(
"flash_attention"
)
.
set_name
(
"flash_attention"
)
.
set_is_inference
(
false
)
.
set_is_inference
(
false
)
.
set_generate_stats
(
generate_stats
)
.
set_causal_mask
(
is_causal
)
.
set_causal_mask
(
is_causal
)
.
set_causal_mask_bottom_right
(
is_bottom_right
)
.
set_causal_mask_bottom_right
(
is_bottom_right
)
.
set_attn_scale
(
attn_scale
);
.
set_attn_scale
(
attn_scale
);
...
@@ -317,7 +323,36 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
...
@@ -317,7 +323,36 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
sdpa_options
.
set_sink_token
(
softmax_offset
);
sdpa_options
.
set_sink_token
(
softmax_offset
);
}
}
auto
[
O
,
Stats
]
=
mha_graph
->
sdpa
(
Q
,
K
,
V
,
sdpa_options
);
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
Max
,
Sum_Exp
;
if
(
is_ragged_q
&&
cudnn_runtime_version
>=
90600
)
{
offset_stats
=
mha_graph
->
tensor
(
fe
::
graph
::
Tensor_attributes
()
.
set_name
(
"offset_stats"
)
.
set_dim
({
b
+
1
,
1
,
1
,
1
})
.
set_stride
({
1
,
1
,
1
,
1
})
.
set_data_type
(
get_cudnn_fe_dtype
(
ragged_offset_type
)));
}
if
(
return_max_logit
)
{
Max
=
mha_graph
->
tensor
(
fe
::
graph
::
Tensor_attributes
()
.
set_name
(
"Max"
)
.
set_dim
({
b
,
h
,
s_q
,
1
})
.
set_data_type
(
fe
::
DataType_t
::
FLOAT
));
Sum_Exp
=
mha_graph
->
tensor
(
fe
::
graph
::
Tensor_attributes
()
.
set_name
(
"Sum_Exp"
)
.
set_dim
({
b
,
h
,
s_q
,
1
})
.
set_data_type
(
fe
::
DataType_t
::
FLOAT
));
if
(
is_ragged_q
&&
cudnn_runtime_version
>=
90600
)
{
Max
->
set_stride
({
h
*
s_q
,
1
,
h
,
1
}).
set_ragged_offset
(
offset_stats
);
Sum_Exp
->
set_stride
({
h
*
s_q
,
1
,
h
,
1
}).
set_ragged_offset
(
offset_stats
);
}
else
{
Max
->
set_stride
({
h
*
s_q
,
s_q
,
1
,
1
});
Sum_Exp
->
set_stride
({
h
*
s_q
,
s_q
,
1
,
1
});
}
sdpa_options
.
set_logit_max
(
Max
);
sdpa_options
.
set_score_sum_exp
(
Sum_Exp
);
}
auto
[
O
,
Stats
]
=
mha_graph
->
sdpa
(
Q
,
K
,
V
,
std
::
move
(
sdpa_options
));
std
::
vector
<
int64_t
>
o_stride
(
4
);
std
::
vector
<
int64_t
>
o_stride
(
4
);
generateMatrixStrides
(
b
,
h
,
s_q
,
s_kv
,
d_v
,
o_stride
.
data
(),
layout
,
generateMatrixStrides
(
b
,
h
,
s_q
,
s_kv
,
d_v
,
o_stride
.
data
(),
layout
,
...
@@ -332,18 +367,14 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
...
@@ -332,18 +367,14 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
O
->
set_ragged_offset
(
offset_o
);
O
->
set_ragged_offset
(
offset_o
);
}
}
if
(
!
return_max_logit
)
{
Stats
->
set_output
(
true
).
set_data_type
(
fe
::
DataType_t
::
FLOAT
).
set_dim
({
b
,
h
,
s_q
,
1
});
Stats
->
set_output
(
true
).
set_data_type
(
fe
::
DataType_t
::
FLOAT
).
set_dim
({
b
,
h
,
s_q
,
1
});
if
(
is_ragged_q
&&
cudnn_runtime_version
>=
90600
)
{
if
(
is_ragged_q
&&
cudnn_runtime_version
>=
90600
)
{
offset_stats
=
mha_graph
->
tensor
(
fe
::
graph
::
Tensor_attributes
()
.
set_name
(
"offset_stats"
)
.
set_dim
({
b
+
1
,
1
,
1
,
1
})
.
set_stride
({
1
,
1
,
1
,
1
})
.
set_data_type
(
get_cudnn_fe_dtype
(
ragged_offset_type
)));
Stats
->
set_stride
({
h
*
s_q
,
1
,
h
,
1
}).
set_ragged_offset
(
offset_stats
);
Stats
->
set_stride
({
h
*
s_q
,
1
,
h
,
1
}).
set_ragged_offset
(
offset_stats
);
}
else
{
}
else
{
Stats
->
set_stride
({
h
*
s_q
,
s_q
,
1
,
1
});
Stats
->
set_stride
({
h
*
s_q
,
s_q
,
1
,
1
});
}
}
}
std
::
tuple
<
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
// Q
std
::
tuple
<
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
// Q
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
// K
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
// K
...
@@ -351,7 +382,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
...
@@ -351,7 +382,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
// attn_scale
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
// attn_scale
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>>
// O
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>>
// O
key_tensors_tuple
=
std
::
make_tuple
(
Q
,
K
,
V
,
attn_scale
,
O
);
key_tensors_tuple
=
std
::
make_tuple
(
Q
,
K
,
V
,
attn_scale
,
O
);
auto
Stats_tuple
=
std
::
make_tuple
(
Stats
);
auto
Stats_tuple
=
generate_stats
?
std
::
make_tuple
(
Stats
,
nullptr
)
:
std
::
make_tuple
(
Max
,
Sum_Exp
);
auto
bias_tuple
=
is_bias
?
std
::
make_tuple
(
bias
)
:
std
::
make_tuple
(
nullptr
);
auto
bias_tuple
=
is_bias
?
std
::
make_tuple
(
bias
)
:
std
::
make_tuple
(
nullptr
);
auto
softmax_offset_tuple
=
auto
softmax_offset_tuple
=
is_softmax_offset
?
std
::
make_tuple
(
softmax_offset
)
:
std
::
make_tuple
(
nullptr
);
is_softmax_offset
?
std
::
make_tuple
(
softmax_offset
)
:
std
::
make_tuple
(
nullptr
);
...
@@ -384,7 +416,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
...
@@ -384,7 +416,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
return
return_tuple
;
return
return_tuple
;
};
};
auto
[
mha_graph
,
Q
,
K
,
V
,
attn_scale
,
O
,
S
tats
,
bias
,
softmax_offset
,
seq_q
,
seq_kv
,
auto
[
mha_graph
,
Q
,
K
,
V
,
attn_scale
,
O
,
S
1
,
S2
,
bias
,
softmax_offset
,
seq_q
,
seq_kv
,
page_table_k
,
page_table_v
,
offset_q
,
offset_o
,
offset_k
,
offset_v
,
offset_stats
,
page_table_k
,
page_table_v
,
offset_q
,
offset_o
,
offset_k
,
offset_v
,
offset_stats
,
dropout_seed
,
dropout_offset
]
=
get_graph
(
sdpa_f16_fprop_cache
,
descriptor
);
dropout_seed
,
dropout_offset
]
=
get_graph
(
sdpa_f16_fprop_cache
,
descriptor
);
...
@@ -417,9 +449,12 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
...
@@ -417,9 +449,12 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
// Build variant pack
// Build variant pack
std
::
unordered_map
<
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
void
*>
variant_pack
=
{
std
::
unordered_map
<
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
void
*>
variant_pack
=
{
{
Q
,
devPtrQ
},
{
K
,
devPtrK
},
{
Q
,
devPtrQ
},
{
K
,
devPtrK
},
{
V
,
devPtrV
},
{
attn_scale
,
&
scaling_factor
},
{
V
,
devPtrV
},
{
attn_scale
,
&
scaling_factor
},
{
O
,
devPtrO
},
{
S1
,
devPtrS1
}};
{
O
,
devPtrO
},
{
Stats
,
devPtrSoftmaxStats
}};
if
(
return_max_logit
)
{
variant_pack
[
S2
]
=
devPtrS2
;
}
if
(
is_bias
)
{
if
(
is_bias
)
{
variant_pack
[
bias
]
=
devPtrBias
;
variant_pack
[
bias
]
=
devPtrBias
;
...
@@ -561,7 +596,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
...
@@ -561,7 +596,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
const
DType
ragged_offset_type
=
cudnn_runtime_version
>=
90500
?
DType
::
kInt64
:
DType
::
kInt32
;
const
DType
ragged_offset_type
=
cudnn_runtime_version
>=
90500
?
DType
::
kInt64
:
DType
::
kInt32
;
try
{
try
{
FADescriptor_v1
descriptor
{
b
,
FADescriptor_v1
descriptor
{
b
,
h
,
h
,
hg
,
hg
,
s_q
,
s_q
,
...
@@ -589,7 +625,9 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
...
@@ -589,7 +625,9 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
tensorType
,
tensorType
,
cudnn_frontend
::
DataType_t
::
NOT_SET
,
cudnn_frontend
::
DataType_t
::
NOT_SET
,
cudnn_frontend
::
DataType_t
::
NOT_SET
,
cudnn_frontend
::
DataType_t
::
NOT_SET
,
cudnn_frontend
::
DataType_t
::
NOT_SET
};
cudnn_frontend
::
DataType_t
::
NOT_SET
,
false
,
};
namespace
fe
=
cudnn_frontend
;
namespace
fe
=
cudnn_frontend
;
using
graph_and_tensors
=
using
graph_and_tensors
=
...
@@ -1001,12 +1039,13 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
...
@@ -1001,12 +1039,13 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
using
namespace
transformer_engine
::
fused_attn
;
using
namespace
transformer_engine
::
fused_attn
;
void
fused_attn_arbitrary_seqlen_fwd_qkvpacked
(
void
fused_attn_arbitrary_seqlen_fwd_qkvpacked
(
size_t
batch
,
size_t
num_attn_heads
,
size_t
max_seqlen
,
size_t
head_dim
,
size_t
num_tokens
,
size_t
batch
,
size_t
num_attn_heads
,
size_t
max_seqlen
,
size_t
head_dim
,
size_t
num_tokens
,
bool
is_training
,
float
attn_scale
,
float
p_dropout
,
NVTE_QKV_Layout
qkv_layout
,
bool
is_training
,
bool
return_max_logit
,
float
attn_scale
,
float
p_dropout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
const
Tensor
*
input_QKV
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
const
Tensor
*
input_Bias
,
const
Tensor
*
input_SoftmaxOffset
,
Tensor
*
output_O
,
const
Tensor
*
input_QKV
,
const
Tensor
*
input_Bias
,
const
Tensor
*
input_SoftmaxOffset
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
Tensor
*
cu_seqlens
,
const
Tensor
*
cu_seqlens_padded
,
Tensor
*
output_O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
Tensor
*
cu_seqlens
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
)
{
const
Tensor
*
cu_seqlens_padded
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
)
{
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
const
auto
QKV_type
=
input_QKV
->
data
.
dtype
;
const
auto
QKV_type
=
input_QKV
->
data
.
dtype
;
...
@@ -1037,7 +1076,8 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
...
@@ -1037,7 +1076,8 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
}
}
void
*
devPtrO
=
output_O
->
data
.
dptr
;
void
*
devPtrO
=
output_O
->
data
.
dptr
;
void
*
devPtrS
=
nullptr
;
void
*
devPtrS1
=
nullptr
;
void
*
devPtrS2
=
nullptr
;
void
*
devPtrCuSeqlens
=
cu_seqlens
->
data
.
dptr
;
void
*
devPtrCuSeqlens
=
cu_seqlens
->
data
.
dptr
;
void
*
devPtrSeqOffsets
=
cu_seqlens_padded
->
data
.
dptr
;
void
*
devPtrSeqOffsets
=
cu_seqlens_padded
->
data
.
dptr
;
...
@@ -1051,6 +1091,24 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
...
@@ -1051,6 +1091,24 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t
i
=
0
;
size_t
i
=
0
;
if
(
Aux_CTX_Tensors
->
size
==
0
)
{
if
(
Aux_CTX_Tensors
->
size
==
0
)
{
const
auto
cudnn_runtime_version
=
cudnnGetVersion
();
const
auto
cudnn_runtime_version
=
cudnnGetVersion
();
if
(
return_max_logit
)
{
Tensor
*
output_Max
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
output_Max
->
data
.
dptr
=
nullptr
;
if
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_THD
&&
cudnn_runtime_version
>=
90600
)
{
output_Max
->
data
.
shape
=
{
max_tokens
,
num_attn_heads
,
1
};
}
else
{
output_Max
->
data
.
shape
=
{
batch
,
num_attn_heads
,
max_seqlen
,
1
};
}
output_Max
->
data
.
dtype
=
DType
::
kFloat32
;
Tensor
*
output_Sum_Exp
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
output_Sum_Exp
->
data
.
dptr
=
nullptr
;
if
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_THD
&&
cudnn_runtime_version
>=
90600
)
{
output_Sum_Exp
->
data
.
shape
=
{
max_tokens
,
num_attn_heads
,
1
};
}
else
{
output_Sum_Exp
->
data
.
shape
=
{
batch
,
num_attn_heads
,
max_seqlen
,
1
};
}
output_Sum_Exp
->
data
.
dtype
=
DType
::
kFloat32
;
}
else
{
Tensor
*
output_S
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
Tensor
*
output_S
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
output_S
->
data
.
dptr
=
nullptr
;
output_S
->
data
.
dptr
=
nullptr
;
if
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_THD
&&
cudnn_runtime_version
>=
90600
)
{
if
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_THD
&&
cudnn_runtime_version
>=
90600
)
{
...
@@ -1059,6 +1117,8 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
...
@@ -1059,6 +1117,8 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
output_S
->
data
.
shape
=
{
batch
,
num_attn_heads
,
max_seqlen
,
1
};
output_S
->
data
.
shape
=
{
batch
,
num_attn_heads
,
max_seqlen
,
1
};
}
}
output_S
->
data
.
dtype
=
DType
::
kFloat32
;
output_S
->
data
.
dtype
=
DType
::
kFloat32
;
}
Tensor
*
output_rng_state
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
Tensor
*
output_rng_state
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
output_rng_state
->
data
.
dptr
=
nullptr
;
output_rng_state
->
data
.
dptr
=
nullptr
;
output_rng_state
->
data
.
shape
=
{
2
};
output_rng_state
->
data
.
shape
=
{
2
};
...
@@ -1080,8 +1140,15 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
...
@@ -1080,8 +1140,15 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
Aux_CTX_Tensors
->
size
=
i
;
Aux_CTX_Tensors
->
size
=
i
;
}
else
if
(
Aux_CTX_Tensors
->
size
>=
2
)
{
}
else
if
(
Aux_CTX_Tensors
->
size
>=
2
)
{
if
(
return_max_logit
)
{
Tensor
*
output_Max
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
devPtrS1
=
output_Max
->
data
.
dptr
;
Tensor
*
output_Sum_Exp
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
devPtrS2
=
output_Sum_Exp
->
data
.
dptr
;
}
else
{
Tensor
*
output_S
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
Tensor
*
output_S
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
devPtrS
=
output_S
->
data
.
dptr
;
devPtrS1
=
output_S
->
data
.
dptr
;
}
Tensor
*
output_rng_state
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
Tensor
*
output_rng_state
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
output_rng_state
->
data
.
dptr
=
rng_state
->
data
.
dptr
;
output_rng_state
->
data
.
dptr
=
rng_state
->
data
.
dptr
;
if
((
bias_type
!=
NVTE_NO_BIAS
)
&&
(
bias_type
!=
NVTE_ALIBI
))
{
if
((
bias_type
!=
NVTE_NO_BIAS
)
&&
(
bias_type
!=
NVTE_ALIBI
))
{
...
@@ -1105,11 +1172,11 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
...
@@ -1105,11 +1172,11 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
fused_attn_arbitrary_seqlen_fwd_impl
(
fused_attn_arbitrary_seqlen_fwd_impl
(
batch
,
num_attn_heads
,
num_attn_heads
,
max_seqlen
,
max_seqlen
,
head_dim
,
head_dim
,
batch
,
num_attn_heads
,
num_attn_heads
,
max_seqlen
,
max_seqlen
,
head_dim
,
head_dim
,
max_batch_size
,
max_tokens
,
max_tokens
,
0
,
0
,
0
,
0
,
0
,
0
,
bias_b
,
bias_h
,
is_training
,
max_batch_size
,
max_tokens
,
max_tokens
,
0
,
0
,
0
,
0
,
0
,
0
,
bias_b
,
bias_h
,
is_training
,
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
mask_type
,
softmax_type
,
window_size_left
,
return_max_logit
,
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
mask_type
,
softmax_type
,
window_size_right
,
devPtrQ
,
devPtrK
,
devPtrV
,
devPtrBias
,
devPtrSoftmaxOffset
,
devPtrS
,
window_size_left
,
window_size_right
,
devPtrQ
,
devPtrK
,
devPtrV
,
devPtrBias
,
devPtrO
,
devPtrDropoutSeed
,
devPtrDropoutOffset
,
devPtrCuSeqlens
,
devPtrCuSeqlens
,
nullptr
,
devPtr
SoftmaxOffset
,
devPtrS1
,
devPtrS2
,
devPtr
O
,
devPtrDropoutSeed
,
devPtrDropoutOffset
,
nullptr
,
devPtrSeqOffsets
,
devPtrSeqOffsets
,
get_cudnn_fe_dtype
(
QKV_type
),
devPtrCuSeqlens
,
devPtrCuSeqlens
,
nullptr
,
nullptr
,
devPtrSeqOffsets
,
devPtrSeqOffsets
,
workspace
->
data
.
dptr
,
&
workspace_size
,
stream
,
handle
);
get_cudnn_fe_dtype
(
QKV_type
),
workspace
->
data
.
dptr
,
&
workspace_size
,
stream
,
handle
);
if
(
workspace_size
>
0
)
{
if
(
workspace_size
>
0
)
{
if
(
workspace
->
data
.
dptr
==
nullptr
)
{
if
(
workspace
->
data
.
dptr
==
nullptr
)
{
...
@@ -1221,14 +1288,15 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
...
@@ -1221,14 +1288,15 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t
batch
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
batch
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim
,
size_t
num_tokens_q
,
size_t
num_tokens_kv
,
size_t
max_seqlen_kv
,
size_t
head_dim
,
size_t
num_tokens_q
,
size_t
num_tokens_kv
,
size_t
num_pages_k
,
size_t
num_pages_v
,
size_t
page_size_k
,
size_t
page_size_v
,
size_t
num_pages_k
,
size_t
num_pages_v
,
size_t
page_size_k
,
size_t
page_size_v
,
size_t
max_pages_per_seq_k
,
size_t
max_pages_per_seq_v
,
bool
is_training
,
float
attn_scale
,
size_t
max_pages_per_seq_k
,
size_t
max_pages_per_seq_v
,
bool
is_training
,
bool
return_max_logit
,
float
p_dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
float
attn_scale
,
float
p_dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
const
Tensor
*
input_Q
,
const
Tensor
*
input_KV
,
const
Tensor
*
input_Bias
,
int64_t
window_size_right
,
const
Tensor
*
input_Q
,
const
Tensor
*
input_KV
,
const
Tensor
*
input_SoftmaxOffset
,
Tensor
*
output_O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
Tensor
*
input_Bias
,
const
Tensor
*
input_SoftmaxOffset
,
Tensor
*
output_O
,
const
Tensor
*
cu_seqlens_q
,
const
Tensor
*
cu_seqlens_kv
,
const
Tensor
*
cu_seqlens_q_padded
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
Tensor
*
cu_seqlens_q
,
const
Tensor
*
cu_seqlens_kv
,
const
Tensor
*
cu_seqlens_kv_padded
,
const
Tensor
*
page_table_k
,
const
Tensor
*
page_table_v
,
const
Tensor
*
cu_seqlens_q_padded
,
const
Tensor
*
cu_seqlens_kv_padded
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
)
{
const
Tensor
*
page_table_k
,
const
Tensor
*
page_table_v
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
)
{
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
const
auto
QKV_type
=
input_Q
->
data
.
dtype
;
const
auto
QKV_type
=
input_Q
->
data
.
dtype
;
...
@@ -1260,7 +1328,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
...
@@ -1260,7 +1328,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
}
}
void
*
devPtrO
=
output_O
->
data
.
dptr
;
void
*
devPtrO
=
output_O
->
data
.
dptr
;
void
*
devPtrS
=
nullptr
;
void
*
devPtrS1
=
nullptr
;
void
*
devPtrS2
=
nullptr
;
void
*
devPtrCuSeqlensQ
=
cu_seqlens_q
->
data
.
dptr
;
void
*
devPtrCuSeqlensQ
=
cu_seqlens_q
->
data
.
dptr
;
void
*
devPtrCuSeqlensKV
=
cu_seqlens_kv
->
data
.
dptr
;
void
*
devPtrCuSeqlensKV
=
cu_seqlens_kv
->
data
.
dptr
;
...
@@ -1285,6 +1354,24 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
...
@@ -1285,6 +1354,24 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t
i
=
0
;
size_t
i
=
0
;
if
(
Aux_CTX_Tensors
->
size
==
0
)
{
if
(
Aux_CTX_Tensors
->
size
==
0
)
{
const
auto
cudnn_runtime_version
=
cudnnGetVersion
();
const
auto
cudnn_runtime_version
=
cudnnGetVersion
();
if
(
return_max_logit
)
{
Tensor
*
output_Max
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
output_Max
->
data
.
dptr
=
nullptr
;
if
(
q_format
==
NVTE_QKV_Format
::
NVTE_THD
&&
cudnn_runtime_version
>=
90600
)
{
output_Max
->
data
.
shape
=
{
max_tokens_q
,
num_attn_heads
,
1
};
}
else
{
output_Max
->
data
.
shape
=
{
batch
,
num_attn_heads
,
max_seqlen_q
,
1
};
}
output_Max
->
data
.
dtype
=
DType
::
kFloat32
;
Tensor
*
output_Sum_Exp
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
output_Sum_Exp
->
data
.
dptr
=
nullptr
;
if
(
q_format
==
NVTE_QKV_Format
::
NVTE_THD
&&
cudnn_runtime_version
>=
90600
)
{
output_Sum_Exp
->
data
.
shape
=
{
max_tokens_q
,
num_attn_heads
,
1
};
}
else
{
output_Sum_Exp
->
data
.
shape
=
{
batch
,
num_attn_heads
,
max_seqlen_q
,
1
};
}
output_Sum_Exp
->
data
.
dtype
=
DType
::
kFloat32
;
}
else
{
Tensor
*
output_S
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
Tensor
*
output_S
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
output_S
->
data
.
dptr
=
nullptr
;
output_S
->
data
.
dptr
=
nullptr
;
if
(
q_format
==
NVTE_QKV_Format
::
NVTE_THD
&&
cudnn_runtime_version
>=
90600
)
{
if
(
q_format
==
NVTE_QKV_Format
::
NVTE_THD
&&
cudnn_runtime_version
>=
90600
)
{
...
@@ -1293,6 +1380,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
...
@@ -1293,6 +1380,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
output_S
->
data
.
shape
=
{
batch
,
num_attn_heads
,
max_seqlen_q
,
1
};
output_S
->
data
.
shape
=
{
batch
,
num_attn_heads
,
max_seqlen_q
,
1
};
}
}
output_S
->
data
.
dtype
=
DType
::
kFloat32
;
output_S
->
data
.
dtype
=
DType
::
kFloat32
;
}
Tensor
*
output_rng_state
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
Tensor
*
output_rng_state
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
output_rng_state
->
data
.
dptr
=
nullptr
;
output_rng_state
->
data
.
dptr
=
nullptr
;
output_rng_state
->
data
.
shape
=
{
2
};
output_rng_state
->
data
.
shape
=
{
2
};
...
@@ -1314,8 +1403,15 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
...
@@ -1314,8 +1403,15 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
Aux_CTX_Tensors
->
size
=
i
;
Aux_CTX_Tensors
->
size
=
i
;
}
else
if
(
Aux_CTX_Tensors
->
size
>=
2
)
{
}
else
if
(
Aux_CTX_Tensors
->
size
>=
2
)
{
if
(
return_max_logit
)
{
Tensor
*
output_Max
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
devPtrS1
=
output_Max
->
data
.
dptr
;
Tensor
*
output_Sum_Exp
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
devPtrS2
=
output_Sum_Exp
->
data
.
dptr
;
}
else
{
Tensor
*
output_S
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
Tensor
*
output_S
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
devPtrS
=
output_S
->
data
.
dptr
;
devPtrS1
=
output_S
->
data
.
dptr
;
}
Tensor
*
output_rng_state
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
Tensor
*
output_rng_state
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
output_rng_state
->
data
.
dptr
=
rng_state
->
data
.
dptr
;
output_rng_state
->
data
.
dptr
=
rng_state
->
data
.
dptr
;
if
((
bias_type
!=
NVTE_NO_BIAS
)
&&
(
bias_type
!=
NVTE_ALIBI
))
{
if
((
bias_type
!=
NVTE_NO_BIAS
)
&&
(
bias_type
!=
NVTE_ALIBI
))
{
...
@@ -1340,11 +1436,12 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
...
@@ -1340,11 +1436,12 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
batch
,
num_attn_heads
,
num_gqa_groups
,
max_seqlen_q
,
max_seqlen_kv
,
head_dim
,
head_dim
,
batch
,
num_attn_heads
,
num_gqa_groups
,
max_seqlen_q
,
max_seqlen_kv
,
head_dim
,
head_dim
,
max_batch_size
,
max_tokens_q
,
max_tokens_kv
,
num_pages_k
,
num_pages_v
,
page_size_k
,
max_batch_size
,
max_tokens_q
,
max_tokens_kv
,
num_pages_k
,
num_pages_v
,
page_size_k
,
page_size_v
,
max_pages_per_seq_k
,
max_pages_per_seq_v
,
bias_b
,
bias_h
,
is_training
,
page_size_v
,
max_pages_per_seq_k
,
max_pages_per_seq_v
,
bias_b
,
bias_h
,
is_training
,
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
mask_type
,
softmax_type
,
window_size_left
,
return_max_logit
,
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
mask_type
,
softmax_type
,
window_size_right
,
devPtrQ
,
devPtrK
,
devPtrV
,
devPtrBias
,
devPtrSoftmaxOffset
,
devPtrS
,
window_size_left
,
window_size_right
,
devPtrQ
,
devPtrK
,
devPtrV
,
devPtrBias
,
devPtrO
,
devPtrDropoutSeed
,
devPtrDropoutOffset
,
devPtrCuSeqlensQ
,
devPtrCuSeqlensKV
,
devPtrSoftmaxOffset
,
devPtrS1
,
devPtrS2
,
devPtrO
,
devPtrDropoutSeed
,
devPtrDropoutOffset
,
devPtrPageTableK
,
devPtrPageTableV
,
devPtrSeqOffsetsQ
,
devPtrSeqOffsetsKV
,
devPtrCuSeqlensQ
,
devPtrCuSeqlensKV
,
devPtrPageTableK
,
devPtrPageTableV
,
devPtrSeqOffsetsQ
,
get_cudnn_fe_dtype
(
QKV_type
),
workspace
->
data
.
dptr
,
&
workspace_size
,
stream
,
handle
);
devPtrSeqOffsetsKV
,
get_cudnn_fe_dtype
(
QKV_type
),
workspace
->
data
.
dptr
,
&
workspace_size
,
stream
,
handle
);
if
(
workspace_size
>
0
)
{
if
(
workspace_size
>
0
)
{
if
(
workspace
->
data
.
dptr
==
nullptr
)
{
if
(
workspace
->
data
.
dptr
==
nullptr
)
{
...
@@ -1471,14 +1568,14 @@ void fused_attn_arbitrary_seqlen_fwd(
...
@@ -1471,14 +1568,14 @@ void fused_attn_arbitrary_seqlen_fwd(
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
size_t
num_tokens_q
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
size_t
num_tokens_q
,
size_t
num_tokens_kv
,
size_t
num_pages_k
,
size_t
num_pages_v
,
size_t
page_size_k
,
size_t
num_tokens_kv
,
size_t
num_pages_k
,
size_t
num_pages_v
,
size_t
page_size_k
,
size_t
page_size_v
,
size_t
max_pages_per_seq_k
,
size_t
max_pages_per_seq_v
,
bool
is_training
,
size_t
page_size_v
,
size_t
max_pages_per_seq_k
,
size_t
max_pages_per_seq_v
,
bool
is_training
,
float
attn_scale
,
float
p_dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
bool
return_max_logit
,
float
attn_scale
,
float
p_dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_
right
,
const
Tensor
*
input_Q
,
const
Tensor
*
input_K
,
const
Tensor
*
input_
V
,
int64_t
window_size_
left
,
int64_t
window_size_right
,
const
Tensor
*
input_
Q
,
const
Tensor
*
input_
Bias
,
const
Tensor
*
input_
SoftmaxOffset
,
Tensor
*
out
put_
O
,
const
Tensor
*
input_
K
,
const
Tensor
*
input_
V
,
const
Tensor
*
in
put_
Bias
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
Tensor
*
cu_seqlens_q
,
const
Tensor
*
cu_seqlens_kv
,
const
Tensor
*
input_SoftmaxOffset
,
Tensor
*
output_O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
Tensor
*
cu_seqlens_q
_padded
,
const
Tensor
*
cu_seqlens_kv_padded
,
const
Tensor
*
cu_seqlens_q
,
const
Tensor
*
cu_seqlens_kv
,
const
Tensor
*
cu_seqlens_q
_padded
,
const
Tensor
*
page_table_k
,
const
Tensor
*
page_table_v
,
const
Tensor
*
rng_state
,
const
Tensor
*
cu_seqlens_kv_padded
,
const
Tensor
*
page_table_k
,
const
Tensor
*
page_table_v
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
)
{
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
)
{
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
const
auto
QKV_type
=
input_Q
->
data
.
dtype
;
const
auto
QKV_type
=
input_Q
->
data
.
dtype
;
...
@@ -1488,7 +1585,8 @@ void fused_attn_arbitrary_seqlen_fwd(
...
@@ -1488,7 +1585,8 @@ void fused_attn_arbitrary_seqlen_fwd(
void
*
devPtrK
=
input_K
->
data
.
dptr
;
void
*
devPtrK
=
input_K
->
data
.
dptr
;
void
*
devPtrV
=
input_V
->
data
.
dptr
;
void
*
devPtrV
=
input_V
->
data
.
dptr
;
void
*
devPtrO
=
output_O
->
data
.
dptr
;
void
*
devPtrO
=
output_O
->
data
.
dptr
;
void
*
devPtrS
=
nullptr
;
void
*
devPtrS1
=
nullptr
;
void
*
devPtrS2
=
nullptr
;
void
*
devPtrBias
=
nullptr
;
void
*
devPtrBias
=
nullptr
;
size_t
bias_b
=
0
;
size_t
bias_b
=
0
;
size_t
bias_h
=
0
;
size_t
bias_h
=
0
;
...
@@ -1525,6 +1623,24 @@ void fused_attn_arbitrary_seqlen_fwd(
...
@@ -1525,6 +1623,24 @@ void fused_attn_arbitrary_seqlen_fwd(
size_t
i
=
0
;
size_t
i
=
0
;
if
(
Aux_CTX_Tensors
->
size
==
0
)
{
if
(
Aux_CTX_Tensors
->
size
==
0
)
{
const
auto
cudnn_runtime_version
=
cudnnGetVersion
();
const
auto
cudnn_runtime_version
=
cudnnGetVersion
();
if
(
return_max_logit
)
{
Tensor
*
output_Max
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
output_Max
->
data
.
dptr
=
nullptr
;
if
(
q_format
==
NVTE_QKV_Format
::
NVTE_THD
&&
cudnn_runtime_version
>=
90600
)
{
output_Max
->
data
.
shape
=
{
max_tokens_q
,
num_attn_heads
,
1
};
}
else
{
output_Max
->
data
.
shape
=
{
batch
,
num_attn_heads
,
max_seqlen_q
,
1
};
}
output_Max
->
data
.
dtype
=
DType
::
kFloat32
;
Tensor
*
output_Sum_Exp
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
output_Sum_Exp
->
data
.
dptr
=
nullptr
;
if
(
q_format
==
NVTE_QKV_Format
::
NVTE_THD
&&
cudnn_runtime_version
>=
90600
)
{
output_Sum_Exp
->
data
.
shape
=
{
max_tokens_q
,
num_attn_heads
,
1
};
}
else
{
output_Sum_Exp
->
data
.
shape
=
{
batch
,
num_attn_heads
,
max_seqlen_q
,
1
};
}
output_Sum_Exp
->
data
.
dtype
=
DType
::
kFloat32
;
}
else
{
Tensor
*
output_S
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
Tensor
*
output_S
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
output_S
->
data
.
dptr
=
nullptr
;
output_S
->
data
.
dptr
=
nullptr
;
if
(
q_format
==
NVTE_QKV_Format
::
NVTE_THD
&&
cudnn_runtime_version
>=
90600
)
{
if
(
q_format
==
NVTE_QKV_Format
::
NVTE_THD
&&
cudnn_runtime_version
>=
90600
)
{
...
@@ -1533,6 +1649,8 @@ void fused_attn_arbitrary_seqlen_fwd(
...
@@ -1533,6 +1649,8 @@ void fused_attn_arbitrary_seqlen_fwd(
output_S
->
data
.
shape
=
{
batch
,
num_attn_heads
,
max_seqlen_q
,
1
};
output_S
->
data
.
shape
=
{
batch
,
num_attn_heads
,
max_seqlen_q
,
1
};
}
}
output_S
->
data
.
dtype
=
DType
::
kFloat32
;
output_S
->
data
.
dtype
=
DType
::
kFloat32
;
}
Tensor
*
output_rng_state
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
Tensor
*
output_rng_state
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
output_rng_state
->
data
.
dptr
=
nullptr
;
output_rng_state
->
data
.
dptr
=
nullptr
;
output_rng_state
->
data
.
shape
=
{
2
};
output_rng_state
->
data
.
shape
=
{
2
};
...
@@ -1554,8 +1672,15 @@ void fused_attn_arbitrary_seqlen_fwd(
...
@@ -1554,8 +1672,15 @@ void fused_attn_arbitrary_seqlen_fwd(
Aux_CTX_Tensors
->
size
=
i
;
Aux_CTX_Tensors
->
size
=
i
;
}
else
if
(
Aux_CTX_Tensors
->
size
>=
2
)
{
}
else
if
(
Aux_CTX_Tensors
->
size
>=
2
)
{
if
(
return_max_logit
)
{
Tensor
*
output_Max
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
devPtrS1
=
output_Max
->
data
.
dptr
;
Tensor
*
output_Sum_Exp
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
devPtrS2
=
output_Sum_Exp
->
data
.
dptr
;
}
else
{
Tensor
*
output_S
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
Tensor
*
output_S
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
devPtrS
=
output_S
->
data
.
dptr
;
devPtrS1
=
output_S
->
data
.
dptr
;
}
Tensor
*
output_rng_state
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
Tensor
*
output_rng_state
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
output_rng_state
->
data
.
dptr
=
rng_state
->
data
.
dptr
;
output_rng_state
->
data
.
dptr
=
rng_state
->
data
.
dptr
;
if
((
bias_type
!=
NVTE_NO_BIAS
)
&&
(
bias_type
!=
NVTE_ALIBI
))
{
if
((
bias_type
!=
NVTE_NO_BIAS
)
&&
(
bias_type
!=
NVTE_ALIBI
))
{
...
@@ -1580,11 +1705,12 @@ void fused_attn_arbitrary_seqlen_fwd(
...
@@ -1580,11 +1705,12 @@ void fused_attn_arbitrary_seqlen_fwd(
batch
,
num_attn_heads
,
num_gqa_groups
,
max_seqlen_q
,
max_seqlen_kv
,
head_dim_qk
,
head_dim_v
,
batch
,
num_attn_heads
,
num_gqa_groups
,
max_seqlen_q
,
max_seqlen_kv
,
head_dim_qk
,
head_dim_v
,
max_batch_size
,
max_tokens_q
,
max_tokens_kv
,
num_pages_k
,
num_pages_v
,
page_size_k
,
max_batch_size
,
max_tokens_q
,
max_tokens_kv
,
num_pages_k
,
num_pages_v
,
page_size_k
,
page_size_v
,
max_pages_per_seq_k
,
max_pages_per_seq_v
,
bias_b
,
bias_h
,
is_training
,
page_size_v
,
max_pages_per_seq_k
,
max_pages_per_seq_v
,
bias_b
,
bias_h
,
is_training
,
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
mask_type
,
softmax_type
,
window_size_left
,
return_max_logit
,
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
mask_type
,
softmax_type
,
window_size_right
,
devPtrQ
,
devPtrK
,
devPtrV
,
devPtrBias
,
devPtrSoftmaxOffset
,
devPtrS
,
window_size_left
,
window_size_right
,
devPtrQ
,
devPtrK
,
devPtrV
,
devPtrBias
,
devPtrO
,
devPtrDropoutSeed
,
devPtrDropoutOffset
,
devPtrCuSeqlensQ
,
devPtrCuSeqlensKV
,
devPtrSoftmaxOffset
,
devPtrS1
,
devPtrS2
,
devPtrO
,
devPtrDropoutSeed
,
devPtrDropoutOffset
,
devPtrPageTableK
,
devPtrPageTableV
,
devPtrSeqOffsetsQ
,
devPtrSeqOffsetsKV
,
devPtrCuSeqlensQ
,
devPtrCuSeqlensKV
,
devPtrPageTableK
,
devPtrPageTableV
,
devPtrSeqOffsetsQ
,
get_cudnn_fe_dtype
(
QKV_type
),
workspace
->
data
.
dptr
,
&
workspace_size
,
stream
,
handle
);
devPtrSeqOffsetsKV
,
get_cudnn_fe_dtype
(
QKV_type
),
workspace
->
data
.
dptr
,
&
workspace_size
,
stream
,
handle
);
if
(
workspace_size
>
0
)
{
if
(
workspace_size
>
0
)
{
if
(
workspace
->
data
.
dptr
==
nullptr
)
{
if
(
workspace
->
data
.
dptr
==
nullptr
)
{
...
...
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h
View file @
0a5016b1
...
@@ -20,12 +20,13 @@ namespace transformer_engine {
...
@@ -20,12 +20,13 @@ namespace transformer_engine {
#if (CUDNN_VERSION >= 8900)
#if (CUDNN_VERSION >= 8900)
void
fused_attn_arbitrary_seqlen_fwd_qkvpacked
(
void
fused_attn_arbitrary_seqlen_fwd_qkvpacked
(
size_t
batch
,
size_t
num_attn_heads
,
size_t
max_seqlen
,
size_t
head_dim
,
size_t
num_tokens
,
size_t
batch
,
size_t
num_attn_heads
,
size_t
max_seqlen
,
size_t
head_dim
,
size_t
num_tokens
,
bool
is_training
,
float
attn_scale
,
float
p_dropout
,
NVTE_QKV_Layout
qkv_layout
,
bool
is_training
,
bool
return_max_logit
,
float
attn_scale
,
float
p_dropout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
const
Tensor
*
input_QKV
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
const
Tensor
*
input_Bias
,
const
Tensor
*
input_SoftmaxOffset
,
Tensor
*
output_O
,
const
Tensor
*
input_QKV
,
const
Tensor
*
input_Bias
,
const
Tensor
*
input_SoftmaxOffset
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
Tensor
*
cu_seqlens
,
const
Tensor
*
cu_seqlens_padded
,
Tensor
*
output_O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
Tensor
*
cu_seqlens
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
);
const
Tensor
*
cu_seqlens_padded
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
);
void
fused_attn_arbitrary_seqlen_bwd_qkvpacked
(
void
fused_attn_arbitrary_seqlen_bwd_qkvpacked
(
size_t
batch
,
size_t
num_attn_heads
,
size_t
max_seqlen
,
size_t
head_dim
,
size_t
num_tokens
,
size_t
batch
,
size_t
num_attn_heads
,
size_t
max_seqlen
,
size_t
head_dim
,
size_t
num_tokens
,
...
@@ -41,14 +42,15 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
...
@@ -41,14 +42,15 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t
batch
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
batch
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim
,
size_t
num_tokens_q
,
size_t
num_tokens_kv
,
size_t
max_seqlen_kv
,
size_t
head_dim
,
size_t
num_tokens_q
,
size_t
num_tokens_kv
,
size_t
num_pages_k
,
size_t
num_pages_v
,
size_t
page_size_k
,
size_t
page_size_v
,
size_t
num_pages_k
,
size_t
num_pages_v
,
size_t
page_size_k
,
size_t
page_size_v
,
size_t
max_pages_per_seq_k
,
size_t
max_pages_per_seq_v
,
bool
is_training
,
float
attn_scale
,
size_t
max_pages_per_seq_k
,
size_t
max_pages_per_seq_v
,
bool
is_training
,
bool
return_max_logit
,
float
p_dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
float
attn_scale
,
float
p_dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
const
Tensor
*
input_Q
,
const
Tensor
*
input_KV
,
const
Tensor
*
input_Bias
,
int64_t
window_size_right
,
const
Tensor
*
input_Q
,
const
Tensor
*
input_KV
,
const
Tensor
*
input_SoftmaxOffset
,
Tensor
*
output_O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
Tensor
*
input_Bias
,
const
Tensor
*
input_SoftmaxOffset
,
Tensor
*
output_O
,
const
Tensor
*
cu_seqlens_q
,
const
Tensor
*
cu_seqlens_kv
,
const
Tensor
*
cu_seqlens_q_padded
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
Tensor
*
cu_seqlens_q
,
const
Tensor
*
cu_seqlens_kv
,
const
Tensor
*
cu_seqlens_kv_padded
,
const
Tensor
*
page_table_k
,
const
Tensor
*
page_table_v
,
const
Tensor
*
cu_seqlens_q_padded
,
const
Tensor
*
cu_seqlens_kv_padded
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
);
const
Tensor
*
page_table_k
,
const
Tensor
*
page_table_v
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
);
void
fused_attn_arbitrary_seqlen_bwd_kvpacked
(
void
fused_attn_arbitrary_seqlen_bwd_kvpacked
(
size_t
batch
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
batch
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
...
@@ -68,14 +70,14 @@ void fused_attn_arbitrary_seqlen_fwd(
...
@@ -68,14 +70,14 @@ void fused_attn_arbitrary_seqlen_fwd(
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
size_t
num_tokens_q
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
size_t
num_tokens_q
,
size_t
num_tokens_kv
,
size_t
num_pages_k
,
size_t
num_pages_v
,
size_t
page_size_k
,
size_t
num_tokens_kv
,
size_t
num_pages_k
,
size_t
num_pages_v
,
size_t
page_size_k
,
size_t
page_size_v
,
size_t
max_pages_per_seq_k
,
size_t
max_pages_per_seq_v
,
bool
is_training
,
size_t
page_size_v
,
size_t
max_pages_per_seq_k
,
size_t
max_pages_per_seq_v
,
bool
is_training
,
float
attn_scale
,
float
p_dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
bool
return_max_logit
,
float
attn_scale
,
float
p_dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_
right
,
const
Tensor
*
input_Q
,
const
Tensor
*
input_K
,
const
Tensor
*
input_
V
,
int64_t
window_size_
left
,
int64_t
window_size_right
,
const
Tensor
*
input_
Q
,
const
Tensor
*
input_
Bias
,
const
Tensor
*
input_
SoftmaxOffset
,
Tensor
*
out
put_
O
,
const
Tensor
*
input_
K
,
const
Tensor
*
input_
V
,
const
Tensor
*
in
put_
Bias
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
Tensor
*
cu_seqlens_q
,
const
Tensor
*
cu_seqlens_kv
,
const
Tensor
*
input_SoftmaxOffset
,
Tensor
*
output_O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
Tensor
*
cu_seqlens_q
_padded
,
const
Tensor
*
cu_seqlens_kv_padded
,
const
Tensor
*
cu_seqlens_q
,
const
Tensor
*
cu_seqlens_kv
,
const
Tensor
*
cu_seqlens_q
_padded
,
const
Tensor
*
page_table_k
,
const
Tensor
*
page_table_v
,
const
Tensor
*
rng_state
,
const
Tensor
*
cu_seqlens_kv_padded
,
const
Tensor
*
page_table_k
,
const
Tensor
*
page_table_v
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
);
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
);
void
fused_attn_arbitrary_seqlen_bwd
(
void
fused_attn_arbitrary_seqlen_bwd
(
size_t
batch
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
batch
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
...
...
transformer_engine/common/fused_attn/fused_attn_fp8.cu
View file @
0a5016b1
...
@@ -1710,7 +1710,8 @@ void fused_attn_fp8_fwd_impl_v1(
...
@@ -1710,7 +1710,8 @@ void fused_attn_fp8_fwd_impl_v1(
qkv_tensor_type
,
qkv_tensor_type
,
o_tensor_type
,
o_tensor_type
,
cudnn_frontend
::
DataType_t
::
NOT_SET
,
cudnn_frontend
::
DataType_t
::
NOT_SET
,
cudnn_frontend
::
DataType_t
::
NOT_SET
};
cudnn_frontend
::
DataType_t
::
NOT_SET
,
false
};
namespace
fe
=
cudnn_frontend
;
namespace
fe
=
cudnn_frontend
;
using
graph_and_tensors
=
using
graph_and_tensors
=
...
@@ -2038,7 +2039,8 @@ void fused_attn_fp8_bwd_impl_v1(
...
@@ -2038,7 +2039,8 @@ void fused_attn_fp8_bwd_impl_v1(
qkv_tensor_type
,
qkv_tensor_type
,
o_tensor_type
,
o_tensor_type
,
do_tensor_type
,
do_tensor_type
,
dqkv_tensor_type
};
dqkv_tensor_type
,
false
};
namespace
fe
=
cudnn_frontend
;
namespace
fe
=
cudnn_frontend
;
using
graph_and_tensors
=
using
graph_and_tensors
=
...
...
transformer_engine/common/fused_attn/utils.h
View file @
0a5016b1
...
@@ -115,20 +115,21 @@ struct FADescriptor_v1 {
...
@@ -115,20 +115,21 @@ struct FADescriptor_v1 {
cudnn_frontend
::
DataType_t
o_tensor_type
;
cudnn_frontend
::
DataType_t
o_tensor_type
;
cudnn_frontend
::
DataType_t
do_tensor_type
;
cudnn_frontend
::
DataType_t
do_tensor_type
;
cudnn_frontend
::
DataType_t
dqkv_tensor_type
;
cudnn_frontend
::
DataType_t
dqkv_tensor_type
;
bool
generate_max_sum_exp
;
bool
operator
<
(
const
FADescriptor_v1
&
rhs
)
const
{
bool
operator
<
(
const
FADescriptor_v1
&
rhs
)
const
{
return
std
::
tie
(
b
,
h
,
hg
,
s_q
,
s_kv
,
d_qk
,
d_v
,
num_pages_k
,
num_pages_v
,
page_size_k
,
return
std
::
tie
(
b
,
h
,
hg
,
s_q
,
s_kv
,
d_qk
,
d_v
,
num_pages_k
,
num_pages_v
,
page_size_k
,
page_size_v
,
max_pages_per_seq_k
,
max_pages_per_seq_v
,
bias_b
,
bias_h
,
page_size_v
,
max_pages_per_seq_k
,
max_pages_per_seq_v
,
bias_b
,
bias_h
,
attnScale
,
isTraining
,
dropoutProbability
,
layout
,
mask_type
,
softmax_type
,
attnScale
,
isTraining
,
dropoutProbability
,
layout
,
mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
deterministic
,
bias_type
,
qkv_tensor_type
,
window_size_left
,
window_size_right
,
deterministic
,
bias_type
,
qkv_tensor_type
,
o_tensor_type
,
do_tensor_type
,
dqkv_tensor_type
)
<
o_tensor_type
,
do_tensor_type
,
dqkv_tensor_type
,
generate_max_sum_exp
)
<
std
::
tie
(
rhs
.
b
,
rhs
.
h
,
rhs
.
hg
,
rhs
.
s_q
,
rhs
.
s_kv
,
rhs
.
d_qk
,
rhs
.
d_v
,
rhs
.
num_pages_k
,
std
::
tie
(
rhs
.
b
,
rhs
.
h
,
rhs
.
hg
,
rhs
.
s_q
,
rhs
.
s_kv
,
rhs
.
d_qk
,
rhs
.
d_v
,
rhs
.
num_pages_k
,
rhs
.
num_pages_v
,
rhs
.
page_size_k
,
rhs
.
page_size_v
,
rhs
.
max_pages_per_seq_k
,
rhs
.
num_pages_v
,
rhs
.
page_size_k
,
rhs
.
page_size_v
,
rhs
.
max_pages_per_seq_k
,
rhs
.
max_pages_per_seq_v
,
rhs
.
bias_b
,
rhs
.
bias_h
,
rhs
.
attnScale
,
rhs
.
isTraining
,
rhs
.
max_pages_per_seq_v
,
rhs
.
bias_b
,
rhs
.
bias_h
,
rhs
.
attnScale
,
rhs
.
isTraining
,
rhs
.
dropoutProbability
,
rhs
.
layout
,
rhs
.
mask_type
,
rhs
.
softmax_type
,
rhs
.
dropoutProbability
,
rhs
.
layout
,
rhs
.
mask_type
,
rhs
.
softmax_type
,
rhs
.
window_size_left
,
rhs
.
window_size_right
,
rhs
.
deterministic
,
rhs
.
bias_type
,
rhs
.
window_size_left
,
rhs
.
window_size_right
,
rhs
.
deterministic
,
rhs
.
bias_type
,
rhs
.
qkv_tensor_type
,
rhs
.
o_tensor_type
,
rhs
.
do_tensor_type
,
rhs
.
qkv_tensor_type
,
rhs
.
o_tensor_type
,
rhs
.
do_tensor_type
,
rhs
.
dqkv_tensor_type
);
rhs
.
dqkv_tensor_type
,
rhs
.
generate_max_sum_exp
);
}
}
};
};
...
...
transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu
View file @
0a5016b1
...
@@ -97,7 +97,8 @@ cutlass::Array<cutlass::float_e2m1_t, 8>
...
@@ -97,7 +97,8 @@ cutlass::Array<cutlass::float_e2m1_t, 8>
StochasticNumericConverterBase
(
cutlass
::
Array
<
float
,
8
>
const
&
input
,
cutlass
::
Array
<
uint32_t
,
2
>
const
&
rbits
)
{
StochasticNumericConverterBase
(
cutlass
::
Array
<
float
,
8
>
const
&
input
,
cutlass
::
Array
<
uint32_t
,
2
>
const
&
rbits
)
{
using
result_type
=
cutlass
::
Array
<
cutlass
::
float_e2m1_t
,
8
>
;
using
result_type
=
cutlass
::
Array
<
cutlass
::
float_e2m1_t
,
8
>
;
result_type
output
;
result_type
output
;
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
constexpr
bool
has_rs
=
ARCH_HAS_STOCHASTIC_ROUNDING
;
if
constexpr
(
has_rs
)
{
auto
output_ptr
=
reinterpret_cast
<
uint16_t
*>
(
&
output
);
auto
output_ptr
=
reinterpret_cast
<
uint16_t
*>
(
&
output
);
asm
volatile
(
\
asm
volatile
(
\
"{
\n
"
\
"{
\n
"
\
...
@@ -109,10 +110,10 @@ StochasticNumericConverterBase(cutlass::Array<float, 8> const &input, cutlass::A
...
@@ -109,10 +110,10 @@ StochasticNumericConverterBase(cutlass::Array<float, 8> const &input, cutlass::A
:
"f"
(
input
[
0
]),
"f"
(
input
[
1
]),
"f"
(
input
[
2
]),
"f"
(
input
[
3
]),
:
"f"
(
input
[
0
]),
"f"
(
input
[
1
]),
"f"
(
input
[
2
]),
"f"
(
input
[
3
]),
"f"
(
input
[
4
]),
"f"
(
input
[
5
]),
"f"
(
input
[
6
]),
"f"
(
input
[
7
]),
"f"
(
input
[
4
]),
"f"
(
input
[
5
]),
"f"
(
input
[
6
]),
"f"
(
input
[
7
]),
"r"
(
rbits
[
0
]),
"r"
(
rbits
[
1
]));
"r"
(
rbits
[
0
]),
"r"
(
rbits
[
1
]));
#
else
}
else
{
NVTE_DEVICE_ERROR
(
"FP4 cvt PTX instructions are architecture-specific. "
NVTE_DEVICE_ERROR
(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
"Try recompiling with sm_XXXa instead of sm_XXX."
);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
}
return
output
;
return
output
;
}
}
...
...
transformer_engine/common/include/transformer_engine/fused_attn.h
View file @
0a5016b1
...
@@ -206,13 +206,14 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
...
@@ -206,13 +206,14 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
* \param[in] head_dim_v The head dimension of V.
* \param[in] head_dim_v The head dimension of V.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
*/
*/
NVTE_Fused_Attn_Backend
nvte_get_fused_attn_backend
(
NVTE_Fused_Attn_Backend
nvte_get_fused_attn_backend
(
bool
is_training
,
NVTEDType
q_dtype
,
NVTEDType
kv_dtype
,
NVTE_QKV_Layout
qkv_layout
,
bool
is_training
,
NVTEDType
q_dtype
,
NVTEDType
kv_dtype
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
float
dropout
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
float
dropout
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
int64_t
window_size_left
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
int64_t
window_size_left
,
int64_t
window_size_right
);
int64_t
window_size_right
,
bool
return_max_logit
);
/*! \brief Compute dot product attention with packed QKV input.
/*! \brief Compute dot product attention with packed QKV input.
*
*
...
@@ -255,6 +256,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
...
@@ -255,6 +256,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] max_seqlen Max sequence length used for computing,
* \param[in] max_seqlen Max sequence length used for computing,
* it may be >= max(seqlen_i) for i=0,...batch_size-1.
* it may be >= max(seqlen_i) for i=0,...batch_size-1.
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] qkv_layout QKV tensor's layout.
...
@@ -266,13 +268,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
...
@@ -266,13 +268,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] workspace Workspace tensor.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
*/
void
nvte_fused_attn_fwd_qkvpacked
(
void
nvte_fused_attn_fwd_qkvpacked
(
const
NVTETensor
QKV
,
const
NVTETensor
Bias
,
const
NVTETensor
QKV
,
const
NVTETensor
Bias
,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens
,
const
NVTETensor
cu_seqlens_padded
,
const
NVTETensor
rng_state
,
size_t
max_seqlen
,
const
NVTETensor
cu_seqlens_padded
,
const
NVTETensor
rng_state
,
bool
is_training
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
size_t
max_seqlen
,
bool
is_training
,
bool
return_max_logit
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
);
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Compute the backward of the dot product attention with packed QKV input.
/*! \brief Compute the backward of the dot product attention with packed QKV input.
*
*
...
@@ -381,6 +386,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
...
@@ -381,6 +386,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] max_seqlen_kv Max sequence length used for computing for KV.
* \param[in] max_seqlen_kv Max sequence length used for computing for KV.
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] qkv_layout QKV tensor's layout.
...
@@ -399,7 +405,7 @@ void nvte_fused_attn_fwd_kvpacked(
...
@@ -399,7 +405,7 @@ void nvte_fused_attn_fwd_kvpacked(
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
page_table_k
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
page_table_k
,
const
NVTETensor
page_table_v
,
const
NVTETensor
rng_state
,
size_t
max_seqlen_q
,
const
NVTETensor
page_table_v
,
const
NVTETensor
rng_state
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
float
attn_scale
,
float
dropout
,
size_t
max_seqlen_kv
,
bool
is_training
,
bool
return_max_logit
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
);
NVTETensor
workspace
,
cudaStream_t
stream
);
...
@@ -520,6 +526,7 @@ void nvte_fused_attn_bwd_kvpacked(
...
@@ -520,6 +526,7 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] max_seqlen_kv Max sequence length used for computing for K and V.
* \param[in] max_seqlen_kv Max sequence length used for computing for K and V.
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensors' layout.
* \param[in] qkv_layout QKV tensors' layout.
...
@@ -531,17 +538,15 @@ void nvte_fused_attn_bwd_kvpacked(
...
@@ -531,17 +538,15 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] workspace Workspace tensor.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
*/
void
nvte_fused_attn_fwd
(
const
NVTETensor
Q
,
const
NVTETensor
K
,
const
NVTETensor
V
,
void
nvte_fused_attn_fwd
(
const
NVTETensor
Bias
,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
const
NVTETensor
Q
,
const
NVTETensor
K
,
const
NVTETensor
V
,
const
NVTETensor
Bias
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
page_table_k
,
const
NVTETensor
page_table_k
,
const
NVTETensor
page_table_v
,
const
NVTETensor
rng_state
,
const
NVTETensor
page_table_v
,
const
NVTETensor
rng_state
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
bool
return_max_logit
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
);
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Compute the backward of the dot product attention with separate Q, K and V.
/*! \brief Compute the backward of the dot product attention with separate Q, K and V.
...
...
transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu
View file @
0a5016b1
...
@@ -264,7 +264,8 @@ __device__ __forceinline__ size_t scale_factor_swizzled_offset(size_t row_idx, s
...
@@ -264,7 +264,8 @@ __device__ __forceinline__ size_t scale_factor_swizzled_offset(size_t row_idx, s
__device__
__forceinline__
__nv_fp4x4_e2m1
cvt_fp32_to_fp4_4x_with_stochastic_rounding
(
__device__
__forceinline__
__nv_fp4x4_e2m1
cvt_fp32_to_fp4_4x_with_stochastic_rounding
(
const
float2
in01
,
const
float2
in23
,
const
uint32_t
rbits
)
{
const
float2
in01
,
const
float2
in23
,
const
uint32_t
rbits
)
{
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
constexpr
bool
has_rs
=
ARCH_HAS_STOCHASTIC_ROUNDING
;
if
constexpr
(
has_rs
)
{
uint16_t
out_4x
;
uint16_t
out_4x
;
asm
volatile
(
asm
volatile
(
"{
\n
"
"{
\n
"
...
@@ -273,19 +274,20 @@ __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_stochastic_ro
...
@@ -273,19 +274,20 @@ __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_stochastic_ro
:
"=h"
(
out_4x
)
:
"=h"
(
out_4x
)
:
"f"
(
in01
.
y
),
"f"
(
in01
.
x
),
"f"
(
in23
.
y
),
"f"
(
in23
.
x
),
"r"
(
rbits
));
:
"f"
(
in01
.
y
),
"f"
(
in01
.
x
),
"f"
(
in23
.
y
),
"f"
(
in23
.
x
),
"r"
(
rbits
));
return
*
reinterpret_cast
<
__nv_fp4x4_e2m1
*>
(
&
out_4x
);
return
*
reinterpret_cast
<
__nv_fp4x4_e2m1
*>
(
&
out_4x
);
#
else
}
else
{
NVTE_DEVICE_ERROR
(
NVTE_DEVICE_ERROR
(
"FP4 cvt PTX instructions are architecture-specific. "
"FP4 cvt
.rs
PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
"Try recompiling with sm_XXXa instead of sm_XXX."
);
uint16_t
dummy
=
0
;
uint16_t
dummy
=
0
;
return
*
reinterpret_cast
<
__nv_fp4x4_e2m1
*>
(
&
dummy
);
return
*
reinterpret_cast
<
__nv_fp4x4_e2m1
*>
(
&
dummy
);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
}
}
}
__device__
__forceinline__
__nv_fp4x4_e2m1
cvt_fp32_to_fp4_4x_with_rn
(
const
float2
in01
,
__device__
__forceinline__
__nv_fp4x4_e2m1
cvt_fp32_to_fp4_4x_with_rn
(
const
float2
in01
,
const
float2
in23
,
const
float2
in23
,
const
uint32_t
rbits
)
{
const
uint32_t
rbits
)
{
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
constexpr
bool
has_fp4
=
ARCH_BLACKWELL_FAMILY
;
if
constexpr
(
has_fp4
)
{
// NOTE: rbits unused for rn.
// NOTE: rbits unused for rn.
uint32_t
out_4x
;
// Only need 16 bit. Using 32 bit container for packing.
uint32_t
out_4x
;
// Only need 16 bit. Using 32 bit container for packing.
asm
volatile
(
asm
volatile
(
...
@@ -299,13 +301,13 @@ __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_rn(const floa
...
@@ -299,13 +301,13 @@ __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_rn(const floa
:
"=r"
(
out_4x
)
:
"=r"
(
out_4x
)
:
"f"
(
in01
.
y
),
"f"
(
in01
.
x
),
"f"
(
in23
.
y
),
"f"
(
in23
.
x
));
:
"f"
(
in01
.
y
),
"f"
(
in01
.
x
),
"f"
(
in23
.
y
),
"f"
(
in23
.
x
));
return
reinterpret_cast
<
__nv_fp4x4_e2m1
*>
(
&
out_4x
)[
0
];
return
reinterpret_cast
<
__nv_fp4x4_e2m1
*>
(
&
out_4x
)[
0
];
#
else
}
else
{
NVTE_DEVICE_ERROR
(
NVTE_DEVICE_ERROR
(
"FP4 cvt PTX instructions are architecture-specific. "
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
"Try recompiling with sm_XXXa instead of sm_XXX."
);
uint16_t
dummy
=
0
;
uint16_t
dummy
=
0
;
return
*
reinterpret_cast
<
__nv_fp4x4_e2m1
*>
(
&
dummy
);
return
*
reinterpret_cast
<
__nv_fp4x4_e2m1
*>
(
&
dummy
);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
}
}
}
template
<
bool
kApplyStochasticRounding
>
template
<
bool
kApplyStochasticRounding
>
...
...
transformer_engine/common/util/nvfp4_transpose.cuh
View file @
0a5016b1
...
@@ -15,10 +15,9 @@
...
@@ -15,10 +15,9 @@
#include <cudaTypedefs.h>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#if
CUDA_VERSION > 12080
#if
FP4_TYPE_SUPPORTED
#include <cuda_fp4.h>
#include <cuda_fp4.h>
#endif // CUDA_VERSION > 12080
#endif // FP4_TYPE_SUPPORTED
#include <cfloat>
#include <cfloat>
#include "../common.h"
#include "../common.h"
...
@@ -30,7 +29,7 @@
...
@@ -30,7 +29,7 @@
namespace
transformer_engine
{
namespace
transformer_engine
{
#if
CUDA_VERSION > 12080
#if
FP4_TYPE_SUPPORTED
namespace
nvfp4_transpose
{
namespace
nvfp4_transpose
{
using
RNG
=
decltype
(
curanddx
::
Generator
<
curanddx
::
philox4_32
>
()
+
curanddx
::
PhiloxRounds
<
10
>
()
+
using
RNG
=
decltype
(
curanddx
::
Generator
<
curanddx
::
philox4_32
>
()
+
curanddx
::
PhiloxRounds
<
10
>
()
+
...
@@ -152,12 +151,11 @@ __device__ __forceinline__ uint32_t get_rbits(RNG &rng, uint4 &random_uint4, int
...
@@ -152,12 +151,11 @@ __device__ __forceinline__ uint32_t get_rbits(RNG &rng, uint4 &random_uint4, int
return
rbits
;
return
rbits
;
}
}
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__device__
__forceinline__
fp4e2m1x4
mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding
(
__device__
__forceinline__
fp4e2m1x4
mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding
(
const
uint64_t
in_4x
,
const
float2
scale
,
const
uint32_t
rbits
)
{
const
uint64_t
in_4x
,
const
float2
scale
,
const
uint32_t
rbits
)
{
uint16_t
out_4x
=
0
;
uint16_t
out_4x
=
0
;
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
constexpr
bool
has_rs
=
ARCH_HAS_STOCHASTIC_ROUNDING
;
if
constexpr
(
has_rs
)
{
asm
volatile
(
asm
volatile
(
"{
\n
"
"{
\n
"
".reg.b64 v01;
\n\t
"
".reg.b64 v01;
\n\t
"
...
@@ -185,20 +183,21 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_stochastic_roun
...
@@ -185,20 +183,21 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_stochastic_roun
"}"
"}"
:
"=h"
(
out_4x
)
:
"=h"
(
out_4x
)
:
"l"
(
in_4x
),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)),
"r"
(
rbits
));
:
"l"
(
in_4x
),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)),
"r"
(
rbits
));
#
else
}
else
{
NVTE_DEVICE_ERROR
(
NVTE_DEVICE_ERROR
(
"FP4 cvt PTX instructions are architecture-specific. "
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
"Try recompiling with sm_XXXa instead of sm_XXX."
);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
}
return
*
reinterpret_cast
<
fp4e2m1x4
*>
(
&
out_4x
);
return
*
reinterpret_cast
<
fp4e2m1x4
*>
(
&
out_4x
);
}
}
__device__
__forceinline__
fp4e2m1x4
mul_cvt_bf16_to_fp4_4x_with_rn
(
const
uint64_t
in_4x
,
__device__
__forceinline__
fp4e2m1x4
mul_cvt_bf16_to_fp4_4x_with_rn
(
const
uint64_t
in_4x
,
const
float2
scale
,
const
float2
scale
,
const
uint32_t
rbits
)
{
const
uint32_t
rbits
)
{
// NOTE: rbits unused for rn.
constexpr
bool
is_blackwell
=
ARCH_BLACKWELL_FAMILY
;
uint32_t
out_4x
=
0
;
// Only need 16 bit. Using 32 bit container for packing.
uint32_t
out_4x
=
0
;
// Only need 16 bit. Using 32 bit container for packing.
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
if
constexpr
(
is_blackwell
)
{
// NOTE: rbits unused for rn.
asm
volatile
(
asm
volatile
(
"{
\n
"
"{
\n
"
".reg.b64 v01;
\n\t
"
".reg.b64 v01;
\n\t
"
...
@@ -230,11 +229,11 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_rn(const uint64
...
@@ -230,11 +229,11 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_rn(const uint64
"}"
"}"
:
"=r"
(
out_4x
)
:
"=r"
(
out_4x
)
:
"l"
(
in_4x
),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)));
:
"l"
(
in_4x
),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)));
#
else
}
else
{
NVTE_DEVICE_ERROR
(
NVTE_DEVICE_ERROR
(
"FP4 cvt PTX instructions are architecture-specific. "
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
"Try recompiling with sm_XXXa instead of sm_XXX."
);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
}
return
reinterpret_cast
<
fp4e2m1x4
*>
(
&
out_4x
)[
0
];
return
reinterpret_cast
<
fp4e2m1x4
*>
(
&
out_4x
)[
0
];
}
}
...
@@ -252,7 +251,8 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x(const uint64_t in_4x
...
@@ -252,7 +251,8 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x(const uint64_t in_4x
__device__
__forceinline__
fp4e2m1x4
mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding
(
__device__
__forceinline__
fp4e2m1x4
mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding
(
const
float2
in01
,
const
float2
in23
,
const
float2
scale
,
const
uint32_t
rbits
)
{
const
float2
in01
,
const
float2
in23
,
const
float2
scale
,
const
uint32_t
rbits
)
{
uint16_t
out_4x
=
0
;
uint16_t
out_4x
=
0
;
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
constexpr
bool
has_rs
=
ARCH_HAS_STOCHASTIC_ROUNDING
;
if
constexpr
(
has_rs
)
{
asm
volatile
(
asm
volatile
(
"{
\n
"
"{
\n
"
".reg.b64 v01;
\n\t
"
".reg.b64 v01;
\n\t
"
...
@@ -275,11 +275,11 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_stochastic_roun
...
@@ -275,11 +275,11 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_stochastic_roun
:
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
in01
)),
:
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
in01
)),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
in23
)),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
in23
)),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)),
"r"
(
rbits
));
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)),
"r"
(
rbits
));
#
else
}
else
{
NVTE_DEVICE_ERROR
(
NVTE_DEVICE_ERROR
(
"FP4 cvt PTX instructions are architecture-specific. "
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
"Try recompiling with sm_XXXa instead of sm_XXX."
);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
}
return
*
reinterpret_cast
<
fp4e2m1x4
*>
(
&
out_4x
);
return
*
reinterpret_cast
<
fp4e2m1x4
*>
(
&
out_4x
);
}
}
...
@@ -287,9 +287,10 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_rn(const float2
...
@@ -287,9 +287,10 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_rn(const float2
const
float2
in23
,
const
float2
in23
,
const
float2
scale
,
const
float2
scale
,
const
uint32_t
rbits
)
{
const
uint32_t
rbits
)
{
// NOTE: rbits unused for rn.
constexpr
bool
is_blackwell
=
ARCH_BLACKWELL_FAMILY
;
uint32_t
out_4x
=
0
;
// Only need 16 bit. Using 32 bit container for packing.
uint32_t
out_4x
=
0
;
// Only need 16 bit. Using 32 bit container for packing.
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
if
constexpr
(
is_blackwell
)
{
// NOTE: rbits unused for rn.
asm
volatile
(
asm
volatile
(
"{
\n
"
"{
\n
"
".reg.b64 v01;
\n\t
"
".reg.b64 v01;
\n\t
"
...
@@ -316,11 +317,11 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_rn(const float2
...
@@ -316,11 +317,11 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_rn(const float2
:
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
in01
)),
:
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
in01
)),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
in23
)),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
in23
)),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)));
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)));
#
else
}
else
{
NVTE_DEVICE_ERROR
(
NVTE_DEVICE_ERROR
(
"FP4 cvt PTX instructions are architecture-specific. "
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
"Try recompiling with sm_XXXa instead of sm_XXX."
);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
}
return
reinterpret_cast
<
fp4e2m1x4
*>
(
&
out_4x
)[
0
];
return
reinterpret_cast
<
fp4e2m1x4
*>
(
&
out_4x
)[
0
];
}
}
...
@@ -335,8 +336,6 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, c
...
@@ -335,8 +336,6 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, c
}
}
}
}
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
template
<
bool
COMPUTE_ACTIVATIONS
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
),
template
<
bool
COMPUTE_ACTIVATIONS
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
),
typename
IType
,
bool
USE_STOCHASTIC_ROUNDING
,
bool
RETURN_TRANSPOSE
>
typename
IType
,
bool
USE_STOCHASTIC_ROUNDING
,
bool
RETURN_TRANSPOSE
>
__global__
void
__launch_bounds__
(
THREADS_NUM
)
__global__
void
__launch_bounds__
(
THREADS_NUM
)
...
@@ -1380,18 +1379,13 @@ __global__ void __launch_bounds__(THREADS_NUM)
...
@@ -1380,18 +1379,13 @@ __global__ void __launch_bounds__(THREADS_NUM)
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
}
}
// namespace nvfp4_transpose
}
// namespace nvfp4_transpose
#endif // CUDA_VERSION > 12080
#endif // FP4_TYPE_SUPPORTED
// Compile-time flag to choose kernel variant
#ifndef USE_2D_NVFP4_KERNEL
#define USE_2D_NVFP4_KERNEL 0
#endif
template
<
bool
COMPUTE_ACTIVATIONS
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
),
template
<
bool
COMPUTE_ACTIVATIONS
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
),
bool
use_2d_quantization
>
bool
use_2d_quantization
>
void
nvfp4_quantize_transpose
(
const
Tensor
&
input
,
const
Tensor
*
noop
,
Tensor
*
output
,
void
nvfp4_quantize_transpose
(
const
Tensor
&
input
,
const
Tensor
*
noop
,
Tensor
*
output
,
const
QuantizationConfig
*
quant_config
,
cudaStream_t
stream
)
{
const
QuantizationConfig
*
quant_config
,
cudaStream_t
stream
)
{
#if
CUDA_VERSION > 12080
#if
FP4_TYPE_SUPPORTED
bool
use_stochastic_rounding
=
quant_config
?
quant_config
->
stochastic_rounding
:
false
;
bool
use_stochastic_rounding
=
quant_config
?
quant_config
->
stochastic_rounding
:
false
;
// If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to
// If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to
...
@@ -1509,7 +1503,7 @@ void nvfp4_quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *o
...
@@ -1509,7 +1503,7 @@ void nvfp4_quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *o
}););
}););
#else
#else
NVTE_ERROR
(
"FP4 support requires CUDA 12.8+, but compile-time CUDA version is "
,
CUDA_VERSION
);
NVTE_ERROR
(
"FP4 support requires CUDA 12.8+, but compile-time CUDA version is "
,
CUDA_VERSION
);
#endif //
CUDA_VERSION > 12080
#endif //
FP4_TYPE_SUPPORTED
}
}
}
// namespace transformer_engine
}
// namespace transformer_engine
...
...
transformer_engine/common/util/ptx.cuh
View file @
0a5016b1
...
@@ -18,44 +18,165 @@
...
@@ -18,44 +18,165 @@
#include <cuda_fp4.h>
#include <cuda_fp4.h>
#endif // CUDA_VERSION >= 12080
#endif // CUDA_VERSION >= 12080
#include "common/utils.cuh"
namespace
transformer_engine
{
namespace
transformer_engine
{
namespace
ptx
{
namespace
ptx
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
template
<
int
N
>
struct
ArchSpecific
{
constexpr
static
int
id
=
N
*
10
;
template
<
int
CurrentArch
,
int
ArchSpecific
,
int
FamilySpecific
>
constexpr
static
bool
compatible
()
{
if
constexpr
(
CurrentArch
==
id
)
{
static_assert
(
ArchSpecific
==
CurrentArch
,
"Compiled for the generic architecture, while utilizing arch-specific "
"features. Please compile for smXXXa architecture instead of smXXX "
"architecture."
);
return
true
;
}
else
{
return
false
;
}
}
};
template
<
int
N
>
struct
FamilySpecific
{
constexpr
static
int
id
=
N
*
10
;
template
<
int
CurrentArch
,
int
ArchSpecific
,
int
FamilySpecific
>
constexpr
static
bool
compatible
()
{
if
constexpr
((
CurrentArch
/
100
)
==
(
id
/
100
))
{
static_assert
(
FamilySpecific
==
CurrentArch
,
"Compiled for the generic architecture, while utilizing family-specific "
"features. Please compile for smXXXf architecture instead of smXXX "
"architecture."
);
return
true
;
}
else
{
return
false
;
}
}
};
template
<
int
Arch
,
int
ArchSpecific
,
int
FamilySpecific
,
class
T
,
class
...
U
>
constexpr
bool
is_supported_arch
()
{
if
constexpr
(
T
::
template
compatible
<
Arch
,
ArchSpecific
,
FamilySpecific
>())
{
return
true
;
}
else
if
constexpr
(
sizeof
...(
U
)
!=
0
)
{
return
is_supported_arch
<
Arch
,
ArchSpecific
,
FamilySpecific
,
U
...
>
();
}
else
{
return
false
;
}
}
#if CUDA_VERSION < 12090
#if __CUDA_ARCH_HAS_FEATURE__(SM90_ALL)
#define __CUDA_ARCH_SPECIFIC__ 900
#define __CUDA_ARCH_FAMILY_SPECIFIC__ 900
#endif
#if __CUDA_ARCH_HAS_FEATURE__(SM100_ALL)
#define __CUDA_ARCH_SPECIFIC__ 1000
#define __CUDA_ARCH_FAMILY_SPECIFIC__ 1000
#endif
#if __CUDA_ARCH_HAS_FEATURE__(SM101_ALL)
#define __CUDA_ARCH_SPECIFIC__ 1010
#define __CUDA_ARCH_FAMILY_SPECIFIC__ 1010
#endif
#if __CUDA_ARCH_HAS_FEATURE__(SM120_ALL)
#define __CUDA_ARCH_SPECIFIC__ 1200
#define __CUDA_ARCH_FAMILY_SPECIFIC__ 1200
#endif
#endif
#ifdef __CUDA_ARCH__
#define __NVTE_CURRENT_ARCH__ constexpr int current_arch = __CUDA_ARCH__;
#else
#define __NVTE_CURRENT_ARCH__ constexpr int current_arch = 0;
#endif
#ifdef __CUDA_ARCH_SPECIFIC__
#define __NVTE_ARCH_SPECIFIC__ constexpr int ArchSpecific = __CUDA_ARCH_SPECIFIC__;
#else
#define __NVTE_ARCH_SPECIFIC__ constexpr int ArchSpecific = 0;
#endif
#ifdef __CUDA_ARCH_FAMILY_SPECIFIC__
#define __NVTE_ARCH_FAMILY_SPECIFIC__ constexpr int FamilySpecific = __CUDA_ARCH_FAMILY_SPECIFIC__;
#else
#define __NVTE_ARCH_FAMILY_SPECIFIC__ constexpr int FamilySpecific = 0;
#endif
#define NVTE_CUDA_ARCH_MATCHES(...) \
[&] { \
__NVTE_CURRENT_ARCH__ \
__NVTE_ARCH_SPECIFIC__ \
__NVTE_ARCH_FAMILY_SPECIFIC__ \
return transformer_engine::ptx::is_supported_arch<current_arch, ArchSpecific, FamilySpecific, \
__VA_ARGS__>(); \
}();
#define ARCH_BLACKWELL_FAMILY \
NVTE_CUDA_ARCH_MATCHES(ptx::FamilySpecific<100>, ptx::FamilySpecific<110>, \
ptx::FamilySpecific<120>)
#define ARCH_HAS_STOCHASTIC_ROUNDING \
NVTE_CUDA_ARCH_MATCHES(ptx::ArchSpecific<100>, ptx::ArchSpecific<103>)
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init
__device__
__forceinline__
void
mbarrier_init
(
uint64_t
*
mbar
,
const
uint32_t
count
)
{
__device__
__forceinline__
void
mbarrier_init
(
uint64_t
*
mbar
,
const
uint32_t
count
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t
mbar_ptr
=
__cvta_generic_to_shared
(
mbar
);
uint32_t
mbar_ptr
=
__cvta_generic_to_shared
(
mbar
);
asm
volatile
(
"mbarrier.init.shared.b64 [%0], %1;"
::
"r"
(
mbar_ptr
),
"r"
(
count
)
:
"memory"
);
asm
volatile
(
"mbarrier.init.shared.b64 [%0], %1;"
::
"r"
(
mbar_ptr
),
"r"
(
count
)
:
"memory"
);
#else
NVTE_DEVICE_ERROR
(
"mbarrier_init is only supported on SM 10.0+."
);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-inval
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-inval
__device__
__forceinline__
void
mbarrier_invalid
(
uint64_t
*
mbar
)
{
__device__
__forceinline__
void
mbarrier_invalid
(
uint64_t
*
mbar
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t
mbar_ptr
=
__cvta_generic_to_shared
(
mbar
);
uint32_t
mbar_ptr
=
__cvta_generic_to_shared
(
mbar
);
asm
volatile
(
"mbarrier.inval.shared.b64 [%0];"
::
"r"
(
mbar_ptr
)
:
"memory"
);
asm
volatile
(
"mbarrier.inval.shared.b64 [%0];"
::
"r"
(
mbar_ptr
)
:
"memory"
);
#else
NVTE_DEVICE_ERROR
(
"mbarrier_invalid is only supported on SM 10.0+."
);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive
__device__
__forceinline__
void
mbarrier_arrive
(
uint64_t
*
mbar
)
{
__device__
__forceinline__
void
mbarrier_arrive
(
uint64_t
*
mbar
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t
mbar_ptr
=
__cvta_generic_to_shared
(
mbar
);
uint32_t
mbar_ptr
=
__cvta_generic_to_shared
(
mbar
);
asm
volatile
(
"mbarrier.arrive.shared.b64 _, [%0];"
::
"r"
(
mbar_ptr
)
:
"memory"
);
asm
volatile
(
"mbarrier.arrive.shared.b64 _, [%0];"
::
"r"
(
mbar_ptr
)
:
"memory"
);
#else
NVTE_DEVICE_ERROR
(
"mbarrier_arrive is only supported on SM 10.0+."
);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive
__device__
__forceinline__
void
mbarrier_arrive_expect_tx
(
uint64_t
*
mbar
,
const
uint32_t
tx_count
)
{
__device__
__forceinline__
void
mbarrier_arrive_expect_tx
(
uint64_t
*
mbar
,
const
uint32_t
tx_count
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t
mbar_ptr
=
__cvta_generic_to_shared
(
mbar
);
uint32_t
mbar_ptr
=
__cvta_generic_to_shared
(
mbar
);
asm
volatile
(
"mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"
::
"r"
(
mbar_ptr
),
"r"
(
tx_count
)
asm
volatile
(
"mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"
::
"r"
(
mbar_ptr
),
"r"
(
tx_count
)
:
"memory"
);
:
"memory"
);
#else
NVTE_DEVICE_ERROR
(
"mbarrier_arrive_expect_tx is only supported on SM 10.0+."
);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
}
__device__
__forceinline__
void
fence_mbarrier_init_release_cluster
()
{
__device__
__forceinline__
void
fence_mbarrier_init_release_cluster
()
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
asm
volatile
(
"fence.mbarrier_init.release.cluster;"
);
asm
volatile
(
"fence.mbarrier_init.release.cluster;"
);
#else
NVTE_DEVICE_ERROR
(
"fence_mbarrier_init_release_cluster is only supported on SM 10.0+."
);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
// global -> shared::cluster
// global -> shared::cluster
__device__
__forceinline__
void
cp_async_bulk_tensor_1d_global_to_shared
(
__device__
__forceinline__
void
cp_async_bulk_tensor_1d_global_to_shared
(
uint64_t
*
dst_shmem
,
const
uint64_t
*
src_global_ptr
,
const
uint32_t
size
,
uint64_t
*
mbar
)
{
uint64_t
*
dst_shmem
,
const
uint64_t
*
src_global_ptr
,
const
uint32_t
size
,
uint64_t
*
mbar
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t
dst_shmem_ptr
=
__cvta_generic_to_shared
(
dst_shmem
);
uint32_t
dst_shmem_ptr
=
__cvta_generic_to_shared
(
dst_shmem
);
uint32_t
mbar_ptr
=
__cvta_generic_to_shared
(
mbar
);
uint32_t
mbar_ptr
=
__cvta_generic_to_shared
(
mbar
);
// triggers async copy, i.e. the thread continues until wait() on mbarrier
// triggers async copy, i.e. the thread continues until wait() on mbarrier
...
@@ -67,6 +188,9 @@ __device__ __forceinline__ void cp_async_bulk_tensor_1d_global_to_shared(
...
@@ -67,6 +188,9 @@ __device__ __forceinline__ void cp_async_bulk_tensor_1d_global_to_shared(
".mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];"
::
"r"
(
dst_shmem_ptr
),
".mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];"
::
"r"
(
dst_shmem_ptr
),
"l"
(
src_global_ptr
),
"r"
(
size
),
"r"
(
mbar_ptr
)
"l"
(
src_global_ptr
),
"r"
(
size
),
"r"
(
mbar_ptr
)
:
"memory"
);
:
"memory"
);
#else
NVTE_DEVICE_ERROR
(
"cp_async_bulk_tensor_1d_global_to_shared is only supported on SM 10.0+."
);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
...
@@ -74,6 +198,7 @@ __device__ __forceinline__ void cp_async_bulk_tensor_1d_global_to_shared(
...
@@ -74,6 +198,7 @@ __device__ __forceinline__ void cp_async_bulk_tensor_1d_global_to_shared(
__device__
__forceinline__
void
cp_async_bulk_tensor_2d_global_to_shared
(
__device__
__forceinline__
void
cp_async_bulk_tensor_2d_global_to_shared
(
uint64_t
*
dst_shmem
,
const
uint64_t
*
tensor_map_ptr
,
const
uint32_t
offset_x
,
uint64_t
*
dst_shmem
,
const
uint64_t
*
tensor_map_ptr
,
const
uint32_t
offset_x
,
const
uint32_t
offset_y
,
uint64_t
*
mbar
)
{
const
uint32_t
offset_y
,
uint64_t
*
mbar
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t
dst_shmem_ptr
=
__cvta_generic_to_shared
(
dst_shmem
);
uint32_t
dst_shmem_ptr
=
__cvta_generic_to_shared
(
dst_shmem
);
uint32_t
mbar_ptr
=
__cvta_generic_to_shared
(
mbar
);
uint32_t
mbar_ptr
=
__cvta_generic_to_shared
(
mbar
);
// triggers async copy, i.e. the thread continues until wait() on mbarrier
// triggers async copy, i.e. the thread continues until wait() on mbarrier
...
@@ -85,9 +210,13 @@ __device__ __forceinline__ void cp_async_bulk_tensor_2d_global_to_shared(
...
@@ -85,9 +210,13 @@ __device__ __forceinline__ void cp_async_bulk_tensor_2d_global_to_shared(
".mbarrier::complete_tx::bytes [%0], [%1, {%2, %3}], [%4];"
::
"r"
(
dst_shmem_ptr
),
".mbarrier::complete_tx::bytes [%0], [%1, {%2, %3}], [%4];"
::
"r"
(
dst_shmem_ptr
),
"l"
(
tensor_map_ptr
),
"r"
(
offset_x
),
"r"
(
offset_y
),
"r"
(
mbar_ptr
)
"l"
(
tensor_map_ptr
),
"r"
(
offset_x
),
"r"
(
offset_y
),
"r"
(
mbar_ptr
)
:
"memory"
);
:
"memory"
);
#else
NVTE_DEVICE_ERROR
(
"cp_async_bulk_tensor_2d_global_to_shared is only supported on SM 10.0+."
);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
}
__device__
__forceinline__
bool
mbarrier_try_wait_parity
(
uint32_t
mbar_ptr
,
const
uint32_t
parity
)
{
__device__
__forceinline__
bool
mbarrier_try_wait_parity
(
uint32_t
mbar_ptr
,
const
uint32_t
parity
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t
waitComplete
;
uint32_t
waitComplete
;
asm
volatile
(
asm
volatile
(
"{
\n\t
.reg .pred P_OUT;
\n\t
"
"{
\n\t
.reg .pred P_OUT;
\n\t
"
...
@@ -98,15 +227,21 @@ __device__ __forceinline__ bool mbarrier_try_wait_parity(uint32_t mbar_ptr, cons
...
@@ -98,15 +227,21 @@ __device__ __forceinline__ bool mbarrier_try_wait_parity(uint32_t mbar_ptr, cons
:
"r"
(
mbar_ptr
),
"r"
(
parity
)
:
"r"
(
mbar_ptr
),
"r"
(
parity
)
:
"memory"
);
:
"memory"
);
return
static_cast
<
bool
>
(
waitComplete
);
return
static_cast
<
bool
>
(
waitComplete
);
#else
NVTE_DEVICE_ERROR
(
"mbarrier_try_wait_parity is only supported on SM 10.0+."
);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
return
true
;
}
}
__device__
__forceinline__
void
mbarrier_wait_parity
(
uint64_t
*
mbar
,
const
uint32_t
parity
)
{
__device__
__forceinline__
void
mbarrier_wait_parity
(
uint64_t
*
mbar
,
const
uint32_t
parity
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t
mbar_ptr
=
__cvta_generic_to_shared
(
mbar
);
uint32_t
mbar_ptr
=
__cvta_generic_to_shared
(
mbar
);
while
(
!
mbarrier_try_wait_parity
(
mbar_ptr
,
parity
))
{
while
(
!
mbarrier_try_wait_parity
(
mbar_ptr
,
parity
))
{
}
}
}
#else
NVTE_DEVICE_ERROR
(
"mbarrier_wait_parity is only supported on SM 10.0+."
);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
constexpr
uint32_t
FP32_MANTISSA_BITS
=
23
;
constexpr
uint32_t
FP32_MANTISSA_BITS
=
23
;
constexpr
uint32_t
FP32_EXPONENT_BIAS
=
127
;
constexpr
uint32_t
FP32_EXPONENT_BIAS
=
127
;
...
@@ -129,13 +264,9 @@ __device__ __forceinline__ float exp2f(e8m0_t biased_exp) {
...
@@ -129,13 +264,9 @@ __device__ __forceinline__ float exp2f(e8m0_t biased_exp) {
return
__int_as_float
(
biased_exp
<<
FP32_MANTISSA_BITS
);
return
__int_as_float
(
biased_exp
<<
FP32_MANTISSA_BITS
);
}
}
#define CUDA_ARCH_HAS_FEATURE_SM10X_ALL \
((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \
(__CUDA_ARCH_HAS_FEATURE__(SM103_ALL)))
__device__
__forceinline__
e8m0_t
float_to_e8m0
(
float
val
)
{
__device__
__forceinline__
e8m0_t
float_to_e8m0
(
float
val
)
{
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
constexpr
bool
is_blackwell
=
ARCH_BLACKWELL_FAMILY
;
if
constexpr
(
is_blackwell
)
{
uint16_t
out
;
uint16_t
out
;
asm
volatile
(
asm
volatile
(
"{
\n
"
"{
\n
"
...
@@ -144,7 +275,7 @@ __device__ __forceinline__ e8m0_t float_to_e8m0(float val) {
...
@@ -144,7 +275,7 @@ __device__ __forceinline__ e8m0_t float_to_e8m0(float val) {
:
"=h"
(
out
)
:
"=h"
(
out
)
:
"f"
(
val
));
:
"f"
(
val
));
return
*
reinterpret_cast
<
e8m0_t
*>
(
&
out
);
return
*
reinterpret_cast
<
e8m0_t
*>
(
&
out
);
#
else
}
else
{
// TODO: nan/inf needs to be set for any value
// TODO: nan/inf needs to be set for any value
// of nan/inf in input not just amax.
// of nan/inf in input not just amax.
if
(
isnan
(
val
))
{
if
(
isnan
(
val
))
{
...
@@ -164,20 +295,22 @@ __device__ __forceinline__ e8m0_t float_to_e8m0(float val) {
...
@@ -164,20 +295,22 @@ __device__ __forceinline__ e8m0_t float_to_e8m0(float val) {
++
exponent
;
++
exponent
;
}
}
return
exponent
;
return
exponent
;
#endif
}
}
}
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
// shared::cta -> global
// shared::cta -> global
__device__
__forceinline__
void
cp_async_bulk_tensor_1d_shared_to_global
(
uint64_t
*
dst_global_ptr
,
__device__
__forceinline__
void
cp_async_bulk_tensor_1d_shared_to_global
(
uint64_t
*
dst_global_ptr
,
const
uint64_t
*
src_shmem
,
const
uint64_t
*
src_shmem
,
const
uint32_t
size
)
{
const
uint32_t
size
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
uint32_t
src_shmem_ptr
=
__cvta_generic_to_shared
(
src_shmem
);
uint32_t
src_shmem_ptr
=
__cvta_generic_to_shared
(
src_shmem
);
asm
volatile
(
"cp.async.bulk.global.shared::cta.bulk_group [%0], [%1], %2;"
::
"l"
(
dst_global_ptr
),
asm
volatile
(
"cp.async.bulk.global.shared::cta.bulk_group [%0], [%1], %2;"
::
"l"
(
dst_global_ptr
),
"r"
(
src_shmem_ptr
),
"r"
(
size
)
"r"
(
src_shmem_ptr
),
"r"
(
size
)
:
"memory"
);
:
"memory"
);
#else
NVTE_DEVICE_ERROR
(
"cp_async_bulk_tensor_1d_shared_to_global is only supported on SM 9.0+."
);
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
...
@@ -185,51 +318,93 @@ __device__ __forceinline__ void cp_async_bulk_tensor_1d_shared_to_global(uint64_
...
@@ -185,51 +318,93 @@ __device__ __forceinline__ void cp_async_bulk_tensor_1d_shared_to_global(uint64_
__device__
__forceinline__
void
cp_async_bulk_tensor_2d_shared_to_global
(
__device__
__forceinline__
void
cp_async_bulk_tensor_2d_shared_to_global
(
const
uint64_t
*
tensor_map_ptr
,
const
uint32_t
offset_x
,
const
uint32_t
offset_y
,
const
uint64_t
*
tensor_map_ptr
,
const
uint32_t
offset_x
,
const
uint32_t
offset_y
,
uint64_t
*
src_shmem
)
{
uint64_t
*
src_shmem
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
uint32_t
src_shmem_ptr
=
__cvta_generic_to_shared
(
src_shmem
);
uint32_t
src_shmem_ptr
=
__cvta_generic_to_shared
(
src_shmem
);
asm
volatile
(
"cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, {%1, %2}], [%3];"
::
"l"
(
asm
volatile
(
"cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, {%1, %2}], [%3];"
::
"l"
(
tensor_map_ptr
),
tensor_map_ptr
),
"r"
(
offset_x
),
"r"
(
offset_y
),
"r"
(
src_shmem_ptr
)
"r"
(
offset_x
),
"r"
(
offset_y
),
"r"
(
src_shmem_ptr
)
:
"memory"
);
:
"memory"
);
#else
NVTE_DEVICE_ERROR
(
"cp_async_bulk_tensor_2d_shared_to_global is only supported on SM 9.0+."
);
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group
__device__
__forceinline__
void
cp_async_bulk_wait_group
()
{
__device__
__forceinline__
void
cp_async_bulk_wait_group
()
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm
volatile
(
"cp.async.bulk.wait_group 0;"
);
asm
volatile
(
"cp.async.bulk.wait_group 0;"
);
#else
NVTE_DEVICE_ERROR
(
"cp_async_bulk_wait_group is only supported on SM 9.0+."
);
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group
template
<
size_t
W
>
template
<
size_t
W
>
__device__
__forceinline__
void
cp_async_bulk_wait_group_read
()
{
__device__
__forceinline__
void
cp_async_bulk_wait_group_read
()
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm
volatile
(
"cp.async.bulk.wait_group.read 0;"
);
asm
volatile
(
"cp.async.bulk.wait_group.read 0;"
);
#else
NVTE_DEVICE_ERROR
(
"cp_async_bulk_wait_group_read is only supported on SM 9.0+."
);
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
}
}
template
<
>
template
<
>
__device__
__forceinline__
void
cp_async_bulk_wait_group_read
<
0
>
()
{
__device__
__forceinline__
void
cp_async_bulk_wait_group_read
<
0
>
()
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm
volatile
(
"cp.async.bulk.wait_group.read 0;"
);
asm
volatile
(
"cp.async.bulk.wait_group.read 0;"
);
#else
NVTE_DEVICE_ERROR
(
"cp_async_bulk_wait_group_read is only supported on SM 9.0+."
);
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
}
}
template
<
>
template
<
>
__device__
__forceinline__
void
cp_async_bulk_wait_group_read
<
1
>
()
{
__device__
__forceinline__
void
cp_async_bulk_wait_group_read
<
1
>
()
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm
volatile
(
"cp.async.bulk.wait_group.read 1;"
);
asm
volatile
(
"cp.async.bulk.wait_group.read 1;"
);
#else
NVTE_DEVICE_ERROR
(
"cp_async_bulk_wait_group_read is only supported on SM 9.0+."
);
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
}
}
template
<
>
template
<
>
__device__
__forceinline__
void
cp_async_bulk_wait_group_read
<
2
>
()
{
__device__
__forceinline__
void
cp_async_bulk_wait_group_read
<
2
>
()
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm
volatile
(
"cp.async.bulk.wait_group.read 2;"
);
asm
volatile
(
"cp.async.bulk.wait_group.read 2;"
);
#else
NVTE_DEVICE_ERROR
(
"cp_async_bulk_wait_group_read is only supported on SM 9.0+."
);
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
}
}
template
<
>
template
<
>
__device__
__forceinline__
void
cp_async_bulk_wait_group_read
<
4
>
()
{
__device__
__forceinline__
void
cp_async_bulk_wait_group_read
<
4
>
()
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm
volatile
(
"cp.async.bulk.wait_group.read 4;"
);
asm
volatile
(
"cp.async.bulk.wait_group.read 4;"
);
#else
NVTE_DEVICE_ERROR
(
"cp_async_bulk_wait_group_read is only supported on SM 9.0+."
);
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group
__device__
__forceinline__
void
cp_async_bulk_commit_group
()
{
__device__
__forceinline__
void
cp_async_bulk_commit_group
()
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm
volatile
(
"cp.async.bulk.commit_group;"
);
asm
volatile
(
"cp.async.bulk.commit_group;"
);
#else
NVTE_DEVICE_ERROR
(
"cp_async_bulk_commit_group is only supported on SM 9.0+."
);
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
}
}
// Proxy fence (bi-directional):
// Proxy fence (bi-directional):
__device__
__forceinline__
void
fence_proxy_async
()
{
asm
volatile
(
"fence.proxy.async;"
);
}
__device__
__forceinline__
void
fence_proxy_async
()
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm
volatile
(
"fence.proxy.async;"
);
#else
NVTE_DEVICE_ERROR
(
"fence_proxy_async is only supported on SM 9.0+."
);
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
}
__device__
__forceinline__
void
fence_proxy_async_shared_cta
()
{
__device__
__forceinline__
void
fence_proxy_async_shared_cta
()
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm
volatile
(
"fence.proxy.async.shared::cta;"
);
asm
volatile
(
"fence.proxy.async.shared::cta;"
);
#else
NVTE_DEVICE_ERROR
(
"fence_proxy_async_shared_cta is only supported on SM 9.0+."
);
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
}
}
template
<
typename
T
>
template
<
typename
T
>
...
@@ -290,15 +465,6 @@ static_assert(sizeof(fp4e2m1x2) == 1);
...
@@ -290,15 +465,6 @@ static_assert(sizeof(fp4e2m1x2) == 1);
static_assert
(
sizeof
(
fp4e2m1x4
)
==
2
);
static_assert
(
sizeof
(
fp4e2m1x4
)
==
2
);
#endif // CUDA_VERSION >= 12080
#endif // CUDA_VERSION >= 12080
// cvt.rn.satfinite.e2m1x2.f32 d, a, b; // Convert two FP32 values to two packed e2m1
// cvt.rn.satfinite{.relu}.{e2m1x2/e2m3x2/e3m2x2/ue8m0x2}.f32 introduced in PTX ISA version 8.6.
// vt.rn.satfinite{.relu}.{e2m1x2/e2m3x2/e3m2x2/ue8m0x2}.f32 is supported on following architectures:
// sm_100a
// sm_101a
// sm_120a
// When converting to .e2m1x2 data formats, the destination operand d has .b8 type.
// When converting to .e2m1x2 data formats, the destination operand d has .b8 type.
// When converting two .f32 inputs to .e2m1x2, each input is converted to the specified format,
// When converting two .f32 inputs to .e2m1x2, each input is converted to the specified format,
// and the converted values are packed in the destination operand d such that the value
// and the converted values are packed in the destination operand d such that the value
...
@@ -321,6 +487,7 @@ __device__ __forceinline__ void mul_cvt_4x(fp4e2m1x4 &out, const Tx2 &in01, cons
...
@@ -321,6 +487,7 @@ __device__ __forceinline__ void mul_cvt_4x(fp4e2m1x4 &out, const Tx2 &in01, cons
// SIMD like "Fused" cast + multiplication (x2)
// SIMD like "Fused" cast + multiplication (x2)
__device__
__forceinline__
void
mul_cvt_2x
(
fp8e4m3x2
&
out
,
const
floatx2
&
in
,
__device__
__forceinline__
void
mul_cvt_2x
(
fp8e4m3x2
&
out
,
const
floatx2
&
in
,
const
floatx2
&
scale
)
{
const
floatx2
&
scale
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
asm
volatile
(
asm
volatile
(
"{
\n
"
"{
\n
"
".reg.b64 val_pair;
\n\t
"
".reg.b64 val_pair;
\n\t
"
...
@@ -333,10 +500,14 @@ __device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const floatx2 &in,
...
@@ -333,10 +500,14 @@ __device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const floatx2 &in,
:
"=h"
(
reinterpret_cast
<
uint16_t
&>
(
out
))
:
"=h"
(
reinterpret_cast
<
uint16_t
&>
(
out
))
:
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
in
)),
:
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
in
)),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)));
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)));
#else
NVTE_DEVICE_ERROR
(
"mul_cvt_2x is only supported on SM 10.0+."
);
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
}
__device__
__forceinline__
void
mul_cvt_2x
(
fp8e5m2x2
&
out
,
const
floatx2
&
in
,
__device__
__forceinline__
void
mul_cvt_2x
(
fp8e5m2x2
&
out
,
const
floatx2
&
in
,
const
floatx2
&
scale
)
{
const
floatx2
&
scale
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
asm
volatile
(
asm
volatile
(
"{
\n
"
"{
\n
"
".reg.b64 val_pair;
\n\t
"
".reg.b64 val_pair;
\n\t
"
...
@@ -349,9 +520,13 @@ __device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const floatx2 &in,
...
@@ -349,9 +520,13 @@ __device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const floatx2 &in,
:
"=h"
(
reinterpret_cast
<
uint16_t
&>
(
out
))
:
"=h"
(
reinterpret_cast
<
uint16_t
&>
(
out
))
:
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
in
)),
:
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
in
)),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)));
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)));
#else
NVTE_DEVICE_ERROR
(
"mul_cvt_2x is only supported on SM 10.0+."
);
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
}
__device__
__forceinline__
void
mul_cvt_2x
(
fp8e4m3x2
&
out
,
const
bf16x2
&
in
,
const
floatx2
&
scale
)
{
__device__
__forceinline__
void
mul_cvt_2x
(
fp8e4m3x2
&
out
,
const
bf16x2
&
in
,
const
floatx2
&
scale
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
asm
volatile
(
asm
volatile
(
"{
\n
"
"{
\n
"
".reg.b64 val_pair_before;
\n\t
"
".reg.b64 val_pair_before;
\n\t
"
...
@@ -371,9 +546,13 @@ __device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const bf16x2 &in, con
...
@@ -371,9 +546,13 @@ __device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const bf16x2 &in, con
:
"=h"
(
reinterpret_cast
<
uint16_t
&>
(
out
))
:
"=h"
(
reinterpret_cast
<
uint16_t
&>
(
out
))
:
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
in
)),
:
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
in
)),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)));
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)));
#else
NVTE_DEVICE_ERROR
(
"mul_cvt_2x is only supported on SM 10.0+."
);
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
}
__device__
__forceinline__
void
mul_cvt_2x
(
fp8e5m2x2
&
out
,
const
bf16x2
&
in
,
const
floatx2
&
scale
)
{
__device__
__forceinline__
void
mul_cvt_2x
(
fp8e5m2x2
&
out
,
const
bf16x2
&
in
,
const
floatx2
&
scale
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
asm
volatile
(
asm
volatile
(
"{
\n
"
"{
\n
"
".reg.b64 val_pair_before;
\n\t
"
".reg.b64 val_pair_before;
\n\t
"
...
@@ -393,9 +572,13 @@ __device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const bf16x2 &in, con
...
@@ -393,9 +572,13 @@ __device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const bf16x2 &in, con
:
"=h"
(
reinterpret_cast
<
uint16_t
&>
(
out
))
:
"=h"
(
reinterpret_cast
<
uint16_t
&>
(
out
))
:
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
in
)),
:
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
in
)),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)));
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)));
#else
NVTE_DEVICE_ERROR
(
"mul_cvt_2x is only supported on SM 10.0+."
);
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
}
__device__
__forceinline__
void
mul_cvt_2x
(
fp8e4m3x2
&
out
,
const
fp16x2
&
in
,
const
floatx2
&
scale
)
{
__device__
__forceinline__
void
mul_cvt_2x
(
fp8e4m3x2
&
out
,
const
fp16x2
&
in
,
const
floatx2
&
scale
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
asm
volatile
(
asm
volatile
(
"{
\n
"
"{
\n
"
".reg.b64 val_pair_before;
\n\t
"
".reg.b64 val_pair_before;
\n\t
"
...
@@ -415,9 +598,13 @@ __device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const fp16x2 &in, con
...
@@ -415,9 +598,13 @@ __device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const fp16x2 &in, con
:
"=h"
(
reinterpret_cast
<
uint16_t
&>
(
out
))
:
"=h"
(
reinterpret_cast
<
uint16_t
&>
(
out
))
:
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
in
)),
:
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
in
)),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)));
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)));
#else
NVTE_DEVICE_ERROR
(
"mul_cvt_2x is only supported on SM 10.0+."
);
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
}
__device__
__forceinline__
void
mul_cvt_2x
(
fp8e5m2x2
&
out
,
const
fp16x2
&
in
,
const
floatx2
&
scale
)
{
__device__
__forceinline__
void
mul_cvt_2x
(
fp8e5m2x2
&
out
,
const
fp16x2
&
in
,
const
floatx2
&
scale
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
asm
volatile
(
asm
volatile
(
"{
\n
"
"{
\n
"
".reg.b64 val_pair_before;
\n\t
"
".reg.b64 val_pair_before;
\n\t
"
...
@@ -437,24 +624,33 @@ __device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const fp16x2 &in, con
...
@@ -437,24 +624,33 @@ __device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const fp16x2 &in, con
:
"=h"
(
reinterpret_cast
<
uint16_t
&>
(
out
))
:
"=h"
(
reinterpret_cast
<
uint16_t
&>
(
out
))
:
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
in
)),
:
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
in
)),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)));
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)));
#else
NVTE_DEVICE_ERROR
(
"mul_cvt_2x is only supported on SM 10.0+."
);
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
}
__device__
__forceinline__
void
abs_max_2x
(
bf16x2
&
dst
,
const
bf16x2
&
p1
,
const
bf16x2
&
p2
)
{
__device__
__forceinline__
void
abs_max_2x
(
bf16x2
&
dst
,
const
bf16x2
&
p1
,
const
bf16x2
&
p2
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
asm
volatile
(
"max.xorsign.abs.bf16x2 %0, %1, %2;"
asm
volatile
(
"max.xorsign.abs.bf16x2 %0, %1, %2;"
:
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
dst
))
:
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
dst
))
:
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
p1
)),
:
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
p1
)),
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
p2
)));
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
p2
)));
#else
NVTE_DEVICE_ERROR
(
"abs_max_2x is only supported on SM 8.9+."
);
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
}
}
__device__
__forceinline__
void
abs_max_2x
(
fp16x2
&
dst
,
const
fp16x2
&
p1
,
const
fp16x2
&
p2
)
{
__device__
__forceinline__
void
abs_max_2x
(
fp16x2
&
dst
,
const
fp16x2
&
p1
,
const
fp16x2
&
p2
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
asm
volatile
(
"max.xorsign.abs.f16x2 %0, %1, %2;"
asm
volatile
(
"max.xorsign.abs.f16x2 %0, %1, %2;"
:
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
dst
))
:
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
dst
))
:
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
p1
)),
:
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
p1
)),
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
p2
)));
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
p2
)));
#else
NVTE_DEVICE_ERROR
(
"abs_max_2x is only supported on SM 8.9+."
);
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
}
}
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
}
// namespace ptx
}
// namespace ptx
namespace
{
namespace
{
...
@@ -472,6 +668,8 @@ __forceinline__ __device__ void initialize_barriers(uint64_t *mbar, const bool i
...
@@ -472,6 +668,8 @@ __forceinline__ __device__ void initialize_barriers(uint64_t *mbar, const bool i
}
}
// Syncthreads so initialized barrier is visible to all threads.
// Syncthreads so initialized barrier is visible to all threads.
__syncthreads
();
__syncthreads
();
#else
NVTE_DEVICE_ERROR
(
"initialize_barriers is only supported on SM 10.0+."
);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
}
...
@@ -487,6 +685,8 @@ __forceinline__ __device__ void destroy_barriers(uint64_t *mbar, const bool is_m
...
@@ -487,6 +685,8 @@ __forceinline__ __device__ void destroy_barriers(uint64_t *mbar, const bool is_m
ptx
::
mbarrier_invalid
(
&
mbar
[
iter
]);
ptx
::
mbarrier_invalid
(
&
mbar
[
iter
]);
}
}
}
}
#else
NVTE_DEVICE_ERROR
(
"destroy_barriers is only supported on SM 10.0+."
);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
}
...
@@ -506,6 +706,8 @@ __forceinline__ __device__ void copy_1d_to_shared(void *dst, const void *src,
...
@@ -506,6 +706,8 @@ __forceinline__ __device__ void copy_1d_to_shared(void *dst, const void *src,
// Other threads just arrive
// Other threads just arrive
ptx
::
mbarrier_arrive
(
barrier
);
ptx
::
mbarrier_arrive
(
barrier
);
}
}
#else
NVTE_DEVICE_ERROR
(
"copy_1d_to_shared is only supported on SM 10.0+."
);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
}
...
@@ -525,6 +727,8 @@ __forceinline__ __device__ void copy_2d_to_shared(void *dst, const void *src, co
...
@@ -525,6 +727,8 @@ __forceinline__ __device__ void copy_2d_to_shared(void *dst, const void *src, co
// Other threads just arrive
// Other threads just arrive
ptx
::
mbarrier_arrive
(
barrier
);
ptx
::
mbarrier_arrive
(
barrier
);
}
}
#else
NVTE_DEVICE_ERROR
(
"copy_2d_to_shared is only supported on SM 10.0+."
);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
}
...
@@ -551,6 +755,8 @@ __forceinline__ __device__ void copy_2d_to_sharedx2(void *dst, const void *src,
...
@@ -551,6 +755,8 @@ __forceinline__ __device__ void copy_2d_to_sharedx2(void *dst, const void *src,
// Other threads just arrive
// Other threads just arrive
ptx
::
mbarrier_arrive
(
barrier
);
ptx
::
mbarrier_arrive
(
barrier
);
}
}
#else
NVTE_DEVICE_ERROR
(
"copy_2d_to_sharedx2 is only supported on SM 10.0+."
);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
}
...
@@ -580,6 +786,8 @@ __forceinline__ __device__ void copy_2d_to_sharedx3(
...
@@ -580,6 +786,8 @@ __forceinline__ __device__ void copy_2d_to_sharedx3(
// Other threads just arrive
// Other threads just arrive
ptx
::
mbarrier_arrive
(
barrier
);
ptx
::
mbarrier_arrive
(
barrier
);
}
}
#else
NVTE_DEVICE_ERROR
(
"copy_2d_to_sharedx3 is only supported on SM 10.0+."
);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
}
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment