Unverified Commit 4f33ece4 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

Add KV cache for paged/non-paged attention (#1355)



* add paged attention; test_kv_cache_accuray and test_paged_attn pass
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove unnecessary change from last commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* test_fused_attn pass
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remove unnecessary import in test_numerics
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add license for test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add to L0 test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update license for test_paged_attn
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update kv_cache_manager license
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix build issue from previous merge
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: minor fix/preparation for inference/cuda graph
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: non-paged
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: non-paged, bshd/sbhd
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: non-paged, thd, no CG
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: non-paged, thd, CG
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: non-paged, CG
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: non-paged, using paged kernel
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: restructure kernels
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: paged, CG
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: padding + BRCM
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: restructure IP, clean up
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: fix non-CG, fused
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: fix last commit
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: unfused, non-CG
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: flash-attn, non-CG
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: flash_attn_with_kvcache
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* commit two files missed by bcef6b34
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: thd_bshd_bshd
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: fix last commit
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: fix 1c31b68d
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: add bshd_2sbhd, sbhd_2bshd
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: some cleanup
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: all qkv_format combinations and merge CM files
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: some lint fixes
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: add docstring for IP
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix sequences_pre
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: minor fixes for multi-layer
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: initial multi-layer test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: minor clean up
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: clean up
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: switch to flash_attn_varlen_func
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: fix unfused for separate q/kv format
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: fix fused for separate q/kv formats
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: flash attn + TELayer + 2 layers
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: unfused + TL + 2layers
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: all modules/backend
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: minor cleanup
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: FlashAttention on Hopper with 2.7.3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: FlashAttention + v3 from 39e7179
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: FlashAttention + v3 + FP8 + WIP
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: add backend support table
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: clean up
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: separate use_flash_attention_2 and _3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: tweaks to paged attn script
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: enable/disable certain cases for fused attn
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: small fixes for lint and cg
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: minor fixes for attn/infer
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: fix CP
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: readd page info to FADescriptor_v1
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor tweak to test_numerics.py
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix 9.5/9.7 sq/skv + mask logic
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* clean up
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fix for FA3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* more minor fixes for FA3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* test page_size=1 for FA3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix t3hd/th3d strides
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix ckpt recompute and fa3 k_scale
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* raise dynamo recompile limit for test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove thunder test from L0
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix FA selection logic
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix FA3 q_descale shape
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remove page_table from IP.step() returns
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix FP8 FlashAttn DPA fp8_dpa tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix CP
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor tweaks
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update FA3 note and L3 test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove redundant import in test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* adopt new FA3 APIs from FA2.7.3+/hopper for CP and non-CP
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* relax tols for TransformerLayers
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix merge
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix merge 2
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix FA import comments
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* relax tols for Ampere
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fa3 version and reduce messaging
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update FA3 to its latest commit on main
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add default values to IP and assertion to graph.py
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add more comments in attention
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* use custom_cache_manager instead of cache_manager
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 05f6a691
...@@ -38,7 +38,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || test_fail ...@@ -38,7 +38,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || test_fail
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py" NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py"
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_paged_attn.py || test_fail "test_paged_attn.py"
if [ "$RET" -ne 0 ]; then if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES" echo "Error in the following test cases:$FAILED_CASES"
exit 1 exit 1
......
...@@ -17,7 +17,7 @@ if [ $sm_arch -gt 90 ] ...@@ -17,7 +17,7 @@ if [ $sm_arch -gt 90 ]
then then
FA_versions=(2.7.3) FA_versions=(2.7.3)
else else
FA_versions=(2.1.1 2.3.0 2.4.1 2.5.7 2.7.3 3.0.0b1) FA_versions=(2.3.0 2.4.1 2.5.7 2.7.3 3.0.0b1)
fi fi
for fa_version in "${FA_versions[@]}" for fa_version in "${FA_versions[@]}"
...@@ -28,10 +28,12 @@ do ...@@ -28,10 +28,12 @@ do
then then
pip3 install flash-attn==${fa_version} pip3 install flash-attn==${fa_version}
else else
pip3 install "git+https://github.com/Dao-AILab/flash-attention.git@v2.7.2#egg=flashattn-hopper&subdirectory=hopper" git clone https://github.com/Dao-AILab/flash-attention.git
python_path=`python3 -c "import site; print(site.getsitepackages()[0])"` cd flash-attention/ && git checkout 27f501d && cd hopper/ && python setup.py install
mkdir -p $python_path/flashattn_hopper python_path=`python -c "import site; print(site.getsitepackages()[0])"`
wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/v2.7.2/hopper/flash_attn_interface.py mkdir -p $python_path/flash_attn_3
wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/27f501dbe011f4371bff938fe7e09311ab3002fa/hopper/flash_attn_interface.py
cd ../../
fi fi
# Run tests # Run tests
......
...@@ -26,6 +26,7 @@ from transformer_engine.pytorch.dot_product_attention.utils import ( ...@@ -26,6 +26,7 @@ from transformer_engine.pytorch.dot_product_attention.utils import (
check_set_window_size, check_set_window_size,
AttentionParams, AttentionParams,
) )
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.dot_product_attention.rope import RotaryPositionEmbedding from transformer_engine.pytorch.dot_product_attention.rope import RotaryPositionEmbedding
from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.constants import TE_DType
import transformer_engine.pytorch.cpp_extensions as ext import transformer_engine.pytorch.cpp_extensions as ext
...@@ -96,6 +97,8 @@ class ModelConfig: ...@@ -96,6 +97,8 @@ class ModelConfig:
num_layers: int = 1, num_layers: int = 1,
bias_shape: str = "1hss", bias_shape: str = "1hss",
window_size: Tuple[int, int] = (-1, -1), window_size: Tuple[int, int] = (-1, -1),
total_requests: int = None,
max_ctx_len: int = None,
): ):
self.batch_size = batch_size self.batch_size = batch_size
self.num_heads = num_heads self.num_heads = num_heads
...@@ -114,6 +117,8 @@ class ModelConfig: ...@@ -114,6 +117,8 @@ class ModelConfig:
self.num_layers = num_layers self.num_layers = num_layers
self.bias_shape = bias_shape self.bias_shape = bias_shape
self.window_size = window_size self.window_size = window_size
self.total_requests = total_requests
self.max_ctx_len = max_ctx_len
@contextmanager @contextmanager
...@@ -136,6 +141,8 @@ def _get_attention_backends( ...@@ -136,6 +141,8 @@ def _get_attention_backends(
deterministic: bool = False, deterministic: bool = False,
fp8: bool = False, fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None, fp8_meta: Optional[Dict[str, Any]] = None,
is_training: bool = True,
inference_params: Optional[InferenceParams] = None,
) -> Tuple[List, List]: ) -> Tuple[List, List]:
"""Check if what attention backends support a model configuration""" """Check if what attention backends support a model configuration"""
...@@ -165,6 +172,7 @@ def _get_attention_backends( ...@@ -165,6 +172,7 @@ def _get_attention_backends(
fused_attn_backends = [] fused_attn_backends = []
available_backends = None available_backends = None
flash_attention_backend = None
fused_attention_backend = None fused_attention_backend = None
def test(): def test():
...@@ -190,10 +198,13 @@ def _get_attention_backends( ...@@ -190,10 +198,13 @@ def _get_attention_backends(
deterministic=deterministic, deterministic=deterministic,
fp8=fp8, fp8=fp8,
fp8_meta=fp8_meta, fp8_meta=fp8_meta,
is_training=is_training,
inference_params=inference_params,
) )
( (
use_flash_attention, use_flash_attention,
use_fused_attention, use_fused_attention,
flash_attention_backend,
fused_attention_backend, fused_attention_backend,
use_unfused_attention, use_unfused_attention,
available_backends, available_backends,
...@@ -202,20 +213,21 @@ def _get_attention_backends( ...@@ -202,20 +213,21 @@ def _get_attention_backends(
# from get_attention_backend() # from get_attention_backend()
_attention_backends["use_flash_attention"] = use_flash_attention _attention_backends["use_flash_attention"] = use_flash_attention
_attention_backends["use_fused_attention"] = use_fused_attention _attention_backends["use_fused_attention"] = use_fused_attention
_attention_backends["flash_attention_backend"] = flash_attention_backend
_attention_backends["fused_attention_backend"] = fused_attention_backend _attention_backends["fused_attention_backend"] = fused_attention_backend
_attention_backends["use_unfused_attention"] = use_unfused_attention _attention_backends["use_unfused_attention"] = use_unfused_attention
_attention_backends["backend_selection_requires_update"] = False _attention_backends["backend_selection_requires_update"] = False
return available_backends, fused_attention_backend return available_backends, flash_attention_backend, fused_attention_backend
backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"} backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"}
with logging_context(): with logging_context():
for i in range(3): for i in range(3):
os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i) os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
available_backends, fused_attention_backend = test() available_backends, flash_attention_backend, fused_attention_backend = test()
if fused_attention_backend == FusedAttnBackend[backends[i]]: if fused_attention_backend == FusedAttnBackend[backends[i]]:
fused_attn_backends.append(fused_attention_backend) fused_attn_backends.append(fused_attention_backend)
return available_backends, fused_attn_backends return available_backends, flash_attention_backend, fused_attn_backends
model_configs_base = { model_configs_base = {
...@@ -267,7 +279,7 @@ def test_dot_product_attention( ...@@ -267,7 +279,7 @@ def test_dot_product_attention(
if config.window_size == (-1, -1) and swa: if config.window_size == (-1, -1) and swa:
config.window_size = [2, 2] config.window_size = [2, 2]
config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
available_backends, fused_attn_backends = _get_attention_backends( available_backends, _, fused_attn_backends = _get_attention_backends(
config, config,
qkv_dtype=dtype, qkv_dtype=dtype,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
...@@ -1131,7 +1143,7 @@ def test_transformer_layer( ...@@ -1131,7 +1143,7 @@ def test_transformer_layer(
workspace_opt = True workspace_opt = True
# Test backend availability # Test backend availability
available_backends, fused_attn_backends = _get_attention_backends( available_backends, _, fused_attn_backends = _get_attention_backends(
config, config,
qkv_dtype=dtype, qkv_dtype=dtype,
qkv_layout="sbh3d" if fused_qkv_params else "sb3hd", qkv_layout="sbh3d" if fused_qkv_params else "sb3hd",
......
This diff is collapsed.
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
from collections import OrderedDict
import math import math
import os import os
from typing import Dict, List, Optional from typing import Dict, List, Optional
...@@ -59,6 +60,8 @@ torch.cuda.manual_seed(seed) ...@@ -59,6 +60,8 @@ torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state() _cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state() _cuda_rng_state = torch.cuda.get_rng_state()
torch._dynamo.config.recompile_limit = 16
class ModelConfig: class ModelConfig:
def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq_len): def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq_len):
...@@ -77,9 +80,9 @@ model_configs = { ...@@ -77,9 +80,9 @@ model_configs = {
model_configs_inference = { model_configs_inference = {
# hidden_size, eps, num_attention_heads, embed, num_layers, seq_len # hidden_size, eps, num_attention_heads, embed, num_layers, seq_len
"126m": ModelConfig(768, 1e-5, 12, 64, 12, 16), "126m": ModelConfig(768, 1e-5, 12, 64, 12, 256),
} }
backends_inference = ["FlashAttention", "UnfusedAttention"] backends_inference = ["FlashAttention", "UnfusedAttention", "FusedAttention"]
module_inference = ["TransformerLayer", "MultiheadAttention"] module_inference = ["TransformerLayer", "MultiheadAttention"]
input_formats_inference = ["sbhd", "bshd"] input_formats_inference = ["sbhd", "bshd"]
...@@ -2037,14 +2040,25 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): ...@@ -2037,14 +2040,25 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
@pytest.mark.parametrize("input_format", input_formats_inference) @pytest.mark.parametrize("input_format", input_formats_inference)
@pytest.mark.parametrize("module", module_inference) @pytest.mark.parametrize("module", module_inference)
@pytest.mark.parametrize("backend", backends_inference) @pytest.mark.parametrize("backend", backends_inference)
def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, backend): @pytest.mark.parametrize("is_paged", [False, True])
def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, backend, is_paged):
reset_rng_states()
if backend in ["FusedAttention", "FlashAttention"] and dtype == torch.float32:
pytest.skip("FusedAttention and FlashAttention do not support FP32")
if use_RoPE:
pytest.skip("KV cache does not support starting positions for RoPE")
os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
if backend == "FlashAttention": if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FLASH_ATTN"] = "1"
elif backend == "FusedAttention": elif backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1"
elif backend == "UnfusedAttention":
os.environ["NVTE_UNFUSED_ATTN"] = "1"
config = model_configs_inference[model_key] config = model_configs_inference[model_key]
...@@ -2057,7 +2071,7 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, ...@@ -2057,7 +2071,7 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
# Limits the max size of KV-cache # Limits the max size of KV-cache
B_max = B B_max = B
S_max = S + 2 S_max = S
if module == "TransformerLayer": if module == "TransformerLayer":
model = TransformerLayer( model = TransformerLayer(
...@@ -2087,7 +2101,17 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, ...@@ -2087,7 +2101,17 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
.eval() .eval()
) )
inference_params = InferenceParams(max_batch_size=B_max, max_sequence_length=S_max) inference_params = InferenceParams(
max_batch_size=B_max,
max_seqlen_kv=S_max,
num_heads_kv=H,
head_dim_k=head_size,
dtype=dtype,
is_paged=is_paged,
total_num_pages=int(B_max * S_max / 256),
page_size=256,
)
rotary_freqs = torch.randn((S_max, 1, 1, head_size), dtype=torch.float, device="cuda") rotary_freqs = torch.randn((S_max, 1, 1, head_size), dtype=torch.float, device="cuda")
input = torch.randn((S, B, D), dtype=dtype, device="cuda") input = torch.randn((S, B, D), dtype=dtype, device="cuda")
...@@ -2100,22 +2124,39 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, ...@@ -2100,22 +2124,39 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
full_output = model(hidden_states=input, rotary_pos_emb=rotary_freqs if use_RoPE else None) full_output = model(hidden_states=input, rotary_pos_emb=rotary_freqs if use_RoPE else None)
# Incrementaly generate outputs using KV-cache # Incrementaly generate outputs using KV-cache
step_dict = OrderedDict(zip(list(range(B)), [1] * B))
for i in range(S): for i in range(S):
inference_params.pre_step(step_dict)
if input_format == "sbhd": if input_format == "sbhd":
incremental_input = input[i].view(1, B, D) incremental_input = input[i].view(1, B, D)
else: else:
incremental_input = input[:, i, :].view(B, 1, D) incremental_input = input[:, i, :].view(B, 1, D)
seqlens_q = torch.ones(B, dtype=torch.int32, device="cuda")
cu_seqlens_q = torch.zeros(B + 1, dtype=torch.int32, device="cuda")
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
cu_seqlens_kv = cu_seqlens_q.clone()
mask_type = "padding"
kwargs = {}
if module == "TransformerLayer":
kwargs["self_attn_mask_type"] = mask_type
else:
kwargs["attn_mask_type"] = mask_type
line_output = model( line_output = model(
hidden_states=incremental_input, hidden_states=incremental_input,
inference_params=inference_params, inference_params=inference_params,
rotary_pos_emb=rotary_freqs if use_RoPE else None, rotary_pos_emb=rotary_freqs if use_RoPE else None,
**kwargs,
max_seqlen_q=1,
max_seqlen_kv=S,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
) )
inference_params.sequence_len_offset += 1
if input_format == "sbhd": if input_format == "sbhd":
incremental_output[i] = line_output.view(B, D) incremental_output[i, :, :] = line_output.view(B, D)
else: else:
incremental_output[:, i, :] = line_output.view(B, D) incremental_output[:, i, :] = line_output.view(B, D)
......
...@@ -37,7 +37,18 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) { ...@@ -37,7 +37,18 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) {
case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD: case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_THD_THD_THD: case NVTE_QKV_Layout::NVTE_THD_THD_THD:
case NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD:
return NVTE_QKV_Layout_Group::NVTE_HD_HD_HD; return NVTE_QKV_Layout_Group::NVTE_HD_HD_HD;
case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD:
return NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD;
default: default:
NVTE_ERROR("qkv_layout not supported!"); NVTE_ERROR("qkv_layout not supported!");
} }
...@@ -51,12 +62,14 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) { ...@@ -51,12 +62,14 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) {
case NVTE_QKV_Layout::NVTE_SBHD_SB2HD: case NVTE_QKV_Layout::NVTE_SBHD_SB2HD:
case NVTE_QKV_Layout::NVTE_SBHD_SBH2D: case NVTE_QKV_Layout::NVTE_SBHD_SBH2D:
case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD:
return NVTE_QKV_Format::NVTE_SBHD; return NVTE_QKV_Format::NVTE_SBHD;
case NVTE_QKV_Layout::NVTE_BS3HD: case NVTE_QKV_Layout::NVTE_BS3HD:
case NVTE_QKV_Layout::NVTE_BSH3D: case NVTE_QKV_Layout::NVTE_BSH3D:
case NVTE_QKV_Layout::NVTE_BSHD_BS2HD: case NVTE_QKV_Layout::NVTE_BSHD_BS2HD:
case NVTE_QKV_Layout::NVTE_BSHD_BSH2D: case NVTE_QKV_Layout::NVTE_BSHD_BSH2D:
case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD: case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD:
return NVTE_QKV_Format::NVTE_BSHD; return NVTE_QKV_Format::NVTE_BSHD;
case NVTE_QKV_Layout::NVTE_T3HD: case NVTE_QKV_Layout::NVTE_T3HD:
case NVTE_QKV_Layout::NVTE_TH3D: case NVTE_QKV_Layout::NVTE_TH3D:
...@@ -64,6 +77,56 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) { ...@@ -64,6 +77,56 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) {
case NVTE_QKV_Layout::NVTE_THD_TH2D: case NVTE_QKV_Layout::NVTE_THD_TH2D:
case NVTE_QKV_Layout::NVTE_THD_THD_THD: case NVTE_QKV_Layout::NVTE_THD_THD_THD:
return NVTE_QKV_Format::NVTE_THD; return NVTE_QKV_Format::NVTE_THD;
case NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD:
return NVTE_QKV_Format::NVTE_SBHD_2BSHD;
case NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD:
return NVTE_QKV_Format::NVTE_BSHD_2SBHD;
case NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD:
return NVTE_QKV_Format::NVTE_THD_2BSHD;
case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD:
return NVTE_QKV_Format::NVTE_THD_2SBHD;
default:
NVTE_ERROR("qkv_layout not supported!");
}
}
// map NVTE_QKV_Layout to NVTE_QKV_Format for Q
NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout) {
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
switch (qkv_format) {
case NVTE_QKV_Format::NVTE_SBHD:
case NVTE_QKV_Format::NVTE_SBHD_2BSHD:
return NVTE_QKV_Format::NVTE_SBHD;
case NVTE_QKV_Format::NVTE_BSHD:
case NVTE_QKV_Format::NVTE_BSHD_2SBHD:
return NVTE_QKV_Format::NVTE_BSHD;
case NVTE_QKV_Format::NVTE_THD:
case NVTE_QKV_Format::NVTE_THD_2BSHD:
case NVTE_QKV_Format::NVTE_THD_2SBHD:
return NVTE_QKV_Format::NVTE_THD;
default:
NVTE_ERROR("qkv_layout not supported!");
}
}
// map NVTE_QKV_Layout to NVTE_QKV_Format for KV
NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) {
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
switch (qkv_format) {
case NVTE_QKV_Format::NVTE_SBHD:
case NVTE_QKV_Format::NVTE_BSHD_2SBHD:
case NVTE_QKV_Format::NVTE_THD_2SBHD:
return NVTE_QKV_Format::NVTE_SBHD;
case NVTE_QKV_Format::NVTE_BSHD:
case NVTE_QKV_Format::NVTE_SBHD_2BSHD:
case NVTE_QKV_Format::NVTE_THD_2BSHD:
return NVTE_QKV_Format::NVTE_BSHD;
case NVTE_QKV_Format::NVTE_THD:
return NVTE_QKV_Format::NVTE_THD;
default: default:
NVTE_ERROR("qkv_layout not supported!"); NVTE_ERROR("qkv_layout not supported!");
} }
...@@ -81,6 +144,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -81,6 +144,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
const int sm_arch_ = cuda::sm_arch(device_id); const int sm_arch_ = cuda::sm_arch(device_id);
NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type."); NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type.");
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout);
NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto cudnn_runtime_version = cudnnGetVersion(); auto cudnn_runtime_version = cudnnGetVersion();
...@@ -202,11 +267,31 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -202,11 +267,31 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK &&
max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv && max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv &&
bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) || bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) ||
// 9.5: adds {paged_kv_bshd, paged_kv_sbhd} + {padding, padding_causal, padding_causal_bottom_right}
(cudnn_runtime_version >= 90500 &&
layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD &&
(attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
(attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK &&
max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv)) &&
bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) ||
// 9.6: adds {bshd, sbhd, thd} + padding_causal_bottom_right + self/cross-attn (sq <= skv) // 9.6: adds {bshd, sbhd, thd} + padding_causal_bottom_right + self/cross-attn (sq <= skv)
(cudnn_runtime_version >= 90600 && (cudnn_runtime_version >= 90600 &&
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK &&
max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv && max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv &&
bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0)) && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) ||
// 9.7: removes s_q/s_kv % 64 = 0 for {causal_bottom_right, padding_causal_bottom_right}
// for any q_format/kv_format, and paged/non-paged
(cudnn_runtime_version >= 90700 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
((attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) &&
bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) ||
((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) &&
max_seqlen_q <= max_seqlen_kv)))) &&
// bias + mask combination // bias + mask combination
(!(cudnn_runtime_version >= 8906 && (!(cudnn_runtime_version >= 8906 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
...@@ -216,7 +301,12 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -216,7 +301,12 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD || (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD ||
(qkv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90 && (qkv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90 &&
((cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups) || ((cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups) ||
cudnn_runtime_version >= 90600))) && cudnn_runtime_version >= 90600)) ||
((q_format == NVTE_QKV_Format::NVTE_SBHD || q_format == NVTE_QKV_Format::NVTE_BSHD ||
(q_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90) ||
kv_format == NVTE_QKV_Format::NVTE_SBHD || kv_format == NVTE_QKV_Format::NVTE_BSHD ||
(kv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90)) &&
cudnn_runtime_version >= 90700)) &&
// sliding window // sliding window
// pre-9.2: full attn, causal // pre-9.2: full attn, causal
((cudnn_runtime_version < 90200 && window_size_left == -1 && ((cudnn_runtime_version < 90200 && window_size_left == -1 &&
...@@ -465,22 +555,23 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con ...@@ -465,22 +555,23 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
} }
} }
// NVTE fused attention FWD with packed KV // NVTE fused attention FWD with packed KV
void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, void nvte_fused_attn_fwd_kvpacked(
NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, int64_t window_size_left, int64_t window_size_right, NVTETensor workspace,
int64_t window_size_left, int64_t window_size_right, cudaStream_t stream) {
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q); const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv); const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv);
const Tensor *input_cu_seqlens_q_padded = reinterpret_cast<const Tensor *>(cu_seqlens_q_padded); const Tensor *input_cu_seqlens_q_padded = reinterpret_cast<const Tensor *>(cu_seqlens_q_padded);
const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast<const Tensor *>(cu_seqlens_kv_padded); const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast<const Tensor *>(cu_seqlens_kv_padded);
const Tensor *input_page_table_k = reinterpret_cast<const Tensor *>(page_table_k);
const Tensor *input_page_table_v = reinterpret_cast<const Tensor *>(page_table_v);
const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(rng_state); const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(rng_state);
const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q); const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q);
const Tensor *input_KV = reinterpret_cast<const Tensor *>(KV); const Tensor *input_KV = reinterpret_cast<const Tensor *>(KV);
...@@ -505,11 +596,40 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const ...@@ -505,11 +596,40 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
} }
size_t t_q = 0; size_t t_q = 0;
size_t t_kv = 0; size_t t_kv = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) { NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
if (q_format == NVTE_QKV_Format::NVTE_THD) {
t_q = input_Q->data.shape[0]; t_q = input_Q->data.shape[0];
}
if (kv_format == NVTE_QKV_Format::NVTE_THD) {
t_kv = input_KV->data.shape[0]; t_kv = input_KV->data.shape[0];
} }
int64_t num_pages_k = 0;
int64_t num_pages_v = 0;
int64_t page_size_k = 0;
int64_t page_size_v = 0;
int64_t max_pages_per_seq_k = 0;
int64_t max_pages_per_seq_v = 0;
if (input_page_table_k->data.dptr != nullptr) {
max_pages_per_seq_k = input_page_table_k->data.shape[1];
}
if (input_page_table_v->data.dptr != nullptr) {
max_pages_per_seq_v = input_page_table_v->data.shape[1];
}
if (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) {
NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
if (kv_format == NVTE_QKV_Format::NVTE_BSHD) {
num_pages_k = input_KV->data.shape[0];
page_size_k = input_KV->data.shape[1];
num_pages_v = num_pages_v;
page_size_v = page_size_v;
} else if (kv_format == NVTE_QKV_Format::NVTE_SBHD) {
num_pages_k = input_KV->data.shape[1];
page_size_k = input_KV->data.shape[0];
num_pages_v = num_pages_v;
page_size_v = page_size_v;
}
}
auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); auto handle = cudnnExecutionPlanManager::Instance().GetHandle();
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype); const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
...@@ -531,11 +651,12 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const ...@@ -531,11 +651,12 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8903) #if (CUDNN_VERSION >= 8903)
fused_attn_arbitrary_seqlen_fwd_kvpacked( fused_attn_arbitrary_seqlen_fwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, is_training, attn_scale, dropout, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, num_pages_k, num_pages_v,
qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale,
input_KV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right,
input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q,
handle); input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded,
input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle);
#else #else
NVTE_ERROR( NVTE_ERROR(
"cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
...@@ -596,9 +717,12 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -596,9 +717,12 @@ void nvte_fused_attn_bwd_kvpacked(
} }
size_t t_q = 0; size_t t_q = 0;
size_t t_kv = 0; size_t t_kv = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) { NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
if (q_format == NVTE_QKV_Format::NVTE_THD) {
t_q = input_Q->data.shape[0]; t_q = input_Q->data.shape[0];
}
if (kv_format == NVTE_QKV_Format::NVTE_THD) {
t_kv = input_KV->data.shape[0]; t_kv = input_KV->data.shape[0];
} }
...@@ -664,7 +788,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -664,7 +788,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const NVTETensor Bias, NVTETensor S, NVTETensor O, const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
...@@ -676,6 +801,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -676,6 +801,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv); const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv);
const Tensor *input_cu_seqlens_q_padded = reinterpret_cast<const Tensor *>(cu_seqlens_q_padded); const Tensor *input_cu_seqlens_q_padded = reinterpret_cast<const Tensor *>(cu_seqlens_q_padded);
const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast<const Tensor *>(cu_seqlens_kv_padded); const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast<const Tensor *>(cu_seqlens_kv_padded);
const Tensor *input_page_table_k = reinterpret_cast<const Tensor *>(page_table_k);
const Tensor *input_page_table_v = reinterpret_cast<const Tensor *>(page_table_v);
const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(rng_state); const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(rng_state);
const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q); const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q);
const Tensor *input_K = reinterpret_cast<const Tensor *>(K); const Tensor *input_K = reinterpret_cast<const Tensor *>(K);
...@@ -686,18 +813,49 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -686,18 +813,49 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
Tensor *wkspace = reinterpret_cast<Tensor *>(workspace); Tensor *wkspace = reinterpret_cast<Tensor *>(workspace);
auto ndim = input_Q->data.shape.size(); auto ndim = input_Q->data.shape.size();
auto ndim_kv = input_K->data.shape.size();
size_t b = input_cu_seqlens_q->data.shape[0] - 1; size_t b = input_cu_seqlens_q->data.shape[0] - 1;
size_t h_q = input_Q->data.shape[ndim - 2]; size_t h_q = input_Q->data.shape[ndim - 2];
size_t h_kv = input_K->data.shape[ndim - 2]; size_t h_kv = input_K->data.shape[ndim_kv - 2];
size_t d_qk = input_Q->data.shape[ndim - 1]; size_t d_qk = input_Q->data.shape[ndim - 1];
size_t d_v = input_V->data.shape[ndim - 1]; size_t d_v = input_V->data.shape[ndim_kv - 1];
size_t t_q = 0; size_t t_q = 0;
size_t t_kv = 0; size_t t_kv = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) { NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
if (q_format == NVTE_QKV_Format::NVTE_THD) {
t_q = input_Q->data.shape[0]; t_q = input_Q->data.shape[0];
}
if (kv_format == NVTE_QKV_Format::NVTE_THD) {
t_kv = input_K->data.shape[0]; t_kv = input_K->data.shape[0];
} }
int64_t num_pages_k = 0;
int64_t num_pages_v = 0;
int64_t page_size_k = 0;
int64_t page_size_v = 0;
int64_t max_pages_per_seq_k = 0;
int64_t max_pages_per_seq_v = 0;
if (input_page_table_k->data.dptr != nullptr) {
max_pages_per_seq_k = input_page_table_k->data.shape[1];
}
if (input_page_table_v->data.dptr != nullptr) {
max_pages_per_seq_v = input_page_table_v->data.shape[1];
}
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) {
NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
if (kv_format == NVTE_QKV_Format::NVTE_BSHD) {
num_pages_k = input_K->data.shape[0];
page_size_k = input_K->data.shape[1];
num_pages_v = input_V->data.shape[0];
page_size_v = input_V->data.shape[1];
} else if (kv_format == NVTE_QKV_Format::NVTE_SBHD) {
num_pages_k = input_K->data.shape[1];
page_size_k = input_K->data.shape[0];
num_pages_v = input_V->data.shape[1];
page_size_v = input_V->data.shape[0];
}
}
auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); auto handle = cudnnExecutionPlanManager::Instance().GetHandle();
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype); const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
...@@ -719,11 +877,12 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -719,11 +877,12 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
fused_attn_arbitrary_seqlen_fwd( fused_attn_arbitrary_seqlen_fwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, is_training, attn_scale, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v,
page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale,
dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right,
input_Q, input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_Q, input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q,
input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded,
wkspace, stream, handle); input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle);
#else #else
NVTE_ERROR( NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
...@@ -773,16 +932,20 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -773,16 +932,20 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
Tensor *wkspace = reinterpret_cast<Tensor *>(workspace); Tensor *wkspace = reinterpret_cast<Tensor *>(workspace);
auto ndim = input_Q->data.shape.size(); auto ndim = input_Q->data.shape.size();
auto ndim_kv = input_K->data.shape.size();
size_t b = input_cu_seqlens_q->data.shape[0] - 1; size_t b = input_cu_seqlens_q->data.shape[0] - 1;
size_t h_q = input_Q->data.shape[ndim - 2]; size_t h_q = input_Q->data.shape[ndim - 2];
size_t h_kv = input_K->data.shape[ndim - 2]; size_t h_kv = input_K->data.shape[ndim_kv - 2];
size_t d_qk = input_Q->data.shape[ndim - 1]; size_t d_qk = input_Q->data.shape[ndim - 1];
size_t d_v = input_V->data.shape[ndim - 1]; size_t d_v = input_V->data.shape[ndim_kv - 1];
size_t t_q = 0; size_t t_q = 0;
size_t t_kv = 0; size_t t_kv = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) { NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
if (q_format == NVTE_QKV_Format::NVTE_THD) {
t_q = input_Q->data.shape[0]; t_q = input_Q->data.shape[0];
}
if (kv_format == NVTE_QKV_Format::NVTE_THD) {
t_kv = input_K->data.shape[0]; t_kv = input_K->data.shape[0];
} }
......
...@@ -38,13 +38,15 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( ...@@ -38,13 +38,15 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
void fused_attn_arbitrary_seqlen_fwd_kvpacked( void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv,
bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale,
int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
cudaStream_t stream, cudnnHandle_t handle); const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded,
const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_kvpacked( void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
...@@ -61,13 +63,15 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( ...@@ -61,13 +63,15 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
void fused_attn_arbitrary_seqlen_fwd( void fused_attn_arbitrary_seqlen_fwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q,
size_t num_tokens_kv, bool is_training, float attn_scale, float p_dropout, size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd( void fused_attn_arbitrary_seqlen_bwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
......
...@@ -1679,6 +1679,12 @@ void fused_attn_fp8_fwd_impl_v1( ...@@ -1679,6 +1679,12 @@ void fused_attn_fp8_fwd_impl_v1(
s_kv, s_kv,
d, d,
d, d,
0,
0,
0,
0,
0,
0,
bias_b, bias_b,
bias_h, bias_h,
scaling_factor, scaling_factor,
...@@ -1977,6 +1983,12 @@ void fused_attn_fp8_bwd_impl_v1( ...@@ -1977,6 +1983,12 @@ void fused_attn_fp8_bwd_impl_v1(
s_kv, s_kv,
d, d,
d, d,
0,
0,
0,
0,
0,
0,
bias_b, bias_b,
bias_h, bias_h,
scaling_factor, scaling_factor,
......
...@@ -117,6 +117,7 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6 ...@@ -117,6 +117,7 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6
} }
break; break;
case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD:
if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) ||
(matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) ||
(matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) ||
...@@ -223,6 +224,9 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6 ...@@ -223,6 +224,9 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6
break; break;
case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD: case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_THD_THD_THD: case NVTE_QKV_Layout::NVTE_THD_THD_THD:
case NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD:
if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) ||
(matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) { (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) {
strideA[batch_dim_idx] = s_q * h * d; strideA[batch_dim_idx] = s_q * h * d;
...@@ -243,6 +247,52 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6 ...@@ -243,6 +247,52 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6
strideA[hidden_transpose_dim_idx] = 1; strideA[hidden_transpose_dim_idx] = 1;
} }
break; break;
case NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD:
if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) ||
(matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) {
strideA[batch_dim_idx] = s_kv * h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = h * d;
strideA[hidden_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) ||
(matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) {
strideA[batch_dim_idx] = s_kv * h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_transpose_dim_idx] = h * d;
strideA[hidden_transpose_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) ||
(matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) {
strideA[batch_dim_idx] = h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = b * h * d;
strideA[hidden_dim_idx] = 1;
}
break;
case NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD:
if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) ||
(matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) {
strideA[batch_dim_idx] = h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = b * h * d;
strideA[hidden_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) ||
(matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) {
strideA[batch_dim_idx] = h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_transpose_dim_idx] = b * h * d;
strideA[hidden_transpose_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) ||
(matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) {
strideA[batch_dim_idx] = s_q * h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = h * d;
strideA[hidden_dim_idx] = 1;
}
break;
} }
if (matrix == NVTE_QKV_Matrix::NVTE_S_Matrix) { if (matrix == NVTE_QKV_Matrix::NVTE_S_Matrix) {
...@@ -379,30 +429,46 @@ __device__ void cu_seqlens_padded_to_offsets_impl( ...@@ -379,30 +429,46 @@ __device__ void cu_seqlens_padded_to_offsets_impl(
size_t tid = blockIdx.x * blockDim.x + threadIdx.x; size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
auto cu_seqlens_id = min(tid, actual_b); auto cu_seqlens_id = min(tid, actual_b);
if (tid <= max_b) { if (tid <= max_b) {
offsets_o[tid] = h * d_v * cu_seqlens_q_padded[cu_seqlens_id];
if (offsets_s != nullptr) { if (offsets_s != nullptr) {
offsets_s[tid] = h * cu_seqlens_q_padded[cu_seqlens_id]; offsets_s[tid] = h * cu_seqlens_q_padded[cu_seqlens_id];
} }
if (offsets_q != nullptr && offsets_o != nullptr) {
offsets_o[tid] = h * d_v * cu_seqlens_q_padded[cu_seqlens_id];
switch (layout_group) { switch (layout_group) {
case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD:
case NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD:
offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[cu_seqlens_id]; offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[cu_seqlens_id];
offsets_k[tid] = hg * d_qk * cu_seqlens_kv_padded[cu_seqlens_id];
offsets_v[tid] = hg * d_v * cu_seqlens_kv_padded[cu_seqlens_id];
break; break;
case NVTE_QKV_Layout_Group::NVTE_3HD: case NVTE_QKV_Layout_Group::NVTE_3HD:
case NVTE_QKV_Layout_Group::NVTE_H3D: case NVTE_QKV_Layout_Group::NVTE_H3D:
offsets_q[tid] = 3 * h * d_qk * cu_seqlens_q_padded[cu_seqlens_id]; offsets_q[tid] = 3 * h * d_qk * cu_seqlens_q_padded[cu_seqlens_id];
offsets_k[tid] = offsets_q[cu_seqlens_id];
offsets_v[tid] = offsets_q[cu_seqlens_id];
break; break;
case NVTE_QKV_Layout_Group::NVTE_HD_2HD: case NVTE_QKV_Layout_Group::NVTE_HD_2HD:
case NVTE_QKV_Layout_Group::NVTE_HD_H2D: case NVTE_QKV_Layout_Group::NVTE_HD_H2D:
offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[cu_seqlens_id]; offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[cu_seqlens_id];
break;
}
}
if (offsets_k != nullptr && offsets_v != nullptr) {
switch (layout_group) {
case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD:
case NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD:
offsets_k[tid] = hg * d_qk * cu_seqlens_kv_padded[cu_seqlens_id];
offsets_v[tid] = hg * d_v * cu_seqlens_kv_padded[cu_seqlens_id];
break;
case NVTE_QKV_Layout_Group::NVTE_3HD:
case NVTE_QKV_Layout_Group::NVTE_H3D:
offsets_k[tid] = 3 * h * d_qk * cu_seqlens_q_padded[cu_seqlens_id];
offsets_v[tid] = offsets_k[cu_seqlens_id];
break;
case NVTE_QKV_Layout_Group::NVTE_HD_2HD:
case NVTE_QKV_Layout_Group::NVTE_HD_H2D:
offsets_k[tid] = 2 * hg * d_qk * cu_seqlens_kv_padded[cu_seqlens_id]; offsets_k[tid] = 2 * hg * d_qk * cu_seqlens_kv_padded[cu_seqlens_id];
offsets_v[tid] = offsets_k[cu_seqlens_id]; offsets_v[tid] = offsets_k[cu_seqlens_id];
break; break;
} }
} }
}
} }
__global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, int64_t actual_b, __global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, int64_t actual_b,
...@@ -433,6 +499,7 @@ DType get_ragged_offset_dtype(NVTE_QKV_Layout_Group layout_group, int64_t num_at ...@@ -433,6 +499,7 @@ DType get_ragged_offset_dtype(NVTE_QKV_Layout_Group layout_group, int64_t num_at
std::array<int64_t, 4> offsets_qkvo{}; std::array<int64_t, 4> offsets_qkvo{};
switch (layout_group) { switch (layout_group) {
case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD:
case NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD:
offsets_qkvo[0] = num_attn_heads * head_dim_qk * max_seqlen_q; offsets_qkvo[0] = num_attn_heads * head_dim_qk * max_seqlen_q;
offsets_qkvo[1] = num_gqa_groups * head_dim_qk * max_seqlen_kv; offsets_qkvo[1] = num_gqa_groups * head_dim_qk * max_seqlen_kv;
offsets_qkvo[2] = num_gqa_groups * head_dim_v * max_seqlen_kv; offsets_qkvo[2] = num_gqa_groups * head_dim_v * max_seqlen_kv;
......
...@@ -93,6 +93,12 @@ struct FADescriptor_v1 { ...@@ -93,6 +93,12 @@ struct FADescriptor_v1 {
std::int64_t s_kv; std::int64_t s_kv;
std::int64_t d_qk; std::int64_t d_qk;
std::int64_t d_v; std::int64_t d_v;
std::int64_t num_pages_k;
std::int64_t num_pages_v;
std::int64_t page_size_k;
std::int64_t page_size_v;
std::int64_t max_pages_per_seq_k;
std::int64_t max_pages_per_seq_v;
std::int64_t bias_b; std::int64_t bias_b;
std::int64_t bias_h; std::int64_t bias_h;
float attnScale; float attnScale;
...@@ -108,13 +114,16 @@ struct FADescriptor_v1 { ...@@ -108,13 +114,16 @@ struct FADescriptor_v1 {
cudnn_frontend::DataType_t bwd_tensor_type; cudnn_frontend::DataType_t bwd_tensor_type;
bool operator<(const FADescriptor_v1 &rhs) const { bool operator<(const FADescriptor_v1 &rhs) const {
return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, bias_b, bias_h, attnScale, isTraining, return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k,
dropoutProbability, layout, mask_type, window_size_left, window_size_right, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h,
deterministic, bias_type, fwd_tensor_type, bwd_tensor_type) < attnScale, isTraining, dropoutProbability, layout, mask_type, window_size_left,
std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.bias_b, window_size_right, deterministic, bias_type, fwd_tensor_type, bwd_tensor_type) <
rhs.bias_h, rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k,
rhs.mask_type, rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k,
rhs.bias_type, rhs.fwd_tensor_type, rhs.bwd_tensor_type); rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining,
rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.window_size_left,
rhs.window_size_right, rhs.deterministic, rhs.bias_type, rhs.fwd_tensor_type,
rhs.bwd_tensor_type);
} }
}; };
......
...@@ -25,7 +25,7 @@ extern "C" { ...@@ -25,7 +25,7 @@ extern "C" {
* head size, and the total number of tokens in a batch, i.e. `t = sum(s_i) for i = 0...b-1`. * head size, and the total number of tokens in a batch, i.e. `t = sum(s_i) for i = 0...b-1`.
* `SBHD` and `BSHD`-based layouts are used when sequences in a batch are of equal length * `SBHD` and `BSHD`-based layouts are used when sequences in a batch are of equal length
* or padded to the same length, and `THD`-based layouts are used when sequences have * or padded to the same length, and `THD`-based layouts are used when sequences have
* different lengths in a batch. * different lengths in a batch. `Paged_KV`-based layouts are used for paged attention.
*/ */
enum NVTE_QKV_Layout { enum NVTE_QKV_Layout {
NVTE_SB3HD = 0, /*!< SB3HD layout */ NVTE_SB3HD = 0, /*!< SB3HD layout */
...@@ -43,6 +43,16 @@ enum NVTE_QKV_Layout { ...@@ -43,6 +43,16 @@ enum NVTE_QKV_Layout {
NVTE_THD_T2HD = 12, /*!< THD_T2HD layout */ NVTE_THD_T2HD = 12, /*!< THD_T2HD layout */
NVTE_THD_TH2D = 13, /*!< THD_TH2D layout */ NVTE_THD_TH2D = 13, /*!< THD_TH2D layout */
NVTE_THD_THD_THD = 14, /*!< THD_THD_THD layout */ NVTE_THD_THD_THD = 14, /*!< THD_THD_THD layout */
NVTE_SBHD_BSHD_BSHD = 15, /*!< SBHD_BSHD_BSHD layout */
NVTE_BSHD_SBHD_SBHD = 16, /*!< BSHD_SBHD_SBHD layout */
NVTE_THD_BSHD_BSHD = 17, /*!< THD_BSHD_BSHD layout */
NVTE_THD_SBHD_SBHD = 18, /*!< THD_SBHD_SBHD layout */
NVTE_Paged_KV_BSHD_BSHD_BSHD = 19, /*!< Paged_KV_BSHD_BSHD_BSHD layout */
NVTE_Paged_KV_BSHD_SBHD_SBHD = 20, /*!< Paged_KV_BSHD_SBHD_SBHD layout */
NVTE_Paged_KV_SBHD_BSHD_BSHD = 21, /*!< Paged_KV_SBHD_BSHD_BSHD layout */
NVTE_Paged_KV_SBHD_SBHD_SBHD = 22, /*!< Paged_KV_SBHD_SBHD_SBHD layout */
NVTE_Paged_KV_THD_BSHD_BSHD = 23, /*!< Paged_KV_THD_BSHD_BSHD layout */
NVTE_Paged_KV_THD_SBHD_SBHD = 24, /*!< Paged_KV_THD_SBHD_SBHD layout */
}; };
/*! \enum NVTE_QKV_Layout_Group /*! \enum NVTE_QKV_Layout_Group
...@@ -59,18 +69,28 @@ enum NVTE_QKV_Layout_Group { ...@@ -59,18 +69,28 @@ enum NVTE_QKV_Layout_Group {
NVTE_HD_H2D = 3, NVTE_HD_H2D = 3,
/*! HD_HD_HD QKV layouts, i.e. BSHD_BSHD_BSHD, SBHD_SBHD_SBHD, THD_THD_THD */ /*! HD_HD_HD QKV layouts, i.e. BSHD_BSHD_BSHD, SBHD_SBHD_SBHD, THD_THD_THD */
NVTE_HD_HD_HD = 4, NVTE_HD_HD_HD = 4,
/*! Paged_KV_HD_HD_HD QKV layouts, e.g. Paged_KV_BSHD_BSHD_BSHD, Paged_KV_THD_SBHD_SBHD */
NVTE_Paged_KV_HD_HD_HD = 5,
}; };
/*! \enum NVTE_QKV_Format /*! \enum NVTE_QKV_Format
* \brief QKV formats * \brief QKV formats
*/ */
enum NVTE_QKV_Format { enum NVTE_QKV_Format {
/*! SBHD QKV format, i.e. SB3HD, SBH3D, SBHD_SB2HD, SBHD_SBH2D, SBHD_SBHD_SBHD */ /*! SBHD QKV format, i.e. SB3HD, SBH3D, SBHD_SB2HD, SBHD_SBH2D, SBHD_SBHD_SBHD, Paged_KV_SBHD_SBHD_SBHD */
NVTE_SBHD = 0, NVTE_SBHD = 0,
/*! BSHD QKV format, i.e. BS3HD, BSH3D, BSHD_BS2HD, BSHD_BSH2D, BSHD_BSHD_BSHD */ /*! BSHD QKV format, i.e. BS3HD, BSH3D, BSHD_BS2HD, BSHD_BSH2D, BSHD_BSHD_BSHD, Paged_KV_BSHD_BSHD_BSHD */
NVTE_BSHD = 1, NVTE_BSHD = 1,
/*! THD QKV format, i.e. T3HD, TH3D, THD_T2HD, THD_TH2D, THD_THD_THD */ /*! THD QKV format, i.e. T3HD, TH3D, THD_T2HD, THD_TH2D, THD_THD_THD */
NVTE_THD = 2, NVTE_THD = 2,
/*! BSHD format for Q and SBHD format for KV, i.e. BSHD_SBHD_SBHD, Paged_KV_BSHD_SBHD_SBHD */
NVTE_BSHD_2SBHD = 3,
/*! SBHD format for Q and BSHD format for KV, i.e. SBHD_BSHD_BSHD, Paged_KV_SBHD_BSHD_BSHD */
NVTE_SBHD_2BSHD = 4,
/*! THD format for Q and BSHD format for KV, i.e. THD_BSHD_BSHD, Paged_KV_THD_BSHD_BSHD */
NVTE_THD_2BSHD = 5,
/*! THD format for Q and SBHD format for KV, i.e. THD_SBHD_SBHD, Paged_KV_THD_SBHD_SBHD */
NVTE_THD_2SBHD = 6,
}; };
/*! \enum NVTE_Bias_Type /*! \enum NVTE_Bias_Type
...@@ -135,6 +155,22 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout); ...@@ -135,6 +155,22 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout);
*/ */
NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout); NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout);
/*! \brief Get Q format for a given QKV layout.
*
* \param[in] qkv_layout QKV layout, e.g. sbhd_bshd_bshd.
*
* \return q format, e.g. sbhd.
*/
NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout);
/*! \brief Get KV format for a given QKV layout.
*
* \param[in] qkv_layout QKV layout, e.g. sbhd_bshd_bshd.
*
* \return kv format, e.g. bshd.
*/
NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
/*! \brief Get fused attention backend based on input parameters. /*! \brief Get fused attention backend based on input parameters.
* *
* \param[in] q_dtype The data type of Tensor Q. * \param[in] q_dtype The data type of Tensor Q.
...@@ -312,6 +348,8 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con ...@@ -312,6 +348,8 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1]. * \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1].
* \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1]. * \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv_padded Cumulative sequence offsets for KV, [batch_size + 1]. * \param[in] cu_seqlens_kv_padded Cumulative sequence offsets for KV, [batch_size + 1].
* \param[in] page_table_k Page table for K cache, [batch_size, max_pages_per_seq_k].
* \param[in] page_table_v Page table for V cache, [batch_size, max_pages_per_seq_v].
* \param[in] rng_state Seed and offset of CUDA random number generator. * \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen_q Max sequence length used for computing for Q. * \param[in] max_seqlen_q Max sequence length used for computing for Q.
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1. * it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
...@@ -329,16 +367,14 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con ...@@ -329,16 +367,14 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] workspace Workspace tensor. * \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, void nvte_fused_attn_fwd_kvpacked(
NVTETensor S, NVTETensor O, NVTETensorPack* Aux_CTX_Tensors, const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream);
int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed KV input. /*! \brief Compute the backward of the dot product attention with packed KV input.
* *
...@@ -445,6 +481,8 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -445,6 +481,8 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1]. * \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1].
* \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1]. * \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv_padded Cumulative sequence offsets for KV, [batch_size + 1]. * \param[in] cu_seqlens_kv_padded Cumulative sequence offsets for KV, [batch_size + 1].
* \param[in] page_table_k Page table for K cache, [batch_size, max_pages_per_seq_k].
* \param[in] page_table_v Page table for V cache, [batch_size, max_pages_per_seq_v].
* \param[in] rng_state Seed and offset of CUDA random number generator. * \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen_q Max sequence length used for computing for Q. * \param[in] max_seqlen_q Max sequence length used for computing for Q.
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1. * it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
...@@ -465,7 +503,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -465,7 +503,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const NVTETensor Bias, NVTETensor S, NVTETensor O, const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
......
...@@ -36,6 +36,14 @@ ...@@ -36,6 +36,14 @@
.value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \ .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \
.value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \ .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \
NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \ NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \
pybind11::enum_<NVTE_QKV_Format>(m, "NVTE_QKV_Format", pybind11::module_local()) \
.value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD) \
.value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD) \
.value("NVTE_THD", NVTE_QKV_Format::NVTE_THD) \
.value("NVTE_SBHD_2BSHD", NVTE_QKV_Format::NVTE_SBHD_2BSHD) \
.value("NVTE_BSHD_2SBHD", NVTE_QKV_Format::NVTE_BSHD_2SBHD) \
.value("NVTE_THD_2BSHD", NVTE_QKV_Format::NVTE_THD_2BSHD) \
.value("NVTE_THD_2SBHD", NVTE_QKV_Format::NVTE_THD_2SBHD); \
pybind11::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout", pybind11::module_local()) \ pybind11::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout", pybind11::module_local()) \
.value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \
.value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \
...@@ -51,7 +59,17 @@ ...@@ -51,7 +59,17 @@
.value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) \ .value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) \
.value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \ .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \
.value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \ .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \
.value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); \ .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD) \
.value("NVTE_SBHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD) \
.value("NVTE_BSHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD) \
.value("NVTE_THD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD) \
.value("NVTE_THD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD) \
.value("NVTE_Paged_KV_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD) \
.value("NVTE_Paged_KV_BSHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD) \
.value("NVTE_Paged_KV_SBHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD) \
.value("NVTE_Paged_KV_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD) \
.value("NVTE_Paged_KV_THD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD) \
.value("NVTE_Paged_KV_THD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD); \
pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) \ pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) \
.value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \
.value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \
......
...@@ -129,6 +129,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -129,6 +129,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
auto o_tensor = TensorWrapper(nullptr, q_shape, dtype); auto o_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64); auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);
auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt32);
NVTETensorPack aux_output_tensors; NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors); nvte_tensor_pack_create(&aux_output_tensors);
...@@ -164,15 +165,16 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -164,15 +165,16 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
nvte_fused_attn_fwd_kvpacked( nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_page_table_tensor.data(),
q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen,
bias_type, mask_type, window_size_left, window_size_right, query_workspace_tensor.data(), kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type,
nullptr); mask_type, window_size_left, window_size_right, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_fwd( nvte_fused_attn_fwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(), q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(),
dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, dropout_probability, qkv_layout, bias_type, mask_type, window_size_left,
window_size_right, query_workspace_tensor.data(), nullptr); window_size_right, query_workspace_tensor.data(), nullptr);
...@@ -256,6 +258,7 @@ static void FusedAttnForwardImpl( ...@@ -256,6 +258,7 @@ static void FusedAttnForwardImpl(
backend, softmax_aux); backend, softmax_aux);
/* Call the underlying NVTE API */ /* Call the underlying NVTE API */
auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt32);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype); auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype);
...@@ -273,9 +276,10 @@ static void FusedAttnForwardImpl( ...@@ -273,9 +276,10 @@ static void FusedAttnForwardImpl(
nvte_fused_attn_fwd_kvpacked( nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), rng_state_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(),
q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, dummy_page_table_tensor.data(), rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen,
bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream); is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
window_size_left, window_size_right, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
...@@ -283,13 +287,13 @@ static void FusedAttnForwardImpl( ...@@ -283,13 +287,13 @@ static void FusedAttnForwardImpl(
auto q_tensor = TensorWrapper(q, q_shape, dtype); auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto k_tensor = TensorWrapper(k, k_shape, dtype); auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v_tensor = TensorWrapper(v, v_shape, dtype); auto v_tensor = TensorWrapper(v, v_shape, dtype);
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), nvte_fused_attn_fwd(
s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(),
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
window_size_left, window_size_right, workspace_tensor.data(), stream); bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream);
} else { } else {
NVTE_ERROR("Unsupported qkv_layout."); NVTE_ERROR("Unsupported qkv_layout.");
} }
......
This diff is collapsed.
...@@ -64,6 +64,16 @@ QKVLayouts = ( ...@@ -64,6 +64,16 @@ QKVLayouts = (
"thd_t2hd", "thd_t2hd",
"thd_th2d", "thd_th2d",
"thd_thd_thd", "thd_thd_thd",
"sbhd_bshd_bshd",
"bshd_sbhd_sbhd",
"thd_bshd_bshd",
"thd_sbhd_sbhd",
"paged_kv_bshd_bshd_bshd",
"paged_kv_bshd_sbhd_sbhd",
"paged_kv_sbhd_bshd_bshd",
"paged_kv_sbhd_sbhd_sbhd",
"paged_kv_thd_bshd_bshd",
"paged_kv_thd_sbhd_sbhd",
) )
LayerTypes = ("encoder", "decoder") LayerTypes = ("encoder", "decoder")
......
...@@ -9,6 +9,7 @@ import torch ...@@ -9,6 +9,7 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import ( from transformer_engine_torch import (
NVTE_QKV_Layout, NVTE_QKV_Layout,
NVTE_QKV_Format,
NVTE_Bias_Type, NVTE_Bias_Type,
NVTE_Mask_Type, NVTE_Mask_Type,
NVTE_Fused_Attn_Backend, NVTE_Fused_Attn_Backend,
...@@ -31,6 +32,16 @@ TORCH_DType = { ...@@ -31,6 +32,16 @@ TORCH_DType = {
tex.DType.kInt32: torch.int32, tex.DType.kInt32: torch.int32,
} }
QKVFormat = {
"bshd": NVTE_QKV_Format.NVTE_BSHD,
"sbhd": NVTE_QKV_Format.NVTE_SBHD,
"thd": NVTE_QKV_Format.NVTE_THD,
"sbhd_2bshd": NVTE_QKV_Format.NVTE_SBHD_2BSHD,
"bshd_2sbhd": NVTE_QKV_Format.NVTE_BSHD_2SBHD,
"thd_2bshd": NVTE_QKV_Format.NVTE_THD_2BSHD,
"thd_2sbhd": NVTE_QKV_Format.NVTE_THD_2SBHD,
}
QKVLayout = { QKVLayout = {
"sb3hd": NVTE_QKV_Layout.NVTE_SB3HD, "sb3hd": NVTE_QKV_Layout.NVTE_SB3HD,
"sbh3d": NVTE_QKV_Layout.NVTE_SBH3D, "sbh3d": NVTE_QKV_Layout.NVTE_SBH3D,
...@@ -47,6 +58,16 @@ QKVLayout = { ...@@ -47,6 +58,16 @@ QKVLayout = {
"thd_t2hd": NVTE_QKV_Layout.NVTE_THD_T2HD, "thd_t2hd": NVTE_QKV_Layout.NVTE_THD_T2HD,
"thd_th2d": NVTE_QKV_Layout.NVTE_THD_TH2D, "thd_th2d": NVTE_QKV_Layout.NVTE_THD_TH2D,
"thd_thd_thd": NVTE_QKV_Layout.NVTE_THD_THD_THD, "thd_thd_thd": NVTE_QKV_Layout.NVTE_THD_THD_THD,
"sbhd_bshd_bshd": NVTE_QKV_Layout.NVTE_SBHD_BSHD_BSHD,
"bshd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_BSHD_SBHD_SBHD,
"thd_bshd_bshd": NVTE_QKV_Layout.NVTE_THD_BSHD_BSHD,
"thd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_THD_SBHD_SBHD,
"paged_kv_bshd_bshd_bshd": NVTE_QKV_Layout.NVTE_Paged_KV_BSHD_BSHD_BSHD,
"paged_kv_bshd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_BSHD_SBHD_SBHD,
"paged_kv_sbhd_bshd_bshd": NVTE_QKV_Layout.NVTE_Paged_KV_SBHD_BSHD_BSHD,
"paged_kv_sbhd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_SBHD_SBHD_SBHD,
"paged_kv_thd_bshd_bshd": NVTE_QKV_Layout.NVTE_Paged_KV_THD_BSHD_BSHD,
"paged_kv_thd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_THD_SBHD_SBHD,
} }
AttnBiasType = { AttnBiasType = {
...@@ -100,6 +121,8 @@ def fused_attn_fwd( ...@@ -100,6 +121,8 @@ def fused_attn_fwd(
attn_bias: torch.Tensor = None, attn_bias: torch.Tensor = None,
cu_seqlens_q_padded: torch.Tensor = None, cu_seqlens_q_padded: torch.Tensor = None,
cu_seqlens_kv_padded: torch.Tensor = None, cu_seqlens_kv_padded: torch.Tensor = None,
page_table_k: torch.Tensor = None,
page_table_v: torch.Tensor = None,
s_quantizer: Quantizer = None, s_quantizer: Quantizer = None,
o_quantizer: Quantizer = None, o_quantizer: Quantizer = None,
attn_scale: float = None, attn_scale: float = None,
...@@ -148,6 +171,10 @@ def fused_attn_fwd( ...@@ -148,6 +171,10 @@ def fused_attn_fwd(
cumulative sequence offsets for Q; shape [batch_size + 1] cumulative sequence offsets for Q; shape [batch_size + 1]
cu_seqlens_kv_padded: torch.Tensor, default = None cu_seqlens_kv_padded: torch.Tensor, default = None
cumulative sequence offsets for KV; shape [batch_size + 1] cumulative sequence offsets for KV; shape [batch_size + 1]
page_table_k: torch.Tensor, default = None
page table for K cache; shape [batch_size, max_pages_per_seq_k]
page_table_v: torch.Tensor, default = None
page table for V cache; shape [batch_size, max_pages_per_seq_v]
s_quantizer: Quantizer, default = None s_quantizer: Quantizer, default = None
Quantizer object for the intermediate value S. Quantizer object for the intermediate value S.
o_quantizer: Quantizer, default = None o_quantizer: Quantizer, default = None
...@@ -268,6 +295,8 @@ def fused_attn_fwd( ...@@ -268,6 +295,8 @@ def fused_attn_fwd(
fake_dtype, fake_dtype,
cu_seqlens_q_padded, cu_seqlens_q_padded,
cu_seqlens_kv_padded, cu_seqlens_kv_padded,
page_table_k,
page_table_v,
s_quantizer, s_quantizer,
o_quantizer, o_quantizer,
attn_bias, attn_bias,
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#include <ATen/cudnn/Handle.h> #include <ATen/cudnn/Handle.h>
#include <ATen/native/DispatchStub.h> #include <ATen/native/DispatchStub.h>
#include <c10/macros/Macros.h> #include <c10/macros/Macros.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e5m2.h>
#include <cublasLt.h> #include <cublasLt.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_bf16.h> #include <cuda_bf16.h>
......
...@@ -51,8 +51,9 @@ std::vector<py::object> fused_attn_fwd( ...@@ -51,8 +51,9 @@ std::vector<py::object> fused_attn_fwd(
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q,
const py::handle K, const py::handle V, const at::ScalarType fake_dtype, const py::handle K, const py::handle V, const at::ScalarType fake_dtype,
const c10::optional<at::Tensor> cu_seqlens_q_padded, const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer, const c10::optional<at::Tensor> cu_seqlens_kv_padded,
py::handle o_quantizer, const c10::optional<at::Tensor> Bias, const c10::optional<at::Tensor> page_table_k, const c10::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread); const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread);
std::vector<py::object> fused_attn_bwd( std::vector<py::object> fused_attn_bwd(
...@@ -69,6 +70,13 @@ std::vector<py::object> fused_attn_bwd( ...@@ -69,6 +70,13 @@ std::vector<py::object> fused_attn_bwd(
at::Tensor fa_prepare_fwd(at::Tensor qkvi); at::Tensor fa_prepare_fwd(at::Tensor qkvi);
at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v);
at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len);
at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t);
void copy_to_kv_cache(torch::Tensor new_k, torch::Tensor new_v, torch::Tensor k_cache,
torch::Tensor v_cache, torch::Tensor page_table, torch::Tensor cu_new_lens,
torch::Tensor cu_cached_lens, NVTE_QKV_Format kv_format, int b,
int max_ctx_len, int max_seq_len, int max_pages_per_seq, bool is_non_paged);
/*************************************************************************************************** /***************************************************************************************************
* GEMM * GEMM
**************************************************************************************************/ **************************************************************************************************/
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "extensions.h" #include "extensions.h"
#include "kv_cache.cuh"
#include "thd_utils.cuh" #include "thd_utils.cuh"
constexpr int block_size = 512; constexpr int block_size = 512;
...@@ -90,8 +91,9 @@ std::vector<py::object> fused_attn_fwd( ...@@ -90,8 +91,9 @@ std::vector<py::object> fused_attn_fwd(
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q,
const py::handle K, const py::handle V, const at::ScalarType fake_dtype, const py::handle K, const py::handle V, const at::ScalarType fake_dtype,
const c10::optional<at::Tensor> cu_seqlens_q_padded, const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer, const c10::optional<at::Tensor> cu_seqlens_kv_padded,
py::handle o_quantizer, const c10::optional<at::Tensor> Bias, const c10::optional<at::Tensor> page_table_k, const c10::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread) { const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread) {
using namespace transformer_engine; using namespace transformer_engine;
using namespace transformer_engine::pytorch; using namespace transformer_engine::pytorch;
...@@ -126,6 +128,7 @@ std::vector<py::object> fused_attn_fwd( ...@@ -126,6 +128,7 @@ std::vector<py::object> fused_attn_fwd(
TensorWrapper te_Bias; TensorWrapper te_Bias;
TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv;
TensorWrapper te_cu_seqlens_q_padded, te_cu_seqlens_kv_padded; TensorWrapper te_cu_seqlens_q_padded, te_cu_seqlens_kv_padded;
TensorWrapper te_page_table_k, te_page_table_v;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8 // FP8
auto h = q_shape[q_shape.size() - 2]; auto h = q_shape[q_shape.size() - 2];
...@@ -170,6 +173,19 @@ std::vector<py::object> fused_attn_fwd( ...@@ -170,6 +173,19 @@ std::vector<py::object> fused_attn_fwd(
cu_seqlens_kv_padded.value().data_ptr(), cu_seqlens_kv_padded_shape, DType::kInt32); cu_seqlens_kv_padded.value().data_ptr(), cu_seqlens_kv_padded_shape, DType::kInt32);
} }
if ((page_table_k.has_value()) && (page_table_v.has_value())) {
auto page_table_k_sizes = page_table_k.value().sizes().vec();
std::vector<size_t> page_table_k_shape{page_table_k_sizes.begin(), page_table_k_sizes.end()};
auto page_table_v_sizes = page_table_v.value().sizes().vec();
std::vector<size_t> page_table_v_shape{page_table_v_sizes.begin(), page_table_v_sizes.end()};
te_page_table_k =
makeTransformerEngineTensor(page_table_k.value().data_ptr(), page_table_k_shape,
DType::kInt32, nullptr, nullptr, nullptr);
te_page_table_v =
makeTransformerEngineTensor(page_table_v.value().data_ptr(), page_table_v_shape,
DType::kInt32, nullptr, nullptr, nullptr);
}
// extract rng seed and offset // extract rng seed and offset
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>( auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); rng_gen, at::cuda::detail::getDefaultCUDAGenerator());
...@@ -187,13 +203,13 @@ std::vector<py::object> fused_attn_fwd( ...@@ -187,13 +203,13 @@ std::vector<py::object> fused_attn_fwd(
TensorWrapper workspace; TensorWrapper workspace;
// populate tensors with appropriate shapes and dtypes // populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), nvte_fused_attn_fwd(
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), te_O.data(),
te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q, te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(),
max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout, bias_type, te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
attn_mask_type, window_size[0], window_size[1], workspace.data(), attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1],
at::cuda::getCurrentCUDAStream()); workspace.data(), at::cuda::getCurrentCUDAStream());
// allocate memory for workspace and auxiliary output tensors // allocate memory for workspace and auxiliary output tensors
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
...@@ -241,13 +257,13 @@ std::vector<py::object> fused_attn_fwd( ...@@ -241,13 +257,13 @@ std::vector<py::object> fused_attn_fwd(
} }
// execute the kernel // execute the kernel
nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), nvte_fused_attn_fwd(
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), te_O.data(),
te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q, te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(),
max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout, bias_type, te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
attn_mask_type, window_size[0], window_size[1], workspace.data(), attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1],
at::cuda::getCurrentCUDAStream()); workspace.data(), at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers, but not allocated memory // destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
...@@ -1012,3 +1028,174 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t ...@@ -1012,3 +1028,174 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t
return output; return output;
} }
/***************************************************************************************************
* KV Cache: Convert a tensor from qkv_format = thd to qkv_format = bshd
**************************************************************************************************/
template <typename scalar_t>
void convert_thd_to_bshd_launcher(at::Tensor tensor, at::Tensor new_tensor, at::Tensor cu_seqlens,
int b, int max_seq_len, int h, int d) {
transformer_engine::fused_attn::
convert_thd_to_bshd_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<scalar_t *>(tensor.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(new_tensor.data_ptr<scalar_t>()), cu_seqlens.data_ptr<int>(),
b, max_seq_len, h, d);
}
at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len) {
int h = tensor.size(1);
int d = tensor.size(2);
std::vector<int64_t> shape = {b, max_seq_len, h, d};
at::Tensor new_tensor = at::zeros(shape, at::CUDA(tensor.scalar_type()));
if (new_tensor.scalar_type() == at::ScalarType::Half) {
using dtype = at::Half;
convert_thd_to_bshd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else if (new_tensor.scalar_type() == at::ScalarType::BFloat16) {
using dtype = at::BFloat16;
convert_thd_to_bshd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else if (new_tensor.scalar_type() == at::ScalarType::Float) {
using dtype = float;
convert_thd_to_bshd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else if (new_tensor.scalar_type() == at::ScalarType::Float8_e4m3fn) {
using dtype = at::Float8_e4m3fn;
convert_thd_to_bshd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else if (new_tensor.scalar_type() == at::ScalarType::Float8_e5m2) {
using dtype = at::Float8_e5m2;
convert_thd_to_bshd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else {
NVTE_ERROR("Unsupported dtype for KV cache.\n");
}
return new_tensor;
}
/***************************************************************************************************
* KV Cache: Convert a tensor from qkv_format = bshd to qkv_format = thd
**************************************************************************************************/
template <typename scalar_t>
void convert_bshd_to_thd_launcher(at::Tensor tensor, at::Tensor new_tensor, at::Tensor cu_seqlens,
int b, int max_seq_len, int h, int d) {
transformer_engine::fused_attn::
convert_bshd_to_thd_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<scalar_t *>(tensor.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(new_tensor.data_ptr<scalar_t>()), cu_seqlens.data_ptr<int>(),
b, max_seq_len, h, d);
}
at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t) {
int b = tensor.size(0);
int max_seq_len = tensor.size(1);
int h = tensor.size(2);
int d = tensor.size(3);
std::vector<int64_t> shape = {t, h, d};
at::Tensor new_tensor = at::zeros(shape, at::CUDA(tensor.scalar_type()));
if (tensor.scalar_type() == at::ScalarType::Half) {
using dtype = at::Half;
convert_bshd_to_thd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else if (tensor.scalar_type() == at::ScalarType::BFloat16) {
using dtype = at::BFloat16;
convert_bshd_to_thd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else if (tensor.scalar_type() == at::ScalarType::Float) {
using dtype = float;
convert_bshd_to_thd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else if (tensor.scalar_type() == at::ScalarType::Float8_e4m3fn) {
using dtype = at::Float8_e4m3fn;
convert_bshd_to_thd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else if (tensor.scalar_type() == at::ScalarType::Float8_e5m2) {
using dtype = at::Float8_e5m2;
convert_bshd_to_thd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else {
NVTE_ERROR("Unsupported dtype for KV cache.\n");
}
return new_tensor;
}
/***************************************************************************************************
* KV Cache: Copy new KV tokens to the KV cache
* 1. new_k and new_v are in qkv_format; k_cache and v_cache are in 'bshd' format
* 2. cu_new_lens and cu_cached_lens are in shape [b + 1]; cu_cached_lens include the added lens
* in current step
* 3. Non-paged KV cache is a special case of paged KV cache, with page_table = [b, 1] and
* max_pages_per_seq = 1. We use the same underlying kernel for both non-paged and paged.
* Set is_non_paged = True/False to indicate as such.
* 4. is_non_paged = True also re-indexes the KV cache, e.g. the initial batch indices [0, 3, 1, 2]
* becomes [0, 1, 1, 2]. The page_table = batch_indices.unsqueeze(1) is however unchanged.
* batch_indices_post can be used for monotonical indexing, i.e. [0, 1, 2, 3]. batch_indices is
* preserved for the next layer in the same iteration.
* 5. Only supports same page_table for k_cache and v_cache
* 6. Only pad_between_seqs = False when qkv_format = thd, i.e. there should be no pad tokens
* between sequences in new_k and new_v such as [a a a 0..0 b b 0..0 c 0..0].
**************************************************************************************************/
template <typename scalar_t>
void copy_to_kv_cache_launcher(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache,
at::Tensor v_cache, at::Tensor page_table, at::Tensor cu_new_lens,
at::Tensor cu_cached_lens, NVTE_QKV_Format qkv_format, int h_kv,
int d_k, int d_v, int b, int max_ctx_len, int max_seq_len,
int max_pages_per_seq, bool is_non_paged) {
if (new_k.data_ptr() != nullptr && new_v.data_ptr() != nullptr && k_cache.data_ptr() != nullptr &&
v_cache.data_ptr() != nullptr) {
if (is_non_paged) {
transformer_engine::fused_attn::
reindex_kv_cache_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<scalar_t *>(k_cache.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(v_cache.data_ptr<scalar_t>()),
page_table.data_ptr<int>(), cu_new_lens.data_ptr<int>(),
cu_cached_lens.data_ptr<int>(), h_kv, d_k, d_v, b, max_seq_len);
}
transformer_engine::fused_attn::
copy_to_kv_cache_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<scalar_t *>(new_k.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(new_v.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(k_cache.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(v_cache.data_ptr<scalar_t>()), page_table.data_ptr<int>(),
cu_new_lens.data_ptr<int>(), cu_cached_lens.data_ptr<int>(), qkv_format, h_kv, d_k, d_v,
b, max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged);
}
}
void copy_to_kv_cache(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache, at::Tensor v_cache,
at::Tensor page_table, at::Tensor cu_new_lens, at::Tensor cu_cached_lens,
NVTE_QKV_Format qkv_format, int b, int max_ctx_len, int max_seq_len,
int max_pages_per_seq, bool is_non_paged) {
int h_kv = new_k.size(-2);
int d_k = new_k.size(-1);
int d_v = new_v.size(-1);
NVTE_CHECK(k_cache.scalar_type() == v_cache.scalar_type() &&
new_k.scalar_type() == new_v.scalar_type() &&
new_k.scalar_type() == k_cache.scalar_type(),
"new_k, new_v, k_cache and v_cache must be of the same data type.");
NVTE_CHECK(qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD ||
qkv_format == NVTE_QKV_Format::NVTE_THD,
"qkv_format must be {BSHD, SBHD, THD}.");
if (k_cache.scalar_type() == at::ScalarType::Half) {
using dtype = at::Half;
copy_to_kv_cache_launcher<dtype>(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens,
cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len,
max_seq_len, max_pages_per_seq, is_non_paged);
} else if (k_cache.scalar_type() == at::ScalarType::BFloat16) {
using dtype = at::BFloat16;
copy_to_kv_cache_launcher<dtype>(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens,
cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len,
max_seq_len, max_pages_per_seq, is_non_paged);
} else if (k_cache.scalar_type() == at::ScalarType::Float) {
using dtype = float;
copy_to_kv_cache_launcher<dtype>(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens,
cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len,
max_seq_len, max_pages_per_seq, is_non_paged);
} else if (k_cache.scalar_type() == at::ScalarType::Float8_e4m3fn) {
using dtype = at::Float8_e4m3fn;
copy_to_kv_cache_launcher<dtype>(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens,
cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len,
max_seq_len, max_pages_per_seq, is_non_paged);
} else if (k_cache.scalar_type() == at::ScalarType::Float8_e5m2) {
using dtype = at::Float8_e5m2;
copy_to_kv_cache_launcher<dtype>(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens,
cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len,
max_seq_len, max_pages_per_seq, is_non_paged);
} else {
NVTE_ERROR("Unsupported dtype for KV cache.\n");
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment