Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
063ef88d
Commit
063ef88d
authored
Dec 03, 2025
by
wenjh
Browse files
Merge nv main up to v2.10.0.dev0
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parents
91670b05
5624dbb4
Changes
298
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1107 additions
and
490 deletions
+1107
-490
tests/pytorch/test_parallel_cross_entropy.py
tests/pytorch/test_parallel_cross_entropy.py
+1
-1
tests/pytorch/test_permutation.py
tests/pytorch/test_permutation.py
+9
-10
tests/pytorch/test_recipe.py
tests/pytorch/test_recipe.py
+64
-24
tests/pytorch/test_sanity.py
tests/pytorch/test_sanity.py
+66
-35
tests/pytorch/utils.py
tests/pytorch/utils.py
+64
-9
transformer_engine/common/CMakeLists.txt
transformer_engine/common/CMakeLists.txt
+30
-0
transformer_engine/common/activation/activation_template.h
transformer_engine/common/activation/activation_template.h
+4
-6
transformer_engine/common/activation/gelu.cu
transformer_engine/common/activation/gelu.cu
+8
-4
transformer_engine/common/activation/relu.cu
transformer_engine/common/activation/relu.cu
+8
-4
transformer_engine/common/activation/swiglu.cu
transformer_engine/common/activation/swiglu.cu
+21
-2
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
...mer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
+77
-21
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp
...common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp
+29
-1
transformer_engine/common/common.cu
transformer_engine/common/common.cu
+10
-2
transformer_engine/common/common.h
transformer_engine/common/common.h
+42
-9
transformer_engine/common/fused_attn/fused_attn.cpp
transformer_engine/common/fused_attn/fused_attn.cpp
+124
-92
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu
...gine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu
+278
-197
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h
...ngine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h
+33
-28
transformer_engine/common/fused_attn/fused_attn_fp8.cu
transformer_engine/common/fused_attn/fused_attn_fp8.cu
+111
-38
transformer_engine/common/fused_attn/utils.h
transformer_engine/common/fused_attn/utils.h
+12
-7
transformer_engine/common/gemm/config.cpp
transformer_engine/common/gemm/config.cpp
+116
-0
No files found.
tests/pytorch/test_parallel_cross_entropy.py
View file @
063ef88d
...
...
@@ -4,7 +4,7 @@
import
random
import
torch
from
transformer_engine.pytorch
.cross_entropy
import
parallel_cross_entropy
from
transformer_engine.pytorch
import
parallel_cross_entropy
from
utils
import
dtype_tols
...
...
tests/pytorch/test_permutation.py
View file @
063ef88d
...
...
@@ -8,6 +8,7 @@ import torch
import
pytest
from
typing
import
Dict
,
List
import
transformer_engine.pytorch
as
te
from
transformer_engine.common
import
recipe
from
transformer_engine.pytorch
import
(
moe_permute
as
te_permute
,
...
...
@@ -16,14 +17,12 @@ from transformer_engine.pytorch import (
moe_sort_chunks_by_index
as
te_sort_chunks_by_index
,
moe_sort_chunks_by_index_with_probs
as
te_sort_chunks_by_index_with_probs
,
)
from
transformer_engine.pytorch.utils
import
is_bf16_compatible
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.tensor.float8_tensor
import
(
from
transformer_engine.pytorch
import
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
,
Float8BlockQuantizer
,
MXFP8Quantizer
,
)
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
transformer_engine.pytorch.tensor.mxfp8_tensor
import
MXFP8Quantizer
import
transformer_engine_torch
as
tex
import
copy
...
...
@@ -1119,7 +1118,7 @@ def perf_test_cuda_kernel(cuda_kernel_fn):
# TE tensor dtypes
_te_dtypes
:
List
[
tex
.
DType
]
=
[
tex
.
DType
.
kFloat32
,
tex
.
DType
.
kFloat16
]
if
is_bf16_
compati
ble
():
if
te
.
is_bf16_
availa
ble
():
_te_dtypes
.
append
(
tex
.
DType
.
kBFloat16
)
...
...
@@ -1239,10 +1238,10 @@ def test_permutation_mask_map_alongside_probs_empty_input(te_dtype):
# Only run FP8 tests on H100.
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
(
)
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
(
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
fp8_available
,
reason_for_no_fp8
=
te
.
is_fp8_available
(
return_reason
=
True
)
mxfp8_available
,
reason_for_no_mxfp8
=
te
.
is_mxfp8_available
(
return_reason
=
True
)
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
te
.
is_fp8_block_scaling_available
(
return_reason
=
True
)
fp8_recipes
=
[
recipe
.
MXFP8BlockScaling
(),
...
...
tests/pytorch/test_recipe.py
View file @
063ef88d
...
...
@@ -2,7 +2,7 @@
#
# See LICENSE for license information.
from
typing
import
Iterable
,
Optional
from
typing
import
Optional
import
pytest
import
torch
...
...
@@ -11,27 +11,34 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
import
transformer_engine.common.recipe
import
transformer_engine.pytorch
as
te
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
transformer_engine.pytorch.tensor.mxfp8_tensor
import
MXFP8Quantizer
from
transformer_engine.pytorch
import
(
Float8BlockQuantizer
,
MXFP8Quantizer
,
Float8Quantizer
,
NVFP4Quantizer
,
quantized_model_init
,
Linear
,
LayerNormLinear
,
LayerNormMLP
,
GroupedLinear
,
)
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.
fp8
import
(
from
transformer_engine.pytorch.
quantization
import
(
FP8GlobalStateManager
,
_amax_and_scale_update
,
fp8_model_init
,
)
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Quantizer
import
transformer_engine.pytorch.ops
as
te_ops
from
transformer_engine.pytorch
import
Linear
,
LayerNormLinear
,
LayerNormMLP
,
GroupedLinear
from
transformer_engine.pytorch.distributed
import
fp8_autocast
from
transformer_engine.common.recipe
import
DelayedScaling
,
Float8BlockScaling
,
MXFP8BlockScaling
import
transformer_engine_torch
as
tex
# Check if FP8 is supported
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
(
)
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
(
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
fp8_available
,
reason_for_no_fp8
=
te
.
is_fp8_available
(
return_reason
=
True
)
mxfp8_available
,
reason_for_no_mxfp8
=
te
.
is_mxfp8_available
(
return_reason
=
True
)
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
te
.
is_fp8_block_scaling_available
(
return_reason
=
True
)
fp4_available
,
reason_for_no_fp4
=
te
.
is_nvfp4_available
(
return_reason
=
True
)
# FP8 per tensor delayed scaling
...
...
@@ -64,7 +71,7 @@ class TestFP8Recipe:
amax_history_len
=
amax_history_len
,
amax_compute_algo
=
amax_compute_algo
,
)
with
te
.
fp8_
autocast
(
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
recipe
=
recipe
):
module
=
te
.
Linear
(
16
,
16
)
y
=
module
(
torch
.
randn
([
16
,
16
],
device
=
"cuda"
),
...
...
@@ -120,7 +127,7 @@ class TestFP8Recipe:
# ref_scale_inv_backward = torch.reciprocal(ref_scale_backward)
# Perform forward, backward, and optimizer steps to update fp8_meta
with
te
.
fp8_
autocast
(
enabled
=
True
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
True
,
recipe
=
recipe
):
x
=
torch
.
randn
([
16
,
16
],
device
=
"cuda"
)
y
=
module
(
x
,
is_first_microbatch
=
is_first_microbatch
)
y
.
backward
(
torch
.
randn_like
(
y
))
...
...
@@ -219,7 +226,7 @@ class TestFP8Recipe:
op
.
weight
.
fill_
(
w_history
[
-
1
])
# Forward and backward pass
with
te
.
fp8_
autocast
(
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
recipe
=
recipe
):
y
=
op
(
x
)
y
.
backward
(
dy
)
...
...
@@ -301,7 +308,7 @@ class TestFP8Recipe:
scaling_factor_compute_algo
=
None
if
fused_update
:
scaling_factor_compute_algo
=
(
lambda
amax
,
scale
,
fp8_max
,
recipe
:
te
.
fp8
.
_default_sf_compute
(
lambda
amax
,
scale
,
fp8_max
,
recipe
:
te
.
quantization
.
_default_sf_compute
(
amax
,
scale
,
fp8_max
,
recipe
.
margin
)
)
...
...
@@ -311,7 +318,7 @@ class TestFP8Recipe:
# Setup fp8_meta dictionary
def
setup_fp8_meta
():
with
te
.
fp8_
autocast
(
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
recipe
=
recipe
):
module
=
te
.
Linear
(
16
,
16
)
y
=
module
(
torch
.
zeros
([
16
,
16
],
device
=
"cuda"
))
y
.
backward
(
torch
.
zeros_like
(
y
))
...
...
@@ -393,11 +400,11 @@ class TestFP8Recipe:
],
)
def
test_check_for_weight_tensor_and_recipe_correspondence
(
self
,
model_init_recipe
):
with
fp8
_model_init
(
enabled
=
True
,
recipe
=
model_init_recipe
):
with
quantized
_model_init
(
enabled
=
True
,
recipe
=
model_init_recipe
):
linear
=
Linear
(
32
,
32
).
cuda
()
x
=
torch
.
randn
(
32
,
32
,
device
=
"cuda"
)
with
fp8_
autocast
(
enabled
=
True
,
fp8_
recipe
=
DelayedScaling
()):
with
te
.
autocast
(
enabled
=
True
,
recipe
=
DelayedScaling
()):
with
pytest
.
raises
(
RuntimeError
)
as
excinfo
:
_
=
linear
(
x
)
assert
"Recipe mismatch for "
in
str
(
excinfo
.
value
)
...
...
@@ -436,7 +443,7 @@ class TestFP8Recipe:
# Run initial iterations with DelayedScaling
for
_
in
range
(
3
):
x
=
torch
.
randn
(
batch_size
,
in_features
,
device
=
"cuda"
)
with
fp8_
autocast
(
enabled
=
True
,
fp8_
recipe
=
initial_recipe
):
with
te
.
autocast
(
enabled
=
True
,
recipe
=
initial_recipe
):
y
=
linear
(
x
)
loss
=
y
.
mean
()
loss
.
backward
()
...
...
@@ -453,7 +460,7 @@ class TestFP8Recipe:
if
i
==
0
:
# Expect a warning on the first iteration with the new recipe
with
pytest
.
warns
(
UserWarning
,
match
=
"Recipe type changed"
):
with
fp8_
autocast
(
enabled
=
True
,
fp8_
recipe
=
target_recipe
):
with
te
.
autocast
(
enabled
=
True
,
recipe
=
target_recipe
):
y
=
linear
(
x
)
for
quantizer
in
linear
.
quantizers
[
"scaling_fwd"
]:
assert
isinstance
(
quantizer
,
expected_quantizer_type
)
...
...
@@ -461,7 +468,7 @@ class TestFP8Recipe:
# No warning expected on subsequent iterations
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
"error"
)
# Raise error if unexpected warning occurs
with
fp8_
autocast
(
enabled
=
True
,
fp8_
recipe
=
target_recipe
):
with
te
.
autocast
(
enabled
=
True
,
recipe
=
target_recipe
):
y
=
linear
(
x
)
loss
=
y
.
mean
()
loss
.
backward
()
...
...
@@ -485,7 +492,7 @@ class TestFP8Recipe:
batch_size
=
32
recipe
=
DelayedScaling
(
amax_history_len
=
1024
)
with
fp8
_model_init
(
recipe
=
recipe
):
with
quantized
_model_init
(
recipe
=
recipe
):
if
module_class
==
GroupedLinear
:
module
=
module_class
(
1
,
in_features
,
out_features
).
cuda
()
else
:
...
...
@@ -493,10 +500,43 @@ class TestFP8Recipe:
x
=
torch
.
randn
(
batch_size
,
in_features
,
device
=
"cuda"
)
recipe
=
DelayedScaling
(
amax_history_len
=
1
)
with
fp8_
autocast
(
enabled
=
True
,
fp8_
recipe
=
recipe
):
with
te
.
autocast
(
enabled
=
True
,
recipe
=
recipe
):
warn_msg
=
"Quantizer is being updated, this may affect model behavior"
with
pytest
.
warns
(
UserWarning
,
match
=
warn_msg
):
if
module_class
==
GroupedLinear
:
y
=
module
(
x
,
[
batch_size
])
else
:
y
=
module
(
x
)
@
pytest
.
mark
.
skipif
(
not
fp4_available
,
reason
=
reason_for_no_fp4
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"M, N"
,
[
# full tile cases
(
128
,
128
),
(
256
,
1024
),
(
1024
,
256
),
# Padding required cases
(
256
,
272
),
(
304
,
304
),
(
320
,
256
),
# # largest tile
(
8192
,
8192
),
],
)
def
test_fp4_dequantize
(
dtype
,
M
,
N
):
q
=
NVFP4Quantizer
()
a
=
torch
.
rand
((
M
,
N
)).
cuda
().
to
(
dtype
=
dtype
)
starting_tensor
=
q
(
a
)
dequantized_tensor
=
starting_tensor
.
dequantize
()
new_tensor
=
q
(
dequantized_tensor
)
torch
.
testing
.
assert_close
(
new_tensor
.
_rowwise_data
,
starting_tensor
.
_rowwise_data
,
rtol
=
0
,
atol
=
0
,
)
new_dequantized_tensor
=
new_tensor
.
dequantize
()
torch
.
testing
.
assert_close
(
dequantized_tensor
,
new_dequantized_tensor
)
tests/pytorch/test_sanity.py
View file @
063ef88d
...
...
@@ -9,18 +9,16 @@ import pytest
import
os
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
import
transformer_engine.pytorch
from
transformer_engine.pytorch.fp8
import
(
fp8_autocast
,
FP8GlobalStateManager
,
fp8_model_init
,
)
import
transformer_engine
import
transformer_engine.pytorch
as
te
from
transformer_engine.pytorch.quantization
import
FP8GlobalStateManager
from
transformer_engine.pytorch.utils
import
(
init_method_normal
,
scaled_init_method_normal
,
is_bf16_compatible
,
)
from
transformer_engine.pytorch
import
(
autocast
,
quantized_model_init
,
LayerNormLinear
,
Linear
,
GroupedLinear
,
...
...
@@ -28,26 +26,25 @@ from transformer_engine.pytorch import (
TransformerLayer
,
RMSNorm
,
LayerNorm
,
Float8CurrentScalingQuantizer
,
Float8Quantizer
,
Float8Tensor
,
MXFP8Tensor
,
checkpoint
,
QuantizedTensor
,
is_bf16_available
,
)
from
transformer_engine.common
import
recipe
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.cpp_extensions
import
general_gemm
from
transformer_engine.pytorch.module.base
import
get_workspace
from
transformer_engine.pytorch.tensor
import
QuantizedTensor
from
transformer_engine.pytorch.tensor.float8_tensor
import
(
Float8CurrentScalingQuantizer
,
Float8Quantizer
,
Float8Tensor
,
)
from
transformer_engine.pytorch.tensor.mxfp8_tensor
import
MXFP8Tensor
from
transformer_engine.pytorch.tensor.utils
import
replace_raw_data
from
transformer_engine.pytorch.distributed
import
checkpoint
from
utils
import
ModelConfig
# Only run FP8 tests on supported devices.
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
(
)
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
fp8_available
,
reason_for_no_fp8
=
te
.
is_fp8_available
(
return_reason
=
True
)
fp8_block_scaling_available
,
_
=
te
.
is_fp8_block_scaling_available
(
return_reason
=
True
)
mxfp8_available
,
reason_for_no_mxfp8
=
te
.
is_mxfp8_available
(
return_reason
=
True
)
# Record initial RNG state from script run.
seed
=
1234
...
...
@@ -88,9 +85,19 @@ model_configs = {
"large"
:
ModelConfig
(
2
,
128
,
4
,
128
,
num_layers
=
1
),
}
def
nvfp4_vanilla
():
nvfp4_recipe
=
recipe
.
NVFP4BlockScaling
()
nvfp4_recipe
.
fp4_quant_fwd_inp
=
recipe
.
QParams
()
nvfp4_recipe
.
fp4_quant_fwd_weight
=
recipe
.
QParams
()
nvfp4_recipe
.
fp4_quant_bwd_grad
=
recipe
.
QParams
()
return
nvfp4_recipe
fp8_recipes
=
[]
if
mxfp8_available
:
fp8_recipes
.
append
(
recipe
.
MXFP8BlockScaling
())
fp8_recipes
.
append
(
nvfp4_vanilla
())
# TODO: fix check for this
if
fp8_block_scaling_available
:
fp8_recipes
.
append
(
recipe
.
Float8BlockScaling
())
if
fp8_available
:
...
...
@@ -99,7 +106,7 @@ if fp8_available:
fp8_recipes
.
append
(
None
)
param_types
=
[
torch
.
float32
,
torch
.
float16
]
if
is_bf16_
compati
ble
():
# bf16 requires sm_80 or higher
if
is_bf16_
availa
ble
():
# bf16 requires sm_80 or higher
param_types
.
append
(
torch
.
bfloat16
)
all_boolean
=
[
True
,
False
]
...
...
@@ -151,7 +158,7 @@ def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
use_fp8
=
fp8_recipe
is
not
None
with
torch
.
autocast
(
device_type
=
"cuda"
,
enabled
=
True
,
dtype
=
dtype
):
with
fp8_
autocast
(
enabled
=
use_fp8
,
fp8_
recipe
=
fp8_recipe
):
with
autocast
(
enabled
=
use_fp8
,
recipe
=
fp8_recipe
):
te_out
=
block
(
te_inp_hidden_states
,
attention_mask
=
te_inp_attn_mask
)
loss
=
te_out
.
sum
()
...
...
@@ -190,7 +197,7 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci
p
.
main_grad
=
torch
.
zeros_like
(
p
)
use_fp8
=
fp8_recipe
is
not
None
with
fp8_
autocast
(
enabled
=
use_fp8
,
fp8_
recipe
=
fp8_recipe
):
with
autocast
(
enabled
=
use_fp8
,
recipe
=
fp8_recipe
):
te_out
=
block
(
te_inp_hidden_states
,
attention_mask
=
te_inp_attn_mask
)
loss
=
te_out
.
sum
()
loss
.
backward
()
...
...
@@ -218,7 +225,7 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad):
_disable_wgrads
(
block
)
use_fp8
=
fp8_recipe
is
not
None
with
fp8_
autocast
(
enabled
=
use_fp8
,
fp8_
recipe
=
fp8_recipe
):
with
autocast
(
enabled
=
use_fp8
,
recipe
=
fp8_recipe
):
te_out
=
block
(
te_inp_hidden_states
)
loss
=
te_out
.
sum
()
loss
.
backward
()
...
...
@@ -244,7 +251,7 @@ def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad):
_disable_wgrads
(
block
)
use_fp8
=
fp8_recipe
is
not
None
with
fp8_
autocast
(
enabled
=
use_fp8
,
fp8_
recipe
=
fp8_recipe
):
with
autocast
(
enabled
=
use_fp8
,
recipe
=
fp8_recipe
):
te_out
=
block
(
te_inp_hidden_states
,
attention_mask
=
te_inp_attn_mask
)
loss
=
te_out
.
sum
()
loss
.
backward
()
...
...
@@ -276,7 +283,7 @@ def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad):
_disable_wgrads
(
block
)
use_fp8
=
fp8_recipe
is
not
None
with
fp8_
autocast
(
enabled
=
use_fp8
,
fp8_
recipe
=
fp8_recipe
):
with
autocast
(
enabled
=
use_fp8
,
recipe
=
fp8_recipe
):
te_out
=
block
(
te_inp_hidden_states
,
attention_mask
=
te_inp_attn_mask
,
...
...
@@ -305,7 +312,7 @@ def _test_sanity_common(
_disable_wgrads
(
block
)
use_fp8
=
fp8_recipe
is
not
None
with
fp8_
autocast
(
enabled
=
use_fp8
,
fp8_
recipe
=
fp8_recipe
):
with
autocast
(
enabled
=
use_fp8
,
recipe
=
fp8_recipe
):
if
not
microbatching
:
te_out
=
block
(
te_inp
)
else
:
...
...
@@ -386,6 +393,8 @@ def test_sanity_layernorm_linear(
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
sigma
=
0.023
init_method
=
init_method_normal
(
sigma
)
...
...
@@ -414,6 +423,8 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba
if
fp8_recipe
is
not
None
:
if
not
is_fp8_supported
(
config
):
pytest
.
skip
(
"Model config does not support FP8"
)
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
sigma
=
0.023
output_layer_init_method
=
scaled_init_method_normal
(
sigma
,
config
.
num_layers
)
...
...
@@ -450,9 +461,11 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
use_fp8
=
fp8_recipe
is
not
None
with
fp8
_model_init
(
enabled
=
use_fp8
and
fp8_model_params
,
recipe
=
fp8_recipe
):
with
quantized
_model_init
(
enabled
=
use_fp8
and
fp8_model_params
,
recipe
=
fp8_recipe
):
te_linear
=
Linear
(
config
.
hidden_size
,
ffn_hidden_size
,
bias
=
use_bias
,
params_dtype
=
dtype
).
cuda
()
...
...
@@ -460,7 +473,7 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
inp_hidden_states
=
torch
.
randn
(
num_tokens
,
config
.
hidden_size
,
dtype
=
dtype
,
requires_grad
=
True
).
cuda
()
with
fp8_
autocast
(
enabled
=
use_fp8
,
fp8_
recipe
=
fp8_recipe
):
with
autocast
(
enabled
=
use_fp8
,
recipe
=
fp8_recipe
):
out
=
te_linear
(
inp_hidden_states
)
loss
=
out
.
sum
()
loss
.
backward
()
...
...
@@ -489,9 +502,11 @@ def test_sanity_grouped_linear(
if
fp8_recipe
is
not
None
:
if
not
is_fp8_supported
(
config
):
pytest
.
skip
(
"Model config does not support FP8"
)
if
fp8_recipe
.
nvfp4
():
pytest
.
skip
(
"NVFP4 not supported for grouped linear"
)
use_fp8
=
fp8_recipe
is
not
None
with
fp8
_model_init
(
enabled
=
use_fp8
and
fp8_model_params
,
recipe
=
fp8_recipe
):
with
quantized
_model_init
(
enabled
=
use_fp8
and
fp8_model_params
,
recipe
=
fp8_recipe
):
te_grouped_linear
=
GroupedLinear
(
num_gemms
,
config
.
hidden_size
,
ffn_hidden_size
,
bias
=
use_bias
,
params_dtype
=
dtype
).
cuda
()
...
...
@@ -507,7 +522,7 @@ def test_sanity_grouped_linear(
elif
empty_split
==
"middle"
:
m_splits
[
num_gemms
//
2
]
=
0
with
fp8_
autocast
(
enabled
=
use_fp8
,
fp8_
recipe
=
fp8_recipe
):
with
autocast
(
enabled
=
use_fp8
,
recipe
=
fp8_recipe
):
out
=
te_grouped_linear
(
inp_hidden_states
,
m_splits
)
loss
=
out
.
sum
()
loss
.
backward
()
...
...
@@ -545,6 +560,8 @@ def test_sanity_layernorm_mlp(
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
sigma
=
0.023
init_method
=
init_method_normal
(
sigma
)
...
...
@@ -593,6 +610,8 @@ def test_sanity_gpt(
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
sigma
=
0.023
init_method
=
init_method_normal
(
sigma
)
...
...
@@ -654,6 +673,8 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, normalization):
pytest
.
skip
(
reason_for_no_fp8
)
if
not
is_fp8_supported
(
config
):
pytest
.
skip
(
"Model config does not support FP8"
)
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
sigma
=
0.023
init_method
=
init_method_normal
(
sigma
)
...
...
@@ -708,6 +729,8 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, normalization):
pytest
.
skip
(
reason_for_no_fp8
)
if
not
is_fp8_supported
(
config
):
pytest
.
skip
(
"Model config does not support FP8"
)
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
sigma
=
0.023
init_method
=
init_method_normal
(
sigma
)
...
...
@@ -765,6 +788,8 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
sigma
=
0.023
init_method
=
init_method_normal
(
sigma
)
...
...
@@ -801,6 +826,8 @@ def test_sanity_drop_path(dtype, fp8_recipe, model):
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
sigma
=
0.023
init_method
=
init_method_normal
(
sigma
)
...
...
@@ -841,6 +868,8 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
sigma
=
0.023
init_method
=
init_method_normal
(
sigma
)
...
...
@@ -881,6 +910,8 @@ def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgra
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
if
fp8_recipe
.
nvfp4
()
and
dtype
==
torch
.
float16
:
pytest
.
skip
(
"FP16 output for NVFP4 not supported"
)
sigma
=
0.023
init_method
=
init_method_normal
(
sigma
)
...
...
@@ -991,9 +1022,9 @@ def test_replace_raw_data_for_float8tensor():
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
def
test_
fp8
_model_init_high_precision_init_val
():
"""Test
fp8
_model_init with preserve_high_precision_init_val=True"""
with
fp8
_model_init
(
preserve_high_precision_init_val
=
True
):
def
test_
quantized
_model_init_high_precision_init_val
():
"""Test
quantized
_model_init with preserve_high_precision_init_val=True"""
with
quantized
_model_init
(
preserve_high_precision_init_val
=
True
):
model
=
Linear
(
768
,
768
)
weight
=
model
.
weight
...
...
@@ -1066,7 +1097,7 @@ def test_linear_frozen_weights_memory_default_recipe():
linear
.
weight
.
requires_grad
=
False
# Forward and backward pass with FP8
with
fp8_
autocast
():
with
autocast
():
o
=
linear
(
x
)
g_o
=
torch
.
randn_like
(
o
)
...
...
@@ -1120,7 +1151,7 @@ def test_inference_mode(
# Construct module
module
=
None
with
torch
.
no_grad
():
with
fp8
_model_init
(
enabled
=
with_quantization
,
recipe
=
quantization_recipe
):
with
quantized
_model_init
(
enabled
=
with_quantization
,
recipe
=
quantization_recipe
):
if
module_name
==
"Linear"
:
module
=
Linear
(
hidden_size
,
hidden_size
)
elif
module_name
==
"LayerNormLinear"
:
...
...
@@ -1155,6 +1186,6 @@ def test_inference_mode(
kwargs
=
{}
if
module_name
==
"GroupedLinear"
:
kwargs
[
"m_splits"
]
=
[
sequence_length
]
with
fp8_
autocast
(
enabled
=
with_quantization
,
fp8_
recipe
=
quantization_recipe
):
with
autocast
(
enabled
=
with_quantization
,
recipe
=
quantization_recipe
):
y
=
module
(
x
,
**
kwargs
)
check_weights
()
tests/pytorch/utils.py
View file @
063ef88d
...
...
@@ -7,19 +7,20 @@ from __future__ import annotations
import
logging
import
os
from
contextlib
import
contextmanager
from
typing
import
Optional
,
Tuple
,
Dict
,
Any
,
List
import
pytest
import
torch
import
transformer_engine
import
transformer_engine.common.recipe
import
transformer_engine.pytorch
as
te
import
transformer_engine_torch
as
tex
from
transformer_engine.common.recipe
import
Recipe
from
transformer_engine.pytorch
import
InferenceParams
from
transformer_engine.pytorch.attention.dot_product_attention
import
_attention_backends
from
transformer_engine.pytorch.attention.dot_product_attention.utils
import
(
get_attention_backend
,
AttentionParams
,
AttentionLogging
,
check_set_window_size
,
)
from
transformer_engine.pytorch.cpp_extensions.fused_attn
import
FusedAttnBackend
...
...
@@ -72,6 +73,8 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]:
# Transformer Engine dtypes
if
isinstance
(
dtype
,
tex
.
DType
):
if
dtype
==
tex
.
DType
.
kFloat4E2M1
:
return
dict
(
rtol
=
0.25
,
atol
=
0.125
)
# epsilon = 0.25
dtype
=
{
tex
.
DType
.
kByte
:
torch
.
uint8
,
tex
.
DType
.
kInt32
:
torch
.
int32
,
...
...
@@ -94,10 +97,25 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]:
if
dtype
==
torch
.
float8_e4m3fn
:
return
dict
(
rtol
=
0.125
,
atol
=
0.0675
)
# epsilon = 0.0625
if
dtype
==
torch
.
float8_e5m2
:
return
dict
(
rtol
=
0.25
,
atol
=
0.125
)
# epsilon = 0.1
5
2
return
dict
(
rtol
=
0.25
,
atol
=
0.125
)
# epsilon = 0.12
5
raise
ValueError
(
f
"Unsupported dtype (
{
dtype
}
)"
)
def
quantization_tols
(
name
:
str
)
->
dict
[
str
,
float
]:
"""Estimated numerical error for a quantization scheme"""
if
name
in
(
"fp8"
,
"fp8_delayed_scaling"
,
"fp8_current_scaling"
,
"mxfp8"
,
"mxfp8_block_scaling"
,
):
return
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
if
name
==
"nvfp4"
:
return
dtype_tols
(
tex
.
DType
.
kFloat4E2M1
)
raise
ValueError
(
f
"Unsupported quantization scheme (
{
name
}
)"
)
def
make_recipe
(
name
:
Optional
[
str
])
->
Optional
[
Recipe
]:
"""Make recipe for quantization scheme"""
if
name
is
None
:
...
...
@@ -117,6 +135,12 @@ def make_recipe(name: Optional[str]) -> Optional[Recipe]:
)
if
name
==
"fp8_block_scaling"
:
return
transformer_engine
.
common
.
recipe
.
Float8BlockScaling
()
if
name
==
"nvfp4"
:
return
transformer_engine
.
common
.
recipe
.
NVFP4BlockScaling
(
disable_rht
=
True
,
disable_stochastic_rounding
=
True
,
disable_2d_quantization
=
True
,
)
raise
ValueError
(
f
"Unsupported quantization scheme (
{
name
}
)"
)
...
...
@@ -137,6 +161,31 @@ def reset_rng_states() -> None:
torch
.
cuda
.
set_rng_state
(
cuda_rng_state
)
def
compare_and_assert
(
a
,
b
,
name_a
,
name_b
,
atol
,
rtol
,
rmse_tol
,
is_fp8
):
if
not
is_fp8
:
torch
.
testing
.
assert_close
(
a
,
b
,
atol
=
atol
,
rtol
=
rtol
)
return
try
:
if
a
.
dtype
!=
b
.
dtype
:
a
=
a
.
to
(
b
.
dtype
)
torch
.
testing
.
assert_close
(
a
,
b
,
atol
=
atol
,
rtol
=
rtol
)
except
Exception
as
e
:
logging
.
debug
(
e
)
rmse
=
torch
.
sqrt
((
a
-
b
).
square
().
mean
()).
item
()
logging
.
debug
(
name_a
+
" vs "
+
name_b
+
" RMSE: {:.6f}"
.
format
(
rmse
))
rmse_range
=
max
(
a
.
max
().
item
(),
b
.
max
().
item
())
-
min
(
a
.
min
().
item
(),
b
.
min
().
item
())
assert
rmse
<
rmse_tol
*
rmse_range
,
(
name_a
+
" vs "
+
name_b
+
" RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})"
.
format
(
rmse
,
rmse_tol
*
rmse_range
,
rmse_tol
,
rmse_range
)
)
class
ModelConfig
:
def
__init__
(
self
,
...
...
@@ -147,12 +196,15 @@ class ModelConfig:
max_seqlen_kv
:
int
=
None
,
num_gqa_groups
:
int
=
None
,
head_dim_v
:
int
=
None
,
softmax_type
:
str
=
"vanilla"
,
dropout_p
:
float
=
0.0
,
attn_mask_type
:
str
=
"no_mask"
,
attn_bias_type
:
str
=
"no_bias"
,
alibi_type
:
str
=
"none"
,
bias_shape
:
str
=
"1hss"
,
window_size
:
Tuple
[
int
,
int
]
=
(
-
1
,
-
1
),
context_parallel
:
bool
=
False
,
cp_comm_type
:
str
=
"p2p"
,
total_requests
:
int
=
None
,
max_ctx_len
:
int
=
None
,
num_layers
:
int
=
1
,
...
...
@@ -171,13 +223,16 @@ class ModelConfig:
self
.
kv_channels
=
(
self
.
head_dim_qk
,
self
.
head_dim_v
)
self
.
hidden_size
=
self
.
num_heads
*
self
.
head_dim_qk
self
.
hidden_size_kv
=
self
.
num_gqa_groups
*
self
.
head_dim_v
self
.
softmax_type
=
softmax_type
self
.
dropout_p
=
dropout_p
self
.
attn_mask_type
=
attn_mask_type
self
.
attn_bias_type
=
attn_bias_type
self
.
alibi_type
=
alibi_type
self
.
attn_type
=
"self"
if
(
self
.
max_seqlen_q
==
self
.
max_seqlen_kv
)
else
"cross"
self
.
bias_shape
=
bias_shape
self
.
window_size
=
window_size
self
.
window_size
=
check_set_window_size
(
self
.
attn_mask_type
,
window_size
)
self
.
context_parallel
=
context_parallel
self
.
cp_comm_type
=
cp_comm_type
self
.
total_requests
=
total_requests
self
.
max_ctx_len
=
max_ctx_len
self
.
num_layers
=
num_layers
...
...
@@ -198,9 +253,7 @@ def get_available_attention_backends(
config
:
ModelConfig
,
qkv_dtype
:
torch
.
dtype
,
qkv_layout
:
str
,
window_size
:
Tuple
[
int
,
int
]
=
(
-
1
,
-
1
),
pad_between_seqs
:
bool
=
False
,
context_parallel
:
bool
=
False
,
deterministic
:
bool
=
False
,
fp8
:
bool
=
False
,
fp8_meta
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
...
...
@@ -250,19 +303,21 @@ def get_available_attention_backends(
head_dim_qk
=
config
.
head_dim_qk
,
head_dim_v
=
config
.
head_dim_v
,
attn_mask_type
=
config
.
attn_mask_type
,
window_size
=
window_size
,
window_size
=
config
.
window_size
,
alibi_slopes_shape
=
alibi_slopes_shape
,
core_attention_bias_type
=
config
.
attn_bias_type
,
core_attention_bias_shape
=
core_attention_bias_shape
,
core_attention_bias_requires_grad
=
core_attention_bias_requires_grad
,
pad_between_seqs
=
pad_between_seqs
,
attention_dropout
=
config
.
dropout_p
,
context_parallel
=
context_parallel
,
context_parallel
=
config
.
context_parallel
,
cp_comm_type
=
config
.
cp_comm_type
,
deterministic
=
deterministic
,
fp8
=
fp8
,
fp8_meta
=
fp8_meta
,
is_training
=
is_training
,
inference_params
=
inference_params
,
softmax_type
=
config
.
softmax_type
,
)
(
use_flash_attention
,
...
...
transformer_engine/common/CMakeLists.txt
View file @
063ef88d
...
...
@@ -110,6 +110,28 @@ set(CUTLASS_TOOLS_INCLUDE_DIR
# Python
find_package
(
Python COMPONENTS Interpreter Development.Module REQUIRED
)
# NVIDIA MathDX include directory (from Python package install location)
if
(
NOT DEFINED MATHDX_INCLUDE_DIR
)
execute_process
(
COMMAND
${
Python_EXECUTABLE
}
-m pip show nvidia-mathdx
OUTPUT_VARIABLE _PIP_SHOW_MATHDX
ERROR_VARIABLE _PIP_SHOW_MATHDX_ERR
RESULT_VARIABLE _PIP_SHOW_MATHDX_RES
OUTPUT_STRIP_TRAILING_WHITESPACE
)
if
(
NOT _PIP_SHOW_MATHDX_RES EQUAL 0
)
message
(
FATAL_ERROR
"Failed to query 'nvidia-mathdx' with pip (using
${
Python_EXECUTABLE
}
):
${
_PIP_SHOW_MATHDX_ERR
}
"
)
endif
()
string
(
REGEX MATCH
"Location: ([^
\n\r
]+)"
_MATHDX_LOC_MATCH
"
${
_PIP_SHOW_MATHDX
}
"
)
if
(
NOT _MATHDX_LOC_MATCH
)
message
(
FATAL_ERROR
"Could not parse installation location for 'nvidia-mathdx'. Output was:
\n
${
_PIP_SHOW_MATHDX
}
"
)
endif
()
set
(
MATHDX_LOCATION
"
${
CMAKE_MATCH_1
}
"
)
set
(
MATHDX_INCLUDE_DIR
"
${
MATHDX_LOCATION
}
/nvidia/mathdx/include"
)
endif
()
if
(
NOT EXISTS
"
${
MATHDX_INCLUDE_DIR
}
"
)
message
(
FATAL_ERROR
"MATHDX include directory not found at
${
MATHDX_INCLUDE_DIR
}
. Set MATHDX_INCLUDE_DIR or ensure 'nvidia-mathdx' is installed for
${
Python_EXECUTABLE
}
."
)
endif
()
# Configure Transformer Engine library
include_directories
(
${
PROJECT_SOURCE_DIR
}
/..
)
set
(
transformer_engine_SOURCES
)
...
...
@@ -132,6 +154,7 @@ if(USE_CUDA)
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu
transpose/swap_first_dims.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu
activation/gelu.cu
dropout/dropout.cu
fused_attn/flash_attn.cu
...
...
@@ -144,6 +167,7 @@ if(USE_CUDA)
fused_attn/fused_attn_fp8.cu
fused_attn/fused_attn.cpp
fused_attn/utils.cu
gemm/config.cpp
gemm/cublaslt_gemm.cu
gemm/cutlass_grouped_gemm.cu
normalization/common.cpp
...
...
@@ -162,6 +186,7 @@ if(USE_CUDA)
util/multi_stream.cpp
util/rtc.cpp
swizzle/swizzle.cu
swizzle/swizzle_block_scaling.cu
fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
...
...
@@ -172,6 +197,9 @@ if(USE_CUDA)
recipe/current_scaling.cu
recipe/delayed_scaling.cu
recipe/fp8_block_scaling.cu
recipe/nvfp4.cu
hadamard_transform/hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
...
...
@@ -206,6 +234,7 @@ else()
dropout/dropout.cu
activation/relu.cu
activation/swiglu.cu
gemm/config.cpp
gemm/cublaslt_gemm.cu
gemm/hipblas_gemm.cu
normalization/common.cpp
...
...
@@ -224,6 +253,7 @@ else()
util/multi_stream.cpp
util/rtc.cpp
swizzle/swizzle.cu
swizzle/swizzle_block_scaling.cu
fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
...
...
transformer_engine/common/activation/activation_template.h
View file @
063ef88d
...
...
@@ -51,22 +51,20 @@ void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output,
}
template
<
typename
ComputeType
,
typename
Param
,
ComputeType
(
*
ActOP
)(
ComputeType
,
const
Param
&
)>
void
gated_act_fn
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
void
gated_act_fn
(
const
NVTETensor
input
,
NVTETensor
output
,
Param
&
p
,
cudaStream_t
stream
)
{
using
namespace
detail
;
constexpr
bool
IS_DGATED
=
false
;
constexpr
NVTETensor
grad
=
nullptr
;
quantize_gated_helper
<
IS_DGATED
,
Param
,
ActOP
,
nullptr
>
(
grad
,
input
,
output
,
stream
);
quantize_gated_helper
<
IS_DGATED
,
Param
,
ActOP
,
nullptr
>
(
grad
,
input
,
output
,
p
,
stream
);
}
template
<
typename
ComputeType
,
typename
Param
,
ComputeType
(
*
ActOP
)(
ComputeType
,
const
Param
&
),
ComputeType
(
*
DActOP
)(
ComputeType
,
const
Param
&
)>
void
dgated_act_fn
(
const
NVTETensor
grad
,
const
NVTETensor
input
,
NVTETensor
output
,
void
dgated_act_fn
(
const
NVTETensor
grad
,
const
NVTETensor
input
,
NVTETensor
output
,
Param
&
p
,
cudaStream_t
stream
)
{
using
namespace
detail
;
constexpr
bool
IS_DGATED
=
true
;
quantize_gated_helper
<
IS_DGATED
,
Param
,
ActOP
,
DActOP
>
(
grad
,
input
,
output
,
stream
);
quantize_gated_helper
<
IS_DGATED
,
Param
,
ActOP
,
DActOP
>
(
grad
,
input
,
output
,
p
,
stream
);
}
}
// namespace transformer_engine
...
...
transformer_engine/common/activation/gelu.cu
View file @
063ef88d
...
...
@@ -23,14 +23,16 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
void
nvte_geglu
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_geglu
);
using
namespace
transformer_engine
;
gated_act_fn
<
fp32
,
Empty
,
gelu
<
fp32
,
fp32
>>
(
input
,
output
,
stream
);
Empty
e
=
{};
gated_act_fn
<
fp32
,
Empty
,
gelu
<
fp32
,
fp32
>>
(
input
,
output
,
e
,
stream
);
}
void
nvte_dgeglu
(
const
NVTETensor
grad
,
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_dgeglu
);
using
namespace
transformer_engine
;
dgated_act_fn
<
fp32
,
Empty
,
gelu
<
fp32
,
fp32
>
,
dgelu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
stream
);
Empty
e
=
{};
dgated_act_fn
<
fp32
,
Empty
,
gelu
<
fp32
,
fp32
>
,
dgelu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
e
,
stream
);
}
void
nvte_qgelu
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
...
...
@@ -49,12 +51,14 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
void
nvte_qgeglu
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_qgeglu
);
using
namespace
transformer_engine
;
gated_act_fn
<
fp32
,
Empty
,
qgelu
<
fp32
,
fp32
>>
(
input
,
output
,
stream
);
Empty
e
=
{};
gated_act_fn
<
fp32
,
Empty
,
qgelu
<
fp32
,
fp32
>>
(
input
,
output
,
e
,
stream
);
}
void
nvte_dqgeglu
(
const
NVTETensor
grad
,
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_dqgeglu
);
using
namespace
transformer_engine
;
dgated_act_fn
<
fp32
,
Empty
,
qgelu
<
fp32
,
fp32
>
,
dqgelu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
stream
);
Empty
e
=
{};
dgated_act_fn
<
fp32
,
Empty
,
qgelu
<
fp32
,
fp32
>
,
dqgelu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
e
,
stream
);
}
transformer_engine/common/activation/relu.cu
View file @
063ef88d
...
...
@@ -23,14 +23,16 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
void
nvte_reglu
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_reglu
);
using
namespace
transformer_engine
;
gated_act_fn
<
fp32
,
Empty
,
relu
<
fp32
,
fp32
>>
(
input
,
output
,
stream
);
Empty
e
=
{};
gated_act_fn
<
fp32
,
Empty
,
relu
<
fp32
,
fp32
>>
(
input
,
output
,
e
,
stream
);
}
void
nvte_dreglu
(
const
NVTETensor
grad
,
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_dreglu
);
using
namespace
transformer_engine
;
dgated_act_fn
<
fp32
,
Empty
,
relu
<
fp32
,
fp32
>
,
drelu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
stream
);
Empty
e
=
{};
dgated_act_fn
<
fp32
,
Empty
,
relu
<
fp32
,
fp32
>
,
drelu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
e
,
stream
);
}
void
nvte_srelu
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
...
...
@@ -49,12 +51,14 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
void
nvte_sreglu
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_sreglu
);
using
namespace
transformer_engine
;
gated_act_fn
<
fp32
,
Empty
,
srelu
<
fp32
,
fp32
>>
(
input
,
output
,
stream
);
Empty
e
=
{};
gated_act_fn
<
fp32
,
Empty
,
srelu
<
fp32
,
fp32
>>
(
input
,
output
,
e
,
stream
);
}
void
nvte_dsreglu
(
const
NVTETensor
grad
,
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_dsreglu
);
using
namespace
transformer_engine
;
dgated_act_fn
<
fp32
,
Empty
,
srelu
<
fp32
,
fp32
>
,
dsrelu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
stream
);
Empty
e
=
{};
dgated_act_fn
<
fp32
,
Empty
,
srelu
<
fp32
,
fp32
>
,
dsrelu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
e
,
stream
);
}
transformer_engine/common/activation/swiglu.cu
View file @
063ef88d
...
...
@@ -23,12 +23,31 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output
void
nvte_swiglu
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_swiglu
);
using
namespace
transformer_engine
;
gated_act_fn
<
fp32
,
Empty
,
silu
<
fp32
,
fp32
>>
(
input
,
output
,
stream
);
Empty
e
=
{};
gated_act_fn
<
fp32
,
Empty
,
silu
<
fp32
,
fp32
>>
(
input
,
output
,
e
,
stream
);
}
void
nvte_dswiglu
(
const
NVTETensor
grad
,
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_dswiglu
);
using
namespace
transformer_engine
;
dgated_act_fn
<
fp32
,
Empty
,
silu
<
fp32
,
fp32
>
,
dsilu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
stream
);
Empty
e
=
{};
dgated_act_fn
<
fp32
,
Empty
,
silu
<
fp32
,
fp32
>
,
dsilu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
e
,
stream
);
}
void
nvte_clamped_swiglu
(
const
NVTETensor
input
,
NVTETensor
output
,
float
limit
,
float
alpha
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_clamped_swiglu
);
using
namespace
transformer_engine
;
ClampedSwiGLUParam
param
=
{
limit
,
alpha
};
gated_act_fn
<
fp32
,
ClampedSwiGLUParam
,
clamped_silu
<
fp32
,
fp32
>>
(
input
,
output
,
param
,
stream
);
}
void
nvte_clamped_dswiglu
(
const
NVTETensor
grad
,
const
NVTETensor
input
,
NVTETensor
output
,
float
limit
,
float
alpha
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_clamped_dswiglu
);
using
namespace
transformer_engine
;
ClampedSwiGLUParam
param
=
{
limit
,
alpha
};
dgated_act_fn
<
fp32
,
ClampedSwiGLUParam
,
clamped_silu
<
fp32
,
fp32
>
,
clamped_dsilu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
param
,
stream
);
}
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
View file @
063ef88d
...
...
@@ -79,6 +79,15 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
#endif
_comm_created
=
true
;
}
initialize
(
tp_size
,
num_splits
,
num_max_streams
,
comm_cga_size
,
gemm_priority
,
comm_priority
,
num_comm_sm
,
set_sm_margin
,
use_ce
,
atomic_gemm
);
}
void
CommOverlapCore
::
initialize
(
int
tp_size
,
int
num_splits
,
int
num_max_streams
,
int
comm_cga_size
,
int
gemm_priority
,
int
comm_priority
,
int
num_comm_sm
,
bool
set_sm_margin
,
bool
use_ce
,
bool
atomic_gemm
)
{
_use_ce
=
static_cast
<
int
>
(
use_ce
);
_num_comm_sm
=
num_comm_sm
;
_cga_size
=
comm_cga_size
;
...
...
@@ -339,6 +348,11 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
_ub_force_blas_multistream
=
true
;
}
_ub_stream_nums
=
num_max_streams
;
initialize
(
buffer_shape
,
buffer_dtype
,
rs_overlap_first_gemm
);
}
void
CommOverlapBase
::
initialize
(
const
std
::
vector
<
size_t
>
&
buffer_shape
,
DType
buffer_dtype
,
bool
rs_overlap_first_gemm
)
{
_rs_overlap_first_gemm
=
rs_overlap_first_gemm
;
_rs_kernel_type
=
getenv
<
int
>
(
"NVTE_RS_STRIDED_ATOMIC"
,
0
);
NVTE_CHECK
(
_rs_kernel_type
>=
0
&&
_rs_kernel_type
<=
3
,
...
...
@@ -349,7 +363,9 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
size_t
buffer_bytes
=
get_buffer_size_bytes
(
buffer_shape
[
0
],
buffer_shape
[
1
],
buffer_dtype
);
void
*
buffer_ptr
;
_ub_reg
=
register_user_buffer_collective
(
&
buffer_ptr
,
buffer_bytes
,
_ub_comm
,
true
);
if
(
_ub_comm
->
myrank
==
0
)
printf
(
"!!! [UB] Register UBuf %d
\n
"
,
_ub_reg
);
if
(
_ub_comm
->
myrank
==
0
)
{
printf
(
"!!! [UB] Register UBuf %d
\n
"
,
_ub_reg
);
}
_ubuf
=
TensorWrapper
(
buffer_ptr
,
buffer_shape
,
buffer_dtype
);
int
comm_cu_nums
=
getIntEnv
(
"NVTE_UB_COMM_CU_NUMS"
,
8
,
4
);
...
...
@@ -765,6 +781,11 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
_ub_force_blas_multistream
=
true
;
}
_ub_stream_nums
=
num_max_streams
;
initialize
(
buffer_shape
,
buffer_dtype
,
comm_type
,
aggregate
);
}
void
CommOverlapP2PBase
::
initialize
(
const
std
::
vector
<
size_t
>
&
buffer_shape
,
DType
buffer_dtype
,
CommOverlapType
comm_type
,
bool
aggregate
)
{
_is_p2p
=
true
;
_is_reduce_scatter
=
comm_type
==
CommOverlapType
::
RS
;
_aggregate
=
aggregate
;
...
...
@@ -772,28 +793,28 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
// Create workspace tensor with userbuffer
NVTE_CHECK
(
buffer_shape
.
size
()
==
2
,
"Userbuffer shape must be 2-dimensional!"
);
size_t
buffer_bytes
=
get_buffer_size_bytes
(
buffer_shape
[
0
],
buffer_shape
[
1
],
buffer_dtype
);
int
buffer_chunk_bytes
=
buffer_bytes
/
tp_size
;
_num_ubuf_chunks
=
tp_size
;
int
buffer_chunk_bytes
=
buffer_bytes
/
_
tp_size
;
_num_ubuf_chunks
=
_
tp_size
;
if
(
_is_reduce_scatter
)
{
// GEMM + RS overlap: Allocate `2 x tp_size - 1` buffers to hold recieved GEMM chunk
// outputs for reduction at the end of the pipelining.
buffer_bytes
=
buffer_bytes
/
tp_size
*
(
tp_size
*
2
-
1
);
_num_ubuf_chunks
=
tp_size
*
2
-
1
;
buffer_bytes
=
buffer_bytes
/
_
tp_size
*
(
_
tp_size
*
2
-
1
);
_num_ubuf_chunks
=
_
tp_size
*
2
-
1
;
}
void
*
buffer_ptr
;
_ub_reg
=
register_user_buffer_collective
(
&
buffer_ptr
,
buffer_bytes
,
_ub_comm
,
true
);
if
(
_rank
==
0
)
printf
(
"!!! [UBP2P]
Register
UBuf %d
\n
"
,
_ub_reg
);
if
(
_rank
==
0
)
printf
(
"!!! [UBP2P] UBuf %d
\n
"
,
_ub_reg
);
_ubuf
=
TensorWrapper
(
buffer_ptr
,
std
::
vector
<
size_t
>
{
buffer_shape
[
0
]
/
tp_size
*
_num_ubuf_chunks
,
buffer_shape
[
1
]},
std
::
vector
<
size_t
>
{
buffer_shape
[
0
]
/
_
tp_size
*
_num_ubuf_chunks
,
buffer_shape
[
1
]},
buffer_dtype
);
// Create tensor chunks for easy management
char
*
ubuf_byte_ptr
=
reinterpret_cast
<
char
*>
(
buffer_ptr
);
for
(
int
i
=
0
;
i
<
_num_ubuf_chunks
;
i
++
)
{
_ubufs
.
push_back
(
TensorWrapper
(
reinterpret_cast
<
void
*>
(
ubuf_byte_ptr
),
std
::
vector
<
size_t
>
{
buffer_shape
[
0
]
/
tp_size
,
buffer_shape
[
1
]},
std
::
vector
<
size_t
>
{
buffer_shape
[
0
]
/
_
tp_size
,
buffer_shape
[
1
]},
buffer_dtype
));
ubuf_byte_ptr
+=
buffer_chunk_bytes
;
}
...
...
@@ -818,7 +839,7 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
static
cudaStream_t
send_streams
[
NVTE_COMM_OVERLAP_MAX_STREAMS
];
static
cudaStream_t
recv_stream
;
for
(
int
i
=
0
;
i
<
std
::
min
(
num_max
_streams
,
_tp_size
);
i
++
)
{
for
(
int
i
=
0
;
i
<
std
::
min
(
_ub
_stream
_num
s
,
_tp_size
);
i
++
)
{
if
(
send_streams
[
i
]
==
nullptr
)
{
NVTE_CHECK_CUDA
(
cudaStreamCreateWithPriority
(
&
send_streams
[
i
],
cudaStreamNonBlocking
,
_comm_priority
));
}
...
...
@@ -842,6 +863,38 @@ CommOverlapP2PBase::~CommOverlapP2PBase() {
}
}
void
CommOverlapP2PBase
::
copy_into_buffer
(
cudaStream_t
stream
,
const
TensorWrapper
&
source
,
bool
local_chunk
,
bool
rowwise
)
{
// Check element size
const
size_t
element_size
=
source
.
element_size
();
NVTE_CHECK
(
_ubuf
.
element_size
()
==
element_size
,
"Tried to copy data into a Userbuffers buffer but dtypes are not compatible "
,
"(source dtype has "
,
element_size
,
" bytes, UB dtype has "
,
_ubuf
.
element_size
(),
" bytes)"
);
// Input data
const
size_t
source_size
=
source
.
numel
();
const
void
*
src_ptr
=
(
rowwise
)
?
source
.
dptr
()
:
source
.
columnwise_dptr
();
// Userbuffers data
void
*
dst_ptr
;
if
(
local_chunk
)
{
NVTE_CHECK
(
_ubufs
[
_tp_id
].
numel
()
==
source_size
,
"Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer "
,
"(source_size="
,
source_size
,
", local_ubuf_size="
,
_ubufs
[
_tp_id
].
numel
(),
")"
);
dst_ptr
=
_ubufs
[
_tp_id
].
dptr
();
}
else
{
NVTE_CHECK
(
_ubuf
.
numel
()
==
source_size
,
"Tried to copy an invalid tensor into a Userbuffers buffer "
,
"(source_size="
,
source_size
,
", ubuf_size="
,
_ubuf
.
numel
(),
")"
);
dst_ptr
=
_ubuf
.
dptr
();
}
// Copy data
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
dst_ptr
,
src_ptr
,
source_size
*
element_size
,
cudaMemcpyDeviceToDevice
,
stream
));
}
TensorWrapper
CommOverlapP2PBase
::
get_buffer_chunk_by_id
(
const
TensorWrapper
&
source
,
size_t
chunk_id
)
{
// Start with a chunk of the source tensor
...
...
@@ -982,6 +1035,15 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
const
bool
do_gelu
=
pre_gelu_out
.
numel
()
>
0
;
size_t
workspace_size_chunk
=
workspace
.
numel
()
/
_stream_compute
.
size
();
// Check B copy sizing
if
(
B_copy
.
numel
()
>
0
)
{
NVTE_CHECK
(
B_copy
.
numel
()
==
_ubuf
.
numel
(),
"Expected all-gathered B copy buffer with "
,
_ubuf
.
numel
(),
" elements but got "
,
B_copy
.
numel
());
NVTE_CHECK
(
B_copy
.
element_size
()
==
_ubuf
.
element_size
(),
"Expected all-gathered B copy buffer with "
,
_ubuf
.
element_size
()
*
8
,
"-bit data type but got "
,
B_copy
.
element_size
()
*
8
,
"-bit"
);
}
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_start_compute
,
stream_main
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_send
[
0
],
_start_compute
,
0
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_recv
,
_start_compute
,
0
));
...
...
@@ -1057,12 +1119,6 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_send
[
0
],
_stop_recv
,
0
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_compute
[(
i
+
1
)
%
_stream_compute
.
size
()],
_stop_recv
,
0
));
}
else
if
(
B_copy
.
numel
()
>
0
)
{
assert
(
B_copy
.
numel
()
==
_ubufs
[
_tp_id
].
numel
());
assert
(
B_copy
.
element_size
()
==
_ubufs
[
_tp_id
].
element_size
());
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
B_copy
.
dptr
(),
_ubufs
[
_tp_id
].
dptr
(),
_ubufs
[
_tp_id
].
bytes
(),
cudaMemcpyDeviceToDevice
,
_stream_send
[
0
]));
}
}
}
else
{
...
...
@@ -1117,16 +1173,16 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_send
[
0
],
_stop_recv
,
0
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_compute
[(
i
+
1
)
%
_stream_compute
.
size
()],
_stop_recv
,
0
));
}
else
if
(
B_copy
.
numel
()
>
0
)
{
assert
(
B_copy
.
numel
()
==
_ubufs
[
_tp_id
].
numel
());
assert
(
B_copy
.
element_size
()
==
_ubufs
[
_tp_id
].
element_size
());
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
B_copy
.
dptr
(),
_ubufs
[
_tp_id
].
dptr
(),
_ubufs
[
_tp_id
].
bytes
(),
cudaMemcpyDeviceToDevice
,
_stream_send
[
0
]));
}
}
}
// Copy all-gathered B from communication buffer into auxiliary output
if
(
B_copy
.
numel
()
>
0
)
{
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
B_copy
.
dptr
(),
_ubuf
.
dptr
(),
_ubuf
.
bytes
(),
cudaMemcpyDeviceToDevice
,
_stream_send
[
0
]));
}
_ub_comm
->
sms
=
ori_sms
;
for
(
size_t
i
=
0
;
i
<
_stream_compute
.
size
();
i
++
)
{
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_compute
,
_stream_compute
[
i
]));
...
...
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp
View file @
063ef88d
...
...
@@ -679,9 +679,36 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
reinterpret_cast
<
void
*>
(
&
memhndl
),
sizeof
(
cudaIpcMemHandle_t
),
comm
->
comm_intra
);
// Check for NVLINK support before attempting IPC operations
if
(
comm
->
nvsize
>
1
)
{
int
current_device
;
NVTE_CHECK_CUDA
(
cudaGetDevice
(
&
current_device
));
cudaDeviceProp
deviceProp
;
NVTE_CHECK_CUDA
(
cudaGetDeviceProperties
(
&
deviceProp
,
current_device
));
bool
peer_access_available
=
false
;
for
(
int
i
=
0
;
i
<
comm
->
nvsize
;
i
++
)
{
if
(
i
!=
comm
->
nvrank
)
{
NVTE_CHECK_CUDA
(
cudaIpcOpenMemHandle
(
&
(
comm
->
peer_ptr
[
hndl
][
i
]),
tmp
[
i
],
// NOLINT(*)
int
can_access_peer
;
cudaError_t
peer_result
=
cudaDeviceCanAccessPeer
(
&
can_access_peer
,
current_device
,
i
);
if
(
peer_result
==
cudaSuccess
&&
can_access_peer
)
{
peer_access_available
=
true
;
break
;
}
}
}
if
(
!
peer_access_available
)
{
free
(
tmp
);
NVTE_ERROR
(
"No peer-to-peer access available between GPUs. This platform does not support the "
"GPU-to-GPU "
"communication required for multi-GPU userbuffers. Consider using single-GPU mode."
);
return
1
;
}
}
for
(
int
i
=
0
;
i
<
comm
->
nvsize
;
i
++
)
{
if
(
i
!=
comm
->
nvrank
)
{
NVTE_CHECK_CUDA
(
cudaIpcOpenMemHandle
(
&
(
comm
->
peer_ptr
[
hndl
][
i
]),
tmp
[
i
],
cudaIpcMemLazyEnablePeerAccess
));
}
}
...
...
@@ -702,4 +729,5 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
comm
->
mem_ptr
[
hndl
]
=
*
gpubuff
;
return
comm
->
free_region
++
;
printf
(
"***** Returning *****
\n
"
);
}
transformer_engine/common/common.cu
View file @
063ef88d
...
...
@@ -39,6 +39,10 @@ cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) {
return
CUDA_R_8F_E4M3
;
case
DType
::
kFloat8E5M2
:
return
CUDA_R_8F_E5M2
;
#if CUDA_VERSION >= 12080
case
DType
::
kFloat4E2M1
:
return
CUDA_R_4F_E2M1
;
#endif
default:
NVTE_ERROR
(
"Invalid type"
);
}
...
...
@@ -165,7 +169,9 @@ CUtensorMapDataType get_CUtensorMapDataType(DType dtype) {
void
create_2D_tensor_map
(
CUtensorMap
&
tensorMap
,
const
SimpleTensor
&
tensor
,
const
uint64_t
globalY
,
const
uint64_t
globalX
,
const
uint32_t
shmemY
,
const
uint32_t
shmemX
,
const
uint32_t
stride_elems
,
const
uint32_t
offset_elems
,
const
size_t
type_num_bits
)
{
const
uint32_t
offset_elems
,
const
size_t
type_num_bits
,
const
CUtensorMapSwizzle
swizzle
)
{
cuda_driver
::
ensure_context_exists
();
// Get a function pointer to the cuTensorMapEncodeTiled driver API
// Note: PFN_cuTensorMapEncodeTiled is not defined in cuda13
static
PFN_cuTensorMapEncodeTiled_v12000
cuDriverTensorMapEncodeTiled
=
[]()
{
...
...
@@ -174,6 +180,8 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
}();
// rank is the number of dimensions of the array
constexpr
uint32_t
rank
=
2
;
// Dimension for the packed data types must reflect the number of individual U# values.
uint64_t
size
[
rank
]
=
{
globalX
,
globalY
};
// The stride is the number of bytes to traverse from the first element of one row to the next
...
...
@@ -212,7 +220,7 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
CUtensorMapInterleave
::
CU_TENSOR_MAP_INTERLEAVE_NONE
,
// Swizzling can be used to avoid shared memory bank conflicts.
CUtensorMapSwizzle
::
CU_TENSOR_MAP_SWIZZLE_NONE
,
swizzle
,
// L2 Promotion can be used to widen the effect of a cache-policy to a wider
// set of L2 cache lines.
...
...
transformer_engine/common/common.h
View file @
063ef88d
...
...
@@ -54,8 +54,14 @@ inline bool is_delayed_tensor_scaling(const NVTEScalingMode &mode) {
return
mode
==
NVTE_DELAYED_TENSOR_SCALING
;
}
inline
bool
is_nvfp4_scaling
(
const
NVTEScalingMode
&
mode
)
{
return
mode
==
NVTE_NVFP4_1D_SCALING
;
}
inline
bool
is_mxfp8_scaling
(
const
NVTEScalingMode
&
mode
)
{
return
mode
==
NVTE_MXFP8_1D_SCALING
;
}
inline
bool
is_mxfp_scaling
(
const
NVTEScalingMode
&
mode
)
{
return
mode
==
NVTE_MXFP8_1D_SCALING
;
}
inline
bool
is_nvfp_scaling
(
const
NVTEScalingMode
&
mode
)
{
return
mode
==
NVTE_NVFP4_1D_SCALING
;
}
inline
size_t
product
(
const
std
::
vector
<
size_t
>
&
shape
,
const
size_t
begin
,
const
size_t
end
)
{
NVTE_CHECK
(
begin
<=
end
&&
end
<=
shape
.
size
(),
"Attempted to access entries "
,
begin
,
" to "
,
end
,
" in a vector with "
,
shape
.
size
(),
" entries"
);
...
...
@@ -114,6 +120,7 @@ struct Tensor {
SimpleTensor
data
;
SimpleTensor
columnwise_data
;
SimpleTensor
amax
;
SimpleTensor
columnwise_amax
;
SimpleTensor
scale
;
SimpleTensor
scale_inv
;
SimpleTensor
columnwise_scale_inv
;
...
...
@@ -125,6 +132,7 @@ struct Tensor {
:
data
(),
columnwise_data
(),
amax
(
nullptr
,
{
1
},
DType
::
kFloat32
),
columnwise_amax
(
nullptr
,
{
1
},
DType
::
kFloat32
),
scale
(
nullptr
,
{
1
},
DType
::
kFloat32
),
scale_inv
(
nullptr
,
{
1
},
DType
::
kFloat32
),
columnwise_scale_inv
(
nullptr
,
{
1
},
DType
::
kFloat32
),
...
...
@@ -135,6 +143,7 @@ struct Tensor {
data
.
clear
();
columnwise_data
.
clear
();
amax
.
clear
();
columnwise_amax
.
clear
();
scale
.
clear
();
scale_inv
.
clear
();
columnwise_scale_inv
.
clear
();
...
...
@@ -180,6 +189,7 @@ struct Tensor {
* https://gcc.gnu.org/bugzilla/show_bug.cgi?id=109569).
*/
switch
(
scaling_mode
)
{
case
NVTE_NVFP4_1D_SCALING
:
case
NVTE_DELAYED_TENSOR_SCALING
:
if
(
!
has_data
()
&&
has_columnwise_data
())
{
std
::
vector
<
size_t
>
ret
;
...
...
@@ -195,7 +205,6 @@ struct Tensor {
}
break
;
case
NVTE_MXFP8_1D_SCALING
:
case
NVTE_FWD_NVFP4_BWD_MXFP8_SCALING
:
if
(
!
has_data
()
&&
has_columnwise_data
())
{
return
columnwise_data
.
shape
;
}
else
{
...
...
@@ -267,12 +276,18 @@ struct QuantizationConfig {
NVTETensor
noop_tensor
=
nullptr
;
Float8BlockScaleTensorFormat
float8_block_scale_tensor_format
=
Float8BlockScaleTensorFormat
::
GEMM_READY
;
NVTETensor
rng_state
=
nullptr
;
bool
nvfp4_2d_quantization
=
false
;
bool
stochastic_rounding
=
false
;
static
constexpr
size_t
attr_sizes
[]
=
{
sizeof
(
bool
),
// force_pow_2_scales
sizeof
(
float
),
// amax_epsilon
sizeof
(
NVTETensor
),
// noop_tensor
sizeof
(
Float8BlockScaleTensorFormat
)
// float8_block_scale_tensor_format
sizeof
(
Float8BlockScaleTensorFormat
),
// float8_block_scale_tensor_format
sizeof
(
NVTETensor
),
// rng_seed and offset
sizeof
(
bool
),
// nvfp4_2d_quantization
sizeof
(
bool
)
// stochastic_rounding
};
};
...
...
@@ -305,6 +320,8 @@ using fp8e8m0 = __nv_fp8_e8m0;
#endif
#if FP4_TYPE_SUPPORTED
using
fp4e2m1
=
__nv_fp4_e2m1
;
using
fp4e2m1x2
=
__nv_fp4x2_e2m1
;
using
fp4e2m1x4
=
__nv_fp4x4_e2m1
;
#endif
using
e8m0_t
=
uint8_t
;
...
...
@@ -342,12 +359,14 @@ struct TypeExtrema;
template
<
>
struct
TypeExtrema
<
fp4e2m1
>
{
static
constexpr
float
max
=
6.0
f
;
static
constexpr
float
max_inverse
=
1.0
/
max
;
};
#endif
template
<
>
struct
TypeExtrema
<
fp8e4m3
>
{
static
constexpr
float
max
=
448.0
f
;
static
constexpr
float
max_inverse
=
1.0
/
max
;
};
template
<
>
...
...
@@ -358,6 +377,7 @@ struct TypeExtrema<int8> {
template
<
>
struct
TypeExtrema
<
fp8e5m2
>
{
static
constexpr
float
max
=
57344.0
f
;
static
constexpr
float
max_inverse
=
1.0
/
max
;
};
template
<
>
...
...
@@ -602,6 +622,18 @@ struct TypeInfo {
NVTE_ERROR("Invalid type."); \
}
// Add a pack_size argument to select the packed type for FP4
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4x2_ONLY(dtype, pack_size, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kFloat4E2M1: { \
using type = __nv_fp4x2_storage_t; \
{ __VA_ARGS__ } \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
...
...
@@ -812,10 +844,11 @@ void checkCuDriverContext(CUstream stream);
CUtensorMapDataType
get_CUtensorMapDataType
(
DType
dtype
);
// Set up parameters to create TMA descriptor.
void
create_2D_tensor_map
(
CUtensorMap
&
tensorMap
,
const
SimpleTensor
&
tensor
,
const
uint64_t
globalY
,
const
uint64_t
globalX
,
const
uint32_t
shmemY
,
const
uint32_t
shmemX
,
const
uint32_t
stride_elems
,
const
uint32_t
offset_elems
,
const
size_t
type_num_bits
);
void
create_2D_tensor_map
(
CUtensorMap
&
tensorMap
,
const
SimpleTensor
&
tensor
,
const
uint64_t
globalY
,
const
uint64_t
globalX
,
const
uint32_t
shmemY
,
const
uint32_t
shmemX
,
const
uint32_t
stride_elems
,
const
uint32_t
offset_elems
,
const
size_t
type_num_bits
,
const
CUtensorMapSwizzle
swizzle
=
CUtensorMapSwizzle
::
CU_TENSOR_MAP_SWIZZLE_NONE
);
#endif
bool
is_supported_by_CC_100
();
...
...
transformer_engine/common/fused_attn/fused_attn.cpp
View file @
063ef88d
...
...
@@ -135,9 +135,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) {
// select a backend for fused attention
NVTE_Fused_Attn_Backend
nvte_get_fused_attn_backend
(
bool
is_training
,
NVTEDType
q_dtype
,
NVTEDType
kv_dtype
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
float
dropout
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
int64_t
window_size_left
,
int64_t
window_size_right
)
{
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
float
dropout
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
int64_t
window_size_left
,
int64_t
window_size_right
)
{
using
namespace
transformer_engine
;
NVTE_Fused_Attn_Backend
backend
=
NVTE_Fused_Attn_Backend
::
NVTE_No_Backend
;
const
int
device_id
=
cuda
::
current_device
();
...
...
@@ -175,7 +176,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
// TODO (cyang): add is_training to nvte_get_fused_attn_backend
// sm90: fwd d<=256, bwd d=128 only
// sm100: fwd d<=128, bwd d<=128
((
sm_arch_
<
100
&&
head_dim_qk
<=
256
&&
head_dim_v
<=
256
)
||
((
sm_arch_
<
100
&&
(
!
is_training
)
&&
head_dim_qk
<=
256
&&
head_dim_v
<=
256
)
||
(
sm_arch_
<
100
&&
is_training
&&
head_dim_qk
==
128
&&
head_dim_v
==
128
)
||
(
sm_arch_
>=
100
&&
head_dim_qk
<=
128
&&
head_dim_v
<=
128
))
&&
head_dim_qk
%
16
==
0
&&
head_dim_v
%
16
==
0
&&
(
attn_mask_type
==
NVTE_Mask_Type
::
NVTE_NO_MASK
||
...
...
@@ -183,7 +185,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
attn_mask_type
==
NVTE_Mask_Type
::
NVTE_PADDING_MASK
||
attn_mask_type
==
NVTE_Mask_Type
::
NVTE_PADDING_CAUSAL_MASK
)))
&&
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_BSHD
||
qkv_format
==
NVTE_QKV_Format
::
NVTE_SBHD
)
&&
!
requires_64bit_ragged_offset
&&
!
requires_64bit_ragged_offset
&&
(
softmax_type
==
NVTE_Softmax_Type
::
NVTE_VANILLA_SOFTMAX
)
&&
// 9.10.0: known bugs with SDPA FP8
(
cudnn_runtime_version
!=
91000
))
{
if
(
cudnn_runtime_version
>=
8900
)
{
...
...
@@ -213,7 +215,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(
qkv_layout
==
NVTE_QKV_Layout
::
NVTE_BSHD_BS2HD
)
||
(
qkv_layout
==
NVTE_QKV_Layout
::
NVTE_BSHD_BSHD_BSHD
))
&&
((
window_size_left
==
-
1
)
&&
(
window_size_right
==
-
1
||
window_size_right
==
0
))
&&
!
requires_64bit_ragged_offset
)
{
!
requires_64bit_ragged_offset
&&
(
softmax_type
==
NVTE_Softmax_Type
::
NVTE_VANILLA_SOFTMAX
))
{
flag_m512
=
true
;
}
if
(
...
...
@@ -363,7 +366,13 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
// check 64-bit ragged offset support
(
supported_ragged_offset_size
)
&&
// 9.10.0/9.10.1: known bugs with SDPA F16
(
cudnn_runtime_version
!=
91000
)
&&
(
cudnn_runtime_version
!=
91001
))
{
(
cudnn_runtime_version
!=
91000
)
&&
(
cudnn_runtime_version
!=
91001
)
&&
// softmax type
// pre-9.13.1: vanilla
// 9.13.1+: vanilla, off-by-one, learnable
(
cudnn_runtime_version
>=
91301
||
(
cudnn_runtime_version
<
91301
&&
softmax_type
==
NVTE_Softmax_Type
::
NVTE_VANILLA_SOFTMAX
)))
{
flag_arb
=
true
;
}
if
(((
max_seqlen_q
>
512
)
||
(
max_seqlen_kv
>
512
))
&&
(
flag_arb
==
true
))
{
...
...
@@ -405,14 +414,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
}
// NVTE fused attention FWD with packed QKV
void
nvte_fused_attn_fwd_qkvpacked
(
const
NVTETensor
QKV
,
const
NVTETensor
Bias
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens
,
const
NVTETensor
cu_seqlens_padded
,
const
NVTETensor
rng_state
,
size_t
max_seqlen
,
bool
is_training
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
void
nvte_fused_attn_fwd_qkvpacked
(
const
NVTETensor
QKV
,
const
NVTETensor
Bias
,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens
,
const
NVTETensor
cu_seqlens_padded
,
const
NVTETensor
rng_state
,
size_t
max_seqlen
,
bool
is_training
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_flash_attn_fwd_qkvpacked
);
using
namespace
transformer_engine
;
...
...
@@ -421,6 +432,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
const
Tensor
*
input_rng_state
=
convertNVTETensorCheck
(
rng_state
);
const
Tensor
*
input_QKV
=
convertNVTETensorCheck
(
QKV
);
const
Tensor
*
input_Bias
=
convertNVTETensorCheck
(
Bias
);
const
Tensor
*
input_SoftmaxOffset
=
convertNVTETensorCheck
(
SoftmaxOffset
);
Tensor
*
input_output_S
=
convertNVTETensorCheck
(
S
);
Tensor
*
output_O
=
convertNVTETensorCheck
(
O
);
Tensor
*
wkspace
=
convertNVTETensor
(
workspace
);
...
...
@@ -447,8 +459,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
const
NVTEDType
QKV_type
=
static_cast
<
NVTEDType
>
(
input_QKV
->
data
.
dtype
);
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
is_training
,
QKV_type
,
QKV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
dropout
,
h
,
h
,
max_seqlen
,
max_seqlen
,
d
,
d
,
window_size_left
,
window_size_right
);
is_training
,
QKV_type
,
QKV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
h
,
h
,
max_seqlen
,
max_seqlen
,
d
,
d
,
window_size_left
,
window_size_right
);
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
#if (CUDNN_VERSION >= 8901)
...
...
@@ -463,9 +475,9 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
#if (CUDNN_VERSION >= 8900)
fused_attn_arbitrary_seqlen_fwd_qkvpacked
(
b
,
h
,
max_seqlen
,
d
,
t
,
is_training
,
attn_scale
,
dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
window_size_left
,
window_size_right
,
input_QKV
,
input_Bias
,
output_O
,
Aux_CTX_Tensors
,
input_cu_seqlens
,
input_cu_seqlens_padded
,
input_rng_state
,
wkspace
,
stream
,
handle
);
attn_mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
input_QKV
,
input_Bias
,
input_SoftmaxOffset
,
output_O
,
Aux_CTX_Tensors
,
input_cu_seqlens
,
input_cu_seqlens_padded
,
input_rng_state
,
wkspace
,
stream
,
handle
);
#else
NVTE_ERROR
(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length.
\n
"
);
...
...
@@ -487,10 +499,11 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
void
nvte_fused_attn_bwd_qkvpacked
(
const
NVTETensor
QKV
,
const
NVTETensor
O
,
const
NVTETensor
dO
,
const
NVTETensor
S
,
NVTETensor
dP
,
const
NVTETensorPack
*
Aux_CTX_Tensors
,
NVTETensor
dQKV
,
NVTETensor
dBias
,
const
NVTETensor
cu_seqlens
,
const
NVTETensor
cu_seqlens_padded
,
size_t
max_seqlen
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTETensor
dBias
,
NVTETensor
dSoftmaxOffset
,
const
NVTETensor
cu_seqlens
,
const
NVTETensor
cu_seqlens_padded
,
size_t
max_seqlen
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_flash_attn_bwd_qkvpacked
);
...
...
@@ -505,6 +518,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
Tensor
*
input_output_dP
=
convertNVTETensorCheck
(
dP
);
Tensor
*
output_dQKV
=
convertNVTETensorCheck
(
dQKV
);
Tensor
*
output_dBias
=
convertNVTETensorCheck
(
dBias
);
Tensor
*
output_dSoftmaxOffset
=
convertNVTETensorCheck
(
dSoftmaxOffset
);
Tensor
*
wkspace
=
convertNVTETensor
(
workspace
);
auto
ndim
=
input_QKV
->
data
.
shape
.
size
();
...
...
@@ -529,8 +543,8 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
const
NVTEDType
QKV_type
=
static_cast
<
NVTEDType
>
(
input_QKV
->
data
.
dtype
);
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
true
,
QKV_type
,
QKV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
dropout
,
h
,
h
,
max_seqlen
,
max_seqlen
,
d
,
d
,
window_size_left
,
window_size_right
);
true
,
QKV_type
,
QKV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
h
,
h
,
max_seqlen
,
max_seqlen
,
d
,
d
,
window_size_left
,
window_size_right
);
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
#if (CUDNN_VERSION >= 8901)
...
...
@@ -543,19 +557,22 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
#endif
}
else
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_arbitrary_seqlen
)
{
#if (CUDNN_VERSION >= 8900)
Tensor
*
output_S
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
input_Bias
,
*
input_rng_state
;
size_t
i
=
0
;
Tensor
*
output_S
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
Tensor
*
input_rng_state
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
Tensor
*
input_Bias
,
*
input_SoftmaxOffset
;
if
((
bias_type
!=
NVTE_NO_BIAS
)
&&
(
bias_type
!=
NVTE_ALIBI
))
{
input_
rng_state
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
1
]);
input_Bias
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
2
]);
}
else
{
input_
rng_state
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
1
]);
input_
Bias
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
}
if
(
softmax_type
!=
NVTE_VANILLA_SOFTMAX
)
{
input_
SoftmaxOffset
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
}
fused_attn_arbitrary_seqlen_bwd_qkvpacked
(
b
,
h
,
max_seqlen
,
d
,
t
,
attn_scale
,
dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
window_size_left
,
window_size_right
,
deterministic
,
input_QKV
,
input_O
,
input_dO
,
input_Bias
,
output_S
,
output_dQKV
,
output_dBias
,
input_cu_seqlens
,
input_cu_seqlens_padded
,
input_rng_state
,
wkspace
,
stream
,
handle
);
softmax_type
,
window_size_left
,
window_size_right
,
deterministic
,
input_QKV
,
input_O
,
input_dO
,
input_Bias
,
input_SoftmaxOffset
,
output_S
,
output_dQKV
,
output_dBias
,
output_dSoftmaxOffset
,
input_cu_seqlens
,
input_cu_seqlens_padded
,
input_rng_state
,
wkspace
,
stream
,
handle
);
#else
const
char
*
err_msg
=
"cuDNN 8.9.0 is required for BF16/FP16 fused attention "
...
...
@@ -580,14 +597,15 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
}
// NVTE fused attention FWD with packed KV
void
nvte_fused_attn_fwd_kvpacked
(
const
NVTETensor
Q
,
const
NVTETensor
KV
,
const
NVTETensor
Bias
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
page_table_k
,
const
NVTETensor
page_table_v
,
const
NVTETensor
rng_state
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
float
attn_scale
,
float
dropout
,
const
NVTETensor
Q
,
const
NVTETensor
KV
,
const
NVTETensor
Bias
,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
page_table_k
,
const
NVTETensor
page_table_v
,
const
NVTETensor
rng_state
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_flash_attn_fwd_kvpacked
);
using
namespace
transformer_engine
;
const
Tensor
*
input_cu_seqlens_q
=
convertNVTETensorCheck
(
cu_seqlens_q
);
...
...
@@ -600,6 +618,7 @@ void nvte_fused_attn_fwd_kvpacked(
const
Tensor
*
input_Q
=
convertNVTETensorCheck
(
Q
);
const
Tensor
*
input_KV
=
convertNVTETensorCheck
(
KV
);
const
Tensor
*
input_Bias
=
convertNVTETensorCheck
(
Bias
);
const
Tensor
*
input_SoftmaxOffset
=
convertNVTETensorCheck
(
SoftmaxOffset
);
Tensor
*
input_output_S
=
convertNVTETensorCheck
(
S
);
Tensor
*
output_O
=
convertNVTETensorCheck
(
O
);
Tensor
*
wkspace
=
convertNVTETensor
(
workspace
);
...
...
@@ -660,8 +679,8 @@ void nvte_fused_attn_fwd_kvpacked(
const
NVTEDType
KV_type
=
static_cast
<
NVTEDType
>
(
input_KV
->
data
.
dtype
);
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
is_training
,
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
dropout
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d
,
d
,
window_size_left
,
window_size_right
);
is_training
,
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d
,
d
,
window_size_left
,
window_size_right
);
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
#if (CUDNN_VERSION >= 8901)
...
...
@@ -677,10 +696,11 @@ void nvte_fused_attn_fwd_kvpacked(
fused_attn_arbitrary_seqlen_fwd_kvpacked
(
b
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d
,
t_q
,
t_kv
,
num_pages_k
,
num_pages_v
,
page_size_k
,
page_size_v
,
max_pages_per_seq_k
,
max_pages_per_seq_v
,
is_training
,
attn_scale
,
dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
window_size_left
,
window_size_right
,
input_Q
,
input_KV
,
input_Bias
,
output_O
,
Aux_CTX_Tensors
,
input_cu_seqlens_q
,
input_cu_seqlens_kv
,
input_cu_seqlens_q_padded
,
input_cu_seqlens_kv_padded
,
input_page_table_k
,
input_page_table_v
,
input_rng_state
,
wkspace
,
stream
,
handle
);
dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
input_Q
,
input_KV
,
input_Bias
,
input_SoftmaxOffset
,
output_O
,
Aux_CTX_Tensors
,
input_cu_seqlens_q
,
input_cu_seqlens_kv
,
input_cu_seqlens_q_padded
,
input_cu_seqlens_kv_padded
,
input_page_table_k
,
input_page_table_v
,
input_rng_state
,
wkspace
,
stream
,
handle
);
#else
NVTE_ERROR
(
"cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length.
\n
"
);
...
...
@@ -702,12 +722,12 @@ void nvte_fused_attn_fwd_kvpacked(
void
nvte_fused_attn_bwd_kvpacked
(
const
NVTETensor
Q
,
const
NVTETensor
KV
,
const
NVTETensor
O
,
const
NVTETensor
dO
,
const
NVTETensor
S
,
NVTETensor
dP
,
const
NVTETensorPack
*
Aux_CTX_Tensors
,
NVTETensor
dQ
,
NVTETensor
dKV
,
NVTETensor
dBias
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_
kv
,
const
NVTETensor
cu_seqlens_
q_padded
,
const
NVTETensor
cu_seqlens_
kv
_padded
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTETensor
dKV
,
NVTETensor
dBias
,
NVTETensor
dSoftmaxOffset
,
const
NVTETensor
cu_seqlens_
q
,
const
NVTETensor
cu_seqlens_
kv
,
const
NVTETensor
cu_seqlens_
q
_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_flash_attn_bwd_kvpacked
);
using
namespace
transformer_engine
;
const
Tensor
*
input_cu_seqlens_q
=
convertNVTETensorCheck
(
cu_seqlens_q
);
...
...
@@ -723,6 +743,7 @@ void nvte_fused_attn_bwd_kvpacked(
Tensor
*
output_dQ
=
convertNVTETensorCheck
(
dQ
);
Tensor
*
output_dKV
=
convertNVTETensorCheck
(
dKV
);
Tensor
*
output_dBias
=
convertNVTETensorCheck
(
dBias
);
Tensor
*
output_dSoftmaxOffset
=
convertNVTETensorCheck
(
dSoftmaxOffset
);
Tensor
*
wkspace
=
convertNVTETensor
(
workspace
);
size_t
b
=
input_cu_seqlens_q
->
data
.
shape
[
0
]
-
1
;
...
...
@@ -755,8 +776,8 @@ void nvte_fused_attn_bwd_kvpacked(
const
NVTEDType
KV_type
=
static_cast
<
NVTEDType
>
(
input_KV
->
data
.
dtype
);
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
true
,
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
dropout
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d
,
d
,
window_size_left
,
window_size_right
);
true
,
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d
,
d
,
window_size_left
,
window_size_right
);
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
#if (CUDNN_VERSION >= 8901)
...
...
@@ -770,20 +791,23 @@ void nvte_fused_attn_bwd_kvpacked(
#endif
}
else
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_arbitrary_seqlen
)
{
#if (CUDNN_VERSION >= 8903)
Tensor
*
output_S
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
input_Bias
,
*
input_rng_state
;
size_t
i
=
0
;
Tensor
*
output_S
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
Tensor
*
input_rng_state
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
Tensor
*
input_Bias
,
*
input_SoftmaxOffset
;
if
((
bias_type
!=
NVTE_NO_BIAS
)
&&
(
bias_type
!=
NVTE_ALIBI
))
{
input_
rng_state
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
1
]);
input_Bias
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
2
]);
}
else
{
input_
rng_state
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
1
]);
input_
Bias
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
}
if
(
softmax_type
!=
NVTE_VANILLA_SOFTMAX
)
{
input_
SoftmaxOffset
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
}
fused_attn_arbitrary_seqlen_bwd_kvpacked
(
b
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d
,
t_q
,
t_kv
,
attn_scale
,
dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
window_size_left
,
window_size_right
,
deterministic
,
input_Q
,
input_KV
,
input_O
,
input_dO
,
input_Bias
,
output_S
,
output_dQ
,
output_dKV
,
output_dBias
,
input_cu_seqlens_q
,
input_cu_seqlens_kv
,
input_cu_seqlens_q_padded
,
input_cu_seqlens_kv_padded
,
input_rng_state
,
wkspace
,
stream
,
handle
);
bias_type
,
attn_mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
deterministic
,
input_Q
,
input_KV
,
input_O
,
input_dO
,
input_Bias
,
input_SoftmaxOffset
,
output_S
,
output_dQ
,
output_dKV
,
output_dBias
,
output_dSoftmaxOffset
,
input_cu_seqlens_q
,
input_cu_seqlens_kv
,
input_cu_seqlens_q_padded
,
input_cu_seqlens_kv_padded
,
input_rng_state
,
wkspace
,
stream
,
handle
);
#else
const
char
*
err_msg
=
"cuDNN 8.9.3 is required for BF16/FP16 fused attention "
...
...
@@ -809,16 +833,17 @@ void nvte_fused_attn_bwd_kvpacked(
}
// NVTE fused attention FWD with separate Q, K and V
void
nvte_fused_attn_fwd
(
const
NVTETensor
Q
,
const
NVTETensor
K
,
const
NVTETensor
V
,
const
NVTETensor
Bias
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
Bias
,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
page_table_k
,
const
NVTETensor
page_table_v
,
const
NVTETensor
rng_state
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_flash_attn_fwd
);
using
namespace
transformer_engine
;
const
Tensor
*
input_cu_seqlens_q
=
convertNVTETensorCheck
(
cu_seqlens_q
);
...
...
@@ -832,6 +857,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const
Tensor
*
input_K
=
convertNVTETensorCheck
(
K
);
const
Tensor
*
input_V
=
convertNVTETensorCheck
(
V
);
const
Tensor
*
input_Bias
=
convertNVTETensorCheck
(
Bias
);
const
Tensor
*
input_SoftmaxOffset
=
convertNVTETensorCheck
(
SoftmaxOffset
);
Tensor
*
input_output_S
=
convertNVTETensorCheck
(
S
);
Tensor
*
output_O
=
convertNVTETensorCheck
(
O
);
Tensor
*
wkspace
=
convertNVTETensor
(
workspace
);
...
...
@@ -886,8 +912,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const
NVTEDType
KV_type
=
static_cast
<
NVTEDType
>
(
input_K
->
data
.
dtype
);
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
is_training
,
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
dropout
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d_qk
,
d_v
,
window_size_left
,
window_size_right
);
is_training
,
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d_qk
,
d_v
,
window_size_left
,
window_size_right
);
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
#if (CUDNN_VERSION >= 8901)
...
...
@@ -903,10 +929,11 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
fused_attn_arbitrary_seqlen_fwd
(
b
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d_qk
,
d_v
,
t_q
,
t_kv
,
num_pages_k
,
num_pages_v
,
page_size_k
,
page_size_v
,
max_pages_per_seq_k
,
max_pages_per_seq_v
,
is_training
,
attn_scale
,
dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
window_size_left
,
window_size_right
,
input_Q
,
input_K
,
input_V
,
input_Bias
,
output_O
,
Aux_CTX_Tensors
,
input_cu_seqlens_q
,
input_cu_seqlens_kv
,
input_cu_seqlens_q_padded
,
input_cu_seqlens_kv_padded
,
input_page_table_k
,
input_page_table_v
,
input_rng_state
,
wkspace
,
stream
,
handle
);
dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
input_Q
,
input_K
,
input_V
,
input_Bias
,
input_SoftmaxOffset
,
output_O
,
Aux_CTX_Tensors
,
input_cu_seqlens_q
,
input_cu_seqlens_kv
,
input_cu_seqlens_q_padded
,
input_cu_seqlens_kv_padded
,
input_page_table_k
,
input_page_table_v
,
input_rng_state
,
wkspace
,
stream
,
handle
);
#else
NVTE_ERROR
(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length.
\n
"
);
...
...
@@ -928,14 +955,15 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
void
nvte_fused_attn_bwd
(
const
NVTETensor
Q
,
const
NVTETensor
K
,
const
NVTETensor
V
,
const
NVTETensor
O
,
const
NVTETensor
dO
,
const
NVTETensor
S
,
NVTETensor
dP
,
const
NVTETensorPack
*
Aux_CTX_Tensors
,
NVTETensor
dQ
,
NVTETensor
dK
,
NVTETensor
dV
,
NVTETensor
dBias
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
NVTETensor
dV
,
NVTETensor
dBias
,
NVTETensor
dSoftmaxOffset
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_flash_attn_bwd
);
using
namespace
transformer_engine
;
const
Tensor
*
input_cu_seqlens_q
=
convertNVTETensorCheck
(
cu_seqlens_q
);
...
...
@@ -953,6 +981,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
Tensor
*
output_dK
=
convertNVTETensorCheck
(
dK
);
Tensor
*
output_dV
=
convertNVTETensorCheck
(
dV
);
Tensor
*
output_dBias
=
convertNVTETensorCheck
(
dBias
);
Tensor
*
output_dSoftmaxOffset
=
convertNVTETensorCheck
(
dSoftmaxOffset
);
Tensor
*
wkspace
=
convertNVTETensor
(
workspace
);
auto
ndim
=
input_Q
->
data
.
shape
.
size
();
...
...
@@ -978,8 +1007,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const
NVTEDType
KV_type
=
static_cast
<
NVTEDType
>
(
input_K
->
data
.
dtype
);
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
true
,
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
dropout
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d_qk
,
d_v
,
window_size_left
,
window_size_right
);
true
,
Q_type
,
KV_type
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
dropout
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d_qk
,
d_v
,
window_size_left
,
window_size_right
);
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
{
#if (CUDNN_VERSION >= 8901)
...
...
@@ -993,19 +1022,22 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
#endif
}
else
if
(
fused_attention_backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_arbitrary_seqlen
)
{
#if (CUDNN_VERSION >= 8900)
Tensor
*
output_S
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
input_Bias
,
*
input_rng_state
;
size_t
i
=
0
;
Tensor
*
output_S
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
Tensor
*
input_rng_state
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
Tensor
*
input_Bias
,
*
input_SoftmaxOffset
;
if
((
bias_type
!=
NVTE_NO_BIAS
)
&&
(
bias_type
!=
NVTE_ALIBI
))
{
input_
rng_state
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
1
]);
input_Bias
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
2
]);
}
else
{
input_
rng_state
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
1
]);
input_
Bias
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
}
if
(
softmax_type
!=
NVTE_VANILLA_SOFTMAX
)
{
input_
SoftmaxOffset
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
}
fused_attn_arbitrary_seqlen_bwd
(
b
,
h_q
,
h_kv
,
max_seqlen_q
,
max_seqlen_kv
,
d_qk
,
d_v
,
t_q
,
t_kv
,
attn_scale
,
dropout
,
qkv_layout
,
bias_type
,
attn_mask_type
,
window_size_left
,
window_size_right
,
deterministic
,
input_Q
,
input_K
,
input_V
,
input_O
,
input_dO
,
input_Bias
,
output_S
,
output_dQ
,
output_dK
,
output_dV
,
output_dBias
,
input_cu_seqlens_q
,
input_cu_seqlens_kv
,
input_cu_seqlens_q_padded
,
qkv_layout
,
bias_type
,
attn_mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
deterministic
,
input_Q
,
input_K
,
input_V
,
input_O
,
input_dO
,
input_Bias
,
input_SoftmaxOffset
,
output_S
,
output_dQ
,
output_dK
,
output_dV
,
output_dBias
,
output_dSoftmaxOffset
,
input_cu_seqlens_q
,
input_cu_seqlens_kv
,
input_cu_seqlens_q_padded
,
input_cu_seqlens_kv_padded
,
input_rng_state
,
wkspace
,
stream
,
handle
);
#else
const
char
*
err_msg
=
...
...
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu
View file @
063ef88d
...
...
@@ -54,10 +54,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
int64_t
page_size_k
,
int64_t
page_size_v
,
int64_t
max_pages_per_seq_k
,
int64_t
max_pages_per_seq_v
,
int64_t
bias_b
,
int64_t
bias_h
,
bool
is_training
,
float
scaling_factor
,
float
dropout_probability
,
NVTE_QKV_Layout
layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
void
*
devPtrQ
,
void
*
devPtrK
,
void
*
devPtrV
,
void
*
devPtrBias
,
void
*
devPtrSoftmaxStats
,
void
*
devPtrO
,
void
*
devPtrDropoutSeed
,
void
*
devPtrDropoutOffset
,
void
*
devPtrCuSeqlensQ
,
void
*
devPtrCuSeqlensKV
,
void
*
devPtrPageTableK
,
void
*
devPtrPageTableV
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
void
*
devPtrQ
,
void
*
devPtrK
,
void
*
devPtrV
,
void
*
devPtrBias
,
void
*
devPtrSoftmaxOffset
,
void
*
devPtrSoftmaxStats
,
void
*
devPtrO
,
void
*
devPtrDropoutSeed
,
void
*
devPtrDropoutOffset
,
void
*
devPtrCuSeqlensQ
,
void
*
devPtrCuSeqlensKV
,
void
*
devPtrPageTableK
,
void
*
devPtrPageTableV
,
void
*
devPtrSeqOffsetsQ
,
void
*
devPtrSeqOffsetsKV
,
cudnn_frontend
::
DataType_t
tensorType
,
void
*
workspace
,
size_t
*
workspace_size
,
cudaStream_t
stream
,
cudnnHandle_t
handle
)
{
using
namespace
transformer_engine
;
...
...
@@ -75,6 +76,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
is_causal
=
true
;
is_bottom_right
=
false
;
}
bool
is_softmax_offset
=
(
softmax_type
!=
NVTE_Softmax_Type
::
NVTE_VANILLA_SOFTMAX
);
bool
is_dropout
=
(
is_training
&&
dropout_probability
!=
0.0
f
);
NVTE_QKV_Format
q_format
=
nvte_get_q_format
(
layout
);
NVTE_QKV_Format
kv_format
=
nvte_get_kv_format
(
layout
);
...
...
@@ -98,8 +100,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
s_q
=
is_ragged_q
?
max_t_q
:
s_q
;
s_kv
=
is_ragged_kv
?
max_t_kv
:
s_kv
;
}
const
DType
ragged_offset_type
=
cudnn_runtime_version
>=
90500
?
DType
::
kInt64
:
DType
::
kInt32
;
const
DType
ragged_offset_type
=
cudnn_runtime_version
>=
90500
?
DType
::
kInt64
:
DType
::
kInt32
;
try
{
FADescriptor_v1
descriptor
{
b
,
h
,
...
...
@@ -122,11 +124,14 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
layout
,
bias_type
,
mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
true
,
tensorType
,
tensorType
};
cudnn_frontend
::
DataType_t
::
NOT_SET
,
cudnn_frontend
::
DataType_t
::
NOT_SET
,
cudnn_frontend
::
DataType_t
::
NOT_SET
};
namespace
fe
=
cudnn_frontend
;
using
graph_and_tensors
=
...
...
@@ -138,6 +143,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
// O
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
// Stats
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
// bias
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
// softmax_offset
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
// seq_q
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
// seq_kv
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
// page_table_k
...
...
@@ -168,7 +174,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
.
set_intermediate_data_type
(
fe
::
DataType_t
::
FLOAT
)
.
set_compute_data_type
(
fe
::
DataType_t
::
FLOAT
);
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
Q
,
K
,
V
,
attn_scale
;
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
Q
,
K
,
V
,
attn_scale
,
softmax_offset
;
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
bias
,
seq_q
,
seq_kv
;
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
page_table_k
,
page_table_v
;
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
offset_q
,
offset_k
,
offset_v
,
offset_o
,
...
...
@@ -302,6 +308,15 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
sdpa_options
.
set_dropout
(
dropout_probability
,
dropout_seed
,
dropout_offset
);
}
if
(
is_softmax_offset
)
{
softmax_offset
=
mha_graph
->
tensor
(
fe
::
graph
::
Tensor_attributes
()
.
set_name
(
"softmax_offset"
)
.
set_dim
({
1
,
h
,
1
,
1
})
.
set_stride
({
h
,
1
,
1
,
1
})
.
set_data_type
(
fe
::
DataType_t
::
FLOAT
));
sdpa_options
.
set_sink_token
(
softmax_offset
);
}
auto
[
O
,
Stats
]
=
mha_graph
->
sdpa
(
Q
,
K
,
V
,
sdpa_options
);
std
::
vector
<
int64_t
>
o_stride
(
4
);
...
...
@@ -338,6 +353,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
key_tensors_tuple
=
std
::
make_tuple
(
Q
,
K
,
V
,
attn_scale
,
O
);
auto
Stats_tuple
=
std
::
make_tuple
(
Stats
);
auto
bias_tuple
=
is_bias
?
std
::
make_tuple
(
bias
)
:
std
::
make_tuple
(
nullptr
);
auto
softmax_offset_tuple
=
is_softmax_offset
?
std
::
make_tuple
(
softmax_offset
)
:
std
::
make_tuple
(
nullptr
);
auto
padding_tuple
=
is_padding
?
std
::
make_tuple
(
seq_q
,
seq_kv
)
:
std
::
make_tuple
(
nullptr
,
nullptr
);
auto
page_table_tuple
=
is_paged_kv
?
std
::
make_tuple
(
page_table_k
,
page_table_v
)
...
...
@@ -358,17 +375,18 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
NVTE_CHECK_CUDNN_FE
(
mha_graph
->
check_support
(
handle
));
NVTE_CHECK_CUDNN_FE
(
mha_graph
->
build_plans
(
handle
));
auto
return_tuple
=
std
::
tuple_cat
(
std
::
make_tuple
(
mha_graph
),
key_tensors_tuple
,
Stats_tuple
,
bias_tuple
,
padding_tuple
,
page_table_tuple
,
offset_qo_tuple
,
offset_kv_tuple
,
offset_s_tuple
,
dropout_tuple
);
auto
return_tuple
=
std
::
tuple_cat
(
std
::
make_tuple
(
mha_graph
),
key_tensors_tuple
,
Stats_tuple
,
bias_tuple
,
softmax_offset_tuple
,
padding_tuple
,
page_table_tuple
,
offset_qo_tuple
,
offset_kv_tuple
,
offset_s_tuple
,
dropout_tuple
);
cache
.
insert
({
descriptor
,
return_tuple
});
return
return_tuple
;
};
auto
[
mha_graph
,
Q
,
K
,
V
,
attn_scale
,
O
,
Stats
,
bias
,
s
eq_q
,
seq_kv
,
page_table_k
,
page_table_
v
,
offset_q
,
offset_o
,
offset_k
,
offset_v
,
offset_stats
,
dropout_seed
,
dropout_offset
]
=
get_graph
(
sdpa_f16_fprop_cache
,
descriptor
);
auto
[
mha_graph
,
Q
,
K
,
V
,
attn_scale
,
O
,
Stats
,
bias
,
s
oftmax_offset
,
seq_q
,
seq_k
v
,
page_table_k
,
page_table_v
,
offset_q
,
offset_o
,
offset_k
,
offset_v
,
offset_stats
,
dropout_seed
,
dropout_offset
]
=
get_graph
(
sdpa_f16_fprop_cache
,
descriptor
);
// Exit to request upper level API to allocate memory if needed
// n.b. Care should be taken to align each of the added worksapce tensors to their type.
...
...
@@ -473,6 +491,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
variant_pack
[
dropout_seed
]
=
devPtrDropoutSeed
;
variant_pack
[
dropout_offset
]
=
devPtrDropoutOffset
;
}
if
(
is_softmax_offset
)
{
variant_pack
[
softmax_offset
]
=
devPtrSoftmaxOffset
;
}
NVTE_CHECK_CUDNN_FE
(
mha_graph
->
execute
(
handle
,
variant_pack
,
workspace
));
}
catch
(
cudnn_frontend
::
cudnnException
&
e
)
{
NVTE_ERROR
(
e
.
what
());
...
...
@@ -483,14 +506,14 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
int64_t
b
,
int64_t
h
,
int64_t
hg
,
int64_t
s_q
,
int64_t
s_kv
,
int64_t
d_qk
,
int64_t
d_v
,
int64_t
max_b
,
int64_t
max_t_q
,
int64_t
max_t_kv
,
int64_t
bias_b
,
int64_t
bias_h
,
float
scaling_factor
,
float
dropout_probability
,
NVTE_QKV_Layout
layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
void
*
devPtrQ
,
void
*
devPtrKTranspose
,
void
*
devPtrVTranspose
,
void
*
devPtrO
,
void
*
devPtrSoftmaxStats
,
void
*
devPtrBias
,
void
*
devPtr
dQ
,
void
*
devPtr
dK
,
void
*
devPtrd
V
,
void
*
devPtrd
O
,
void
*
devPtrd
Bias
,
void
*
devPtr
DropoutSeed
,
void
*
devPtr
Dropout
Offset
,
void
*
devPtr
CuSeqlensQ
,
void
*
devPtr
CuSeqlensKV
,
void
*
devPtrSeq
Offset
sQ
,
void
*
devPtrSeq
Offset
sKV
,
cudnn_frontend
::
DataType_t
tensorType
,
void
*
workspace
,
size_t
*
workspace_siz
e
,
cudaStream_t
stream
,
cudnnHandle_t
handle
)
{
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
void
*
devPtrQ
,
void
*
devPtrKTranspose
,
void
*
devPtrVTranspose
,
void
*
devPtrO
,
void
*
devPtrSoftmaxStats
,
void
*
devPtr
Bias
,
void
*
devPtr
SoftmaxOffset
,
void
*
devPtrd
Q
,
void
*
devPtrd
K
,
void
*
devPtrd
V
,
void
*
devPtr
dO
,
void
*
devPtrdBias
,
void
*
devPtr
dSoftmax
Offset
,
void
*
devPtr
DropoutSeed
,
void
*
devPtr
DropoutOffset
,
void
*
devPtr
Cu
Seq
len
sQ
,
void
*
devPtr
Cu
Seq
len
sKV
,
void
*
devPtrSeqOffsetsQ
,
void
*
devPtrSeqOffsetsKV
,
cudnn_frontend
::
DataType_t
tensorTyp
e
,
void
*
workspace
,
size_t
*
workspace_size
,
cudaStream_t
stream
,
cudnnHandle_t
handle
)
{
using
namespace
transformer_engine
;
bool
is_bias
=
(
bias_type
==
NVTE_Bias_Type
::
NVTE_POST_SCALE_BIAS
);
...
...
@@ -506,6 +529,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
is_causal
=
true
;
is_bottom_right
=
false
;
}
bool
is_softmax_offset
=
(
softmax_type
!=
NVTE_Softmax_Type
::
NVTE_VANILLA_SOFTMAX
);
bool
is_dropout
=
(
dropout_probability
!=
0.0
f
);
NVTE_QKV_Format
q_format
=
nvte_get_q_format
(
layout
);
NVTE_QKV_Format
kv_format
=
nvte_get_kv_format
(
layout
);
...
...
@@ -558,11 +582,14 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
layout
,
bias_type
,
mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
deterministic
,
tensorType
,
tensorType
};
cudnn_frontend
::
DataType_t
::
NOT_SET
,
cudnn_frontend
::
DataType_t
::
NOT_SET
,
cudnn_frontend
::
DataType_t
::
NOT_SET
};
namespace
fe
=
cudnn_frontend
;
using
graph_and_tensors
=
...
...
@@ -579,6 +606,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
// dV
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
// bias
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
// dBias
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
// softmax_offset
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
// d_softmax_offset
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
// seq_q
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
// seq_kv
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
// offset_q
...
...
@@ -608,7 +637,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
.
set_compute_data_type
(
fe
::
DataType_t
::
FLOAT
);
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
q
,
k
,
v
,
o
,
dO
,
stats
,
attn_scale
;
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
bias
,
dBias
,
seq_q
,
seq_kv
;
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
bias
,
dBias
,
softmax_offset
,
d_softmax_offset
,
seq_q
,
seq_kv
;
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
offset_q
,
offset_k
,
offset_v
,
offset_o
,
offset_stats
;
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
dropout_seed
,
dropout_offset
;
...
...
@@ -771,6 +801,21 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
sdpa_backward_options
.
set_dropout
(
dropout_probability
,
dropout_seed
,
dropout_offset
);
}
if
(
is_softmax_offset
)
{
softmax_offset
=
mha_graph
->
tensor
(
fe
::
graph
::
Tensor_attributes
()
.
set_name
(
"softmax_offset"
)
.
set_dim
({
1
,
h
,
1
,
1
})
.
set_stride
({
h
,
1
,
1
,
1
})
.
set_data_type
(
fe
::
DataType_t
::
FLOAT
));
sdpa_backward_options
.
set_sink_token
(
softmax_offset
);
d_softmax_offset
=
mha_graph
->
tensor
(
fe
::
graph
::
Tensor_attributes
()
.
set_name
(
"d_softmax_offset"
)
.
set_dim
({
1
,
h
,
1
,
1
})
.
set_stride
({
h
,
1
,
1
,
1
})
.
set_data_type
(
fe
::
DataType_t
::
FLOAT
));
sdpa_backward_options
.
set_dsink_token
(
d_softmax_offset
);
}
auto
[
dQ
,
dK
,
dV
]
=
mha_graph
->
sdpa_backward
(
q
,
k
,
v
,
o
,
dO
,
stats
,
sdpa_backward_options
);
dQ
->
set_output
(
true
).
set_dim
({
b
,
h
,
s_q
,
d_qk
}).
set_stride
(
q_stride
);
...
...
@@ -796,6 +841,9 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>>
// dV
key_tensors_tuple
=
std
::
make_tuple
(
q
,
k
,
v
,
o
,
dO
,
stats
,
attn_scale
,
dQ
,
dK
,
dV
);
auto
bias_tuple
=
is_bias
?
std
::
make_tuple
(
bias
,
dBias
)
:
std
::
make_tuple
(
nullptr
,
nullptr
);
auto
softmax_offset_tuple
=
is_softmax_offset
?
std
::
make_tuple
(
softmax_offset
,
d_softmax_offset
)
:
std
::
make_tuple
(
nullptr
,
nullptr
);
auto
padding_tuple
=
is_padding
?
std
::
make_tuple
(
seq_q
,
seq_kv
)
:
std
::
make_tuple
(
nullptr
,
nullptr
);
auto
offset_qo_tuple
=
...
...
@@ -814,17 +862,17 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
NVTE_CHECK_CUDNN_FE
(
mha_graph
->
check_support
(
handle
));
NVTE_CHECK_CUDNN_FE
(
mha_graph
->
build_plans
(
handle
));
auto
return_tuple
=
std
::
tuple_cat
(
std
::
make_tuple
(
mha_graph
),
key_tensors
_tuple
,
bias
_tuple
,
padding
_tuple
,
offset_qo_tuple
,
offset_kv_tuple
,
offset_s_tuple
,
dropout_tuple
);
auto
return_tuple
=
std
::
tuple_cat
(
std
::
make_tuple
(
mha_graph
),
key_tensors_tuple
,
bias_tuple
,
softmax_offset
_tuple
,
padding
_tuple
,
offset_qo
_tuple
,
offset_kv_tuple
,
offset_s_tuple
,
dropout_tuple
);
cache
.
insert
({
descriptor
,
return_tuple
});
return
return_tuple
;
};
auto
[
mha_graph
,
q
,
k
,
v
,
o
,
dO
,
stats
,
attn_scale
,
dQ
,
dK
,
dV
,
bias
,
dBias
,
s
eq_q
,
seq_kv
,
offset_q
,
offset_o
,
offset_k
,
offset_v
,
offset_stats
,
dropout_seed
,
dropout_offset
]
=
get_graph
(
sdpa_f16_bprop_cache
,
descriptor
);
auto
[
mha_graph
,
q
,
k
,
v
,
o
,
dO
,
stats
,
attn_scale
,
dQ
,
dK
,
dV
,
bias
,
dBias
,
s
oftmax_offset
,
d_softmax_offset
,
seq_q
,
seq_kv
,
offset_q
,
offset_o
,
offset_k
,
offset_v
,
offset_stats
,
dropout_seed
,
dropout_offset
]
=
get_graph
(
sdpa_f16_bprop_cache
,
descriptor
);
// Exit to request upper level API to allocate memory if needed
// n.b. Care should be taken to align each of the added worksapce tensors to their type.
...
...
@@ -938,6 +986,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
variant_pack
[
dropout_offset
]
=
devPtrDropoutOffset
;
}
if
(
is_softmax_offset
)
{
variant_pack
[
softmax_offset
]
=
devPtrSoftmaxOffset
;
variant_pack
[
d_softmax_offset
]
=
devPtrdSoftmaxOffset
;
}
NVTE_CHECK_CUDNN_FE
(
mha_graph
->
execute
(
handle
,
variant_pack
,
workspace
));
}
catch
(
cudnn_frontend
::
cudnnException
&
e
)
{
NVTE_ERROR
(
e
.
what
());
...
...
@@ -949,8 +1002,9 @@ using namespace transformer_engine::fused_attn;
void
fused_attn_arbitrary_seqlen_fwd_qkvpacked
(
size_t
batch
,
size_t
num_attn_heads
,
size_t
max_seqlen
,
size_t
head_dim
,
size_t
num_tokens
,
bool
is_training
,
float
attn_scale
,
float
p_dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
const
Tensor
*
input_QKV
,
const
Tensor
*
input_Bias
,
Tensor
*
output_O
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
const
Tensor
*
input_QKV
,
const
Tensor
*
input_Bias
,
const
Tensor
*
input_SoftmaxOffset
,
Tensor
*
output_O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
Tensor
*
cu_seqlens
,
const
Tensor
*
cu_seqlens_padded
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
)
{
using
namespace
transformer_engine
;
...
...
@@ -977,6 +1031,10 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
bias_b
=
input_Bias
->
data
.
shape
[
0
];
bias_h
=
input_Bias
->
data
.
shape
[
1
];
}
void
*
devPtrSoftmaxOffset
=
nullptr
;
if
(
softmax_type
!=
NVTE_VANILLA_SOFTMAX
)
{
devPtrSoftmaxOffset
=
input_SoftmaxOffset
->
data
.
dptr
;
}
void
*
devPtrO
=
output_O
->
data
.
dptr
;
void
*
devPtrS
=
nullptr
;
...
...
@@ -990,11 +1048,10 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
max_tokens
=
get_max_tokens
(
num_tokens
);
}
size_t
i
=
0
;
if
(
Aux_CTX_Tensors
->
size
==
0
)
{
const
auto
cudnn_runtime_version
=
cudnnGetVersion
();
if
((
bias_type
!=
NVTE_NO_BIAS
)
&&
(
bias_type
!=
NVTE_ALIBI
))
{
Aux_CTX_Tensors
->
size
=
3
;
Tensor
*
output_S
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
output_S
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
output_S
->
data
.
dptr
=
nullptr
;
if
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_THD
&&
cudnn_runtime_version
>=
90600
)
{
output_S
->
data
.
shape
=
{
max_tokens
,
num_attn_heads
,
1
};
...
...
@@ -1002,41 +1059,39 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
output_S
->
data
.
shape
=
{
batch
,
num_attn_heads
,
max_seqlen
,
1
};
}
output_S
->
data
.
dtype
=
DType
::
kFloat32
;
Tensor
*
output_rng_state
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
1
]);
Tensor
*
output_rng_state
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
output_rng_state
->
data
.
dptr
=
nullptr
;
output_rng_state
->
data
.
shape
=
{
2
};
output_rng_state
->
data
.
dtype
=
DType
::
kInt64
;
Tensor
*
output_bias
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
2
]);
if
((
bias_type
!=
NVTE_NO_BIAS
)
&&
(
bias_type
!=
NVTE_ALIBI
))
{
Tensor
*
output_bias
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
output_bias
->
data
.
dptr
=
nullptr
;
output_bias
->
data
.
shape
=
{
bias_b
,
bias_h
,
max_seqlen
,
max_seqlen
};
output_bias
->
data
.
dtype
=
QKV_type
;
}
else
{
Aux_CTX_Tensors
->
size
=
2
;
Tensor
*
output_S
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
0
]);
output_S
->
data
.
dptr
=
nullptr
;
if
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_THD
&&
cudnn_runtime_version
>=
90600
)
{
output_S
->
data
.
shape
=
{
max_tokens
,
num_attn_heads
,
1
};
}
else
{
output_S
->
data
.
shape
=
{
batch
,
num_attn_heads
,
max_seqlen
,
1
};
}
output_S
->
data
.
dtype
=
DType
::
kFloat32
;
Tensor
*
output_rng_state
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
1
]);
output_rng_state
->
data
.
dptr
=
nullptr
;
output_rng_state
->
data
.
shape
=
{
2
};
output_rng_state
->
data
.
dtype
=
DType
::
kInt64
;
if
(
softmax_type
!=
NVTE_VANILLA_SOFTMAX
)
{
Tensor
*
output_softmax_offset
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
output_softmax_offset
->
data
.
dptr
=
nullptr
;
output_softmax_offset
->
data
.
shape
=
{
1
,
num_attn_heads
,
1
,
1
};
output_softmax_offset
->
data
.
dtype
=
DType
::
kFloat32
;
}
}
else
if
(
Aux_CTX_Tensors
->
size
==
2
)
{
Tensor
*
output_S
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
0
]);
devPtrS
=
output_S
->
data
.
dptr
;
Tensor
*
output_rng_state
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
1
]);
output_rng_state
->
data
.
dptr
=
rng_state
->
data
.
dptr
;
}
else
if
(
Aux_CTX_Tensors
->
size
==
3
)
{
Tensor
*
output_S
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Aux_CTX_Tensors
->
size
=
i
;
}
else
if
(
Aux_CTX_Tensors
->
size
>=
2
)
{
Tensor
*
output_S
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
devPtrS
=
output_S
->
data
.
dptr
;
Tensor
*
output_rng_state
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
1
]);
Tensor
*
output_rng_state
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
output_rng_state
->
data
.
dptr
=
rng_state
->
data
.
dptr
;
Tensor
*
output_bias
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
2
]);
if
((
bias_type
!=
NVTE_NO_BIAS
)
&&
(
bias_type
!=
NVTE_ALIBI
))
{
Tensor
*
output_bias
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
output_bias
->
data
.
dptr
=
devPtrBias
;
}
if
(
softmax_type
!=
NVTE_VANILLA_SOFTMAX
)
{
Tensor
*
output_softmax_offset
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
output_softmax_offset
->
data
.
dptr
=
devPtrSoftmaxOffset
;
}
}
else
{
NVTE_ERROR
(
"Unexpected Aux_CTX_Tensors->size."
);
}
...
...
@@ -1050,11 +1105,11 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
fused_attn_arbitrary_seqlen_fwd_impl
(
batch
,
num_attn_heads
,
num_attn_heads
,
max_seqlen
,
max_seqlen
,
head_dim
,
head_dim
,
max_batch_size
,
max_tokens
,
max_tokens
,
0
,
0
,
0
,
0
,
0
,
0
,
bias_b
,
bias_h
,
is_training
,
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
mask_type
,
window_size_left
,
window_size_
righ
t
,
devPtrQ
,
devPtrK
,
devPtrV
,
devPtrBias
,
devPtrS
,
devPtrO
,
devPtrDropoutSeed
,
devPtrDropoutOffset
,
devPtrCuSeqlens
,
devPtrCuSeqlens
,
nullptr
,
nullptr
,
devPtrSeqOffsets
,
devPtrSeqOffsets
,
get_cudnn_fe_dtype
(
QKV_type
),
workspace
->
data
.
dptr
,
&
workspace_size
,
stream
,
handle
);
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
mask_type
,
softmax_type
,
window_size_
lef
t
,
window_size_right
,
devPtrQ
,
devPtrK
,
devPtrV
,
devPtrBias
,
devPtrS
oftmaxOffset
,
devPtrS
,
devPtrO
,
devPtrDropoutSeed
,
devPtrDropoutOffset
,
devPtrCuSeqlens
,
devPtrCuSeqlens
,
nullptr
,
nullptr
,
devPtrSeqOffsets
,
devPtrSeqOffsets
,
get_cudnn_fe_dtype
(
QKV_type
),
workspace
->
data
.
dptr
,
&
workspace_size
,
stream
,
handle
);
if
(
workspace_size
>
0
)
{
if
(
workspace
->
data
.
dptr
==
nullptr
)
{
...
...
@@ -1074,9 +1129,10 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
void
fused_attn_arbitrary_seqlen_bwd_qkvpacked
(
size_t
batch
,
size_t
num_attn_heads
,
size_t
max_seqlen
,
size_t
head_dim
,
size_t
num_tokens
,
float
attn_scale
,
float
p_dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
const
Tensor
*
input_QKV
,
const
Tensor
*
input_O
,
const
Tensor
*
input_dO
,
const
Tensor
*
input_Bias
,
Tensor
*
output_S
,
Tensor
*
output_dQKV
,
Tensor
*
output_dBias
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
const
Tensor
*
input_QKV
,
const
Tensor
*
input_O
,
const
Tensor
*
input_dO
,
const
Tensor
*
input_Bias
,
const
Tensor
*
input_SoftmaxOffset
,
Tensor
*
output_S
,
Tensor
*
output_dQKV
,
Tensor
*
output_dBias
,
Tensor
*
output_dSoftmaxOffset
,
const
Tensor
*
cu_seqlens
,
const
Tensor
*
cu_seqlens_padded
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
)
{
using
namespace
transformer_engine
;
...
...
@@ -1122,6 +1178,12 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
void
*
devPtrSoftmaxStats
=
nullptr
;
devPtrSoftmaxStats
=
output_S
->
data
.
dptr
;
void
*
devPtrSoftmaxOffset
=
nullptr
;
void
*
devPtrdSoftmaxOffset
=
nullptr
;
if
(
softmax_type
!=
NVTE_VANILLA_SOFTMAX
)
{
devPtrSoftmaxOffset
=
input_SoftmaxOffset
->
data
.
dptr
;
devPtrdSoftmaxOffset
=
output_dSoftmaxOffset
->
data
.
dptr
;
}
void
*
devPtrCuSeqlens
=
cu_seqlens
->
data
.
dptr
;
void
*
devPtrSeqOffsets
=
cu_seqlens_padded
->
data
.
dptr
;
...
...
@@ -1135,11 +1197,11 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
fused_attn_arbitrary_seqlen_bwd_impl
(
batch
,
num_attn_heads
,
num_attn_heads
,
max_seqlen
,
max_seqlen
,
head_dim
,
head_dim
,
max_batch_size
,
max_tokens
,
max_tokens
,
bias_b
,
bias_h
,
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
mask_type
,
window_size_left
,
window_size_right
,
deterministic
,
devPtrQ
,
devPtrK
,
devPtrV
,
devPtrO
,
devPtrSoftmaxStats
,
devPtrBias
,
devPtr
dQ
,
devPtrdK
,
devPtrdV
,
devPtrdO
,
devPtrd
Bias
,
devPtr
DropoutSeed
,
devPtrDropoutOffset
,
devPtrCuSeqlens
,
devPtrCuSeqlens
,
devPtr
Seq
Offset
s
,
devPtrSeq
Offsets
,
get_cudnn_fe_dtype
(
QKV_type
),
workspace
->
data
.
dptr
,
&
workspace_size
,
stream
,
handle
);
bias_type
,
mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
deterministic
,
devPtrQ
,
devPtrK
,
devPtrV
,
devPtrO
,
devPtrSoftmaxStats
,
devPtrBias
,
devPtr
SoftmaxOffset
,
devPtrd
Q
,
devPtr
dK
,
devPtrdV
,
devPtrdO
,
devPtrdBias
,
devPtrdSoftmaxOffset
,
devPtrDropoutSeed
,
devPtr
Dropout
Offset
,
devPtr
Cu
Seq
lens
,
devPtrCuSeqlens
,
devPtrSeqOffsets
,
devPtrSeqOffsets
,
get_cudnn_fe_dtype
(
QKV_type
),
workspace
->
data
.
dptr
,
&
workspace_size
,
stream
,
handle
);
if
(
workspace_size
>
0
)
{
if
(
workspace
->
data
.
dptr
==
nullptr
)
{
...
...
@@ -1161,12 +1223,12 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t
num_pages_k
,
size_t
num_pages_v
,
size_t
page_size_k
,
size_t
page_size_v
,
size_t
max_pages_per_seq_k
,
size_t
max_pages_per_seq_v
,
bool
is_training
,
float
attn_scale
,
float
p_dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
const
Tensor
*
input_Q
,
const
Tensor
*
input_KV
,
const
Tensor
*
input_Bias
,
Tensor
*
output_O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
Tensor
*
cu_seqlens_q
,
const
Tensor
*
cu_seqlens_kv
,
const
Tensor
*
cu_seqlens_q
_padded
,
const
Tensor
*
cu_seqlens_kv_padded
,
const
Tensor
*
page_table_k
,
const
Tensor
*
page_table_v
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
)
{
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
const
Tensor
*
input_Q
,
const
Tensor
*
input_KV
,
const
Tensor
*
input_Bias
,
const
Tensor
*
input_SoftmaxOffset
,
Tensor
*
output_O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
Tensor
*
cu_seqlens_q
,
const
Tensor
*
cu_seqlens_kv
,
const
Tensor
*
cu_seqlens_q
_padded
,
const
Tensor
*
cu_seqlens_kv_padded
,
const
Tensor
*
page_table_k
,
const
Tensor
*
page_table_v
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
)
{
using
namespace
transformer_engine
;
const
auto
QKV_type
=
input_Q
->
data
.
dtype
;
...
...
@@ -1192,6 +1254,10 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
bias_b
=
input_Bias
->
data
.
shape
[
0
];
bias_h
=
input_Bias
->
data
.
shape
[
1
];
}
void
*
devPtrSoftmaxOffset
=
nullptr
;
if
(
softmax_type
!=
NVTE_VANILLA_SOFTMAX
)
{
devPtrSoftmaxOffset
=
input_SoftmaxOffset
->
data
.
dptr
;
}
void
*
devPtrO
=
output_O
->
data
.
dptr
;
void
*
devPtrS
=
nullptr
;
...
...
@@ -1216,11 +1282,10 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
max_tokens_kv
=
get_max_tokens
(
num_tokens_kv
);
}
size_t
i
=
0
;
if
(
Aux_CTX_Tensors
->
size
==
0
)
{
const
auto
cudnn_runtime_version
=
cudnnGetVersion
();
if
((
bias_type
!=
NVTE_NO_BIAS
)
&&
(
bias_type
!=
NVTE_ALIBI
))
{
Aux_CTX_Tensors
->
size
=
3
;
Tensor
*
output_S
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
output_S
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
output_S
->
data
.
dptr
=
nullptr
;
if
(
q_format
==
NVTE_QKV_Format
::
NVTE_THD
&&
cudnn_runtime_version
>=
90600
)
{
output_S
->
data
.
shape
=
{
max_tokens_q
,
num_attn_heads
,
1
};
...
...
@@ -1228,41 +1293,39 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
output_S
->
data
.
shape
=
{
batch
,
num_attn_heads
,
max_seqlen_q
,
1
};
}
output_S
->
data
.
dtype
=
DType
::
kFloat32
;
Tensor
*
output_rng_state
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
1
]);
Tensor
*
output_rng_state
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
output_rng_state
->
data
.
dptr
=
nullptr
;
output_rng_state
->
data
.
shape
=
{
2
};
output_rng_state
->
data
.
dtype
=
DType
::
kInt64
;
Tensor
*
output_bias
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
2
]);
if
((
bias_type
!=
NVTE_NO_BIAS
)
&&
(
bias_type
!=
NVTE_ALIBI
))
{
Tensor
*
output_bias
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
output_bias
->
data
.
dptr
=
nullptr
;
output_bias
->
data
.
shape
=
{
bias_b
,
bias_h
,
max_seqlen_q
,
max_seqlen_kv
};
output_bias
->
data
.
dtype
=
QKV_type
;
}
else
{
Aux_CTX_Tensors
->
size
=
2
;
Tensor
*
output_S
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
0
]);
output_S
->
data
.
dptr
=
nullptr
;
if
(
q_format
==
NVTE_QKV_Format
::
NVTE_THD
&&
cudnn_runtime_version
>=
90600
)
{
output_S
->
data
.
shape
=
{
max_tokens_q
,
num_attn_heads
,
1
};
}
else
{
output_S
->
data
.
shape
=
{
batch
,
num_attn_heads
,
max_seqlen_q
,
1
};
}
output_S
->
data
.
dtype
=
DType
::
kFloat32
;
Tensor
*
output_rng_state
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
1
]);
output_rng_state
->
data
.
dptr
=
nullptr
;
output_rng_state
->
data
.
shape
=
{
2
};
output_rng_state
->
data
.
dtype
=
DType
::
kInt64
;
if
(
softmax_type
!=
NVTE_VANILLA_SOFTMAX
)
{
Tensor
*
output_softmax_offset
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
output_softmax_offset
->
data
.
dptr
=
nullptr
;
output_softmax_offset
->
data
.
shape
=
{
1
,
num_attn_heads
,
1
,
1
};
output_softmax_offset
->
data
.
dtype
=
DType
::
kFloat32
;
}
}
else
if
(
Aux_CTX_Tensors
->
size
==
2
)
{
Tensor
*
output_S
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
0
]);
devPtrS
=
output_S
->
data
.
dptr
;
Tensor
*
output_rng_state
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
1
]);
output_rng_state
->
data
.
dptr
=
rng_state
->
data
.
dptr
;
}
else
if
(
Aux_CTX_Tensors
->
size
==
3
)
{
Tensor
*
output_S
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Aux_CTX_Tensors
->
size
=
i
;
}
else
if
(
Aux_CTX_Tensors
->
size
>=
2
)
{
Tensor
*
output_S
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
devPtrS
=
output_S
->
data
.
dptr
;
Tensor
*
output_rng_state
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
1
]);
Tensor
*
output_rng_state
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
output_rng_state
->
data
.
dptr
=
rng_state
->
data
.
dptr
;
Tensor
*
output_bias
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
2
]);
if
((
bias_type
!=
NVTE_NO_BIAS
)
&&
(
bias_type
!=
NVTE_ALIBI
))
{
Tensor
*
output_bias
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
output_bias
->
data
.
dptr
=
devPtrBias
;
}
if
(
softmax_type
!=
NVTE_VANILLA_SOFTMAX
)
{
Tensor
*
output_softmax_offset
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
output_softmax_offset
->
data
.
dptr
=
devPtrSoftmaxOffset
;
}
}
else
{
NVTE_ERROR
(
"Unexpected Aux_CTX_Tensors->size."
);
}
...
...
@@ -1277,11 +1340,11 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
batch
,
num_attn_heads
,
num_gqa_groups
,
max_seqlen_q
,
max_seqlen_kv
,
head_dim
,
head_dim
,
max_batch_size
,
max_tokens_q
,
max_tokens_kv
,
num_pages_k
,
num_pages_v
,
page_size_k
,
page_size_v
,
max_pages_per_seq_k
,
max_pages_per_seq_v
,
bias_b
,
bias_h
,
is_training
,
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
mask_type
,
window_size_left
,
window_size_
righ
t
,
devPtrQ
,
devPtrK
,
devPtrV
,
devPtrBias
,
devPtrS
,
devPtrO
,
devPtrDropoutSeed
,
devPtrDropoutOffset
,
devPtrCuSeqlensQ
,
devPtrCuSeqlensKV
,
devPtrPageTableK
,
devPtrPageTableV
,
devPtr
SeqOffsetsQ
,
devPtrSeqOffsetsKV
,
get_cudnn_fe_dtype
(
QKV_type
),
workspace
->
data
.
dptr
,
&
workspace_size
,
stream
,
handle
);
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
mask_type
,
softmax_type
,
window_size_
lef
t
,
window_size_right
,
devPtrQ
,
devPtrK
,
devPtrV
,
devPtrBias
,
devPtrS
oftmaxOffset
,
devPtrS
,
devPtrO
,
devPtrDropoutSeed
,
devPtrDropoutOffset
,
devPtrCuSeqlensQ
,
devPtrCuSeqlensKV
,
devPtr
PageTableK
,
devPtrPageTableV
,
devPtrSeqOffsetsQ
,
devPtrSeqOffsetsKV
,
get_cudnn_fe_dtype
(
QKV_type
),
workspace
->
data
.
dptr
,
&
workspace_size
,
stream
,
handle
);
if
(
workspace_size
>
0
)
{
if
(
workspace
->
data
.
dptr
==
nullptr
)
{
...
...
@@ -1302,10 +1365,11 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t
batch
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim
,
size_t
num_tokens_q
,
size_t
num_tokens_kv
,
float
attn_scale
,
float
p_dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
const
Tensor
*
input_Q
,
const
Tensor
*
input_KV
,
const
Tensor
*
input_O
,
const
Tensor
*
input_dO
,
const
Tensor
*
input_Bias
,
Tensor
*
output_S
,
Tensor
*
output_dQ
,
Tensor
*
output_dKV
,
Tensor
*
output_dBias
,
const
Tensor
*
cu_seqlens_q
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
const
Tensor
*
input_Q
,
const
Tensor
*
input_KV
,
const
Tensor
*
input_O
,
const
Tensor
*
input_dO
,
const
Tensor
*
input_Bias
,
const
Tensor
*
input_SoftmaxOffset
,
Tensor
*
output_S
,
Tensor
*
output_dQ
,
Tensor
*
output_dKV
,
Tensor
*
output_dBias
,
Tensor
*
output_dSoftmaxOffset
,
const
Tensor
*
cu_seqlens_q
,
const
Tensor
*
cu_seqlens_kv
,
const
Tensor
*
cu_seqlens_q_padded
,
const
Tensor
*
cu_seqlens_kv_padded
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
)
{
...
...
@@ -1359,6 +1423,12 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
void
*
devPtrSoftmaxStats
=
nullptr
;
devPtrSoftmaxStats
=
output_S
->
data
.
dptr
;
void
*
devPtrSoftmaxOffset
=
nullptr
;
void
*
devPtrdSoftmaxOffset
=
nullptr
;
if
(
softmax_type
!=
NVTE_VANILLA_SOFTMAX
)
{
devPtrSoftmaxOffset
=
input_SoftmaxOffset
->
data
.
dptr
;
devPtrdSoftmaxOffset
=
output_dSoftmaxOffset
->
data
.
dptr
;
}
void
*
devPtrCuSeqlensQ
=
cu_seqlens_q
->
data
.
dptr
;
void
*
devPtrCuSeqlensKV
=
cu_seqlens_kv
->
data
.
dptr
;
...
...
@@ -1374,9 +1444,10 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
fused_attn_arbitrary_seqlen_bwd_impl
(
batch
,
num_attn_heads
,
num_gqa_groups
,
max_seqlen_q
,
max_seqlen_kv
,
head_dim
,
head_dim
,
max_batch_size
,
max_tokens_q
,
max_tokens_kv
,
bias_b
,
bias_h
,
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
mask_type
,
window_size_left
,
window_size_right
,
deterministic
,
devPtrQ
,
devPtrK
,
devPtrV
,
devPtrO
,
devPtrSoftmaxStats
,
devPtrBias
,
devPtrdQ
,
devPtrdK
,
devPtrdV
,
devPtrdO
,
devPtrdBias
,
devPtrDropoutSeed
,
devPtrDropoutOffset
,
devPtrCuSeqlensQ
,
qkv_layout
,
bias_type
,
mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
deterministic
,
devPtrQ
,
devPtrK
,
devPtrV
,
devPtrO
,
devPtrSoftmaxStats
,
devPtrBias
,
devPtrSoftmaxOffset
,
devPtrdQ
,
devPtrdK
,
devPtrdV
,
devPtrdO
,
devPtrdBias
,
devPtrdSoftmaxOffset
,
devPtrDropoutSeed
,
devPtrDropoutOffset
,
devPtrCuSeqlensQ
,
devPtrCuSeqlensKV
,
devPtrSeqOffsetsQ
,
devPtrSeqOffsetsKV
,
get_cudnn_fe_dtype
(
QKV_type
),
workspace
->
data
.
dptr
,
&
workspace_size
,
stream
,
handle
);
...
...
@@ -1401,12 +1472,13 @@ void fused_attn_arbitrary_seqlen_fwd(
size_t
num_tokens_kv
,
size_t
num_pages_k
,
size_t
num_pages_v
,
size_t
page_size_k
,
size_t
page_size_v
,
size_t
max_pages_per_seq_k
,
size_t
max_pages_per_seq_v
,
bool
is_training
,
float
attn_scale
,
float
p_dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
const
Tensor
*
input_Q
,
const
Tensor
*
input_K
,
const
Tensor
*
input_V
,
const
Tensor
*
input_Bias
,
Tensor
*
output_O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
Tensor
*
cu_seqlens_q
,
const
Tensor
*
cu_seqlens_kv
,
const
Tensor
*
cu_seqlens_q_padded
,
const
Tensor
*
cu_seqlens_kv_padded
,
const
Tensor
*
page_table_k
,
const
Tensor
*
page_table_v
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
)
{
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
const
Tensor
*
input_Q
,
const
Tensor
*
input_K
,
const
Tensor
*
input_V
,
const
Tensor
*
input_Bias
,
const
Tensor
*
input_SoftmaxOffset
,
Tensor
*
output_O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
Tensor
*
cu_seqlens_q
,
const
Tensor
*
cu_seqlens_kv
,
const
Tensor
*
cu_seqlens_q_padded
,
const
Tensor
*
cu_seqlens_kv_padded
,
const
Tensor
*
page_table_k
,
const
Tensor
*
page_table_v
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
)
{
using
namespace
transformer_engine
;
const
auto
QKV_type
=
input_Q
->
data
.
dtype
;
...
...
@@ -1425,6 +1497,10 @@ void fused_attn_arbitrary_seqlen_fwd(
bias_b
=
input_Bias
->
data
.
shape
[
0
];
bias_h
=
input_Bias
->
data
.
shape
[
1
];
}
void
*
devPtrSoftmaxOffset
=
nullptr
;
if
(
softmax_type
!=
NVTE_VANILLA_SOFTMAX
)
{
devPtrSoftmaxOffset
=
input_SoftmaxOffset
->
data
.
dptr
;
}
void
*
devPtrCuSeqlensQ
=
cu_seqlens_q
->
data
.
dptr
;
void
*
devPtrCuSeqlensKV
=
cu_seqlens_kv
->
data
.
dptr
;
...
...
@@ -1446,11 +1522,10 @@ void fused_attn_arbitrary_seqlen_fwd(
max_tokens_kv
=
get_max_tokens
(
num_tokens_kv
);
}
size_t
i
=
0
;
if
(
Aux_CTX_Tensors
->
size
==
0
)
{
const
auto
cudnn_runtime_version
=
cudnnGetVersion
();
if
((
bias_type
!=
NVTE_NO_BIAS
)
&&
(
bias_type
!=
NVTE_ALIBI
))
{
Aux_CTX_Tensors
->
size
=
3
;
Tensor
*
output_S
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Tensor
*
output_S
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
output_S
->
data
.
dptr
=
nullptr
;
if
(
q_format
==
NVTE_QKV_Format
::
NVTE_THD
&&
cudnn_runtime_version
>=
90600
)
{
output_S
->
data
.
shape
=
{
max_tokens_q
,
num_attn_heads
,
1
};
...
...
@@ -1458,41 +1533,39 @@ void fused_attn_arbitrary_seqlen_fwd(
output_S
->
data
.
shape
=
{
batch
,
num_attn_heads
,
max_seqlen_q
,
1
};
}
output_S
->
data
.
dtype
=
DType
::
kFloat32
;
Tensor
*
output_rng_state
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
1
]);
Tensor
*
output_rng_state
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
output_rng_state
->
data
.
dptr
=
nullptr
;
output_rng_state
->
data
.
shape
=
{
2
};
output_rng_state
->
data
.
dtype
=
DType
::
kInt64
;
Tensor
*
output_bias
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
2
]);
if
((
bias_type
!=
NVTE_NO_BIAS
)
&&
(
bias_type
!=
NVTE_ALIBI
))
{
Tensor
*
output_bias
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
output_bias
->
data
.
dptr
=
nullptr
;
output_bias
->
data
.
shape
=
{
bias_b
,
bias_h
,
max_seqlen_q
,
max_seqlen_kv
};
output_bias
->
data
.
dtype
=
QKV_type
;
}
else
{
Aux_CTX_Tensors
->
size
=
2
;
Tensor
*
output_S
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
0
]);
output_S
->
data
.
dptr
=
nullptr
;
if
(
q_format
==
NVTE_QKV_Format
::
NVTE_THD
&&
cudnn_runtime_version
>=
90600
)
{
output_S
->
data
.
shape
=
{
max_tokens_q
,
num_attn_heads
,
1
};
}
else
{
output_S
->
data
.
shape
=
{
batch
,
num_attn_heads
,
max_seqlen_q
,
1
};
}
output_S
->
data
.
dtype
=
DType
::
kFloat32
;
Tensor
*
output_rng_state
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
1
]);
output_rng_state
->
data
.
dptr
=
nullptr
;
output_rng_state
->
data
.
shape
=
{
2
};
output_rng_state
->
data
.
dtype
=
DType
::
kInt64
;
if
(
softmax_type
!=
NVTE_VANILLA_SOFTMAX
)
{
Tensor
*
output_softmax_offset
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
output_softmax_offset
->
data
.
dptr
=
nullptr
;
output_softmax_offset
->
data
.
shape
=
{
1
,
num_attn_heads
,
1
,
1
};
output_softmax_offset
->
data
.
dtype
=
DType
::
kFloat32
;
}
}
else
if
(
Aux_CTX_Tensors
->
size
==
2
)
{
Tensor
*
output_S
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
0
]);
devPtrS
=
output_S
->
data
.
dptr
;
Tensor
*
output_rng_state
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
1
]);
output_rng_state
->
data
.
dptr
=
rng_state
->
data
.
dptr
;
}
else
if
(
Aux_CTX_Tensors
->
size
==
3
)
{
Tensor
*
output_S
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
0
]);
Aux_CTX_Tensors
->
size
=
i
;
}
else
if
(
Aux_CTX_Tensors
->
size
>=
2
)
{
Tensor
*
output_S
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
devPtrS
=
output_S
->
data
.
dptr
;
Tensor
*
output_rng_state
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
1
]);
Tensor
*
output_rng_state
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
output_rng_state
->
data
.
dptr
=
rng_state
->
data
.
dptr
;
Tensor
*
output_bias
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
2
]);
if
((
bias_type
!=
NVTE_NO_BIAS
)
&&
(
bias_type
!=
NVTE_ALIBI
))
{
Tensor
*
output_bias
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
output_bias
->
data
.
dptr
=
devPtrBias
;
}
if
(
softmax_type
!=
NVTE_VANILLA_SOFTMAX
)
{
Tensor
*
output_softmax_offset
=
convertNVTETensorCheck
(
Aux_CTX_Tensors
->
tensors
[
i
++
]);
output_softmax_offset
->
data
.
dptr
=
devPtrSoftmaxOffset
;
}
}
else
{
NVTE_ERROR
(
"Unexpected Aux_CTX_Tensors->size."
);
}
...
...
@@ -1507,11 +1580,11 @@ void fused_attn_arbitrary_seqlen_fwd(
batch
,
num_attn_heads
,
num_gqa_groups
,
max_seqlen_q
,
max_seqlen_kv
,
head_dim_qk
,
head_dim_v
,
max_batch_size
,
max_tokens_q
,
max_tokens_kv
,
num_pages_k
,
num_pages_v
,
page_size_k
,
page_size_v
,
max_pages_per_seq_k
,
max_pages_per_seq_v
,
bias_b
,
bias_h
,
is_training
,
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
mask_type
,
window_size_left
,
window_size_
righ
t
,
devPtrQ
,
devPtrK
,
devPtrV
,
devPtrBias
,
devPtrS
,
devPtrO
,
devPtrDropoutSeed
,
devPtrDropoutOffset
,
devPtrCuSeqlensQ
,
devPtrCuSeqlensKV
,
devPtrPageTableK
,
devPtrPageTableV
,
devPtr
SeqOffsetsQ
,
devPtrSeqOffsetsKV
,
get_cudnn_fe_dtype
(
QKV_type
),
workspace
->
data
.
dptr
,
&
workspace_size
,
stream
,
handle
);
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
mask_type
,
softmax_type
,
window_size_
lef
t
,
window_size_right
,
devPtrQ
,
devPtrK
,
devPtrV
,
devPtrBias
,
devPtrS
oftmaxOffset
,
devPtrS
,
devPtrO
,
devPtrDropoutSeed
,
devPtrDropoutOffset
,
devPtrCuSeqlensQ
,
devPtrCuSeqlensKV
,
devPtr
PageTableK
,
devPtrPageTableV
,
devPtrSeqOffsetsQ
,
devPtrSeqOffsetsKV
,
get_cudnn_fe_dtype
(
QKV_type
),
workspace
->
data
.
dptr
,
&
workspace_size
,
stream
,
handle
);
if
(
workspace_size
>
0
)
{
if
(
workspace
->
data
.
dptr
==
nullptr
)
{
...
...
@@ -1532,13 +1605,14 @@ void fused_attn_arbitrary_seqlen_bwd(
size_t
batch
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
size_t
num_tokens_q
,
size_t
num_tokens_kv
,
float
attn_scale
,
float
p_dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
const
Tensor
*
input_Q
,
const
Tensor
*
input_K
,
const
Tensor
*
input_V
,
const
Tensor
*
input_O
,
const
Tensor
*
input_dO
,
const
Tensor
*
input_Bias
,
Tensor
*
output_S
,
Tensor
*
output_dQ
,
Tensor
*
output_dK
,
Tensor
*
output_dV
,
Tensor
*
output_dBias
,
const
Tensor
*
cu_seqlens_q
,
const
Tensor
*
cu_seqlens_kv
,
const
Tensor
*
cu_seqlens_q_padded
,
const
Tensor
*
cu_seqlens_kv_padded
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
)
{
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
const
Tensor
*
input_Q
,
const
Tensor
*
input_K
,
const
Tensor
*
input_V
,
const
Tensor
*
input_O
,
const
Tensor
*
input_dO
,
const
Tensor
*
input_Bias
,
const
Tensor
*
input_SoftmaxOffset
,
Tensor
*
output_S
,
Tensor
*
output_dQ
,
Tensor
*
output_dK
,
Tensor
*
output_dV
,
Tensor
*
output_dBias
,
Tensor
*
output_dSoftmaxOffset
,
const
Tensor
*
cu_seqlens_q
,
const
Tensor
*
cu_seqlens_kv
,
const
Tensor
*
cu_seqlens_q_padded
,
const
Tensor
*
cu_seqlens_kv_padded
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
)
{
using
namespace
transformer_engine
;
const
auto
QKV_type
=
input_Q
->
data
.
dtype
;
void
*
devPtrQ
=
input_Q
->
data
.
dptr
;
...
...
@@ -1577,6 +1651,12 @@ void fused_attn_arbitrary_seqlen_bwd(
void
*
devPtrdV
=
output_dV
->
data
.
dptr
;
void
*
devPtrSoftmaxStats
=
nullptr
;
devPtrSoftmaxStats
=
output_S
->
data
.
dptr
;
void
*
devPtrSoftmaxOffset
=
nullptr
;
void
*
devPtrdSoftmaxOffset
=
nullptr
;
if
(
softmax_type
!=
NVTE_VANILLA_SOFTMAX
)
{
devPtrSoftmaxOffset
=
input_SoftmaxOffset
->
data
.
dptr
;
devPtrdSoftmaxOffset
=
output_dSoftmaxOffset
->
data
.
dptr
;
}
void
*
devPtrCuSeqlensQ
=
cu_seqlens_q
->
data
.
dptr
;
void
*
devPtrCuSeqlensKV
=
cu_seqlens_kv
->
data
.
dptr
;
...
...
@@ -1592,9 +1672,10 @@ void fused_attn_arbitrary_seqlen_bwd(
fused_attn_arbitrary_seqlen_bwd_impl
(
batch
,
num_attn_heads
,
num_gqa_groups
,
max_seqlen_q
,
max_seqlen_kv
,
head_dim_qk
,
head_dim_v
,
max_batch_size
,
max_tokens_q
,
max_tokens_kv
,
bias_b
,
bias_h
,
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
mask_type
,
window_size_left
,
window_size_right
,
deterministic
,
devPtrQ
,
devPtrK
,
devPtrV
,
devPtrO
,
devPtrSoftmaxStats
,
devPtrBias
,
devPtrdQ
,
devPtrdK
,
devPtrdV
,
devPtrdO
,
devPtrdBias
,
devPtrDropoutSeed
,
devPtrDropoutOffset
,
devPtrCuSeqlensQ
,
qkv_layout
,
bias_type
,
mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
deterministic
,
devPtrQ
,
devPtrK
,
devPtrV
,
devPtrO
,
devPtrSoftmaxStats
,
devPtrBias
,
devPtrSoftmaxOffset
,
devPtrdQ
,
devPtrdK
,
devPtrdV
,
devPtrdO
,
devPtrdBias
,
devPtrdSoftmaxOffset
,
devPtrDropoutSeed
,
devPtrDropoutOffset
,
devPtrCuSeqlensQ
,
devPtrCuSeqlensKV
,
devPtrSeqOffsetsQ
,
devPtrSeqOffsetsKV
,
get_cudnn_fe_dtype
(
QKV_type
),
workspace
->
data
.
dptr
,
&
workspace_size
,
stream
,
handle
);
...
...
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h
View file @
063ef88d
...
...
@@ -21,17 +21,19 @@ namespace transformer_engine {
void
fused_attn_arbitrary_seqlen_fwd_qkvpacked
(
size_t
batch
,
size_t
num_attn_heads
,
size_t
max_seqlen
,
size_t
head_dim
,
size_t
num_tokens
,
bool
is_training
,
float
attn_scale
,
float
p_dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
const
Tensor
*
input_QKV
,
const
Tensor
*
input_Bias
,
Tensor
*
output_O
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
const
Tensor
*
input_QKV
,
const
Tensor
*
input_Bias
,
const
Tensor
*
input_SoftmaxOffset
,
Tensor
*
output_O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
Tensor
*
cu_seqlens
,
const
Tensor
*
cu_seqlens_padded
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
);
void
fused_attn_arbitrary_seqlen_bwd_qkvpacked
(
size_t
batch
,
size_t
num_attn_heads
,
size_t
max_seqlen
,
size_t
head_dim
,
size_t
num_tokens
,
float
attn_scale
,
float
p_dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
const
Tensor
*
input_QKV
,
const
Tensor
*
input_O
,
const
Tensor
*
input_dO
,
const
Tensor
*
input_Bias
,
Tensor
*
output_S
,
Tensor
*
output_dQKV
,
Tensor
*
output_dBias
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
const
Tensor
*
input_QKV
,
const
Tensor
*
input_O
,
const
Tensor
*
input_dO
,
const
Tensor
*
input_Bias
,
const
Tensor
*
input_SoftmaxOffset
,
Tensor
*
output_S
,
Tensor
*
output_dQKV
,
Tensor
*
output_dBias
,
Tensor
*
output_dSoftmaxOffset
,
const
Tensor
*
cu_seqlens
,
const
Tensor
*
cu_seqlens_padded
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
);
...
...
@@ -41,21 +43,22 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t
num_pages_k
,
size_t
num_pages_v
,
size_t
page_size_k
,
size_t
page_size_v
,
size_t
max_pages_per_seq_k
,
size_t
max_pages_per_seq_v
,
bool
is_training
,
float
attn_scale
,
float
p_dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
const
Tensor
*
input_Q
,
const
Tensor
*
input_KV
,
const
Tensor
*
input_Bias
,
Tensor
*
output_O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
Tensor
*
cu_seqlens_q
,
const
Tensor
*
cu_seqlens_kv
,
const
Tensor
*
cu_seqlens_q
_padded
,
const
Tensor
*
cu_seqlens_kv_padded
,
const
Tensor
*
page_table_k
,
const
Tensor
*
page_table_v
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
);
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
const
Tensor
*
input_Q
,
const
Tensor
*
input_KV
,
const
Tensor
*
input_Bias
,
const
Tensor
*
input_SoftmaxOffset
,
Tensor
*
output_O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
Tensor
*
cu_seqlens_q
,
const
Tensor
*
cu_seqlens_kv
,
const
Tensor
*
cu_seqlens_q
_padded
,
const
Tensor
*
cu_seqlens_kv_padded
,
const
Tensor
*
page_table_k
,
const
Tensor
*
page_table_v
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
);
void
fused_attn_arbitrary_seqlen_bwd_kvpacked
(
size_t
batch
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim
,
size_t
num_tokens_q
,
size_t
num_tokens_kv
,
float
attn_scale
,
float
p_dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
const
Tensor
*
input_Q
,
const
Tensor
*
input_KV
,
const
Tensor
*
input_O
,
const
Tensor
*
input_dO
,
const
Tensor
*
input_Bias
,
Tensor
*
output_S
,
Tensor
*
output_dQ
,
Tensor
*
output_dKV
,
Tensor
*
output_dBias
,
const
Tensor
*
cu_seqlens_q
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
const
Tensor
*
input_Q
,
const
Tensor
*
input_KV
,
const
Tensor
*
input_O
,
const
Tensor
*
input_dO
,
const
Tensor
*
input_Bias
,
const
Tensor
*
input_SoftmaxOffset
,
Tensor
*
output_S
,
Tensor
*
output_dQ
,
Tensor
*
output_dKV
,
Tensor
*
output_dBias
,
Tensor
*
output_dSoftmaxOffset
,
const
Tensor
*
cu_seqlens_q
,
const
Tensor
*
cu_seqlens_kv
,
const
Tensor
*
cu_seqlens_q_padded
,
const
Tensor
*
cu_seqlens_kv_padded
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
);
...
...
@@ -66,24 +69,26 @@ void fused_attn_arbitrary_seqlen_fwd(
size_t
num_tokens_kv
,
size_t
num_pages_k
,
size_t
num_pages_v
,
size_t
page_size_k
,
size_t
page_size_v
,
size_t
max_pages_per_seq_k
,
size_t
max_pages_per_seq_v
,
bool
is_training
,
float
attn_scale
,
float
p_dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
const
Tensor
*
input_Q
,
const
Tensor
*
input_K
,
const
Tensor
*
input_V
,
const
Tensor
*
input_Bias
,
Tensor
*
output_O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
Tensor
*
cu_seqlens_q
,
const
Tensor
*
cu_seqlens_kv
,
const
Tensor
*
cu_seqlens_q_padded
,
const
Tensor
*
cu_seqlens_kv_padded
,
const
Tensor
*
page_table_k
,
const
Tensor
*
page_table_v
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
);
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
const
Tensor
*
input_Q
,
const
Tensor
*
input_K
,
const
Tensor
*
input_V
,
const
Tensor
*
input_Bias
,
const
Tensor
*
input_SoftmaxOffset
,
Tensor
*
output_O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
Tensor
*
cu_seqlens_q
,
const
Tensor
*
cu_seqlens_kv
,
const
Tensor
*
cu_seqlens_q_padded
,
const
Tensor
*
cu_seqlens_kv_padded
,
const
Tensor
*
page_table_k
,
const
Tensor
*
page_table_v
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
);
void
fused_attn_arbitrary_seqlen_bwd
(
size_t
batch
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
size_t
num_tokens_q
,
size_t
num_tokens_kv
,
float
attn_scale
,
float
p_dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
const
Tensor
*
input_Q
,
const
Tensor
*
input_K
,
const
Tensor
*
input_V
,
const
Tensor
*
input_O
,
const
Tensor
*
input_dO
,
const
Tensor
*
input_Bias
,
Tensor
*
output_S
,
Tensor
*
output_dQ
,
Tensor
*
output_dK
,
Tensor
*
output_dV
,
Tensor
*
output_dBias
,
const
Tensor
*
cu_seqlens_q
,
const
Tensor
*
cu_seqlens_kv
,
const
Tensor
*
cu_seqlens_q_padded
,
const
Tensor
*
cu_seqlens_kv_padded
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
);
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
const
Tensor
*
input_Q
,
const
Tensor
*
input_K
,
const
Tensor
*
input_V
,
const
Tensor
*
input_O
,
const
Tensor
*
input_dO
,
const
Tensor
*
input_Bias
,
const
Tensor
*
input_SoftmaxOffset
,
Tensor
*
output_S
,
Tensor
*
output_dQ
,
Tensor
*
output_dK
,
Tensor
*
output_dV
,
Tensor
*
output_dBias
,
Tensor
*
output_dSoftmaxOffset
,
const
Tensor
*
cu_seqlens_q
,
const
Tensor
*
cu_seqlens_kv
,
const
Tensor
*
cu_seqlens_q_padded
,
const
Tensor
*
cu_seqlens_kv_padded
,
const
Tensor
*
rng_state
,
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
);
#endif // CUDNN_VERSION >= 8900
}
// namespace transformer_engine
...
...
transformer_engine/common/fused_attn/fused_attn_fp8.cu
View file @
063ef88d
...
...
@@ -1658,8 +1658,9 @@ void fused_attn_fp8_fwd_impl_v1(
void
*
devPtrM
,
void
*
devPtrZInv
,
void
*
devPtrO
,
void
*
devPtrDescaleQ
,
void
*
devPtrDescaleK
,
void
*
devPtrDescaleV
,
void
*
devPtrDescaleS
,
void
*
devPtrScaleS
,
void
*
devPtrScaleO
,
void
*
devPtrAmaxO
,
void
*
devPtrAmaxS
,
void
*
devPtrcuSeqlensQ
,
void
*
devPtrcuSeqlensKV
,
void
*
devPtrDropoutSeed
,
void
*
devPtrDropoutOffset
,
cudnn_frontend
::
DataType_t
fwd_tensor_type
,
void
*
workspace
,
size_t
*
workspace_size
,
cudaStream_t
stream
,
cudnnHandle_t
handle
)
{
void
*
devPtrDropoutSeed
,
void
*
devPtrDropoutOffset
,
cudnn_frontend
::
DataType_t
qkv_tensor_type
,
cudnn_frontend
::
DataType_t
o_tensor_type
,
void
*
workspace
,
size_t
*
workspace_size
,
cudaStream_t
stream
,
cudnnHandle_t
handle
)
{
using
namespace
transformer_engine
;
bool
is_bias
=
(
bias_type
==
NVTE_Bias_Type
::
NVTE_POST_SCALE_BIAS
);
bool
is_alibi
=
(
bias_type
==
NVTE_Bias_Type
::
NVTE_ALIBI
);
...
...
@@ -1672,6 +1673,13 @@ void fused_attn_fp8_fwd_impl_v1(
auto
bias_h
=
h
;
NVTE_CHECK
(
~
is_bias
,
"FP8 fused attention does not support pre/post_scale_bias yet!"
);
NVTE_CHECK
(
~
is_alibi
,
"FP8 fused attention does not support ALiBi yet!"
);
bool
is_current_scaling
=
(
o_tensor_type
==
cudnn_frontend
::
DataType_t
::
HALF
||
o_tensor_type
==
cudnn_frontend
::
DataType_t
::
BFLOAT16
);
bool
is_delayed_scaling
=
(
o_tensor_type
==
cudnn_frontend
::
DataType_t
::
FP8_E4M3
||
o_tensor_type
==
cudnn_frontend
::
DataType_t
::
FP8_E5M2
);
NVTE_CHECK
(
is_current_scaling
||
is_delayed_scaling
,
"FP8 fused attention only supports O tensor in kFloat16, kBFloat16, kFloat8E4M3 or "
"kFloat8E5M2!"
);
try
{
FADescriptor_v1
descriptor
{
b
,
...
...
@@ -1695,11 +1703,14 @@ void fused_attn_fp8_fwd_impl_v1(
layout
,
bias_type
,
mask_type
,
NVTE_Softmax_Type
::
NVTE_VANILLA_SOFTMAX
,
0
,
0
,
true
,
fwd_tensor_type
,
fwd_tensor_type
};
qkv_tensor_type
,
o_tensor_type
,
cudnn_frontend
::
DataType_t
::
NOT_SET
,
cudnn_frontend
::
DataType_t
::
NOT_SET
};
namespace
fe
=
cudnn_frontend
;
using
graph_and_tensors
=
...
...
@@ -1738,7 +1749,7 @@ void fused_attn_fp8_fwd_impl_v1(
// otherwise, build the op_graph and the plan. Then update cache
auto
mha_graph
=
std
::
make_shared
<
fe
::
graph
::
Graph
>
();
mha_graph
->
set_io_data_type
(
fwd
_tensor_type
)
mha_graph
->
set_io_data_type
(
qkv
_tensor_type
)
.
set_intermediate_data_type
(
fe
::
DataType_t
::
FLOAT
)
.
set_compute_data_type
(
fe
::
DataType_t
::
FLOAT
);
...
...
@@ -1786,7 +1797,13 @@ void fused_attn_fp8_fwd_impl_v1(
descale_v
=
mha_graph
->
tensor_like
(
descale_q
,
"Descale_V"
);
descale_s
=
mha_graph
->
tensor_like
(
descale_q
,
"Descale_S"
);
scale_s
=
mha_graph
->
tensor_like
(
descale_q
,
"Scale_S"
);
if
(
is_delayed_scaling
)
{
scale_o
=
mha_graph
->
tensor_like
(
descale_q
,
"Scale_O"
);
}
if
(
is_current_scaling
)
{
scale_o
=
mha_graph
->
tensor
(
1.0
f
);
}
fe
::
graph
::
SDPA_fp8_attributes
sdpa_options
;
sdpa_options
=
fe
::
graph
::
SDPA_fp8_attributes
()
...
...
@@ -1838,11 +1855,12 @@ void fused_attn_fp8_fwd_impl_v1(
std
::
vector
<
int64_t
>
o_stride
(
4
);
generateMatrixStrides
(
b
,
h
,
s_q
,
s_kv
,
d
,
o_stride
.
data
(),
layout
,
NVTE_QKV_Matrix
::
NVTE_O_Matrix
);
O
->
set_output
(
true
).
set_dim
({
b
,
h
,
s_q
,
d
}).
set_stride
(
o_stride
);
O
->
set_output
(
true
).
set_dim
({
b
,
h
,
s_q
,
d
}).
set_stride
(
o_stride
)
.
set_data_type
(
o_tensor_type
)
;
amax_o
->
set_output
(
true
)
.
set_dim
({
1
,
1
,
1
,
1
})
.
set_stride
({
1
,
1
,
1
,
1
})
.
set_data_type
(
fe
::
DataType_t
::
FLOAT
);
amax_s
->
set_output
(
true
)
.
set_dim
({
1
,
1
,
1
,
1
})
.
set_stride
({
1
,
1
,
1
,
1
})
...
...
@@ -1915,13 +1933,16 @@ void fused_attn_fp8_fwd_impl_v1(
{
descale_v
,
devPtrDescaleV
},
{
descale_s
,
devPtrDescaleS
},
{
scale_s
,
devPtrScaleS
},
{
scale_o
,
devPtrScaleO
},
{
attn_scale
,
&
scaling_factor
},
{
O
,
devPtrO
},
{
amax_s
,
devPtrAmaxS
},
{
amax_o
,
devPtrAmaxO
},
{
Stats
,
devPtrM
}};
if
(
is_delayed_scaling
)
{
variant_pack
[
scale_o
]
=
devPtrScaleO
;
}
/* if (is_bias) {
variant_pack[bias] = devPtrBias;
} */
...
...
@@ -1962,8 +1983,9 @@ void fused_attn_fp8_bwd_impl_v1(
void
*
devPtrScaledP
,
void
*
devPtrScaledQ
,
void
*
devPtrScaledK
,
void
*
devPtrScaledV
,
void
*
devPtrAmaxdP
,
void
*
devPtrAmaxdQ
,
void
*
devPtrAmaxdK
,
void
*
devPtrAmaxdV
,
void
*
devPtrcuSeqlensQ
,
void
*
devPtrcuSeqlensKV
,
void
*
devPtrDropoutSeed
,
void
*
devPtrDropoutOffset
,
cudnn_frontend
::
DataType_t
fwd_tensor_type
,
cudnn_frontend
::
DataType_t
bwd_tensor_type
,
void
*
workspace
,
size_t
*
workspace_size
,
void
*
devPtrDropoutOffset
,
cudnn_frontend
::
DataType_t
qkv_tensor_type
,
cudnn_frontend
::
DataType_t
o_tensor_type
,
cudnn_frontend
::
DataType_t
do_tensor_type
,
cudnn_frontend
::
DataType_t
dqkv_tensor_type
,
void
*
workspace
,
size_t
*
workspace_size
,
cudaStream_t
stream
,
cudnnHandle_t
handle
)
{
using
namespace
transformer_engine
;
bool
is_bias
=
(
bias_type
==
NVTE_Bias_Type
::
NVTE_POST_SCALE_BIAS
);
...
...
@@ -1977,6 +1999,15 @@ void fused_attn_fp8_bwd_impl_v1(
auto
bias_h
=
h
;
NVTE_CHECK
(
~
is_bias
,
"FP8 fused attention does not support pre/post_scale_bias yet!"
);
NVTE_CHECK
(
~
is_alibi
,
"FP8 fused attention does not support ALiBi yet!"
);
bool
is_current_scaling
=
(
dqkv_tensor_type
==
cudnn_frontend
::
DataType_t
::
HALF
||
dqkv_tensor_type
==
cudnn_frontend
::
DataType_t
::
BFLOAT16
);
bool
is_delayed_scaling
=
(
dqkv_tensor_type
==
cudnn_frontend
::
DataType_t
::
FP8_E4M3
||
dqkv_tensor_type
==
cudnn_frontend
::
DataType_t
::
FP8_E5M2
);
NVTE_CHECK
(
is_current_scaling
||
is_delayed_scaling
,
"FP8 fused attention only supports dQKV tensor in kFloat16, kBFloat16, kFloat8E4M3 or "
"kFloat8E5M2!"
);
bool
is_O_in_F16
=
(
o_tensor_type
==
cudnn_frontend
::
DataType_t
::
HALF
||
o_tensor_type
==
cudnn_frontend
::
DataType_t
::
BFLOAT16
);
try
{
FADescriptor_v1
descriptor
{
b
,
...
...
@@ -2000,11 +2031,14 @@ void fused_attn_fp8_bwd_impl_v1(
layout
,
bias_type
,
mask_type
,
NVTE_Softmax_Type
::
NVTE_VANILLA_SOFTMAX
,
0
,
0
,
false
,
fwd_tensor_type
,
bwd_tensor_type
};
qkv_tensor_type
,
o_tensor_type
,
do_tensor_type
,
dqkv_tensor_type
};
namespace
fe
=
cudnn_frontend
;
using
graph_and_tensors
=
...
...
@@ -2057,7 +2091,7 @@ void fused_attn_fp8_bwd_impl_v1(
// otherwise, build the op_graph and the plan. Then update cache
auto
mha_graph
=
std
::
make_shared
<
fe
::
graph
::
Graph
>
();
mha_graph
->
set_io_data_type
(
fwd
_tensor_type
)
mha_graph
->
set_io_data_type
(
qkv
_tensor_type
)
.
set_intermediate_data_type
(
fe
::
DataType_t
::
FLOAT
)
.
set_compute_data_type
(
fe
::
DataType_t
::
FLOAT
);
...
...
@@ -2097,7 +2131,8 @@ void fused_attn_fp8_bwd_impl_v1(
o
=
mha_graph
->
tensor
(
fe
::
graph
::
Tensor_attributes
()
.
set_name
(
"O"
)
.
set_dim
({
b
,
h
,
s_q
,
d
})
.
set_stride
(
o_stride
));
.
set_stride
(
o_stride
)
.
set_data_type
(
o_tensor_type
));
dO
=
mha_graph
->
tensor
(
fe
::
graph
::
Tensor_attributes
()
.
set_name
(
"dO"
)
.
set_dim
({
b
,
h
,
s_q
,
d
})
...
...
@@ -2123,14 +2158,26 @@ void fused_attn_fp8_bwd_impl_v1(
descale_k
=
mha_graph
->
tensor_like
(
descale_q
,
"Descale_q"
);
descale_v
=
mha_graph
->
tensor_like
(
descale_q
,
"Descale_V"
);
descale_s
=
mha_graph
->
tensor_like
(
descale_q
,
"Descale_S"
);
descale_o
=
mha_graph
->
tensor_like
(
descale_q
,
"Descale_O"
);
descale_dP
=
mha_graph
->
tensor_like
(
descale_q
,
"Descale_dP"
);
if
(
is_O_in_F16
)
{
descale_o
=
mha_graph
->
tensor
(
1.0
f
);
}
else
{
descale_o
=
mha_graph
->
tensor_like
(
descale_q
,
"Descale_O"
);
}
descale_dO
=
mha_graph
->
tensor_like
(
descale_q
,
"Descale_dO"
);
scale_s
=
mha_graph
->
tensor_like
(
descale_q
,
"Scale_S"
);
scale_dP
=
mha_graph
->
tensor_like
(
descale_q
,
"Scale_dP"
);
if
(
is_delayed_scaling
)
{
scale_dQ
=
mha_graph
->
tensor_like
(
descale_q
,
"Scale_dQ"
);
scale_dK
=
mha_graph
->
tensor_like
(
descale_q
,
"Scale_dK"
);
scale_dV
=
mha_graph
->
tensor_like
(
descale_q
,
"Scale_dV"
);
}
if
(
is_current_scaling
)
{
scale_dQ
=
mha_graph
->
tensor
(
1.0
f
);
scale_dK
=
mha_graph
->
tensor
(
1.0
f
);
scale_dV
=
mha_graph
->
tensor
(
1.0
f
);
}
fe
::
graph
::
SDPA_fp8_backward_attributes
sdpa_backward_options
;
sdpa_backward_options
=
fe
::
graph
::
SDPA_fp8_backward_attributes
()
...
...
@@ -2212,10 +2259,10 @@ void fused_attn_fp8_bwd_impl_v1(
.
set_stride
({
1
,
1
,
1
,
1
})
.
set_data_type
(
fe
::
DataType_t
::
FLOAT
);
dO
->
set_data_type
(
bw
d_tensor_type
);
dQ
->
set_data_type
(
bwd
_tensor_type
);
dK
->
set_data_type
(
bwd
_tensor_type
);
dV
->
set_data_type
(
bwd
_tensor_type
);
dO
->
set_data_type
(
d
o
_tensor_type
);
dQ
->
set_data_type
(
dqkv
_tensor_type
);
dK
->
set_data_type
(
dqkv
_tensor_type
);
dV
->
set_data_type
(
dqkv
_tensor_type
);
std
::
tuple
<
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
// q
std
::
shared_ptr
<
fe
::
graph
::
Tensor_attributes
>
,
// k
...
...
@@ -2296,14 +2343,10 @@ void fused_attn_fp8_bwd_impl_v1(
{
descale_q
,
devPtrDescaleQ
},
{
descale_k
,
devPtrDescaleK
},
{
descale_v
,
devPtrDescaleV
},
{
descale_o
,
devPtrDescaleO
},
{
descale_dO
,
devPtrDescaledO
},
{
descale_s
,
devPtrDescaleS
},
{
descale_dP
,
devPtrDescaledP
},
{
scale_s
,
devPtrScaleS
},
{
scale_dQ
,
devPtrScaledQ
},
{
scale_dK
,
devPtrScaledK
},
{
scale_dV
,
devPtrScaledV
},
{
scale_dP
,
devPtrScaledP
},
{
dQ
,
devPtrdQ
},
{
dK
,
devPtrdK
},
...
...
@@ -2314,6 +2357,15 @@ void fused_attn_fp8_bwd_impl_v1(
{
amax_dP
,
devPtrAmaxdP
},
};
if
(
is_delayed_scaling
)
{
variant_pack
[
scale_dQ
]
=
devPtrScaledQ
;
variant_pack
[
scale_dK
]
=
devPtrScaledK
;
variant_pack
[
scale_dV
]
=
devPtrScaledV
;
}
if
(
!
is_O_in_F16
)
{
variant_pack
[
descale_o
]
=
devPtrDescaleO
;
}
/* if (is_bias) {
variant_pack[bias] = devPtrBias;
if ((bias_b == 1) && (bias_h == h)) {
...
...
@@ -2364,6 +2416,7 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma
cudnnHandle_t
handle
)
{
using
namespace
transformer_engine
;
const
DType
QKV_type
=
input_QKV
->
data
.
dtype
;
const
DType
O_type
=
output_O
->
data
.
dtype
;
void
*
devPtrQKV
=
input_QKV
->
data
.
dptr
;
NVTE_QKV_Layout_Group
layout_group
=
nvte_get_qkv_layout_group
(
qkv_layout
);
size_t
stride
=
0
;
...
...
@@ -2430,8 +2483,8 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
mask_type
,
devPtrQ
,
devPtrK
,
devPtrV
,
devPtrM
,
devPtrZInv
,
devPtrO
,
devPtrDescaleQ
,
devPtrDescaleK
,
devPtrDescaleV
,
devPtrDescaleS
,
devPtrScaleS
,
devPtrScaleO
,
devPtrAmaxO
,
devPtrAmaxS
,
devPtrcuSeqlens
,
devPtrcuSeqlens
,
devPtrDropoutSeed
,
devPtrDropoutOffset
,
get_cudnn_fe_dtype
(
QKV_type
),
workspace
->
data
.
dptr
,
&
workspace_size
,
stream
,
handle
);
devPtrDropoutSeed
,
devPtrDropoutOffset
,
get_cudnn_fe_dtype
(
QKV_type
),
get_cudnn_fe_dtype
(
O_type
),
workspace
->
data
.
dptr
,
&
workspace_size
,
stream
,
handle
);
}
else
if
(
qkv_layout
==
NVTE_QKV_Layout
::
NVTE_T3HD
)
{
fused_attn
::
fused_attn_fp8_fwd_impl
(
batch
,
num_attn_heads
,
max_seqlen
,
max_seqlen
,
head_dim
,
is_training
,
attn_scale
,
p_dropout
,
...
...
@@ -2465,6 +2518,7 @@ void fused_attn_fp8_bwd_qkvpacked(
cudaStream_t
stream
,
cudnnHandle_t
handle
)
{
using
namespace
transformer_engine
;
const
DType
QKV_type
=
input_QKV
->
data
.
dtype
;
const
DType
dO_type
=
input_dO
->
data
.
dtype
;
const
DType
dQKV_type
=
output_dQKV
->
data
.
dtype
;
void
*
devPtrQKV
=
input_QKV
->
data
.
dptr
;
NVTE_QKV_Layout_Group
layout_group
=
nvte_get_qkv_layout_group
(
qkv_layout
);
...
...
@@ -2482,7 +2536,11 @@ void fused_attn_fp8_bwd_qkvpacked(
void
*
devPtrDescaleV
=
input_QKV
->
scale_inv
.
dptr
;
void
*
devPtrO
=
input_O
->
data
.
dptr
;
void
*
devPtrDescaleO
=
input_O
->
scale_inv
.
dptr
;
const
DType
O_type
=
input_O
->
data
.
dtype
;
void
*
devPtrDescaleO
=
nullptr
;
if
(
O_type
==
DType
::
kFloat8E4M3
||
O_type
==
DType
::
kFloat8E5M2
)
{
devPtrDescaleO
=
input_O
->
scale_inv
.
dptr
;
}
void
*
devPtrdO
=
input_dO
->
data
.
dptr
;
void
*
devPtrDescaledO
=
input_dO
->
scale_inv
.
dptr
;
...
...
@@ -2525,7 +2583,8 @@ void fused_attn_fp8_bwd_qkvpacked(
devPtrScaleS
,
devPtrScaledP
,
devPtrScaledQ
,
devPtrScaledK
,
devPtrScaledV
,
devPtrAmaxdP
,
devPtrAmaxdQ
,
devPtrAmaxdK
,
devPtrAmaxdV
,
devPtrcuSeqlens
,
devPtrcuSeqlens
,
devPtrDropoutSeed
,
devPtrDropoutOffset
,
get_cudnn_fe_dtype
(
QKV_type
),
get_cudnn_fe_dtype
(
dQKV_type
),
workspace
->
data
.
dptr
,
&
workspace_size
,
stream
,
handle
);
get_cudnn_fe_dtype
(
O_type
),
get_cudnn_fe_dtype
(
dO_type
),
get_cudnn_fe_dtype
(
dQKV_type
),
workspace
->
data
.
dptr
,
&
workspace_size
,
stream
,
handle
);
}
else
if
(
qkv_layout
==
NVTE_QKV_Layout
::
NVTE_T3HD
)
{
fused_attn
::
fused_attn_fp8_bwd_impl
(
batch
,
num_attn_heads
,
max_seqlen
,
max_seqlen
,
head_dim
,
attn_scale
,
p_dropout
,
qkv_layout
,
...
...
@@ -2563,6 +2622,7 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num
Tensor
*
workspace
,
cudaStream_t
stream
,
cudnnHandle_t
handle
)
{
using
namespace
transformer_engine
;
const
DType
QKV_type
=
input_Q
->
data
.
dtype
;
const
DType
O_type
=
output_O
->
data
.
dtype
;
void
*
devPtrQ
=
input_Q
->
data
.
dptr
;
void
*
devPtrKV
=
input_KV
->
data
.
dptr
;
NVTE_QKV_Layout_Group
layout_group
=
nvte_get_qkv_layout_group
(
qkv_layout
);
...
...
@@ -2631,8 +2691,8 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
mask_type
,
devPtrQ
,
devPtrK
,
devPtrV
,
devPtrM
,
devPtrZInv
,
devPtrO
,
devPtrDescaleQ
,
devPtrDescaleK
,
devPtrDescaleV
,
devPtrDescaleS
,
devPtrScaleS
,
devPtrScaleO
,
devPtrAmaxO
,
devPtrAmaxS
,
devPtrcuSeqlensQ
,
devPtrcuSeqlensKV
,
devPtrDropoutSeed
,
devPtrDropoutOffset
,
get_cudnn_fe_dtype
(
QKV_type
),
workspace
->
data
.
dptr
,
&
workspace_size
,
stream
,
handle
);
devPtrDropoutSeed
,
devPtrDropoutOffset
,
get_cudnn_fe_dtype
(
QKV_type
),
get_cudnn_fe_dtype
(
O_type
),
workspace
->
data
.
dptr
,
&
workspace_size
,
stream
,
handle
);
}
else
if
(
qkv_layout
==
NVTE_QKV_Layout
::
NVTE_T3HD
)
{
fused_attn
::
fused_attn_fp8_fwd_impl
(
batch
,
num_attn_heads
,
max_seqlen_q
,
max_seqlen_kv
,
head_dim
,
is_training
,
attn_scale
,
...
...
@@ -2669,6 +2729,7 @@ void fused_attn_fp8_bwd_kvpacked(
cudnnHandle_t
handle
)
{
using
namespace
transformer_engine
;
const
DType
QKV_type
=
input_Q
->
data
.
dtype
;
const
DType
dO_type
=
input_dO
->
data
.
dtype
;
const
DType
dQKV_type
=
output_dQ
->
data
.
dtype
;
void
*
devPtrQ
=
input_Q
->
data
.
dptr
;
void
*
devPtrKV
=
input_KV
->
data
.
dptr
;
...
...
@@ -2686,7 +2747,11 @@ void fused_attn_fp8_bwd_kvpacked(
void
*
devPtrDescaleV
=
input_KV
->
scale_inv
.
dptr
;
void
*
devPtrO
=
input_O
->
data
.
dptr
;
void
*
devPtrDescaleO
=
input_O
->
scale_inv
.
dptr
;
const
DType
O_type
=
input_O
->
data
.
dtype
;
void
*
devPtrDescaleO
=
nullptr
;
if
(
O_type
==
DType
::
kFloat8E4M3
||
O_type
==
DType
::
kFloat8E5M2
)
{
devPtrDescaleO
=
input_O
->
scale_inv
.
dptr
;
}
void
*
devPtrdO
=
input_dO
->
data
.
dptr
;
void
*
devPtrDescaledO
=
input_dO
->
scale_inv
.
dptr
;
...
...
@@ -2731,7 +2796,8 @@ void fused_attn_fp8_bwd_kvpacked(
devPtrScaleS
,
devPtrScaledP
,
devPtrScaledQ
,
devPtrScaledK
,
devPtrScaledV
,
devPtrAmaxdP
,
devPtrAmaxdQ
,
devPtrAmaxdK
,
devPtrAmaxdV
,
devPtrcuSeqlensQ
,
devPtrcuSeqlensKV
,
devPtrDropoutSeed
,
devPtrDropoutOffset
,
get_cudnn_fe_dtype
(
QKV_type
),
get_cudnn_fe_dtype
(
dQKV_type
),
workspace
->
data
.
dptr
,
&
workspace_size
,
stream
,
handle
);
get_cudnn_fe_dtype
(
O_type
),
get_cudnn_fe_dtype
(
dO_type
),
get_cudnn_fe_dtype
(
dQKV_type
),
workspace
->
data
.
dptr
,
&
workspace_size
,
stream
,
handle
);
}
else
if
(
qkv_layout
==
NVTE_QKV_Layout
::
NVTE_T3HD
)
{
fused_attn
::
fused_attn_fp8_bwd_impl
(
batch
,
num_attn_heads
,
max_seqlen_q
,
max_seqlen_kv
,
head_dim
,
attn_scale
,
p_dropout
,
...
...
@@ -2820,6 +2886,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou
reinterpret_cast
<
void
*>
(
reinterpret_cast
<
uint64_t
*>
(
rng_state
->
data
.
dptr
)
+
1
);
const
DType
QKV_type
=
input_Q
->
data
.
dtype
;
const
DType
O_type
=
output_O
->
data
.
dtype
;
size_t
workspace_size
=
0
;
NVTE_QKV_Format
qkv_format
=
nvte_get_qkv_format
(
qkv_layout
);
...
...
@@ -2829,8 +2896,8 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou
attn_scale
,
p_dropout
,
qkv_layout
,
bias_type
,
mask_type
,
devPtrQ
,
devPtrK
,
devPtrV
,
devPtrM
,
devPtrZInv
,
devPtrO
,
devPtrDescaleQ
,
devPtrDescaleK
,
devPtrDescaleV
,
devPtrDescaleS
,
devPtrScaleS
,
devPtrScaleO
,
devPtrAmaxO
,
devPtrAmaxS
,
devPtrcuSeqlensQ
,
devPtrcuSeqlensKV
,
devPtrDropoutSeed
,
devPtrDropoutOffset
,
get_cudnn_fe_dtype
(
QKV_type
),
workspace
->
data
.
dptr
,
&
workspace_size
,
stream
,
handle
);
devPtrDropoutSeed
,
devPtrDropoutOffset
,
get_cudnn_fe_dtype
(
QKV_type
),
get_cudnn_fe_dtype
(
O_type
),
workspace
->
data
.
dptr
,
&
workspace_size
,
stream
,
handle
);
}
else
if
(
qkv_layout
==
NVTE_QKV_Layout
::
NVTE_T3HD
)
{
fused_attn
::
fused_attn_fp8_fwd_impl
(
batch
,
num_attn_heads
,
max_seqlen_q
,
max_seqlen_kv
,
head_dim
,
is_training
,
attn_scale
,
...
...
@@ -2876,7 +2943,11 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou
void
*
devPtrDescaleV
=
input_Q
->
scale_inv
.
dptr
;
void
*
devPtrO
=
input_O
->
data
.
dptr
;
void
*
devPtrDescaleO
=
input_O
->
scale_inv
.
dptr
;
const
DType
O_type
=
input_O
->
data
.
dtype
;
void
*
devPtrDescaleO
=
nullptr
;
if
(
O_type
==
DType
::
kFloat8E4M3
||
O_type
==
DType
::
kFloat8E5M2
)
{
devPtrDescaleO
=
input_O
->
scale_inv
.
dptr
;
}
void
*
devPtrdO
=
input_dO
->
data
.
dptr
;
void
*
devPtrDescaledO
=
input_dO
->
scale_inv
.
dptr
;
...
...
@@ -2909,6 +2980,7 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou
reinterpret_cast
<
void
*>
(
reinterpret_cast
<
uint64_t
*>
(
rng_state
->
data
.
dptr
)
+
1
);
const
DType
QKV_type
=
input_Q
->
data
.
dtype
;
const
DType
dO_type
=
input_dO
->
data
.
dtype
;
const
DType
dQKV_type
=
output_dQ
->
data
.
dtype
;
size_t
workspace_size
=
0
;
...
...
@@ -2922,7 +2994,8 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou
devPtrScaleS
,
devPtrScaledP
,
devPtrScaledQ
,
devPtrScaledK
,
devPtrScaledV
,
devPtrAmaxdP
,
devPtrAmaxdQ
,
devPtrAmaxdK
,
devPtrAmaxdV
,
devPtrcuSeqlensQ
,
devPtrcuSeqlensKV
,
devPtrDropoutSeed
,
devPtrDropoutOffset
,
get_cudnn_fe_dtype
(
QKV_type
),
get_cudnn_fe_dtype
(
dQKV_type
),
workspace
->
data
.
dptr
,
&
workspace_size
,
stream
,
handle
);
get_cudnn_fe_dtype
(
O_type
),
get_cudnn_fe_dtype
(
dO_type
),
get_cudnn_fe_dtype
(
dQKV_type
),
workspace
->
data
.
dptr
,
&
workspace_size
,
stream
,
handle
);
}
else
if
(
qkv_layout
==
NVTE_QKV_Layout
::
NVTE_T3HD
)
{
fused_attn
::
fused_attn_fp8_bwd_impl
(
batch
,
num_attn_heads
,
max_seqlen_q
,
max_seqlen_kv
,
head_dim
,
attn_scale
,
p_dropout
,
...
...
transformer_engine/common/fused_attn/utils.h
View file @
063ef88d
...
...
@@ -107,23 +107,28 @@ struct FADescriptor_v1 {
NVTE_QKV_Layout
layout
;
NVTE_Bias_Type
bias_type
;
NVTE_Mask_Type
mask_type
;
NVTE_Softmax_Type
softmax_type
;
std
::
int64_t
window_size_left
;
std
::
int64_t
window_size_right
;
bool
deterministic
;
cudnn_frontend
::
DataType_t
fwd_tensor_type
;
cudnn_frontend
::
DataType_t
bwd_tensor_type
;
cudnn_frontend
::
DataType_t
qkv_tensor_type
;
cudnn_frontend
::
DataType_t
o_tensor_type
;
cudnn_frontend
::
DataType_t
do_tensor_type
;
cudnn_frontend
::
DataType_t
dqkv_tensor_type
;
bool
operator
<
(
const
FADescriptor_v1
&
rhs
)
const
{
return
std
::
tie
(
b
,
h
,
hg
,
s_q
,
s_kv
,
d_qk
,
d_v
,
num_pages_k
,
num_pages_v
,
page_size_k
,
page_size_v
,
max_pages_per_seq_k
,
max_pages_per_seq_v
,
bias_b
,
bias_h
,
attnScale
,
isTraining
,
dropoutProbability
,
layout
,
mask_type
,
window_size_left
,
window_size_right
,
deterministic
,
bias_type
,
fwd_tensor_type
,
bwd_tensor_type
)
<
attnScale
,
isTraining
,
dropoutProbability
,
layout
,
mask_type
,
softmax_type
,
window_size_left
,
window_size_right
,
deterministic
,
bias_type
,
qkv_tensor_type
,
o_tensor_type
,
do_tensor_type
,
dqkv_tensor_type
)
<
std
::
tie
(
rhs
.
b
,
rhs
.
h
,
rhs
.
hg
,
rhs
.
s_q
,
rhs
.
s_kv
,
rhs
.
d_qk
,
rhs
.
d_v
,
rhs
.
num_pages_k
,
rhs
.
num_pages_v
,
rhs
.
page_size_k
,
rhs
.
page_size_v
,
rhs
.
max_pages_per_seq_k
,
rhs
.
max_pages_per_seq_v
,
rhs
.
bias_b
,
rhs
.
bias_h
,
rhs
.
attnScale
,
rhs
.
isTraining
,
rhs
.
dropoutProbability
,
rhs
.
layout
,
rhs
.
mask_type
,
rhs
.
window_size_left
,
rhs
.
window_size_right
,
rhs
.
deterministic
,
rhs
.
bias_type
,
rhs
.
fwd_tensor_type
,
rhs
.
bwd_tensor_type
);
rhs
.
dropoutProbability
,
rhs
.
layout
,
rhs
.
mask_type
,
rhs
.
softmax_type
,
rhs
.
window_size_left
,
rhs
.
window_size_right
,
rhs
.
deterministic
,
rhs
.
bias_type
,
rhs
.
qkv_tensor_type
,
rhs
.
o_tensor_type
,
rhs
.
do_tensor_type
,
rhs
.
dqkv_tensor_type
);
}
};
...
...
transformer_engine/common/gemm/config.cpp
0 → 100644
View file @
063ef88d
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "./config.h"
#include <transformer_engine/gemm.h>
#include <transformer_engine/transformer_engine.h>
#include <cstring>
#include "../util/logging.h"
NVTEMatmulConfig
nvte_create_matmul_config
()
{
return
new
transformer_engine
::
MatmulConfig
;
}
void
nvte_get_matmul_config_attribute
(
NVTEMatmulConfig
config
,
NVTEMatmulConfigAttribute
attr
,
void
*
buf
,
size_t
size_in_bytes
,
size_t
*
size_written
)
{
// Write attribute size
NVTE_CHECK
(
attr
<
kNVTEMatmulConfigNumAttributes
,
"Invalid NVTEMatmulConfigAttribute (got "
,
static_cast
<
int
>
(
attr
),
")"
);
NVTE_CHECK
(
size_written
!=
nullptr
,
"Invalid size_written (got NULL)"
);
const
auto
&
attr_size
=
transformer_engine
::
MatmulConfig
::
attr_sizes
[
attr
];
*
size_written
=
attr_size
;
// Return immediately if buffer is not provided
if
(
buf
==
nullptr
)
{
return
;
}
// Check buffer size
NVTE_CHECK
(
size_in_bytes
>=
attr_size
,
"Buffer is too small for matmul config attribute "
"(attribute "
,
static_cast
<
int
>
(
attr
),
" needs "
,
attr_size
,
" bytes, but buffer has "
,
size_in_bytes
,
" bytes)"
);
// Write to buffer
NVTE_CHECK
(
config
!=
nullptr
,
"Invalid NVTEMatmulConfig (got NULL)"
);
const
auto
&
config_
=
*
reinterpret_cast
<
const
transformer_engine
::
MatmulConfig
*>
(
config
);
switch
(
attr
)
{
case
kNVTEMatmulConfigBiasTensor
:
std
::
memcpy
(
buf
,
&
config_
.
bias_tensor
,
attr_size
);
break
;
case
kNVTEMatmulConfigDBiasTensor
:
std
::
memcpy
(
buf
,
&
config_
.
dbias_tensor
,
attr_size
);
break
;
case
kNVTEMatmulConfigWithGELUEpilogue
:
std
::
memcpy
(
buf
,
&
config_
.
with_gelu_epilogue
,
attr_size
);
break
;
case
kNVTEMatmulConfigWithDGELUEpilogue
:
std
::
memcpy
(
buf
,
&
config_
.
with_dgelu_epilogue
,
attr_size
);
break
;
case
kNVTEMatmulConfigEpilogueAuxTensor
:
std
::
memcpy
(
buf
,
&
config_
.
epilogue_aux_tensor
,
attr_size
);
break
;
case
kNVTEMatmulConfigUseSplitAccumulator
:
std
::
memcpy
(
buf
,
&
config_
.
use_split_accumulator
,
attr_size
);
break
;
case
kNVTEMatmulConfigSMCount
:
std
::
memcpy
(
buf
,
&
config_
.
sm_count
,
attr_size
);
break
;
default:
NVTE_ERROR
(
"Unsupported NVTEMatmulConfigAttribute (got "
,
static_cast
<
int
>
(
attr
),
")"
);
}
}
void
nvte_set_matmul_config_attribute
(
NVTEMatmulConfig
config
,
NVTEMatmulConfigAttribute
attr
,
const
void
*
buf
,
size_t
size_in_bytes
)
{
// Check attribute and buffer
NVTE_CHECK
(
attr
<
kNVTEMatmulConfigNumAttributes
,
"Invalid NVTEMatmulConfigAttribute (got "
,
static_cast
<
int
>
(
attr
),
")"
);
const
auto
&
attr_size
=
transformer_engine
::
MatmulConfig
::
attr_sizes
[
attr
];
NVTE_CHECK
(
size_in_bytes
>=
attr_size
,
"Buffer is too small for matmul config attribute "
"(attribute "
,
static_cast
<
int
>
(
attr
),
" needs "
,
attr_size
,
" bytes, but buffer has "
,
size_in_bytes
,
" bytes)"
);
NVTE_CHECK
(
buf
!=
nullptr
,
"Invalid buffer (got NULL)"
);
// Read from buffer
NVTE_CHECK
(
config
!=
nullptr
,
"Invalid NVTEMatmulConfig (got NULL)"
);
auto
&
config_
=
*
reinterpret_cast
<
transformer_engine
::
MatmulConfig
*>
(
config
);
switch
(
attr
)
{
case
kNVTEMatmulConfigBiasTensor
:
std
::
memcpy
(
&
config_
.
bias_tensor
,
buf
,
attr_size
);
break
;
case
kNVTEMatmulConfigDBiasTensor
:
std
::
memcpy
(
&
config_
.
dbias_tensor
,
buf
,
attr_size
);
break
;
case
kNVTEMatmulConfigWithGELUEpilogue
:
std
::
memcpy
(
&
config_
.
with_gelu_epilogue
,
buf
,
attr_size
);
break
;
case
kNVTEMatmulConfigWithDGELUEpilogue
:
std
::
memcpy
(
&
config_
.
with_dgelu_epilogue
,
buf
,
attr_size
);
break
;
case
kNVTEMatmulConfigEpilogueAuxTensor
:
std
::
memcpy
(
&
config_
.
epilogue_aux_tensor
,
buf
,
attr_size
);
break
;
case
kNVTEMatmulConfigUseSplitAccumulator
:
std
::
memcpy
(
&
config_
.
use_split_accumulator
,
buf
,
attr_size
);
break
;
case
kNVTEMatmulConfigSMCount
:
std
::
memcpy
(
&
config_
.
sm_count
,
buf
,
attr_size
);
break
;
default:
NVTE_ERROR
(
"Unsupported NVTEMatmulConfigAttribute (got "
,
static_cast
<
int
>
(
attr
),
")"
);
}
}
void
nvte_destroy_matmul_config
(
NVTEMatmulConfig
config
)
{
if
(
config
!=
nullptr
)
{
delete
reinterpret_cast
<
transformer_engine
::
MatmulConfig
*>
(
config
);
}
}
Prev
1
…
3
4
5
6
7
8
9
10
11
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment