Unverified Commit 4f33ece4 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

Add KV cache for paged/non-paged attention (#1355)



* add paged attention; test_kv_cache_accuray and test_paged_attn pass
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove unnecessary change from last commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* test_fused_attn pass
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remove unnecessary import in test_numerics
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add license for test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add to L0 test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update license for test_paged_attn
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update kv_cache_manager license
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix build issue from previous merge
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: minor fix/preparation for inference/cuda graph
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: non-paged
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: non-paged, bshd/sbhd
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: non-paged, thd, no CG
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: non-paged, thd, CG
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: non-paged, CG
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: non-paged, using paged kernel
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: restructure kernels
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: paged, CG
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: padding + BRCM
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: restructure IP, clean up
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: fix non-CG, fused
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: fix last commit
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: unfused, non-CG
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: flash-attn, non-CG
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: flash_attn_with_kvcache
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* commit two files missed by bcef6b34
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: thd_bshd_bshd
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: fix last commit
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: fix 1c31b68d
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: add bshd_2sbhd, sbhd_2bshd
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: some cleanup
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: all qkv_format combinations and merge CM files
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: some lint fixes
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: add docstring for IP
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix sequences_pre
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: minor fixes for multi-layer
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: initial multi-layer test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: minor clean up
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: clean up
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: switch to flash_attn_varlen_func
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: fix unfused for separate q/kv format
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: fix fused for separate q/kv formats
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: flash attn + TELayer + 2 layers
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: unfused + TL + 2layers
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: all modules/backend
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: minor cleanup
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: FlashAttention on Hopper with 2.7.3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: FlashAttention + v3 from 39e7179
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: FlashAttention + v3 + FP8 + WIP
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: add backend support table
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: clean up
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: separate use_flash_attention_2 and _3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: tweaks to paged attn script
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: enable/disable certain cases for fused attn
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: small fixes for lint and cg
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: minor fixes for attn/infer
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: fix CP
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: readd page info to FADescriptor_v1
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor tweak to test_numerics.py
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix 9.5/9.7 sq/skv + mask logic
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* clean up
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fix for FA3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* more minor fixes for FA3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* test page_size=1 for FA3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix t3hd/th3d strides
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix ckpt recompute and fa3 k_scale
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* raise dynamo recompile limit for test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove thunder test from L0
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix FA selection logic
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix FA3 q_descale shape
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remove page_table from IP.step() returns
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix FP8 FlashAttn DPA fp8_dpa tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix CP
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor tweaks
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update FA3 note and L3 test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove redundant import in test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* adopt new FA3 APIs from FA2.7.3+/hopper for CP and non-CP
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* relax tols for TransformerLayers
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix merge
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix merge 2
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix FA import comments
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* relax tols for Ampere
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fa3 version and reduce messaging
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update FA3 to its latest commit on main
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add default values to IP and assertion to graph.py
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add more comments in attention
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* use custom_cache_manager instead of cache_manager
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 05f6a691
...@@ -38,7 +38,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || test_fail ...@@ -38,7 +38,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || test_fail
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py" NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py"
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_paged_attn.py || test_fail "test_paged_attn.py"
if [ "$RET" -ne 0 ]; then if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES" echo "Error in the following test cases:$FAILED_CASES"
exit 1 exit 1
......
...@@ -17,7 +17,7 @@ if [ $sm_arch -gt 90 ] ...@@ -17,7 +17,7 @@ if [ $sm_arch -gt 90 ]
then then
FA_versions=(2.7.3) FA_versions=(2.7.3)
else else
FA_versions=(2.1.1 2.3.0 2.4.1 2.5.7 2.7.3 3.0.0b1) FA_versions=(2.3.0 2.4.1 2.5.7 2.7.3 3.0.0b1)
fi fi
for fa_version in "${FA_versions[@]}" for fa_version in "${FA_versions[@]}"
...@@ -28,10 +28,12 @@ do ...@@ -28,10 +28,12 @@ do
then then
pip3 install flash-attn==${fa_version} pip3 install flash-attn==${fa_version}
else else
pip3 install "git+https://github.com/Dao-AILab/flash-attention.git@v2.7.2#egg=flashattn-hopper&subdirectory=hopper" git clone https://github.com/Dao-AILab/flash-attention.git
python_path=`python3 -c "import site; print(site.getsitepackages()[0])"` cd flash-attention/ && git checkout 27f501d && cd hopper/ && python setup.py install
mkdir -p $python_path/flashattn_hopper python_path=`python -c "import site; print(site.getsitepackages()[0])"`
wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/v2.7.2/hopper/flash_attn_interface.py mkdir -p $python_path/flash_attn_3
wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/27f501dbe011f4371bff938fe7e09311ab3002fa/hopper/flash_attn_interface.py
cd ../../
fi fi
# Run tests # Run tests
......
...@@ -26,6 +26,7 @@ from transformer_engine.pytorch.dot_product_attention.utils import ( ...@@ -26,6 +26,7 @@ from transformer_engine.pytorch.dot_product_attention.utils import (
check_set_window_size, check_set_window_size,
AttentionParams, AttentionParams,
) )
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.dot_product_attention.rope import RotaryPositionEmbedding from transformer_engine.pytorch.dot_product_attention.rope import RotaryPositionEmbedding
from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.constants import TE_DType
import transformer_engine.pytorch.cpp_extensions as ext import transformer_engine.pytorch.cpp_extensions as ext
...@@ -96,6 +97,8 @@ class ModelConfig: ...@@ -96,6 +97,8 @@ class ModelConfig:
num_layers: int = 1, num_layers: int = 1,
bias_shape: str = "1hss", bias_shape: str = "1hss",
window_size: Tuple[int, int] = (-1, -1), window_size: Tuple[int, int] = (-1, -1),
total_requests: int = None,
max_ctx_len: int = None,
): ):
self.batch_size = batch_size self.batch_size = batch_size
self.num_heads = num_heads self.num_heads = num_heads
...@@ -114,6 +117,8 @@ class ModelConfig: ...@@ -114,6 +117,8 @@ class ModelConfig:
self.num_layers = num_layers self.num_layers = num_layers
self.bias_shape = bias_shape self.bias_shape = bias_shape
self.window_size = window_size self.window_size = window_size
self.total_requests = total_requests
self.max_ctx_len = max_ctx_len
@contextmanager @contextmanager
...@@ -136,6 +141,8 @@ def _get_attention_backends( ...@@ -136,6 +141,8 @@ def _get_attention_backends(
deterministic: bool = False, deterministic: bool = False,
fp8: bool = False, fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None, fp8_meta: Optional[Dict[str, Any]] = None,
is_training: bool = True,
inference_params: Optional[InferenceParams] = None,
) -> Tuple[List, List]: ) -> Tuple[List, List]:
"""Check if what attention backends support a model configuration""" """Check if what attention backends support a model configuration"""
...@@ -165,6 +172,7 @@ def _get_attention_backends( ...@@ -165,6 +172,7 @@ def _get_attention_backends(
fused_attn_backends = [] fused_attn_backends = []
available_backends = None available_backends = None
flash_attention_backend = None
fused_attention_backend = None fused_attention_backend = None
def test(): def test():
...@@ -190,10 +198,13 @@ def _get_attention_backends( ...@@ -190,10 +198,13 @@ def _get_attention_backends(
deterministic=deterministic, deterministic=deterministic,
fp8=fp8, fp8=fp8,
fp8_meta=fp8_meta, fp8_meta=fp8_meta,
is_training=is_training,
inference_params=inference_params,
) )
( (
use_flash_attention, use_flash_attention,
use_fused_attention, use_fused_attention,
flash_attention_backend,
fused_attention_backend, fused_attention_backend,
use_unfused_attention, use_unfused_attention,
available_backends, available_backends,
...@@ -202,20 +213,21 @@ def _get_attention_backends( ...@@ -202,20 +213,21 @@ def _get_attention_backends(
# from get_attention_backend() # from get_attention_backend()
_attention_backends["use_flash_attention"] = use_flash_attention _attention_backends["use_flash_attention"] = use_flash_attention
_attention_backends["use_fused_attention"] = use_fused_attention _attention_backends["use_fused_attention"] = use_fused_attention
_attention_backends["flash_attention_backend"] = flash_attention_backend
_attention_backends["fused_attention_backend"] = fused_attention_backend _attention_backends["fused_attention_backend"] = fused_attention_backend
_attention_backends["use_unfused_attention"] = use_unfused_attention _attention_backends["use_unfused_attention"] = use_unfused_attention
_attention_backends["backend_selection_requires_update"] = False _attention_backends["backend_selection_requires_update"] = False
return available_backends, fused_attention_backend return available_backends, flash_attention_backend, fused_attention_backend
backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"} backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"}
with logging_context(): with logging_context():
for i in range(3): for i in range(3):
os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i) os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
available_backends, fused_attention_backend = test() available_backends, flash_attention_backend, fused_attention_backend = test()
if fused_attention_backend == FusedAttnBackend[backends[i]]: if fused_attention_backend == FusedAttnBackend[backends[i]]:
fused_attn_backends.append(fused_attention_backend) fused_attn_backends.append(fused_attention_backend)
return available_backends, fused_attn_backends return available_backends, flash_attention_backend, fused_attn_backends
model_configs_base = { model_configs_base = {
...@@ -267,7 +279,7 @@ def test_dot_product_attention( ...@@ -267,7 +279,7 @@ def test_dot_product_attention(
if config.window_size == (-1, -1) and swa: if config.window_size == (-1, -1) and swa:
config.window_size = [2, 2] config.window_size = [2, 2]
config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
available_backends, fused_attn_backends = _get_attention_backends( available_backends, _, fused_attn_backends = _get_attention_backends(
config, config,
qkv_dtype=dtype, qkv_dtype=dtype,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
...@@ -1131,7 +1143,7 @@ def test_transformer_layer( ...@@ -1131,7 +1143,7 @@ def test_transformer_layer(
workspace_opt = True workspace_opt = True
# Test backend availability # Test backend availability
available_backends, fused_attn_backends = _get_attention_backends( available_backends, _, fused_attn_backends = _get_attention_backends(
config, config,
qkv_dtype=dtype, qkv_dtype=dtype,
qkv_layout="sbh3d" if fused_qkv_params else "sb3hd", qkv_layout="sbh3d" if fused_qkv_params else "sb3hd",
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from collections import OrderedDict
from typing import List
import os
import logging
import math
import pytest
import torch
from torch.distributions import Exponential
from transformer_engine.pytorch import make_graphed_callables
from transformer_engine.common import recipe
from transformer_engine.pytorch import fp8_autocast, fp8_model_init
from transformer_engine.pytorch.transformer import (
TransformerLayer,
)
from transformer_engine.pytorch.attention import DotProductAttention
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils as fa_utils
from transformer_engine.pytorch.utils import (
get_device_compute_capability,
init_method_normal,
scaled_init_method_normal,
is_bf16_compatible,
)
from test_fused_attn import (
ModelConfig,
reset_rng_states,
_get_attention_backends,
)
# Initialize RNG state
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
param_types = [torch.float16]
if is_bf16_compatible():
param_types.append(torch.bfloat16)
model_configs_infer = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"infer_0": ModelConfig(
4, 16, 16, 128, 64, 64, 0.0, "no_mask", "no_bias", total_requests=8, max_ctx_len=16
),
"infer_1": ModelConfig(
2, 16, 4, 64, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6, max_ctx_len=16
),
}
qkv_formats = ["bshd", "sbhd", "thd"]
def to_pretty_string(x: torch.Tensor):
return "[" + ",".join(["{:>3s}".format(str(i)) for i in x.tolist()]) + "]"
def round_up(a: int, b: int):
return b * math.ceil(a / b)
class Simulation:
def __init__(
self,
total_requests: int = 10,
max_seq_len: int = 1024,
max_ctx_len: int = 128,
max_batch_size: int = 5,
poisson_rate: float = 1,
):
self.total_requests = total_requests
self.max_seq_len = max_seq_len
self.max_batch_size = max_batch_size
self.poisson_rate = poisson_rate
# calculate maximum context/generation length
self.max_ctx_len = max_ctx_len
self.max_gen_len = max_seq_len - self.max_ctx_len
# simulate sequence ids in monotonically increasing fashion
self.seq_ids = torch.range(0, total_requests - 1, dtype=torch.int32, device="cpu")
# simulate context lengths in Uniform distribution
self.context_lens = torch.randint(
1, self.max_ctx_len, [total_requests], dtype=torch.int32, device="cpu"
)
# simulate gen lengths in Exponential distribution
gen_dist = Exponential(1 / self.max_gen_len)
gen_lens = gen_dist.sample((total_requests,))
gen_lens = torch.where(gen_lens > self.max_gen_len, self.max_gen_len, gen_lens).to(
dtype=torch.int32, device="cpu"
)
self.gen_lens = torch.where(gen_lens == 0, 1, gen_lens).to(dtype=torch.int32, device="cpu")
# simulate arrival times in Poisson distribution
if poisson_rate is None:
self.poisson_rate = torch.randint(1, max_batch_size, [1]).item()
interval_dist = Exponential(self.poisson_rate)
arrival_intervals = interval_dist.sample((total_requests,))
self.arrival_times = torch.cumsum(arrival_intervals, dim=0).to(
dtype=torch.int32, device="cpu"
)
self.last_arrival = self.arrival_times.max().item()
# initialize tensors
self.reset()
def reset(self):
self.t = 0
self.request_delays = torch.zeros([self.total_requests], dtype=torch.int32, device="cpu")
self.delayed_seq_ids = torch.Tensor().to(dtype=torch.int32, device="cpu")
self.serving_times = self.arrival_times
self.complete_times = self.arrival_times
# batch info at step t
self.t_seq_ids = torch.Tensor([]).to(dtype=torch.bool, device="cpu")
self.t_ctx_lens = torch.Tensor([]).to(dtype=torch.bool, device="cpu")
self.t_gen_lens = torch.Tensor([]).to(dtype=torch.bool, device="cpu")
self.t_total_lens = self.t_ctx_lens + self.t_gen_lens
self.t_batch_size = 0
# step info from step t-1 to t
self.step_lens = torch.Tensor([]).to(dtype=torch.int32, device="cpu")
def print_setup(self, logger):
logger.info("Simulation:")
logger.info(" {:<31s}: {}".format("total number of requests", self.total_requests))
logger.info(" {:<31s}: {}".format("max sequence length per request", self.max_seq_len))
logger.info(" {:<31s}: {}".format("max context length", self.max_ctx_len))
logger.info(" {:<31s}: {}".format("max generation length", self.max_gen_len))
logger.info(" {:<31s}: {}".format("max batch size per iteration", self.max_batch_size))
logger.info(" {:<31s}: {}".format("Poisson rate", self.poisson_rate))
logger.info(" {:<17s}: {}".format("sequence ids", to_pretty_string(self.seq_ids)))
logger.info(" {:<17s}: {}".format("arrival times", to_pretty_string(self.arrival_times)))
logger.info(" {:<17s}: {}".format("context lengths", to_pretty_string(self.context_lens)))
logger.info(" {:<17s}: {}".format("generation lengths", to_pretty_string(self.gen_lens)))
def print_step(self, logger):
logger.info(f"Step t = {self.t}:")
logger.info(" {:<15s}: {}".format("t_batch_size", self.t_batch_size))
logger.info(" {:<15s}: {}".format("t_seq_ids", self.t_seq_ids.tolist()))
logger.info(" {:<15s}: {}".format("t_ctx_lens", self.t_ctx_lens.tolist()))
logger.info(" {:<15s}: {}".format("t_gen_lens", self.t_gen_lens.tolist()))
logger.info(" {:<15s}: {}".format("t_total_lens", self.t_total_lens.tolist()))
logger.info(" {:<15s}: {}".format("step_lens", self.step_lens.tolist()))
def print_summary(self, logger):
logger.info("Summary:")
logger.info(" {:<18s}: {}".format("total steps taken", self.t))
logger.info(" {:<18s}: {}".format("arrival_times", to_pretty_string(self.arrival_times)))
logger.info(" {:<18s}: {}".format("serving_times", to_pretty_string(self.serving_times)))
logger.info(" {:<18s}: {}".format("total_gen_lens", to_pretty_string(self.gen_lens)))
logger.info(" {:<18s}: {}".format("complete_times", to_pretty_string(self.complete_times)))
def add_new_seqs(self, new_seq_ids):
# get ctx_lens for new seqs
self.t_seq_ids = torch.cat([self.t_seq_ids, new_seq_ids], dim=0)
self.t_ctx_lens = torch.cat([self.t_ctx_lens, self.context_lens[new_seq_ids]], dim=0)
gen_lens = torch.Tensor([0] * len(new_seq_ids)).to(dtype=torch.int32, device="cpu")
self.t_gen_lens = torch.cat([self.t_gen_lens, gen_lens], dim=0)
# append new seqs' ctx_lens to step_lens
self.step_lens = torch.cat([self.step_lens, self.context_lens[new_seq_ids]], dim=0)
def remove_finished(self):
# figure out which seqs have finished
finished = torch.where(self.t_gen_lens - self.gen_lens[self.t_seq_ids] < 0, False, True).to(
dtype=torch.bool, device="cpu"
)
self.t_seq_ids = self.t_seq_ids[~finished]
self.t_ctx_lens = self.t_ctx_lens[~finished]
self.t_gen_lens = self.t_gen_lens[~finished]
# add ones for unfinished seqs to step_lens
self.step_lens = torch.ones([len(self.t_seq_ids)], dtype=torch.int32, device="cpu")
def step(self, dynamic_fill: bool = True):
# remove finished seqs
if self.t != 0:
self.remove_finished()
# get allowed new seqs
arrived_seq_ids = torch.where(self.arrival_times == self.t, True, False).nonzero().view(-1)
queuing_seq_ids = torch.cat([self.delayed_seq_ids, arrived_seq_ids], dim=0)
if dynamic_fill:
allowed_num_new_seqs = self.max_batch_size - len(self.t_seq_ids)
else:
allowed_num_new_seqs = 0 if len(self.t_seq_ids) else self.max_batch_size
if len(queuing_seq_ids) > allowed_num_new_seqs:
new_seq_ids = queuing_seq_ids[:allowed_num_new_seqs]
self.delayed_seq_ids = queuing_seq_ids[allowed_num_new_seqs:]
self.request_delays[self.delayed_seq_ids.tolist()] += 1
else:
new_seq_ids = queuing_seq_ids
self.delayed_seq_ids = torch.Tensor().to(dtype=torch.int32)
# add new seqs to batch
self.add_new_seqs(new_seq_ids)
# update batch variables
self.t_batch_size = len(self.t_seq_ids)
self.t_total_lens = self.t_ctx_lens + self.t_gen_lens
def get_model(
module: torch.nn.Module,
config: ModelConfig,
dtype: torch.dtype,
backend: str = "FusedAttention",
qkv_format: str = "bshd",
num_layers: int = 1,
mode: str = "reference",
is_fp8: bool = False,
):
reset_rng_states()
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, num_layers)
if mode == "reference":
attn_mask_type = "causal"
qkv_format = "bshd"
if mode == "inference":
attn_mask_type = "padding_causal" if backend != "FusedAttention" else "padding"
fp8_recipe = recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
fp8_dpa=is_fp8,
fp8_mha=False,
)
if module == "TransformerLayer":
hidden_size = config.head_dim_qk * config.num_heads
with fp8_model_init(enabled=is_fp8, recipe=fp8_recipe):
model = [
TransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=4 * hidden_size,
num_attention_heads=config.num_heads,
num_gqa_groups=config.num_gqa_groups,
hidden_dropout=0.0,
attention_dropout=config.dropout_p,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
layer_number=layer_number,
kv_channels=config.head_dim_qk,
self_attn_mask_type=attn_mask_type,
fuse_qkv_params=False,
params_dtype=dtype,
attn_input_format=qkv_format,
)
.cuda()
.eval()
for layer_number in range(1, num_layers + 1)
]
if module == "DotProductAttention":
with fp8_model_init(enabled=is_fp8, recipe=fp8_recipe):
model = [
DotProductAttention(
kv_channels=config.head_dim_qk,
num_attention_heads=config.num_heads,
num_gqa_groups=config.num_gqa_groups,
layer_number=layer_number,
attention_dropout=config.dropout_p,
qkv_format=qkv_format,
attn_mask_type=attn_mask_type,
)
.cuda()
.eval()
for layer_number in range(1, num_layers + 1)
]
return model
def generate_args(
module: torch.nn.Module,
config: ModelConfig,
dtype: torch.dtype,
qkv_format: str = "bshd",
mode: str = "full_inputs",
):
# full inputs used as reference
if mode == "full_inputs":
warmup = False
shapes = []
if module == "TransformerLayer":
shapes.append(
[config.total_requests, config.max_seqlen_kv, config.num_heads * config.head_dim_qk]
)
if module == "DotProductAttention":
shapes.append(
[config.total_requests, config.max_seqlen_kv, config.num_heads, config.head_dim_qk]
)
shapes.append(
[
config.total_requests,
config.max_seqlen_kv,
config.num_gqa_groups,
config.head_dim_qk,
]
)
shapes.append(
[
config.total_requests,
config.max_seqlen_kv,
config.num_gqa_groups,
config.head_dim_v,
]
)
# sample args used for cuda graph warmup
elif mode == "sample_args":
warmup = True
shapes = []
if qkv_format == "bshd":
shape = [config.batch_size, config.max_ctx_len]
if qkv_format == "sbhd":
shape = [config.max_ctx_len, config.batch_size]
if qkv_format == "thd":
shape = [config.batch_size * config.max_ctx_len]
if module == "TransformerLayer":
shapes.append([*shape, config.num_heads * config.head_dim_qk])
if module == "DotProductAttention":
shapes.append([*shape, config.num_heads, config.head_dim_qk])
shapes.append([*shape, config.num_gqa_groups, config.head_dim_qk])
shapes.append([*shape, config.num_gqa_groups, config.head_dim_v])
num_tensors = len(shapes)
if warmup:
return [
torch.ones(
*shapes[i],
device="cuda",
dtype=dtype,
)
for i in range(num_tensors)
]
elif module == "TransformerLayer":
return [
0.01
* torch.randint(
-100,
100,
shapes[i],
device="cuda",
dtype=dtype,
)
for i in range(num_tensors)
]
elif module == "DotProductAttention":
return [
0.1
* torch.randn(
*shapes[i],
device="cuda",
dtype=dtype,
)
for i in range(num_tensors)
]
def get_tols(module, backend, dtype):
if module == "TransformerLayer":
tols = {
torch.half: (5e-3, 5e-3),
torch.bfloat16: (3.5e-2, 3.5e-2),
}
if module == "DotProductAttention":
tols = {
torch.half: (1e-3, 1e-3),
torch.bfloat16: (1e-2, 1e-3),
torch.float8_e4m3fn: (2e-2, 3e-2),
}
return tols[dtype]
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model", model_configs_infer.keys())
@pytest.mark.parametrize("qkv_format", qkv_formats)
@pytest.mark.parametrize("is_paged", [False, True])
@pytest.mark.parametrize("backend", ["FusedAttention", "FlashAttention", "UnfusedAttention"])
@pytest.mark.parametrize("module", ["TransformerLayer", "DotProductAttention"])
@pytest.mark.parametrize("is_cuda_graph", [False, True])
@pytest.mark.parametrize("is_fp8", [False, True])
def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda_graph, is_fp8):
reset_rng_states()
logger = logging.getLogger("test_paged_attn")
fp8_recipe = recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
fp8_dpa=is_fp8,
fp8_mha=False,
)
fp8_meta = {}
fp8_meta["recipe"] = fp8_recipe
config = model_configs_infer[model]
num_layers = 2 if module == "TransformerLayer" and backend != "FusedAttention" else 1
# flash-attn v2 requires page_size >= 256
if backend == "FlashAttention" and not fa_utils.v3_is_installed:
config_max_seqlen_q = config.max_seqlen_q
config_max_seqlen_kv = config.max_seqlen_kv
config.max_seqlen_q = 256
config.max_seqlen_kv = 256
# create a real-life simulation
max_batch_size = config.batch_size
page_size = None
total_num_pages = None
if is_paged:
page_size = 256 if backend == "FlashAttention" and not fa_utils.v3_is_installed else 1
config.max_seqlen_kv = round_up(config.max_seqlen_kv, page_size)
total_num_pages = int(max_batch_size * config.max_seqlen_kv / page_size)
else:
config.max_seqlen_kv = round_up(config.max_seqlen_kv, 64)
sim = Simulation(
total_requests=config.total_requests,
max_seq_len=config.max_seqlen_kv,
max_ctx_len=config.max_ctx_len,
max_batch_size=max_batch_size,
poisson_rate=2,
)
sim.print_setup(logger)
# initialize inference_params
inference_params = InferenceParams(
max_batch_size=max_batch_size,
max_seqlen_kv=config.max_seqlen_kv,
num_heads_kv=config.num_gqa_groups,
head_dim_k=config.head_dim_qk,
head_dim_v=config.head_dim_v,
dtype=dtype,
is_paged=is_paged,
page_size=page_size,
total_num_pages=total_num_pages,
max_ctx_len=config.max_ctx_len,
qkv_format=qkv_format,
)
if module == "DotProductAttention":
for layer_number in range(1, num_layers + 1):
inference_params.allocate_memory(layer_number)
# figure out supported backends
inference_params_qkv_format = "bshd"
qkv_layout = qkv_format + "_" + "_".join([inference_params_qkv_format] * 2)
if is_paged:
qkv_layout = "paged_kv_" + qkv_layout
available_backends, _, fused_attn_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=False,
is_training=False,
fp8=is_fp8,
fp8_meta=fp8_meta,
inference_params=inference_params,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if backend == "FlashAttention" and not flash_attn_supported:
pytest.skip("FlashAttention backend is not supported")
if backend == "FusedAttention" and not fused_attn_supported:
pytest.skip("FusedAttention backend is not supported")
if backend == "UnfusedAttention" and not unfused_attn_supported:
pytest.skip("UnfusedAttention backend is not supported")
os.environ["NVTE_FLASH_ATTN"] = str(int(backend == "FlashAttention"))
os.environ["NVTE_FUSED_ATTN"] = str(int(backend == "FusedAttention"))
os.environ["NVTE_UNFUSED_ATTN"] = str(int(backend == "UnfusedAttention"))
if backend == "UnfusedAttention" and is_cuda_graph:
pytest.skip("CUDA graph is not supported for UnfusedAttention backend")
# TransformerLayer FP8 TN Gemm currently requires %8=0
if is_fp8 and not (qkv_format == "thd" and module == "DotProductAttention"):
pytest.skip("BSHD/SBHD <-> THD conversions for FP8 are not supported")
# create full model
logger.info("=== Generating all tokens at once ===")
model = get_model(module, config, dtype, backend, qkv_format, num_layers, mode="reference")
# generate data for all requests
full_inputs = generate_args(module, config, dtype, qkv_format="bshd", mode="full_inputs")
# generate reference results
if module == "DotProductAttention":
full_output = full_inputs
for m in model:
full_output = m(
*full_output if isinstance(full_output, List) else full_output,
)
if module == "TransformerLayer":
full_output = full_inputs
for m in model:
full_output = m(
full_output[0] if isinstance(full_output, List) else full_output,
)
# create inference model
logger.info("=== Generating one token at a time ===")
model = get_model(
module,
config,
dtype,
backend,
qkv_format,
num_layers,
mode="inference",
is_fp8=is_fp8,
)
# graph the model if necessary
if is_cuda_graph:
t_seq_ids = torch.range(0, max_batch_size, dtype=torch.int32, device="cpu")
step_lens = config.max_ctx_len * torch.ones(max_batch_size, dtype=torch.int32, device="cpu")
step_dict = OrderedDict(zip(t_seq_ids.tolist(), step_lens.tolist()))
inference_params.pre_step(step_dict)
sample_args = generate_args(
module, config, dtype, qkv_format=qkv_format, mode="sample_args"
)
sample_kwargs = {}
sample_kwargs["cu_seqlens_q"] = torch.linspace(
0,
config.batch_size * config.max_ctx_len,
steps=config.batch_size + 1,
device="cuda",
dtype=torch.int32,
)
sample_kwargs["cu_seqlens_kv"] = torch.linspace(
0,
config.batch_size * config.max_ctx_len,
steps=config.batch_size + 1,
device="cuda",
dtype=torch.int32,
)
sample_kwargs["inference_params"] = inference_params
sample_kwargs["max_seqlen_q"] = config.max_ctx_len
sample_kwargs["max_seqlen_kv"] = config.max_seqlen_kv
model = [
make_graphed_callables(
model[i],
sample_args,
num_warmup_iters=10,
fp8_enabled=is_fp8,
sample_kwargs=sample_kwargs,
fp8_recipe=fp8_recipe,
)
for i in range(num_layers)
]
sim.reset()
inference_params.reset()
step_dict = OrderedDict()
# simulate step by step
# t-1: ...
# compute for seq_ids = [0, 1, 2], ctx_lens = [5, 2, 3], gen_lens = [2, 9, 4],
# batch_size = 3, step_lens = [1, 1, 1]
# increase counter for gen_lens = [3, 10, 5]
# t: detect seq 1 is finished since expected_gen_lens = [12, 10, 15]
# add two new seqs 3 and 4, with ctx lens 10 and 11
# compute for seq_ids = [0, 2, 3, 4], ctx_lens = [5, 3, 10, 11], gen_lens = [3, 5, 0, 0],
# batch_size = 4, step_lens = [1, 1, 10, 11]
# increase counter for gen_lens = [3, 5, 1, 1]
max_tokens = config.batch_size * config.max_ctx_len
while True:
# prepare batch for the current step
dynamic_fill = True # inference_params.is_paged
sim.step(dynamic_fill=dynamic_fill)
sim.print_step(logger)
if sim.t_batch_size == 0:
# all sequences are finished
if sim.t > sim.last_arrival:
sim.serving_times = sim.arrival_times + sim.request_delays
sim.complete_times = sim.serving_times + sim.gen_lens
break
# not finished; run next iteration
else:
sim.t += 1
continue
# create incremental input
batch_size = max_batch_size if is_cuda_graph else sim.t_batch_size
max_seqlen_q = sim.max_ctx_len if is_cuda_graph else max(sim.step_lens).item()
num_tensors = len(full_inputs)
if qkv_format == "thd":
incremental_inputs = []
for i in range(num_tensors):
inp = full_inputs[i]
inc_inp = torch.Tensor().to(dtype=dtype, device="cuda")
for i, seq in enumerate(sim.t_seq_ids):
start = (sim.t_total_lens[i] - sim.step_lens[i]).item()
end = sim.t_total_lens[i].item()
inc_inp = torch.cat([inc_inp, inp[seq, start:end]], dim=0)
if is_cuda_graph:
inc_inp = torch.cat(
[
inc_inp,
torch.zeros(
max_tokens - sum(sim.step_lens),
*inp.shape[2:],
dtype=dtype,
device=inc_inp.device,
),
],
dim=0,
)
incremental_inputs.append(inc_inp)
else:
incremental_inputs = []
for i in range(num_tensors):
inp = full_inputs[i]
inc_inp = torch.zeros(
batch_size,
max_seqlen_q,
*inp.shape[2:],
dtype=dtype,
device="cuda",
)
for i, seq in enumerate(sim.t_seq_ids):
start = (sim.t_total_lens[i] - sim.step_lens[i]).item()
end = sim.t_total_lens[i].item()
inc_inp[i, : sim.step_lens[i], :] = inp[seq, start:end]
if qkv_format == "sbhd":
inc_inp = inc_inp.transpose(0, 1).contiguous()
incremental_inputs.append(inc_inp)
# run step
batch_size = max_batch_size if is_cuda_graph else sim.t_batch_size
cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_q[1 : sim.t_batch_size + 1] = torch.cumsum(sim.step_lens, dim=0)
cu_seqlens_kv = cu_seqlens_q.clone()
step_dict = OrderedDict(zip(sim.t_seq_ids.tolist(), sim.step_lens.tolist()))
inference_params.pre_step(step_dict)
if inference_params.is_paged:
inference_params.cache_manager.print_cache()
incremental_output = incremental_inputs
with fp8_autocast(enabled=is_fp8, fp8_recipe=fp8_recipe):
for m in model:
incremental_output = m(
*incremental_output,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
inference_params=inference_params,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
)
incremental_output = [incremental_output]
incremental_output = incremental_output[0]
# compare results
atol, rtol = get_tols(module, backend, dtype=dtype if not is_fp8 else torch.float8_e4m3fn)
for i, seq in enumerate(sim.t_seq_ids):
token_index = sim.step_lens[i] - 1
if qkv_format == "bshd":
torch.testing.assert_close(
full_output[seq, sim.t_total_lens[i] - 1, :],
incremental_output[i, sim.step_lens[i] - 1, :],
atol=atol,
rtol=rtol,
)
if qkv_format == "sbhd":
torch.testing.assert_close(
full_output[seq, sim.t_total_lens[i] - 1, :],
incremental_output[sim.step_lens[i] - 1, i, :],
atol=atol,
rtol=rtol,
)
if qkv_format == "thd":
torch.testing.assert_close(
full_output[seq, sim.t_total_lens[i] - 1, :],
incremental_output[cu_seqlens_q[i + 1] - 1, :],
atol=atol,
rtol=rtol,
)
sim.t += 1
sim.t_gen_lens = sim.t_gen_lens + 1
# last value in complete_times should be equal to sim.t
sim.serving_times = sim.arrival_times + sim.request_delays
sim.complete_times = sim.serving_times + sim.gen_lens
sim.print_summary(logger)
if backend == "FlashAttention" and not fa_utils.v3_is_installed:
config.max_seqlen_q = config_max_seqlen_q
config.max_seqlen_kv = config_max_seqlen_kv
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
from collections import OrderedDict
import math import math
import os import os
from typing import Dict, List, Optional from typing import Dict, List, Optional
...@@ -59,6 +60,8 @@ torch.cuda.manual_seed(seed) ...@@ -59,6 +60,8 @@ torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state() _cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state() _cuda_rng_state = torch.cuda.get_rng_state()
torch._dynamo.config.recompile_limit = 16
class ModelConfig: class ModelConfig:
def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq_len): def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq_len):
...@@ -77,9 +80,9 @@ model_configs = { ...@@ -77,9 +80,9 @@ model_configs = {
model_configs_inference = { model_configs_inference = {
# hidden_size, eps, num_attention_heads, embed, num_layers, seq_len # hidden_size, eps, num_attention_heads, embed, num_layers, seq_len
"126m": ModelConfig(768, 1e-5, 12, 64, 12, 16), "126m": ModelConfig(768, 1e-5, 12, 64, 12, 256),
} }
backends_inference = ["FlashAttention", "UnfusedAttention"] backends_inference = ["FlashAttention", "UnfusedAttention", "FusedAttention"]
module_inference = ["TransformerLayer", "MultiheadAttention"] module_inference = ["TransformerLayer", "MultiheadAttention"]
input_formats_inference = ["sbhd", "bshd"] input_formats_inference = ["sbhd", "bshd"]
...@@ -2037,14 +2040,25 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): ...@@ -2037,14 +2040,25 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
@pytest.mark.parametrize("input_format", input_formats_inference) @pytest.mark.parametrize("input_format", input_formats_inference)
@pytest.mark.parametrize("module", module_inference) @pytest.mark.parametrize("module", module_inference)
@pytest.mark.parametrize("backend", backends_inference) @pytest.mark.parametrize("backend", backends_inference)
def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, backend): @pytest.mark.parametrize("is_paged", [False, True])
def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, backend, is_paged):
reset_rng_states()
if backend in ["FusedAttention", "FlashAttention"] and dtype == torch.float32:
pytest.skip("FusedAttention and FlashAttention do not support FP32")
if use_RoPE:
pytest.skip("KV cache does not support starting positions for RoPE")
os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
if backend == "FlashAttention": if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FLASH_ATTN"] = "1"
elif backend == "FusedAttention": elif backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1"
elif backend == "UnfusedAttention":
os.environ["NVTE_UNFUSED_ATTN"] = "1"
config = model_configs_inference[model_key] config = model_configs_inference[model_key]
...@@ -2057,7 +2071,7 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, ...@@ -2057,7 +2071,7 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
# Limits the max size of KV-cache # Limits the max size of KV-cache
B_max = B B_max = B
S_max = S + 2 S_max = S
if module == "TransformerLayer": if module == "TransformerLayer":
model = TransformerLayer( model = TransformerLayer(
...@@ -2087,7 +2101,17 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, ...@@ -2087,7 +2101,17 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
.eval() .eval()
) )
inference_params = InferenceParams(max_batch_size=B_max, max_sequence_length=S_max) inference_params = InferenceParams(
max_batch_size=B_max,
max_seqlen_kv=S_max,
num_heads_kv=H,
head_dim_k=head_size,
dtype=dtype,
is_paged=is_paged,
total_num_pages=int(B_max * S_max / 256),
page_size=256,
)
rotary_freqs = torch.randn((S_max, 1, 1, head_size), dtype=torch.float, device="cuda") rotary_freqs = torch.randn((S_max, 1, 1, head_size), dtype=torch.float, device="cuda")
input = torch.randn((S, B, D), dtype=dtype, device="cuda") input = torch.randn((S, B, D), dtype=dtype, device="cuda")
...@@ -2100,22 +2124,39 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, ...@@ -2100,22 +2124,39 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
full_output = model(hidden_states=input, rotary_pos_emb=rotary_freqs if use_RoPE else None) full_output = model(hidden_states=input, rotary_pos_emb=rotary_freqs if use_RoPE else None)
# Incrementaly generate outputs using KV-cache # Incrementaly generate outputs using KV-cache
step_dict = OrderedDict(zip(list(range(B)), [1] * B))
for i in range(S): for i in range(S):
inference_params.pre_step(step_dict)
if input_format == "sbhd": if input_format == "sbhd":
incremental_input = input[i].view(1, B, D) incremental_input = input[i].view(1, B, D)
else: else:
incremental_input = input[:, i, :].view(B, 1, D) incremental_input = input[:, i, :].view(B, 1, D)
seqlens_q = torch.ones(B, dtype=torch.int32, device="cuda")
cu_seqlens_q = torch.zeros(B + 1, dtype=torch.int32, device="cuda")
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
cu_seqlens_kv = cu_seqlens_q.clone()
mask_type = "padding"
kwargs = {}
if module == "TransformerLayer":
kwargs["self_attn_mask_type"] = mask_type
else:
kwargs["attn_mask_type"] = mask_type
line_output = model( line_output = model(
hidden_states=incremental_input, hidden_states=incremental_input,
inference_params=inference_params, inference_params=inference_params,
rotary_pos_emb=rotary_freqs if use_RoPE else None, rotary_pos_emb=rotary_freqs if use_RoPE else None,
**kwargs,
max_seqlen_q=1,
max_seqlen_kv=S,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
) )
inference_params.sequence_len_offset += 1
if input_format == "sbhd": if input_format == "sbhd":
incremental_output[i] = line_output.view(B, D) incremental_output[i, :, :] = line_output.view(B, D)
else: else:
incremental_output[:, i, :] = line_output.view(B, D) incremental_output[:, i, :] = line_output.view(B, D)
......
...@@ -37,7 +37,18 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) { ...@@ -37,7 +37,18 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) {
case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD: case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_THD_THD_THD: case NVTE_QKV_Layout::NVTE_THD_THD_THD:
case NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD:
return NVTE_QKV_Layout_Group::NVTE_HD_HD_HD; return NVTE_QKV_Layout_Group::NVTE_HD_HD_HD;
case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD:
return NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD;
default: default:
NVTE_ERROR("qkv_layout not supported!"); NVTE_ERROR("qkv_layout not supported!");
} }
...@@ -51,12 +62,14 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) { ...@@ -51,12 +62,14 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) {
case NVTE_QKV_Layout::NVTE_SBHD_SB2HD: case NVTE_QKV_Layout::NVTE_SBHD_SB2HD:
case NVTE_QKV_Layout::NVTE_SBHD_SBH2D: case NVTE_QKV_Layout::NVTE_SBHD_SBH2D:
case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD:
return NVTE_QKV_Format::NVTE_SBHD; return NVTE_QKV_Format::NVTE_SBHD;
case NVTE_QKV_Layout::NVTE_BS3HD: case NVTE_QKV_Layout::NVTE_BS3HD:
case NVTE_QKV_Layout::NVTE_BSH3D: case NVTE_QKV_Layout::NVTE_BSH3D:
case NVTE_QKV_Layout::NVTE_BSHD_BS2HD: case NVTE_QKV_Layout::NVTE_BSHD_BS2HD:
case NVTE_QKV_Layout::NVTE_BSHD_BSH2D: case NVTE_QKV_Layout::NVTE_BSHD_BSH2D:
case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD: case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD:
return NVTE_QKV_Format::NVTE_BSHD; return NVTE_QKV_Format::NVTE_BSHD;
case NVTE_QKV_Layout::NVTE_T3HD: case NVTE_QKV_Layout::NVTE_T3HD:
case NVTE_QKV_Layout::NVTE_TH3D: case NVTE_QKV_Layout::NVTE_TH3D:
...@@ -64,6 +77,56 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) { ...@@ -64,6 +77,56 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) {
case NVTE_QKV_Layout::NVTE_THD_TH2D: case NVTE_QKV_Layout::NVTE_THD_TH2D:
case NVTE_QKV_Layout::NVTE_THD_THD_THD: case NVTE_QKV_Layout::NVTE_THD_THD_THD:
return NVTE_QKV_Format::NVTE_THD; return NVTE_QKV_Format::NVTE_THD;
case NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD:
return NVTE_QKV_Format::NVTE_SBHD_2BSHD;
case NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD:
return NVTE_QKV_Format::NVTE_BSHD_2SBHD;
case NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD:
return NVTE_QKV_Format::NVTE_THD_2BSHD;
case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD:
return NVTE_QKV_Format::NVTE_THD_2SBHD;
default:
NVTE_ERROR("qkv_layout not supported!");
}
}
// map NVTE_QKV_Layout to NVTE_QKV_Format for Q
NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout) {
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
switch (qkv_format) {
case NVTE_QKV_Format::NVTE_SBHD:
case NVTE_QKV_Format::NVTE_SBHD_2BSHD:
return NVTE_QKV_Format::NVTE_SBHD;
case NVTE_QKV_Format::NVTE_BSHD:
case NVTE_QKV_Format::NVTE_BSHD_2SBHD:
return NVTE_QKV_Format::NVTE_BSHD;
case NVTE_QKV_Format::NVTE_THD:
case NVTE_QKV_Format::NVTE_THD_2BSHD:
case NVTE_QKV_Format::NVTE_THD_2SBHD:
return NVTE_QKV_Format::NVTE_THD;
default:
NVTE_ERROR("qkv_layout not supported!");
}
}
// map NVTE_QKV_Layout to NVTE_QKV_Format for KV
NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) {
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
switch (qkv_format) {
case NVTE_QKV_Format::NVTE_SBHD:
case NVTE_QKV_Format::NVTE_BSHD_2SBHD:
case NVTE_QKV_Format::NVTE_THD_2SBHD:
return NVTE_QKV_Format::NVTE_SBHD;
case NVTE_QKV_Format::NVTE_BSHD:
case NVTE_QKV_Format::NVTE_SBHD_2BSHD:
case NVTE_QKV_Format::NVTE_THD_2BSHD:
return NVTE_QKV_Format::NVTE_BSHD;
case NVTE_QKV_Format::NVTE_THD:
return NVTE_QKV_Format::NVTE_THD;
default: default:
NVTE_ERROR("qkv_layout not supported!"); NVTE_ERROR("qkv_layout not supported!");
} }
...@@ -81,6 +144,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -81,6 +144,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
const int sm_arch_ = cuda::sm_arch(device_id); const int sm_arch_ = cuda::sm_arch(device_id);
NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type."); NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type.");
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout);
NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto cudnn_runtime_version = cudnnGetVersion(); auto cudnn_runtime_version = cudnnGetVersion();
...@@ -202,11 +267,31 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -202,11 +267,31 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK &&
max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv && max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv &&
bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) || bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) ||
// 9.5: adds {paged_kv_bshd, paged_kv_sbhd} + {padding, padding_causal, padding_causal_bottom_right}
(cudnn_runtime_version >= 90500 &&
layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD &&
(attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
(attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK &&
max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv)) &&
bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) ||
// 9.6: adds {bshd, sbhd, thd} + padding_causal_bottom_right + self/cross-attn (sq <= skv) // 9.6: adds {bshd, sbhd, thd} + padding_causal_bottom_right + self/cross-attn (sq <= skv)
(cudnn_runtime_version >= 90600 && (cudnn_runtime_version >= 90600 &&
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK &&
max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv && max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv &&
bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0)) && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) ||
// 9.7: removes s_q/s_kv % 64 = 0 for {causal_bottom_right, padding_causal_bottom_right}
// for any q_format/kv_format, and paged/non-paged
(cudnn_runtime_version >= 90700 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
((attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) &&
bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) ||
((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) &&
max_seqlen_q <= max_seqlen_kv)))) &&
// bias + mask combination // bias + mask combination
(!(cudnn_runtime_version >= 8906 && (!(cudnn_runtime_version >= 8906 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
...@@ -216,7 +301,12 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -216,7 +301,12 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD || (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD ||
(qkv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90 && (qkv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90 &&
((cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups) || ((cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups) ||
cudnn_runtime_version >= 90600))) && cudnn_runtime_version >= 90600)) ||
((q_format == NVTE_QKV_Format::NVTE_SBHD || q_format == NVTE_QKV_Format::NVTE_BSHD ||
(q_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90) ||
kv_format == NVTE_QKV_Format::NVTE_SBHD || kv_format == NVTE_QKV_Format::NVTE_BSHD ||
(kv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90)) &&
cudnn_runtime_version >= 90700)) &&
// sliding window // sliding window
// pre-9.2: full attn, causal // pre-9.2: full attn, causal
((cudnn_runtime_version < 90200 && window_size_left == -1 && ((cudnn_runtime_version < 90200 && window_size_left == -1 &&
...@@ -465,22 +555,23 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con ...@@ -465,22 +555,23 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
} }
} }
// NVTE fused attention FWD with packed KV // NVTE fused attention FWD with packed KV
void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, void nvte_fused_attn_fwd_kvpacked(
NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, int64_t window_size_left, int64_t window_size_right, NVTETensor workspace,
int64_t window_size_left, int64_t window_size_right, cudaStream_t stream) {
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q); const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv); const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv);
const Tensor *input_cu_seqlens_q_padded = reinterpret_cast<const Tensor *>(cu_seqlens_q_padded); const Tensor *input_cu_seqlens_q_padded = reinterpret_cast<const Tensor *>(cu_seqlens_q_padded);
const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast<const Tensor *>(cu_seqlens_kv_padded); const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast<const Tensor *>(cu_seqlens_kv_padded);
const Tensor *input_page_table_k = reinterpret_cast<const Tensor *>(page_table_k);
const Tensor *input_page_table_v = reinterpret_cast<const Tensor *>(page_table_v);
const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(rng_state); const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(rng_state);
const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q); const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q);
const Tensor *input_KV = reinterpret_cast<const Tensor *>(KV); const Tensor *input_KV = reinterpret_cast<const Tensor *>(KV);
...@@ -505,11 +596,40 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const ...@@ -505,11 +596,40 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
} }
size_t t_q = 0; size_t t_q = 0;
size_t t_kv = 0; size_t t_kv = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) { NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
if (q_format == NVTE_QKV_Format::NVTE_THD) {
t_q = input_Q->data.shape[0]; t_q = input_Q->data.shape[0];
}
if (kv_format == NVTE_QKV_Format::NVTE_THD) {
t_kv = input_KV->data.shape[0]; t_kv = input_KV->data.shape[0];
} }
int64_t num_pages_k = 0;
int64_t num_pages_v = 0;
int64_t page_size_k = 0;
int64_t page_size_v = 0;
int64_t max_pages_per_seq_k = 0;
int64_t max_pages_per_seq_v = 0;
if (input_page_table_k->data.dptr != nullptr) {
max_pages_per_seq_k = input_page_table_k->data.shape[1];
}
if (input_page_table_v->data.dptr != nullptr) {
max_pages_per_seq_v = input_page_table_v->data.shape[1];
}
if (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) {
NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
if (kv_format == NVTE_QKV_Format::NVTE_BSHD) {
num_pages_k = input_KV->data.shape[0];
page_size_k = input_KV->data.shape[1];
num_pages_v = num_pages_v;
page_size_v = page_size_v;
} else if (kv_format == NVTE_QKV_Format::NVTE_SBHD) {
num_pages_k = input_KV->data.shape[1];
page_size_k = input_KV->data.shape[0];
num_pages_v = num_pages_v;
page_size_v = page_size_v;
}
}
auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); auto handle = cudnnExecutionPlanManager::Instance().GetHandle();
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype); const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
...@@ -531,11 +651,12 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const ...@@ -531,11 +651,12 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8903) #if (CUDNN_VERSION >= 8903)
fused_attn_arbitrary_seqlen_fwd_kvpacked( fused_attn_arbitrary_seqlen_fwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, is_training, attn_scale, dropout, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, num_pages_k, num_pages_v,
qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale,
input_KV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right,
input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q,
handle); input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded,
input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle);
#else #else
NVTE_ERROR( NVTE_ERROR(
"cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
...@@ -596,9 +717,12 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -596,9 +717,12 @@ void nvte_fused_attn_bwd_kvpacked(
} }
size_t t_q = 0; size_t t_q = 0;
size_t t_kv = 0; size_t t_kv = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) { NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
if (q_format == NVTE_QKV_Format::NVTE_THD) {
t_q = input_Q->data.shape[0]; t_q = input_Q->data.shape[0];
}
if (kv_format == NVTE_QKV_Format::NVTE_THD) {
t_kv = input_KV->data.shape[0]; t_kv = input_KV->data.shape[0];
} }
...@@ -664,7 +788,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -664,7 +788,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const NVTETensor Bias, NVTETensor S, NVTETensor O, const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
...@@ -676,6 +801,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -676,6 +801,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv); const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv);
const Tensor *input_cu_seqlens_q_padded = reinterpret_cast<const Tensor *>(cu_seqlens_q_padded); const Tensor *input_cu_seqlens_q_padded = reinterpret_cast<const Tensor *>(cu_seqlens_q_padded);
const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast<const Tensor *>(cu_seqlens_kv_padded); const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast<const Tensor *>(cu_seqlens_kv_padded);
const Tensor *input_page_table_k = reinterpret_cast<const Tensor *>(page_table_k);
const Tensor *input_page_table_v = reinterpret_cast<const Tensor *>(page_table_v);
const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(rng_state); const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(rng_state);
const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q); const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q);
const Tensor *input_K = reinterpret_cast<const Tensor *>(K); const Tensor *input_K = reinterpret_cast<const Tensor *>(K);
...@@ -686,18 +813,49 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -686,18 +813,49 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
Tensor *wkspace = reinterpret_cast<Tensor *>(workspace); Tensor *wkspace = reinterpret_cast<Tensor *>(workspace);
auto ndim = input_Q->data.shape.size(); auto ndim = input_Q->data.shape.size();
auto ndim_kv = input_K->data.shape.size();
size_t b = input_cu_seqlens_q->data.shape[0] - 1; size_t b = input_cu_seqlens_q->data.shape[0] - 1;
size_t h_q = input_Q->data.shape[ndim - 2]; size_t h_q = input_Q->data.shape[ndim - 2];
size_t h_kv = input_K->data.shape[ndim - 2]; size_t h_kv = input_K->data.shape[ndim_kv - 2];
size_t d_qk = input_Q->data.shape[ndim - 1]; size_t d_qk = input_Q->data.shape[ndim - 1];
size_t d_v = input_V->data.shape[ndim - 1]; size_t d_v = input_V->data.shape[ndim_kv - 1];
size_t t_q = 0; size_t t_q = 0;
size_t t_kv = 0; size_t t_kv = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) { NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
if (q_format == NVTE_QKV_Format::NVTE_THD) {
t_q = input_Q->data.shape[0]; t_q = input_Q->data.shape[0];
}
if (kv_format == NVTE_QKV_Format::NVTE_THD) {
t_kv = input_K->data.shape[0]; t_kv = input_K->data.shape[0];
} }
int64_t num_pages_k = 0;
int64_t num_pages_v = 0;
int64_t page_size_k = 0;
int64_t page_size_v = 0;
int64_t max_pages_per_seq_k = 0;
int64_t max_pages_per_seq_v = 0;
if (input_page_table_k->data.dptr != nullptr) {
max_pages_per_seq_k = input_page_table_k->data.shape[1];
}
if (input_page_table_v->data.dptr != nullptr) {
max_pages_per_seq_v = input_page_table_v->data.shape[1];
}
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) {
NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
if (kv_format == NVTE_QKV_Format::NVTE_BSHD) {
num_pages_k = input_K->data.shape[0];
page_size_k = input_K->data.shape[1];
num_pages_v = input_V->data.shape[0];
page_size_v = input_V->data.shape[1];
} else if (kv_format == NVTE_QKV_Format::NVTE_SBHD) {
num_pages_k = input_K->data.shape[1];
page_size_k = input_K->data.shape[0];
num_pages_v = input_V->data.shape[1];
page_size_v = input_V->data.shape[0];
}
}
auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); auto handle = cudnnExecutionPlanManager::Instance().GetHandle();
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype); const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
...@@ -719,11 +877,12 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -719,11 +877,12 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
fused_attn_arbitrary_seqlen_fwd( fused_attn_arbitrary_seqlen_fwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, is_training, attn_scale, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v,
page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale,
dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right,
input_Q, input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_Q, input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q,
input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded,
wkspace, stream, handle); input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle);
#else #else
NVTE_ERROR( NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
...@@ -773,16 +932,20 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -773,16 +932,20 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
Tensor *wkspace = reinterpret_cast<Tensor *>(workspace); Tensor *wkspace = reinterpret_cast<Tensor *>(workspace);
auto ndim = input_Q->data.shape.size(); auto ndim = input_Q->data.shape.size();
auto ndim_kv = input_K->data.shape.size();
size_t b = input_cu_seqlens_q->data.shape[0] - 1; size_t b = input_cu_seqlens_q->data.shape[0] - 1;
size_t h_q = input_Q->data.shape[ndim - 2]; size_t h_q = input_Q->data.shape[ndim - 2];
size_t h_kv = input_K->data.shape[ndim - 2]; size_t h_kv = input_K->data.shape[ndim_kv - 2];
size_t d_qk = input_Q->data.shape[ndim - 1]; size_t d_qk = input_Q->data.shape[ndim - 1];
size_t d_v = input_V->data.shape[ndim - 1]; size_t d_v = input_V->data.shape[ndim_kv - 1];
size_t t_q = 0; size_t t_q = 0;
size_t t_kv = 0; size_t t_kv = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) { NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
if (q_format == NVTE_QKV_Format::NVTE_THD) {
t_q = input_Q->data.shape[0]; t_q = input_Q->data.shape[0];
}
if (kv_format == NVTE_QKV_Format::NVTE_THD) {
t_kv = input_K->data.shape[0]; t_kv = input_K->data.shape[0];
} }
......
...@@ -50,14 +50,16 @@ namespace transformer_engine { ...@@ -50,14 +50,16 @@ namespace transformer_engine {
namespace fused_attn { namespace fused_attn {
void fused_attn_arbitrary_seqlen_fwd_impl( void fused_attn_arbitrary_seqlen_fwd_impl(
int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v,
int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t num_pages_k, int64_t num_pages_v,
bool is_training, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, int64_t page_size_k, int64_t page_size_v, int64_t max_pages_per_seq_k,
int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training,
float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
int64_t window_size_right, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias, int64_t window_size_right, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias,
void *devPtrSoftmaxStats, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrSoftmaxStats, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset,
void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV,
void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType,
size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
...@@ -66,26 +68,35 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -66,26 +68,35 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_bottom_right = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) || bool is_bottom_right = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK));
if (is_bottom_right && s_q == s_kv) {
is_causal = true;
is_bottom_right = false;
}
bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK));
if (is_bottom_right && s_q == s_kv && !is_padding) {
is_causal = true;
is_bottom_right = false;
}
bool is_dropout = (is_training && dropout_probability != 0.0f); bool is_dropout = (is_training && dropout_probability != 0.0f);
bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD); NVTE_QKV_Format q_format = nvte_get_q_format(layout);
NVTE_QKV_Format kv_format = nvte_get_kv_format(layout);
bool is_ragged_q = (q_format == NVTE_QKV_Format::NVTE_THD);
bool is_ragged_kv = (kv_format == NVTE_QKV_Format::NVTE_THD);
const auto cudnn_runtime_version = cudnnGetVersion(); const auto cudnn_runtime_version = cudnnGetVersion();
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout);
bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD);
if (is_paged_kv) {
NVTE_CHECK(is_padding, "Paged attention requires padding mask!");
}
// keep original batch size because cu_seqlens are created with [b+1] shape // keep original batch size because cu_seqlens are created with [b+1] shape
int64_t actual_b = b; int64_t actual_b = b;
if (is_ragged && cudnn_runtime_version >= 90600) { if ((is_ragged_q || is_ragged_kv) && cudnn_runtime_version >= 90600) {
NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!"); NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!");
// replace batch size and maximum sequence lengths with maximum token counts // replace batch size and maximum sequence lengths with maximum token counts
// for query and key/value so the graph is static within each quantization bucket // for query and key/value so the graph is static within each quantization bucket
b = max_b; b = max_b;
s_q = max_t_q; s_q = is_ragged_q ? max_t_q : s_q;
s_kv = max_t_kv; s_kv = is_ragged_kv ? max_t_kv : s_kv;
} }
const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32; const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32;
...@@ -97,6 +108,12 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -97,6 +108,12 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
s_kv, s_kv,
d_qk, d_qk,
d_v, d_v,
num_pages_k,
num_pages_v,
page_size_k,
page_size_v,
max_pages_per_seq_k,
max_pages_per_seq_v,
bias_b, bias_b,
bias_h, bias_h,
scaling_factor, scaling_factor,
...@@ -123,6 +140,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -123,6 +140,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
std::shared_ptr<fe::graph::Tensor_attributes>, // bias std::shared_ptr<fe::graph::Tensor_attributes>, // bias
std::shared_ptr<fe::graph::Tensor_attributes>, // seq_q std::shared_ptr<fe::graph::Tensor_attributes>, // seq_q
std::shared_ptr<fe::graph::Tensor_attributes>, // seq_kv std::shared_ptr<fe::graph::Tensor_attributes>, // seq_kv
std::shared_ptr<fe::graph::Tensor_attributes>, // page_table_k
std::shared_ptr<fe::graph::Tensor_attributes>, // page_table_v
std::shared_ptr<fe::graph::Tensor_attributes>, // offset_q std::shared_ptr<fe::graph::Tensor_attributes>, // offset_q
std::shared_ptr<fe::graph::Tensor_attributes>, // offset_k std::shared_ptr<fe::graph::Tensor_attributes>, // offset_k
std::shared_ptr<fe::graph::Tensor_attributes>, // offset_v std::shared_ptr<fe::graph::Tensor_attributes>, // offset_v
...@@ -151,6 +170,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -151,6 +170,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
std::shared_ptr<fe::graph::Tensor_attributes> Q, K, V, attn_scale; std::shared_ptr<fe::graph::Tensor_attributes> Q, K, V, attn_scale;
std::shared_ptr<fe::graph::Tensor_attributes> bias, seq_q, seq_kv; std::shared_ptr<fe::graph::Tensor_attributes> bias, seq_q, seq_kv;
std::shared_ptr<fe::graph::Tensor_attributes> page_table_k, page_table_v;
std::shared_ptr<fe::graph::Tensor_attributes> offset_q, offset_k, offset_v, offset_o, std::shared_ptr<fe::graph::Tensor_attributes> offset_q, offset_k, offset_v, offset_o,
offset_stats; offset_stats;
std::shared_ptr<fe::graph::Tensor_attributes> dropout_seed, dropout_offset; std::shared_ptr<fe::graph::Tensor_attributes> dropout_seed, dropout_offset;
...@@ -160,17 +180,36 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -160,17 +180,36 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
std::vector<int64_t> v_stride(4); std::vector<int64_t> v_stride(4);
generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), layout, generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), layout,
NVTE_QKV_Matrix::NVTE_Q_Matrix); NVTE_QKV_Matrix::NVTE_Q_Matrix);
generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), layout, if (is_paged_kv) {
NVTE_QKV_Matrix::NVTE_K_Matrix); generateMatrixStrides(num_pages_k, hg, page_size_k, page_size_v, d_qk, k_stride.data(),
generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), layout, layout, NVTE_QKV_Matrix::NVTE_K_Matrix);
NVTE_QKV_Matrix::NVTE_V_Matrix); generateMatrixStrides(num_pages_v, hg, page_size_k, page_size_v, d_v, v_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_V_Matrix);
} else {
generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), layout,
NVTE_QKV_Matrix::NVTE_K_Matrix);
generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), layout,
NVTE_QKV_Matrix::NVTE_V_Matrix);
}
if (is_ragged) { Q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim({b, h, s_q, d_qk})
.set_stride(q_stride));
if (is_ragged_q) {
offset_q = mha_graph->tensor(fe::graph::Tensor_attributes() offset_q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_q") .set_name("offset_q")
.set_dim({b + 1, 1, 1, 1}) .set_dim({b + 1, 1, 1, 1})
.set_stride({1, 1, 1, 1}) .set_stride({1, 1, 1, 1})
.set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); .set_data_type(get_cudnn_fe_dtype(ragged_offset_type)));
Q->set_ragged_offset(offset_q);
}
K = mha_graph->tensor(fe::graph::Tensor_attributes().set_name("K").set_stride(k_stride));
V = mha_graph->tensor(fe::graph::Tensor_attributes().set_name("V").set_stride(v_stride));
if (is_paged_kv) {
K->set_dim({num_pages_k, hg, page_size_k, d_qk});
V->set_dim({num_pages_v, hg, page_size_v, d_v});
} else if (is_ragged_kv) {
offset_k = mha_graph->tensor(fe::graph::Tensor_attributes() offset_k = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_k") .set_name("offset_k")
.set_dim({b + 1, 1, 1, 1}) .set_dim({b + 1, 1, 1, 1})
...@@ -181,34 +220,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -181,34 +220,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
.set_dim({b + 1, 1, 1, 1}) .set_dim({b + 1, 1, 1, 1})
.set_stride({1, 1, 1, 1}) .set_stride({1, 1, 1, 1})
.set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); .set_data_type(get_cudnn_fe_dtype(ragged_offset_type)));
Q = mha_graph->tensor(fe::graph::Tensor_attributes() K->set_dim({b, hg, s_kv, d_qk}).set_ragged_offset(offset_k);
.set_name("Q") V->set_dim({b, hg, s_kv, d_v}).set_ragged_offset(offset_v);
.set_dim({b, h, s_q, d_qk})
.set_stride(q_stride)
.set_ragged_offset(offset_q));
K = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("K")
.set_dim({b, hg, s_kv, d_qk})
.set_stride(k_stride)
.set_ragged_offset(offset_k));
V = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("V")
.set_dim({b, hg, s_kv, d_v})
.set_stride(v_stride)
.set_ragged_offset(offset_v));
} else { } else {
Q = mha_graph->tensor(fe::graph::Tensor_attributes() K->set_dim({b, hg, s_kv, d_qk});
.set_name("Q") V->set_dim({b, hg, s_kv, d_v});
.set_dim({b, h, s_q, d_qk})
.set_stride(q_stride));
K = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("K")
.set_dim({b, hg, s_kv, d_qk})
.set_stride(k_stride));
V = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("V")
.set_dim({b, hg, s_kv, d_v})
.set_stride(v_stride));
} }
attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes()
...@@ -254,6 +270,24 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -254,6 +270,24 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
sdpa_options.set_padding_mask(is_padding).set_seq_len_q(seq_q).set_seq_len_kv(seq_kv); sdpa_options.set_padding_mask(is_padding).set_seq_len_q(seq_q).set_seq_len_kv(seq_kv);
} }
if (is_paged_kv) {
page_table_k =
mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("page_table_k")
.set_dim({b, 1, max_pages_per_seq_k, 1})
.set_stride({{max_pages_per_seq_k, max_pages_per_seq_v, 1, 1}})
.set_data_type(fe::DataType_t::INT32));
page_table_v =
mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("page_table_v")
.set_dim({b, 1, max_pages_per_seq_v, 1})
.set_stride({{max_pages_per_seq_v, max_pages_per_seq_v, 1, 1}})
.set_data_type(fe::DataType_t::INT32));
sdpa_options.set_paged_attention_k_table(page_table_k);
sdpa_options.set_paged_attention_v_table(page_table_v);
sdpa_options.set_paged_attention_max_seq_len_kv(static_cast<int32_t>(s_kv));
}
if (is_dropout) { if (is_dropout) {
dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes() dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Seed") .set_name("Seed")
...@@ -273,37 +307,27 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -273,37 +307,27 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
std::vector<int64_t> o_stride(4); std::vector<int64_t> o_stride(4);
generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout, generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout,
NVTE_QKV_Matrix::NVTE_O_Matrix); NVTE_QKV_Matrix::NVTE_O_Matrix);
if (is_ragged) { O->set_output(true).set_dim({b, h, s_q, d_v}).set_stride(o_stride);
if (is_ragged_q) {
offset_o = mha_graph->tensor(fe::graph::Tensor_attributes() offset_o = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_o") .set_name("offset_o")
.set_dim({b + 1, 1, 1, 1}) .set_dim({b + 1, 1, 1, 1})
.set_stride({1, 1, 1, 1}) .set_stride({1, 1, 1, 1})
.set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); .set_data_type(get_cudnn_fe_dtype(ragged_offset_type)));
O->set_output(true) O->set_ragged_offset(offset_o);
.set_dim({b, h, s_q, d_v})
.set_stride(o_stride)
.set_ragged_offset(offset_o);
} else {
O->set_output(true).set_dim({b, h, s_q, d_v}).set_stride(o_stride);
} }
if (is_ragged && cudnn_runtime_version >= 90600) { Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1});
if (is_ragged_q && cudnn_runtime_version >= 90600) {
offset_stats = offset_stats =
mha_graph->tensor(fe::graph::Tensor_attributes() mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_stats") .set_name("offset_stats")
.set_dim({b + 1, 1, 1, 1}) .set_dim({b + 1, 1, 1, 1})
.set_stride({1, 1, 1, 1}) .set_stride({1, 1, 1, 1})
.set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); .set_data_type(get_cudnn_fe_dtype(ragged_offset_type)));
Stats->set_output(true) Stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats);
.set_data_type(fe::DataType_t::FLOAT)
.set_dim({b, h, s_q, 1})
.set_stride({h * s_q, 1, h, 1})
.set_ragged_offset(offset_stats);
} else { } else {
Stats->set_output(true) Stats->set_stride({h * s_q, s_q, 1, 1});
.set_data_type(fe::DataType_t::FLOAT)
.set_dim({b, h, s_q, 1})
.set_stride({h * s_q, s_q, 1, 1});
} }
std::tuple<std::shared_ptr<fe::graph::Tensor_attributes>, // Q std::tuple<std::shared_ptr<fe::graph::Tensor_attributes>, // Q
...@@ -316,9 +340,13 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -316,9 +340,13 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr); auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr);
auto padding_tuple = auto padding_tuple =
is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr);
auto offset_qkvo_tuple = is_ragged ? std::make_tuple(offset_q, offset_k, offset_v, offset_o) auto page_table_tuple = is_paged_kv ? std::make_tuple(page_table_k, page_table_v)
: std::make_tuple(nullptr, nullptr, nullptr, nullptr); : std::make_tuple(nullptr, nullptr);
auto offset_s_tuple = (is_ragged && cudnn_runtime_version >= 90600) auto offset_qo_tuple =
is_ragged_q ? std::make_tuple(offset_q, offset_o) : std::make_tuple(nullptr, nullptr);
auto offset_kv_tuple =
is_ragged_kv ? std::make_tuple(offset_k, offset_v) : std::make_tuple(nullptr, nullptr);
auto offset_s_tuple = (is_ragged_q && cudnn_runtime_version >= 90600)
? std::make_tuple(offset_stats) ? std::make_tuple(offset_stats)
: std::make_tuple(nullptr); : std::make_tuple(nullptr);
auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset) auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset)
...@@ -330,16 +358,16 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -330,16 +358,16 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle));
NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle));
auto return_tuple = auto return_tuple = std::tuple_cat(
std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, padding_tuple,
padding_tuple, offset_qkvo_tuple, offset_s_tuple, dropout_tuple); page_table_tuple, offset_qo_tuple, offset_kv_tuple, offset_s_tuple, dropout_tuple);
cache.insert({descriptor, return_tuple}); cache.insert({descriptor, return_tuple});
return return_tuple; return return_tuple;
}; };
auto [mha_graph, Q, K, V, attn_scale, O, Stats, bias, seq_q, seq_kv, offset_q, offset_k, auto [mha_graph, Q, K, V, attn_scale, O, Stats, bias, seq_q, seq_kv, page_table_k, page_table_v,
offset_v, offset_o, offset_stats, dropout_seed, dropout_offset] = offset_q, offset_o, offset_k, offset_v, offset_stats, dropout_seed, dropout_offset] =
get_graph(sdpa_f16_fprop_cache, descriptor); get_graph(sdpa_f16_fprop_cache, descriptor);
// Exit to request upper level API to allocate memory if needed // Exit to request upper level API to allocate memory if needed
...@@ -351,11 +379,12 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -351,11 +379,12 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
const size_t num_bytes_per_ragged_offset = const size_t num_bytes_per_ragged_offset =
alignTo<16>((b + 1) * typeToSize(ragged_offset_type)); alignTo<16>((b + 1) * typeToSize(ragged_offset_type));
size_t seqlen_offsets_workspace_size = 0; size_t seqlen_offsets_workspace_size = 0;
if (is_ragged) { if (is_ragged_q || is_ragged_kv) {
if (cudnn_runtime_version >= 90600) { size_t count = 2 * (static_cast<size_t>(is_ragged_q) + static_cast<size_t>(is_ragged_kv));
seqlen_offsets_workspace_size = 5 * num_bytes_per_ragged_offset; if (is_ragged_q && cudnn_runtime_version >= 90600) {
seqlen_offsets_workspace_size = (count + 1) * num_bytes_per_ragged_offset;
} else { } else {
seqlen_offsets_workspace_size = 4 * num_bytes_per_ragged_offset; seqlen_offsets_workspace_size = count * num_bytes_per_ragged_offset;
} }
} }
if (workspace == nullptr) { if (workspace == nullptr) {
...@@ -391,28 +420,49 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -391,28 +420,49 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
variant_pack[seq_kv] = devActualSeqlenKV; variant_pack[seq_kv] = devActualSeqlenKV;
} }
if (is_ragged) { if (is_paged_kv) {
variant_pack[page_table_k] = devPtrPageTableK;
variant_pack[page_table_v] = devPtrPageTableV;
}
if (is_ragged_q || is_ragged_kv) {
constexpr size_t nthreads_per_block = 128; constexpr size_t nthreads_per_block = 128;
const size_t grid = (b + nthreads_per_block) / nthreads_per_block; const size_t grid = (b + nthreads_per_block) / nthreads_per_block;
void *devOffsetsQ = void *devOffsets =
static_cast<int8_t *>(workspace) + plan_workspace_size + actual_seqlen_workspace_size; static_cast<int8_t *>(workspace) + plan_workspace_size + actual_seqlen_workspace_size;
void *devOffsetsK = static_cast<int8_t *>(devOffsetsQ) + num_bytes_per_ragged_offset; void *devOffsetsQ = nullptr;
void *devOffsetsV = static_cast<int8_t *>(devOffsetsK) + num_bytes_per_ragged_offset; void *devOffsetsO = nullptr;
void *devOffsetsO = static_cast<int8_t *>(devOffsetsV) + num_bytes_per_ragged_offset; if (is_ragged_q) {
devOffsetsQ = devOffsets;
devOffsetsO = static_cast<int8_t *>(devOffsetsQ) + num_bytes_per_ragged_offset;
}
void *devOffsetsK = nullptr;
void *devOffsetsV = nullptr;
if (is_ragged_kv) {
devOffsetsK = static_cast<int8_t *>(devOffsets) +
static_cast<int>(is_ragged_q) * 2 * num_bytes_per_ragged_offset;
devOffsetsV = static_cast<int8_t *>(devOffsetsK) + num_bytes_per_ragged_offset;
}
void *devOffsetsS = nullptr; void *devOffsetsS = nullptr;
if (cudnn_runtime_version >= 90600) { if (is_ragged_q && cudnn_runtime_version >= 90600) {
devOffsetsS = static_cast<int8_t *>(devOffsetsO) + num_bytes_per_ragged_offset; devOffsetsS = static_cast<int8_t *>(devOffsets) +
(static_cast<int>(is_ragged_q) + static_cast<int>(is_ragged_kv)) * 2 *
num_bytes_per_ragged_offset;
} }
const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout);
cu_seqlens_padded_to_offsets<<<grid, nthreads_per_block, 0, stream>>>( cu_seqlens_padded_to_offsets<<<grid, nthreads_per_block, 0, stream>>>(
layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast<int32_t *>(devPtrSeqOffsetsQ), layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast<int32_t *>(devPtrSeqOffsetsQ),
static_cast<int32_t *>(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK, static_cast<int32_t *>(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK,
devOffsetsV, devOffsetsO, devOffsetsS); devOffsetsV, devOffsetsO, devOffsetsS);
variant_pack[offset_q] = devOffsetsQ; if (is_ragged_q) {
variant_pack[offset_k] = devOffsetsK; variant_pack[offset_q] = devOffsetsQ;
variant_pack[offset_v] = devOffsetsV; variant_pack[offset_o] = devOffsetsO;
variant_pack[offset_o] = devOffsetsO; }
if (cudnn_runtime_version >= 90600) { if (is_ragged_kv) {
variant_pack[offset_k] = devOffsetsK;
variant_pack[offset_v] = devOffsetsV;
}
if (is_ragged_q && cudnn_runtime_version >= 90600) {
variant_pack[offset_stats] = devOffsetsS; variant_pack[offset_stats] = devOffsetsS;
} }
} }
...@@ -447,25 +497,37 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -447,25 +497,37 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_bottom_right = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) || bool is_bottom_right = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK));
if (is_bottom_right && s_q == s_kv) {
is_causal = true;
is_bottom_right = false;
}
bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK));
if (is_bottom_right && s_q == s_kv && !is_padding) {
is_causal = true;
is_bottom_right = false;
}
bool is_dropout = (dropout_probability != 0.0f); bool is_dropout = (dropout_probability != 0.0f);
bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD); NVTE_QKV_Format q_format = nvte_get_q_format(layout);
NVTE_QKV_Format kv_format = nvte_get_kv_format(layout);
bool is_ragged_q = (q_format == NVTE_QKV_Format::NVTE_THD);
bool is_ragged_kv = (kv_format == NVTE_QKV_Format::NVTE_THD);
const auto cudnn_runtime_version = cudnnGetVersion(); const auto cudnn_runtime_version = cudnnGetVersion();
const int device_id = cuda::current_device();
const int sm_arch_ = cuda::sm_arch(device_id);
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout);
bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD);
if (is_paged_kv) {
NVTE_CHECK(is_padding, "Paged attention requires padding mask!");
}
// keep original batch size because cu_seqlens are created with [b+1] shape // keep original batch size because cu_seqlens are created with [b+1] shape
int64_t actual_b = b; int64_t actual_b = b;
if (is_ragged && cudnn_runtime_version >= 90600) { if ((is_ragged_q || is_ragged_kv) && cudnn_runtime_version >= 90600) {
NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!"); NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!");
// replace batch size and maximum sequence lengths with maximum token counts // replace batch size and maximum sequence lengths with maximum token counts
// for query and key/value so the graph is static within each quantization bucket // for query and key/value so the graph is static within each quantization bucket
b = max_b; b = max_b;
s_q = max_t_q; s_q = is_ragged_q ? max_t_q : s_q;
s_kv = max_t_kv; s_kv = is_ragged_kv ? max_t_kv : s_kv;
} }
// We choose between 32-bit and 64-bit offsets depending on need. // We choose between 32-bit and 64-bit offsets depending on need.
...@@ -480,6 +542,12 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -480,6 +542,12 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
s_kv, s_kv,
d_qk, d_qk,
d_v, d_v,
0,
0,
0,
0,
0,
0,
bias_b, bias_b,
bias_h, bias_h,
scaling_factor, scaling_factor,
...@@ -556,12 +624,42 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -556,12 +624,42 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout, generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout,
NVTE_QKV_Matrix::NVTE_O_Matrix); NVTE_QKV_Matrix::NVTE_O_Matrix);
if (is_ragged) { q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim({b, h, s_q, d_qk})
.set_stride(q_stride));
k = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("K")
.set_dim({b, hg, s_kv, d_qk})
.set_stride(k_stride));
v = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("V")
.set_dim({b, hg, s_kv, d_v})
.set_stride(v_stride));
o = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("O")
.set_dim({b, h, s_q, d_v})
.set_stride(o_stride));
dO = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("dO")
.set_dim({b, h, s_q, d_v})
.set_stride(o_stride));
if (is_ragged_q) {
offset_q = mha_graph->tensor(fe::graph::Tensor_attributes() offset_q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_q") .set_name("offset_q")
.set_dim({b + 1, 1, 1, 1}) .set_dim({b + 1, 1, 1, 1})
.set_stride({1, 1, 1, 1}) .set_stride({1, 1, 1, 1})
.set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); .set_data_type(get_cudnn_fe_dtype(ragged_offset_type)));
offset_o = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_o")
.set_dim({b + 1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(get_cudnn_fe_dtype(ragged_offset_type)));
q->set_ragged_offset(offset_q);
o->set_ragged_offset(offset_o);
dO->set_ragged_offset(offset_o);
}
if (is_ragged_kv) {
offset_k = mha_graph->tensor(fe::graph::Tensor_attributes() offset_k = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_k") .set_name("offset_k")
.set_dim({b + 1, 1, 1, 1}) .set_dim({b + 1, 1, 1, 1})
...@@ -572,77 +670,24 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -572,77 +670,24 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
.set_dim({b + 1, 1, 1, 1}) .set_dim({b + 1, 1, 1, 1})
.set_stride({1, 1, 1, 1}) .set_stride({1, 1, 1, 1})
.set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); .set_data_type(get_cudnn_fe_dtype(ragged_offset_type)));
offset_o = mha_graph->tensor(fe::graph::Tensor_attributes() k->set_ragged_offset(offset_k);
.set_name("offset_o") v->set_ragged_offset(offset_v);
.set_dim({b + 1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(get_cudnn_fe_dtype(ragged_offset_type)));
q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim({b, h, s_q, d_qk})
.set_stride(q_stride)
.set_ragged_offset(offset_q));
k = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("K")
.set_dim({b, hg, s_kv, d_qk})
.set_stride(k_stride)
.set_ragged_offset(offset_k));
v = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("V")
.set_dim({b, hg, s_kv, d_v})
.set_stride(v_stride)
.set_ragged_offset(offset_v));
o = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("O")
.set_dim({b, h, s_q, d_v})
.set_stride(o_stride)
.set_ragged_offset(offset_o));
dO = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("dO")
.set_dim({b, h, s_q, d_v})
.set_stride(o_stride)
.set_ragged_offset(offset_o));
} else {
q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim({b, h, s_q, d_qk})
.set_stride(q_stride));
k = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("K")
.set_dim({b, hg, s_kv, d_qk})
.set_stride(k_stride));
v = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("V")
.set_dim({b, hg, s_kv, d_v})
.set_stride(v_stride));
o = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("O")
.set_dim({b, h, s_q, d_v})
.set_stride(o_stride));
dO = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("dO")
.set_dim({b, h, s_q, d_v})
.set_stride(o_stride));
} }
if (is_ragged && cudnn_runtime_version >= 90600) {
stats = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("stats")
.set_dim({b, h, s_q, 1})
.set_data_type(fe::DataType_t::FLOAT));
if (is_ragged_q && cudnn_runtime_version >= 90600) {
offset_stats = offset_stats =
mha_graph->tensor(fe::graph::Tensor_attributes() mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_stats") .set_name("offset_stats")
.set_dim({b + 1, 1, 1, 1}) .set_dim({b + 1, 1, 1, 1})
.set_stride({1, 1, 1, 1}) .set_stride({1, 1, 1, 1})
.set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); .set_data_type(get_cudnn_fe_dtype(ragged_offset_type)));
stats = mha_graph->tensor(fe::graph::Tensor_attributes() stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats);
.set_name("stats")
.set_dim({b, h, s_q, 1})
.set_stride({h * s_q, 1, h, 1})
.set_data_type(fe::DataType_t::FLOAT)
.set_ragged_offset(offset_stats));
} else { } else {
stats = mha_graph->tensor(fe::graph::Tensor_attributes() stats->set_stride({h * s_q, s_q, 1, 1});
.set_name("stats")
.set_dim({b, h, s_q, 1})
.set_stride({h * s_q, s_q, 1, 1})
.set_data_type(fe::DataType_t::FLOAT));
} }
attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes()
...@@ -659,8 +704,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -659,8 +704,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
.set_causal_mask_bottom_right(is_bottom_right) .set_causal_mask_bottom_right(is_bottom_right)
.set_attn_scale(attn_scale); .set_attn_scale(attn_scale);
if (is_ragged && cudnn_runtime_version >= 90600) { if (is_ragged_q && cudnn_runtime_version >= 90600) {
sdpa_backward_options.set_max_total_seq_len_q(s_q); sdpa_backward_options.set_max_total_seq_len_q(s_q);
}
if (is_ragged_kv && cudnn_runtime_version >= 90600) {
sdpa_backward_options.set_max_total_seq_len_kv(s_kv); sdpa_backward_options.set_max_total_seq_len_kv(s_kv);
} }
...@@ -724,23 +771,15 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -724,23 +771,15 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
auto [dQ, dK, dV] = mha_graph->sdpa_backward(q, k, v, o, dO, stats, sdpa_backward_options); auto [dQ, dK, dV] = mha_graph->sdpa_backward(q, k, v, o, dO, stats, sdpa_backward_options);
if (is_ragged) { dQ->set_output(true).set_dim({b, h, s_q, d_qk}).set_stride(q_stride);
dQ->set_output(true) dK->set_output(true).set_dim({b, hg, s_kv, d_qk}).set_stride(k_stride);
.set_dim({b, h, s_q, d_qk}) dV->set_output(true).set_dim({b, hg, s_kv, d_v}).set_stride(v_stride);
.set_stride(q_stride) if (is_ragged_q) {
.set_ragged_offset(offset_q); dQ->set_ragged_offset(offset_q);
dK->set_output(true) }
.set_dim({b, hg, s_kv, d_qk}) if (is_ragged_kv) {
.set_stride(k_stride) dK->set_ragged_offset(offset_k);
.set_ragged_offset(offset_k); dV->set_ragged_offset(offset_v);
dV->set_output(true)
.set_dim({b, hg, s_kv, d_v})
.set_stride(v_stride)
.set_ragged_offset(offset_v);
} else {
dQ->set_output(true).set_dim({b, h, s_q, d_qk}).set_stride(q_stride);
dK->set_output(true).set_dim({b, hg, s_kv, d_qk}).set_stride(k_stride);
dV->set_output(true).set_dim({b, hg, s_kv, d_v}).set_stride(v_stride);
} }
std::tuple<std::shared_ptr<fe::graph::Tensor_attributes>, // q std::tuple<std::shared_ptr<fe::graph::Tensor_attributes>, // q
...@@ -757,9 +796,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -757,9 +796,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
auto bias_tuple = is_bias ? std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr); auto bias_tuple = is_bias ? std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr);
auto padding_tuple = auto padding_tuple =
is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr);
auto offset_qkvo_tuple = is_ragged ? std::make_tuple(offset_q, offset_k, offset_v, offset_o) auto offset_qo_tuple =
: std::make_tuple(nullptr, nullptr, nullptr, nullptr); is_ragged_q ? std::make_tuple(offset_q, offset_o) : std::make_tuple(nullptr, nullptr);
auto offset_s_tuple = (is_ragged && cudnn_runtime_version >= 90600) auto offset_kv_tuple =
is_ragged_kv ? std::make_tuple(offset_k, offset_v) : std::make_tuple(nullptr, nullptr);
auto offset_s_tuple = (is_ragged_q && cudnn_runtime_version >= 90600)
? std::make_tuple(offset_stats) ? std::make_tuple(offset_stats)
: std::make_tuple(nullptr); : std::make_tuple(nullptr);
auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset) auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset)
...@@ -773,14 +814,14 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -773,14 +814,14 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
auto return_tuple = auto return_tuple =
std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, padding_tuple, std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, padding_tuple,
offset_qkvo_tuple, offset_s_tuple, dropout_tuple); offset_qo_tuple, offset_kv_tuple, offset_s_tuple, dropout_tuple);
cache.insert({descriptor, return_tuple}); cache.insert({descriptor, return_tuple});
return return_tuple; return return_tuple;
}; };
auto [mha_graph, q, k, v, o, dO, stats, attn_scale, dQ, dK, dV, bias, dBias, seq_q, seq_kv, auto [mha_graph, q, k, v, o, dO, stats, attn_scale, dQ, dK, dV, bias, dBias, seq_q, seq_kv,
offset_q, offset_k, offset_v, offset_o, offset_stats, dropout_seed, dropout_offset] = offset_q, offset_o, offset_k, offset_v, offset_stats, dropout_seed, dropout_offset] =
get_graph(sdpa_f16_bprop_cache, descriptor); get_graph(sdpa_f16_bprop_cache, descriptor);
// Exit to request upper level API to allocate memory if needed // Exit to request upper level API to allocate memory if needed
...@@ -792,11 +833,12 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -792,11 +833,12 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
const size_t num_bytes_per_ragged_offset = const size_t num_bytes_per_ragged_offset =
alignTo<16>((b + 1) * typeToSize(ragged_offset_type)); alignTo<16>((b + 1) * typeToSize(ragged_offset_type));
size_t seqlen_offsets_workspace_size = 0; size_t seqlen_offsets_workspace_size = 0;
if (is_ragged) { if (is_ragged_q || is_ragged_kv) {
if (cudnn_runtime_version >= 90600) { size_t count = 2 * (static_cast<size_t>(is_ragged_q) + static_cast<size_t>(is_ragged_kv));
seqlen_offsets_workspace_size = 5 * num_bytes_per_ragged_offset; if (is_ragged_q && cudnn_runtime_version >= 90600) {
seqlen_offsets_workspace_size = (count + 1) * num_bytes_per_ragged_offset;
} else { } else {
seqlen_offsets_workspace_size = 4 * num_bytes_per_ragged_offset; seqlen_offsets_workspace_size = count * num_bytes_per_ragged_offset;
} }
} }
if (workspace == nullptr) { if (workspace == nullptr) {
...@@ -845,28 +887,44 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -845,28 +887,44 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
variant_pack[seq_kv] = devActualSeqlenKV; variant_pack[seq_kv] = devActualSeqlenKV;
} }
if (is_ragged) { if (is_ragged_q || is_ragged_kv) {
constexpr size_t nthreads_per_block = 128; constexpr size_t nthreads_per_block = 128;
const size_t grid = (b + nthreads_per_block) / nthreads_per_block; const size_t grid = (b + nthreads_per_block) / nthreads_per_block;
void *devOffsetsQ = void *devOffsets =
static_cast<int8_t *>(workspace) + plan_workspace_size + actual_seqlen_workspace_size; static_cast<int8_t *>(workspace) + plan_workspace_size + actual_seqlen_workspace_size;
void *devOffsetsK = static_cast<int8_t *>(devOffsetsQ) + num_bytes_per_ragged_offset; void *devOffsetsQ = nullptr;
void *devOffsetsV = static_cast<int8_t *>(devOffsetsK) + num_bytes_per_ragged_offset; void *devOffsetsO = nullptr;
void *devOffsetsO = static_cast<int8_t *>(devOffsetsV) + num_bytes_per_ragged_offset; if (is_ragged_q) {
devOffsetsQ = devOffsets;
devOffsetsO = static_cast<int8_t *>(devOffsetsQ) + num_bytes_per_ragged_offset;
}
void *devOffsetsK = nullptr;
void *devOffsetsV = nullptr;
if (is_ragged_kv) {
devOffsetsK = static_cast<int8_t *>(devOffsets) +
static_cast<int>(is_ragged_q) * 2 * num_bytes_per_ragged_offset;
devOffsetsV = static_cast<int8_t *>(devOffsetsK) + num_bytes_per_ragged_offset;
}
void *devOffsetsS = nullptr; void *devOffsetsS = nullptr;
if (cudnn_runtime_version >= 90600) { if (is_ragged_q && cudnn_runtime_version >= 90600) {
devOffsetsS = static_cast<int8_t *>(devOffsetsO) + num_bytes_per_ragged_offset; devOffsetsS = static_cast<int8_t *>(devOffsets) +
(static_cast<int>(is_ragged_q) + static_cast<int>(is_ragged_kv)) * 2 *
num_bytes_per_ragged_offset;
} }
const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout);
cu_seqlens_padded_to_offsets<<<grid, nthreads_per_block, 0, stream>>>( cu_seqlens_padded_to_offsets<<<grid, nthreads_per_block, 0, stream>>>(
layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast<int32_t *>(devPtrSeqOffsetsQ), layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast<int32_t *>(devPtrSeqOffsetsQ),
static_cast<int32_t *>(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK, static_cast<int32_t *>(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK,
devOffsetsV, devOffsetsO, devOffsetsS); devOffsetsV, devOffsetsO, devOffsetsS);
variant_pack[offset_q] = devOffsetsQ; if (is_ragged_q) {
variant_pack[offset_k] = devOffsetsK; variant_pack[offset_q] = devOffsetsQ;
variant_pack[offset_v] = devOffsetsV; variant_pack[offset_o] = devOffsetsO;
variant_pack[offset_o] = devOffsetsO; }
if (cudnn_runtime_version >= 90600) { if (is_ragged_kv) {
variant_pack[offset_k] = devOffsetsK;
variant_pack[offset_v] = devOffsetsV;
}
if (is_ragged_q && cudnn_runtime_version >= 90600) {
variant_pack[offset_stats] = devOffsetsS; variant_pack[offset_stats] = devOffsetsS;
} }
} }
...@@ -987,11 +1045,12 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( ...@@ -987,11 +1045,12 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
fused_attn_arbitrary_seqlen_fwd_impl( fused_attn_arbitrary_seqlen_fwd_impl(
batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim,
max_batch_size, max_tokens, max_tokens, bias_b, bias_h, is_training, attn_scale, p_dropout, max_batch_size, max_tokens, max_tokens, 0, 0, 0, 0, 0, 0, bias_b, bias_h, is_training,
qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ, devPtrK, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right,
devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed,
devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, nullptr, nullptr, devPtrSeqOffsets,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream,
handle);
if (workspace_size > 0) { if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
...@@ -1095,20 +1154,23 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( ...@@ -1095,20 +1154,23 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
void fused_attn_arbitrary_seqlen_fwd_kvpacked( void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv,
bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale,
int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
cudaStream_t stream, cudnnHandle_t handle) { const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded,
const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype; const auto QKV_type = input_Q->data.dtype;
void *devPtrQ = input_Q->data.dptr; void *devPtrQ = input_Q->data.dptr;
void *devPtrKV = input_KV->data.dptr; void *devPtrKV = input_KV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout);
NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
size_t stride = 0; size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
stride = typeToSize(QKV_type) * num_gqa_groups * head_dim; stride = typeToSize(QKV_type) * num_gqa_groups * head_dim;
...@@ -1134,13 +1196,19 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( ...@@ -1134,13 +1196,19 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr; void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr;
void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr; void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr;
void *devPtrPageTableK = page_table_k->data.dptr;
void *devPtrPageTableV = page_table_v->data.dptr;
size_t max_batch_size = 0; size_t max_batch_size = 0;
size_t max_tokens_q = 0; size_t max_tokens_q = 0;
size_t max_tokens_kv = 0; size_t max_tokens_kv = 0;
if (qkv_format == NVTE_QKV_Format::NVTE_THD) { if (q_format == NVTE_QKV_Format::NVTE_THD || kv_format == NVTE_QKV_Format::NVTE_THD) {
max_batch_size = get_max_batch_size(batch); max_batch_size = get_max_batch_size(batch);
}
if (q_format == NVTE_QKV_Format::NVTE_THD) {
max_tokens_q = get_max_tokens(num_tokens_q); max_tokens_q = get_max_tokens(num_tokens_q);
}
if (kv_format == NVTE_QKV_Format::NVTE_THD) {
max_tokens_kv = get_max_tokens(num_tokens_kv); max_tokens_kv = get_max_tokens(num_tokens_kv);
} }
...@@ -1150,7 +1218,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( ...@@ -1150,7 +1218,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
Aux_CTX_Tensors->size = 3; Aux_CTX_Tensors->size = 3;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr; output_S->data.dptr = nullptr;
if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
} else { } else {
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
...@@ -1168,7 +1236,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( ...@@ -1168,7 +1236,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
Aux_CTX_Tensors->size = 2; Aux_CTX_Tensors->size = 2;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr; output_S->data.dptr = nullptr;
if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
} else { } else {
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
...@@ -1203,11 +1271,13 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( ...@@ -1203,11 +1271,13 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
fused_attn_arbitrary_seqlen_fwd_impl( fused_attn_arbitrary_seqlen_fwd_impl(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim,
max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, is_training, attn_scale, max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k,
p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training,
devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right,
devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV,
devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle);
if (workspace_size > 0) { if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
...@@ -1266,10 +1336,15 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( ...@@ -1266,10 +1336,15 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t max_batch_size = 0; size_t max_batch_size = 0;
size_t max_tokens_q = 0; size_t max_tokens_q = 0;
size_t max_tokens_kv = 0; size_t max_tokens_kv = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) { NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
if (q_format == NVTE_QKV_Format::NVTE_THD || kv_format == NVTE_QKV_Format::NVTE_THD) {
max_batch_size = get_max_batch_size(batch); max_batch_size = get_max_batch_size(batch);
}
if (q_format == NVTE_QKV_Format::NVTE_THD) {
max_tokens_q = get_max_tokens(num_tokens_q); max_tokens_q = get_max_tokens(num_tokens_q);
}
if (kv_format == NVTE_QKV_Format::NVTE_THD) {
max_tokens_kv = get_max_tokens(num_tokens_kv); max_tokens_kv = get_max_tokens(num_tokens_kv);
} }
...@@ -1319,17 +1394,20 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( ...@@ -1319,17 +1394,20 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
void fused_attn_arbitrary_seqlen_fwd( void fused_attn_arbitrary_seqlen_fwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q,
size_t num_tokens_kv, bool is_training, float attn_scale, float p_dropout, size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype; const auto QKV_type = input_Q->data.dtype;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout);
NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
void *devPtrQ = input_Q->data.dptr; void *devPtrQ = input_Q->data.dptr;
void *devPtrK = input_K->data.dptr; void *devPtrK = input_K->data.dptr;
void *devPtrV = input_V->data.dptr; void *devPtrV = input_V->data.dptr;
...@@ -1348,13 +1426,19 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -1348,13 +1426,19 @@ void fused_attn_arbitrary_seqlen_fwd(
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr; void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr;
void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr; void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr;
void *devPtrPageTableK = page_table_k->data.dptr;
void *devPtrPageTableV = page_table_v->data.dptr;
size_t max_batch_size = 0; size_t max_batch_size = 0;
size_t max_tokens_q = 0; size_t max_tokens_q = 0;
size_t max_tokens_kv = 0; size_t max_tokens_kv = 0;
if (qkv_format == NVTE_QKV_Format::NVTE_THD) { if (q_format == NVTE_QKV_Format::NVTE_THD || kv_format == NVTE_QKV_Format::NVTE_THD) {
max_batch_size = get_max_batch_size(batch); max_batch_size = get_max_batch_size(batch);
}
if (q_format == NVTE_QKV_Format::NVTE_THD) {
max_tokens_q = get_max_tokens(num_tokens_q); max_tokens_q = get_max_tokens(num_tokens_q);
}
if (kv_format == NVTE_QKV_Format::NVTE_THD) {
max_tokens_kv = get_max_tokens(num_tokens_kv); max_tokens_kv = get_max_tokens(num_tokens_kv);
} }
...@@ -1364,7 +1448,7 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -1364,7 +1448,7 @@ void fused_attn_arbitrary_seqlen_fwd(
Aux_CTX_Tensors->size = 3; Aux_CTX_Tensors->size = 3;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr; output_S->data.dptr = nullptr;
if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
} else { } else {
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
...@@ -1382,7 +1466,7 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -1382,7 +1466,7 @@ void fused_attn_arbitrary_seqlen_fwd(
Aux_CTX_Tensors->size = 2; Aux_CTX_Tensors->size = 2;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr; output_S->data.dptr = nullptr;
if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
} else { } else {
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
...@@ -1417,11 +1501,13 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -1417,11 +1501,13 @@ void fused_attn_arbitrary_seqlen_fwd(
fused_attn_arbitrary_seqlen_fwd_impl( fused_attn_arbitrary_seqlen_fwd_impl(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v,
max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, is_training, attn_scale, max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k,
p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training,
devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right,
devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV,
devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle);
if (workspace_size > 0) { if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
...@@ -1470,10 +1556,15 @@ void fused_attn_arbitrary_seqlen_bwd( ...@@ -1470,10 +1556,15 @@ void fused_attn_arbitrary_seqlen_bwd(
size_t max_batch_size = 0; size_t max_batch_size = 0;
size_t max_tokens_q = 0; size_t max_tokens_q = 0;
size_t max_tokens_kv = 0; size_t max_tokens_kv = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) { NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
if (q_format == NVTE_QKV_Format::NVTE_THD || kv_format == NVTE_QKV_Format::NVTE_THD) {
max_batch_size = get_max_batch_size(batch); max_batch_size = get_max_batch_size(batch);
}
if (q_format == NVTE_QKV_Format::NVTE_THD) {
max_tokens_q = get_max_tokens(num_tokens_q); max_tokens_q = get_max_tokens(num_tokens_q);
}
if (kv_format == NVTE_QKV_Format::NVTE_THD) {
max_tokens_kv = get_max_tokens(num_tokens_kv); max_tokens_kv = get_max_tokens(num_tokens_kv);
} }
......
...@@ -38,13 +38,15 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( ...@@ -38,13 +38,15 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
void fused_attn_arbitrary_seqlen_fwd_kvpacked( void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv,
bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale,
int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
cudaStream_t stream, cudnnHandle_t handle); const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded,
const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_kvpacked( void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
...@@ -61,13 +63,15 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( ...@@ -61,13 +63,15 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
void fused_attn_arbitrary_seqlen_fwd( void fused_attn_arbitrary_seqlen_fwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q,
size_t num_tokens_kv, bool is_training, float attn_scale, float p_dropout, size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd( void fused_attn_arbitrary_seqlen_bwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
......
...@@ -1679,6 +1679,12 @@ void fused_attn_fp8_fwd_impl_v1( ...@@ -1679,6 +1679,12 @@ void fused_attn_fp8_fwd_impl_v1(
s_kv, s_kv,
d, d,
d, d,
0,
0,
0,
0,
0,
0,
bias_b, bias_b,
bias_h, bias_h,
scaling_factor, scaling_factor,
...@@ -1977,6 +1983,12 @@ void fused_attn_fp8_bwd_impl_v1( ...@@ -1977,6 +1983,12 @@ void fused_attn_fp8_bwd_impl_v1(
s_kv, s_kv,
d, d,
d, d,
0,
0,
0,
0,
0,
0,
bias_b, bias_b,
bias_h, bias_h,
scaling_factor, scaling_factor,
......
...@@ -117,6 +117,7 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6 ...@@ -117,6 +117,7 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6
} }
break; break;
case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD:
if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) ||
(matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) ||
(matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) ||
...@@ -223,6 +224,9 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6 ...@@ -223,6 +224,9 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6
break; break;
case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD: case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_THD_THD_THD: case NVTE_QKV_Layout::NVTE_THD_THD_THD:
case NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD:
if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) ||
(matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) { (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) {
strideA[batch_dim_idx] = s_q * h * d; strideA[batch_dim_idx] = s_q * h * d;
...@@ -243,6 +247,52 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6 ...@@ -243,6 +247,52 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6
strideA[hidden_transpose_dim_idx] = 1; strideA[hidden_transpose_dim_idx] = 1;
} }
break; break;
case NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD:
if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) ||
(matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) {
strideA[batch_dim_idx] = s_kv * h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = h * d;
strideA[hidden_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) ||
(matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) {
strideA[batch_dim_idx] = s_kv * h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_transpose_dim_idx] = h * d;
strideA[hidden_transpose_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) ||
(matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) {
strideA[batch_dim_idx] = h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = b * h * d;
strideA[hidden_dim_idx] = 1;
}
break;
case NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD:
if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) ||
(matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) {
strideA[batch_dim_idx] = h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = b * h * d;
strideA[hidden_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) ||
(matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) {
strideA[batch_dim_idx] = h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_transpose_dim_idx] = b * h * d;
strideA[hidden_transpose_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) ||
(matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) {
strideA[batch_dim_idx] = s_q * h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = h * d;
strideA[hidden_dim_idx] = 1;
}
break;
} }
if (matrix == NVTE_QKV_Matrix::NVTE_S_Matrix) { if (matrix == NVTE_QKV_Matrix::NVTE_S_Matrix) {
...@@ -379,28 +429,44 @@ __device__ void cu_seqlens_padded_to_offsets_impl( ...@@ -379,28 +429,44 @@ __device__ void cu_seqlens_padded_to_offsets_impl(
size_t tid = blockIdx.x * blockDim.x + threadIdx.x; size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
auto cu_seqlens_id = min(tid, actual_b); auto cu_seqlens_id = min(tid, actual_b);
if (tid <= max_b) { if (tid <= max_b) {
offsets_o[tid] = h * d_v * cu_seqlens_q_padded[cu_seqlens_id];
if (offsets_s != nullptr) { if (offsets_s != nullptr) {
offsets_s[tid] = h * cu_seqlens_q_padded[cu_seqlens_id]; offsets_s[tid] = h * cu_seqlens_q_padded[cu_seqlens_id];
} }
switch (layout_group) { if (offsets_q != nullptr && offsets_o != nullptr) {
case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: offsets_o[tid] = h * d_v * cu_seqlens_q_padded[cu_seqlens_id];
offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[cu_seqlens_id]; switch (layout_group) {
offsets_k[tid] = hg * d_qk * cu_seqlens_kv_padded[cu_seqlens_id]; case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD:
offsets_v[tid] = hg * d_v * cu_seqlens_kv_padded[cu_seqlens_id]; case NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD:
break; offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[cu_seqlens_id];
case NVTE_QKV_Layout_Group::NVTE_3HD: break;
case NVTE_QKV_Layout_Group::NVTE_H3D: case NVTE_QKV_Layout_Group::NVTE_3HD:
offsets_q[tid] = 3 * h * d_qk * cu_seqlens_q_padded[cu_seqlens_id]; case NVTE_QKV_Layout_Group::NVTE_H3D:
offsets_k[tid] = offsets_q[cu_seqlens_id]; offsets_q[tid] = 3 * h * d_qk * cu_seqlens_q_padded[cu_seqlens_id];
offsets_v[tid] = offsets_q[cu_seqlens_id]; break;
break; case NVTE_QKV_Layout_Group::NVTE_HD_2HD:
case NVTE_QKV_Layout_Group::NVTE_HD_2HD: case NVTE_QKV_Layout_Group::NVTE_HD_H2D:
case NVTE_QKV_Layout_Group::NVTE_HD_H2D: offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[cu_seqlens_id];
offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[cu_seqlens_id]; break;
offsets_k[tid] = 2 * hg * d_qk * cu_seqlens_kv_padded[cu_seqlens_id]; }
offsets_v[tid] = offsets_k[cu_seqlens_id]; }
break; if (offsets_k != nullptr && offsets_v != nullptr) {
switch (layout_group) {
case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD:
case NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD:
offsets_k[tid] = hg * d_qk * cu_seqlens_kv_padded[cu_seqlens_id];
offsets_v[tid] = hg * d_v * cu_seqlens_kv_padded[cu_seqlens_id];
break;
case NVTE_QKV_Layout_Group::NVTE_3HD:
case NVTE_QKV_Layout_Group::NVTE_H3D:
offsets_k[tid] = 3 * h * d_qk * cu_seqlens_q_padded[cu_seqlens_id];
offsets_v[tid] = offsets_k[cu_seqlens_id];
break;
case NVTE_QKV_Layout_Group::NVTE_HD_2HD:
case NVTE_QKV_Layout_Group::NVTE_HD_H2D:
offsets_k[tid] = 2 * hg * d_qk * cu_seqlens_kv_padded[cu_seqlens_id];
offsets_v[tid] = offsets_k[cu_seqlens_id];
break;
}
} }
} }
} }
...@@ -433,6 +499,7 @@ DType get_ragged_offset_dtype(NVTE_QKV_Layout_Group layout_group, int64_t num_at ...@@ -433,6 +499,7 @@ DType get_ragged_offset_dtype(NVTE_QKV_Layout_Group layout_group, int64_t num_at
std::array<int64_t, 4> offsets_qkvo{}; std::array<int64_t, 4> offsets_qkvo{};
switch (layout_group) { switch (layout_group) {
case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD:
case NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD:
offsets_qkvo[0] = num_attn_heads * head_dim_qk * max_seqlen_q; offsets_qkvo[0] = num_attn_heads * head_dim_qk * max_seqlen_q;
offsets_qkvo[1] = num_gqa_groups * head_dim_qk * max_seqlen_kv; offsets_qkvo[1] = num_gqa_groups * head_dim_qk * max_seqlen_kv;
offsets_qkvo[2] = num_gqa_groups * head_dim_v * max_seqlen_kv; offsets_qkvo[2] = num_gqa_groups * head_dim_v * max_seqlen_kv;
......
...@@ -93,6 +93,12 @@ struct FADescriptor_v1 { ...@@ -93,6 +93,12 @@ struct FADescriptor_v1 {
std::int64_t s_kv; std::int64_t s_kv;
std::int64_t d_qk; std::int64_t d_qk;
std::int64_t d_v; std::int64_t d_v;
std::int64_t num_pages_k;
std::int64_t num_pages_v;
std::int64_t page_size_k;
std::int64_t page_size_v;
std::int64_t max_pages_per_seq_k;
std::int64_t max_pages_per_seq_v;
std::int64_t bias_b; std::int64_t bias_b;
std::int64_t bias_h; std::int64_t bias_h;
float attnScale; float attnScale;
...@@ -108,13 +114,16 @@ struct FADescriptor_v1 { ...@@ -108,13 +114,16 @@ struct FADescriptor_v1 {
cudnn_frontend::DataType_t bwd_tensor_type; cudnn_frontend::DataType_t bwd_tensor_type;
bool operator<(const FADescriptor_v1 &rhs) const { bool operator<(const FADescriptor_v1 &rhs) const {
return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, bias_b, bias_h, attnScale, isTraining, return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k,
dropoutProbability, layout, mask_type, window_size_left, window_size_right, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h,
deterministic, bias_type, fwd_tensor_type, bwd_tensor_type) < attnScale, isTraining, dropoutProbability, layout, mask_type, window_size_left,
std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.bias_b, window_size_right, deterministic, bias_type, fwd_tensor_type, bwd_tensor_type) <
rhs.bias_h, rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k,
rhs.mask_type, rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k,
rhs.bias_type, rhs.fwd_tensor_type, rhs.bwd_tensor_type); rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining,
rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.window_size_left,
rhs.window_size_right, rhs.deterministic, rhs.bias_type, rhs.fwd_tensor_type,
rhs.bwd_tensor_type);
} }
}; };
......
...@@ -25,24 +25,34 @@ extern "C" { ...@@ -25,24 +25,34 @@ extern "C" {
* head size, and the total number of tokens in a batch, i.e. `t = sum(s_i) for i = 0...b-1`. * head size, and the total number of tokens in a batch, i.e. `t = sum(s_i) for i = 0...b-1`.
* `SBHD` and `BSHD`-based layouts are used when sequences in a batch are of equal length * `SBHD` and `BSHD`-based layouts are used when sequences in a batch are of equal length
* or padded to the same length, and `THD`-based layouts are used when sequences have * or padded to the same length, and `THD`-based layouts are used when sequences have
* different lengths in a batch. * different lengths in a batch. `Paged_KV`-based layouts are used for paged attention.
*/ */
enum NVTE_QKV_Layout { enum NVTE_QKV_Layout {
NVTE_SB3HD = 0, /*!< SB3HD layout */ NVTE_SB3HD = 0, /*!< SB3HD layout */
NVTE_SBH3D = 1, /*!< SBH3D layout */ NVTE_SBH3D = 1, /*!< SBH3D layout */
NVTE_SBHD_SB2HD = 2, /*!< SBHD_SB2HD layout */ NVTE_SBHD_SB2HD = 2, /*!< SBHD_SB2HD layout */
NVTE_SBHD_SBH2D = 3, /*!< SBHD_SBH2D layout */ NVTE_SBHD_SBH2D = 3, /*!< SBHD_SBH2D layout */
NVTE_SBHD_SBHD_SBHD = 4, /*!< SBHD_SBHD_SBHD layout */ NVTE_SBHD_SBHD_SBHD = 4, /*!< SBHD_SBHD_SBHD layout */
NVTE_BS3HD = 5, /*!< BS3HD layout */ NVTE_BS3HD = 5, /*!< BS3HD layout */
NVTE_BSH3D = 6, /*!< BSH3D layout */ NVTE_BSH3D = 6, /*!< BSH3D layout */
NVTE_BSHD_BS2HD = 7, /*!< BSHD_BS2HD layout */ NVTE_BSHD_BS2HD = 7, /*!< BSHD_BS2HD layout */
NVTE_BSHD_BSH2D = 8, /*!< BSHD_BSH2D layout */ NVTE_BSHD_BSH2D = 8, /*!< BSHD_BSH2D layout */
NVTE_BSHD_BSHD_BSHD = 9, /*!< BSHD_BSHD_BSHD layout */ NVTE_BSHD_BSHD_BSHD = 9, /*!< BSHD_BSHD_BSHD layout */
NVTE_T3HD = 10, /*!< T3HD layout */ NVTE_T3HD = 10, /*!< T3HD layout */
NVTE_TH3D = 11, /*!< TH3D layout */ NVTE_TH3D = 11, /*!< TH3D layout */
NVTE_THD_T2HD = 12, /*!< THD_T2HD layout */ NVTE_THD_T2HD = 12, /*!< THD_T2HD layout */
NVTE_THD_TH2D = 13, /*!< THD_TH2D layout */ NVTE_THD_TH2D = 13, /*!< THD_TH2D layout */
NVTE_THD_THD_THD = 14, /*!< THD_THD_THD layout */ NVTE_THD_THD_THD = 14, /*!< THD_THD_THD layout */
NVTE_SBHD_BSHD_BSHD = 15, /*!< SBHD_BSHD_BSHD layout */
NVTE_BSHD_SBHD_SBHD = 16, /*!< BSHD_SBHD_SBHD layout */
NVTE_THD_BSHD_BSHD = 17, /*!< THD_BSHD_BSHD layout */
NVTE_THD_SBHD_SBHD = 18, /*!< THD_SBHD_SBHD layout */
NVTE_Paged_KV_BSHD_BSHD_BSHD = 19, /*!< Paged_KV_BSHD_BSHD_BSHD layout */
NVTE_Paged_KV_BSHD_SBHD_SBHD = 20, /*!< Paged_KV_BSHD_SBHD_SBHD layout */
NVTE_Paged_KV_SBHD_BSHD_BSHD = 21, /*!< Paged_KV_SBHD_BSHD_BSHD layout */
NVTE_Paged_KV_SBHD_SBHD_SBHD = 22, /*!< Paged_KV_SBHD_SBHD_SBHD layout */
NVTE_Paged_KV_THD_BSHD_BSHD = 23, /*!< Paged_KV_THD_BSHD_BSHD layout */
NVTE_Paged_KV_THD_SBHD_SBHD = 24, /*!< Paged_KV_THD_SBHD_SBHD layout */
}; };
/*! \enum NVTE_QKV_Layout_Group /*! \enum NVTE_QKV_Layout_Group
...@@ -59,18 +69,28 @@ enum NVTE_QKV_Layout_Group { ...@@ -59,18 +69,28 @@ enum NVTE_QKV_Layout_Group {
NVTE_HD_H2D = 3, NVTE_HD_H2D = 3,
/*! HD_HD_HD QKV layouts, i.e. BSHD_BSHD_BSHD, SBHD_SBHD_SBHD, THD_THD_THD */ /*! HD_HD_HD QKV layouts, i.e. BSHD_BSHD_BSHD, SBHD_SBHD_SBHD, THD_THD_THD */
NVTE_HD_HD_HD = 4, NVTE_HD_HD_HD = 4,
/*! Paged_KV_HD_HD_HD QKV layouts, e.g. Paged_KV_BSHD_BSHD_BSHD, Paged_KV_THD_SBHD_SBHD */
NVTE_Paged_KV_HD_HD_HD = 5,
}; };
/*! \enum NVTE_QKV_Format /*! \enum NVTE_QKV_Format
* \brief QKV formats * \brief QKV formats
*/ */
enum NVTE_QKV_Format { enum NVTE_QKV_Format {
/*! SBHD QKV format, i.e. SB3HD, SBH3D, SBHD_SB2HD, SBHD_SBH2D, SBHD_SBHD_SBHD */ /*! SBHD QKV format, i.e. SB3HD, SBH3D, SBHD_SB2HD, SBHD_SBH2D, SBHD_SBHD_SBHD, Paged_KV_SBHD_SBHD_SBHD */
NVTE_SBHD = 0, NVTE_SBHD = 0,
/*! BSHD QKV format, i.e. BS3HD, BSH3D, BSHD_BS2HD, BSHD_BSH2D, BSHD_BSHD_BSHD */ /*! BSHD QKV format, i.e. BS3HD, BSH3D, BSHD_BS2HD, BSHD_BSH2D, BSHD_BSHD_BSHD, Paged_KV_BSHD_BSHD_BSHD */
NVTE_BSHD = 1, NVTE_BSHD = 1,
/*! THD QKV format, i.e. T3HD, TH3D, THD_T2HD, THD_TH2D, THD_THD_THD */ /*! THD QKV format, i.e. T3HD, TH3D, THD_T2HD, THD_TH2D, THD_THD_THD */
NVTE_THD = 2, NVTE_THD = 2,
/*! BSHD format for Q and SBHD format for KV, i.e. BSHD_SBHD_SBHD, Paged_KV_BSHD_SBHD_SBHD */
NVTE_BSHD_2SBHD = 3,
/*! SBHD format for Q and BSHD format for KV, i.e. SBHD_BSHD_BSHD, Paged_KV_SBHD_BSHD_BSHD */
NVTE_SBHD_2BSHD = 4,
/*! THD format for Q and BSHD format for KV, i.e. THD_BSHD_BSHD, Paged_KV_THD_BSHD_BSHD */
NVTE_THD_2BSHD = 5,
/*! THD format for Q and SBHD format for KV, i.e. THD_SBHD_SBHD, Paged_KV_THD_SBHD_SBHD */
NVTE_THD_2SBHD = 6,
}; };
/*! \enum NVTE_Bias_Type /*! \enum NVTE_Bias_Type
...@@ -135,6 +155,22 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout); ...@@ -135,6 +155,22 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout);
*/ */
NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout); NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout);
/*! \brief Get Q format for a given QKV layout.
*
* \param[in] qkv_layout QKV layout, e.g. sbhd_bshd_bshd.
*
* \return q format, e.g. sbhd.
*/
NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout);
/*! \brief Get KV format for a given QKV layout.
*
* \param[in] qkv_layout QKV layout, e.g. sbhd_bshd_bshd.
*
* \return kv format, e.g. bshd.
*/
NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
/*! \brief Get fused attention backend based on input parameters. /*! \brief Get fused attention backend based on input parameters.
* *
* \param[in] q_dtype The data type of Tensor Q. * \param[in] q_dtype The data type of Tensor Q.
...@@ -312,6 +348,8 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con ...@@ -312,6 +348,8 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1]. * \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1].
* \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1]. * \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv_padded Cumulative sequence offsets for KV, [batch_size + 1]. * \param[in] cu_seqlens_kv_padded Cumulative sequence offsets for KV, [batch_size + 1].
* \param[in] page_table_k Page table for K cache, [batch_size, max_pages_per_seq_k].
* \param[in] page_table_v Page table for V cache, [batch_size, max_pages_per_seq_v].
* \param[in] rng_state Seed and offset of CUDA random number generator. * \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen_q Max sequence length used for computing for Q. * \param[in] max_seqlen_q Max sequence length used for computing for Q.
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1. * it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
...@@ -329,16 +367,14 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con ...@@ -329,16 +367,14 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] workspace Workspace tensor. * \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, void nvte_fused_attn_fwd_kvpacked(
NVTETensor S, NVTETensor O, NVTETensorPack* Aux_CTX_Tensors, const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream);
int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed KV input. /*! \brief Compute the backward of the dot product attention with packed KV input.
* *
...@@ -445,6 +481,8 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -445,6 +481,8 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1]. * \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1].
* \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1]. * \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv_padded Cumulative sequence offsets for KV, [batch_size + 1]. * \param[in] cu_seqlens_kv_padded Cumulative sequence offsets for KV, [batch_size + 1].
* \param[in] page_table_k Page table for K cache, [batch_size, max_pages_per_seq_k].
* \param[in] page_table_v Page table for V cache, [batch_size, max_pages_per_seq_v].
* \param[in] rng_state Seed and offset of CUDA random number generator. * \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen_q Max sequence length used for computing for Q. * \param[in] max_seqlen_q Max sequence length used for computing for Q.
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1. * it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
...@@ -465,7 +503,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -465,7 +503,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const NVTETensor Bias, NVTETensor S, NVTETensor O, const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
......
...@@ -36,6 +36,14 @@ ...@@ -36,6 +36,14 @@
.value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \ .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \
.value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \ .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \
NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \ NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \
pybind11::enum_<NVTE_QKV_Format>(m, "NVTE_QKV_Format", pybind11::module_local()) \
.value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD) \
.value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD) \
.value("NVTE_THD", NVTE_QKV_Format::NVTE_THD) \
.value("NVTE_SBHD_2BSHD", NVTE_QKV_Format::NVTE_SBHD_2BSHD) \
.value("NVTE_BSHD_2SBHD", NVTE_QKV_Format::NVTE_BSHD_2SBHD) \
.value("NVTE_THD_2BSHD", NVTE_QKV_Format::NVTE_THD_2BSHD) \
.value("NVTE_THD_2SBHD", NVTE_QKV_Format::NVTE_THD_2SBHD); \
pybind11::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout", pybind11::module_local()) \ pybind11::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout", pybind11::module_local()) \
.value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \
.value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \
...@@ -51,7 +59,17 @@ ...@@ -51,7 +59,17 @@
.value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) \ .value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) \
.value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \ .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \
.value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \ .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \
.value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); \ .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD) \
.value("NVTE_SBHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD) \
.value("NVTE_BSHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD) \
.value("NVTE_THD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD) \
.value("NVTE_THD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD) \
.value("NVTE_Paged_KV_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD) \
.value("NVTE_Paged_KV_BSHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD) \
.value("NVTE_Paged_KV_SBHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD) \
.value("NVTE_Paged_KV_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD) \
.value("NVTE_Paged_KV_THD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD) \
.value("NVTE_Paged_KV_THD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD); \
pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) \ pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) \
.value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \
.value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \
......
...@@ -129,6 +129,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -129,6 +129,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
auto o_tensor = TensorWrapper(nullptr, q_shape, dtype); auto o_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64); auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);
auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt32);
NVTETensorPack aux_output_tensors; NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors); nvte_tensor_pack_create(&aux_output_tensors);
...@@ -164,15 +165,16 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -164,15 +165,16 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
nvte_fused_attn_fwd_kvpacked( nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_page_table_tensor.data(),
q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen,
bias_type, mask_type, window_size_left, window_size_right, query_workspace_tensor.data(), kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type,
nullptr); mask_type, window_size_left, window_size_right, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_fwd( nvte_fused_attn_fwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(), q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(),
dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, dropout_probability, qkv_layout, bias_type, mask_type, window_size_left,
window_size_right, query_workspace_tensor.data(), nullptr); window_size_right, query_workspace_tensor.data(), nullptr);
...@@ -256,6 +258,7 @@ static void FusedAttnForwardImpl( ...@@ -256,6 +258,7 @@ static void FusedAttnForwardImpl(
backend, softmax_aux); backend, softmax_aux);
/* Call the underlying NVTE API */ /* Call the underlying NVTE API */
auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt32);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype); auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype);
...@@ -273,9 +276,10 @@ static void FusedAttnForwardImpl( ...@@ -273,9 +276,10 @@ static void FusedAttnForwardImpl(
nvte_fused_attn_fwd_kvpacked( nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), rng_state_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(),
q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, dummy_page_table_tensor.data(), rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen,
bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream); is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
window_size_left, window_size_right, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
...@@ -283,13 +287,13 @@ static void FusedAttnForwardImpl( ...@@ -283,13 +287,13 @@ static void FusedAttnForwardImpl(
auto q_tensor = TensorWrapper(q, q_shape, dtype); auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto k_tensor = TensorWrapper(k, k_shape, dtype); auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v_tensor = TensorWrapper(v, v_shape, dtype); auto v_tensor = TensorWrapper(v, v_shape, dtype);
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), nvte_fused_attn_fwd(
s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(),
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
window_size_left, window_size_right, workspace_tensor.data(), stream); bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream);
} else { } else {
NVTE_ERROR("Unsupported qkv_layout."); NVTE_ERROR("Unsupported qkv_layout.");
} }
......
...@@ -68,6 +68,7 @@ from transformer_engine.pytorch.distributed import ( ...@@ -68,6 +68,7 @@ from transformer_engine.pytorch.distributed import (
) )
from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
from transformer_engine.pytorch.graph import is_graph_capturing from transformer_engine.pytorch.graph import is_graph_capturing
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.tensor.quantized_tensor import ( from transformer_engine.pytorch.tensor.quantized_tensor import (
QuantizedTensor, QuantizedTensor,
prepare_for_saving, prepare_for_saving,
...@@ -76,7 +77,6 @@ from transformer_engine.pytorch.tensor.quantized_tensor import ( ...@@ -76,7 +77,6 @@ from transformer_engine.pytorch.tensor.quantized_tensor import (
# Import attention utils # Import attention utils
import transformer_engine.pytorch.dot_product_attention.utils as dpa_utils import transformer_engine.pytorch.dot_product_attention.utils as dpa_utils
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils as fa_utils from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils as fa_utils
from transformer_engine.pytorch.dot_product_attention.utils import AttentionLogging as attn_log from transformer_engine.pytorch.dot_product_attention.utils import AttentionLogging as attn_log
from transformer_engine.pytorch.dot_product_attention.rope import apply_rotary_pos_emb from transformer_engine.pytorch.dot_product_attention.rope import apply_rotary_pos_emb
...@@ -85,7 +85,7 @@ from transformer_engine.pytorch.dot_product_attention.rope import apply_rotary_p ...@@ -85,7 +85,7 @@ from transformer_engine.pytorch.dot_product_attention.rope import apply_rotary_p
# Setup Attention Logging # Setup Attention Logging
attn_log.setup_logging() attn_log.setup_logging()
# Global vars for flash attn imports # Global vars for flash attn v2 and v3 imports
flash_attn_cuda_bwd = None flash_attn_cuda_bwd = None
flash_attn_func = None flash_attn_func = None
flash_attn_varlen_func = None flash_attn_varlen_func = None
...@@ -96,15 +96,7 @@ _flash_attn_varlen_bwd = None ...@@ -96,15 +96,7 @@ _flash_attn_varlen_bwd = None
try: try:
fa_utils.version = PkgVersion(get_pkg_version("flash-attn")) fa_utils.version = PkgVersion(get_pkg_version("flash-attn"))
except PackageNotFoundError: except PackageNotFoundError:
if ( pass # only print warning if use_flash_attention_2 = True in get_attention_backend
torch.cuda.is_available()
and get_device_compute_capability() >= (8, 0)
and dpa_utils._NVTE_FLASH_ATTN
):
attn_log.fa_logger.debug(
"flash-attn v2 is not installed. To use, please install it by"
""" "pip3 install flash-attn".""",
)
else: else:
if torch.cuda.is_available() and get_device_compute_capability() >= (10, 0): if torch.cuda.is_available() and get_device_compute_capability() >= (10, 0):
if fa_utils.version_required_blackwell <= fa_utils.version <= fa_utils.max_version: if fa_utils.version_required_blackwell <= fa_utils.version <= fa_utils.max_version:
...@@ -143,35 +135,20 @@ else: ...@@ -143,35 +135,20 @@ else:
), ),
fa_utils.version, fa_utils.version,
) )
# Detect flash-attn v3 in the environment
# This section will be removed when FA3 is released as a regular FA package,
# i.e. flashattn-hopper 3.0.0 as flash-attn 3.0.0
try: try:
fa_utils.fa3_version = PkgVersion(get_pkg_version("flashattn-hopper")) fa_utils.fa3_version = PkgVersion(get_pkg_version("flash-attn-3"))
except PackageNotFoundError: except PackageNotFoundError:
if ( pass # only print warning if use_flash_attention_3 = True in get_attention_backend
torch.cuda.is_available()
and get_device_compute_capability() >= (9, 0)
and dpa_utils._NVTE_FLASH_ATTN
):
attn_log.fa_logger.debug(
"flash-attn v3 is not installed. To use, please install it by \n%s",
fa_utils.v3_installation_steps,
)
else: else:
from flashattn_hopper.flash_attn_interface import flash_attn_func as flash_attn_func_v3 from flash_attn_3.flash_attn_interface import flash_attn_func as flash_attn_func_v3
from flashattn_hopper.flash_attn_interface import ( from flash_attn_3.flash_attn_interface import (
flash_attn_varlen_func as flash_attn_varlen_func_v3, flash_attn_varlen_func as flash_attn_varlen_func_v3,
) )
from flashattn_hopper.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3 from flash_attn_3.flash_attn_interface import (
from flashattn_hopper.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3 flash_attn_with_kvcache as flash_attn_with_kvcache_v3,
from flashattn_hopper.flash_attn_interface import (
_flash_attn_varlen_forward as _flash_attn_varlen_fwd_v3,
)
from flashattn_hopper.flash_attn_interface import (
_flash_attn_varlen_backward as _flash_attn_varlen_bwd_v3,
) )
from flash_attn_3.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3
from flash_attn_3.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3
fa_utils.set_flash_attention_3_params() fa_utils.set_flash_attention_3_params()
...@@ -179,6 +156,7 @@ else: ...@@ -179,6 +156,7 @@ else:
_attention_backends = { _attention_backends = {
"attention_params": None, "attention_params": None,
"use_flash_attention": None, "use_flash_attention": None,
"flash_attention_backend": None,
"use_fused_attention": None, "use_fused_attention": None,
"fused_attention_backend": None, "fused_attention_backend": None,
"use_unfused_attention": None, "use_unfused_attention": None,
...@@ -487,6 +465,89 @@ def _get_cu_seqlens_info_with_cp( ...@@ -487,6 +465,89 @@ def _get_cu_seqlens_info_with_cp(
return _cu_seqlens_info_with_cp_cache[(batch_size, max_seqlen, cp_size)] return _cu_seqlens_info_with_cp_cache[(batch_size, max_seqlen, cp_size)]
def get_fa_args(
forward: bool,
use_flash_attn_3: bool,
qkv_format: str,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
dq=None,
dk=None,
dv=None,
):
"""Get forward/backward arguments for flash-attn v2 and v3."""
if use_flash_attn_3:
if forward:
if qkv_format == "thd":
return [
*[None] * 4, # k_new, v_new, qv, out
cu_seqlens_q,
cu_seqlens_kv,
*[None] * 3, # cu_seqlens_k_new, seqused_q, seqused_k
max_seqlen_q,
max_seqlen_kv,
*[None]
* 8, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale
]
return [
*[None]
* 9, # k_new, v_new, qv, out, cu_seqlens_q, cu_seqlens_kv, cu_seqlens_k_new, seqused_q, seqused_k
max_seqlen_q,
max_seqlen_kv,
*[None]
* 8, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale
]
if qkv_format == "thd":
return [
cu_seqlens_q,
cu_seqlens_kv,
None, # sequed_q
None, # sequed_k
max_seqlen_q,
max_seqlen_kv,
dq,
dk,
dv,
]
return [
None, # cu_seqlens_q
None, # cu_seqlens_kv
None, # sequed_q
None, # sequed_k
max_seqlen_q,
max_seqlen_kv,
dq,
dk,
dv,
]
if forward:
if qkv_format == "thd":
return [
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
]
return []
if qkv_format == "thd":
return [
dq,
dk,
dv,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
]
return [
dq,
dk,
dv,
]
class AttnFuncWithCPAndKVP2P(torch.autograd.Function): class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
""" """
Attention implementation with context parallelism. Exchange KV between CP ranks Attention implementation with context parallelism. Exchange KV between CP ranks
...@@ -527,6 +588,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -527,6 +588,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
cp_stream, cp_stream,
quantizers, quantizers,
pad_between_seqs, pad_between_seqs,
use_flash_attn_3,
): ):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.forward") nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.forward")
...@@ -685,16 +747,15 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -685,16 +747,15 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if use_fused_attention: if use_fused_attention:
softmax_lse_in_packed_format = get_cudnn_version() >= (9, 6, 0) softmax_lse_in_packed_format = get_cudnn_version() >= (9, 6, 0)
else: else:
softmax_lse_in_packed_format = fa_utils.v2_6_0_plus or fa_utils.use_v3 softmax_lse_in_packed_format = fa_utils.v2_6_0_plus or use_flash_attn_3
flash_attn_fwd = None flash_attn_fwd = None
if not use_fused_attention: if not use_fused_attention:
fa_forward_kwargs = {"softmax_scale": softmax_scale} fa_forward_kwargs = {"softmax_scale": softmax_scale}
if fa_utils.use_v3: if use_flash_attn_3:
if qkv_format == "thd": flash_attn_fwd = (
flash_attn_fwd = _flash_attn_varlen_fwd_v3 _flash_attn_fwd_v3 # pylint: disable=possibly-used-before-assignment
else: )
flash_attn_fwd = _flash_attn_fwd_v3
fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1) fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1)
else: else:
if qkv_format == "thd": if qkv_format == "thd":
...@@ -703,7 +764,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -703,7 +764,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
flash_attn_fwd = _flash_attn_fwd flash_attn_fwd = _flash_attn_fwd
fa_forward_kwargs["dropout_p"] = dropout_p fa_forward_kwargs["dropout_p"] = dropout_p
fa_forward_kwargs["return_softmax"] = False fa_forward_kwargs["return_softmax"] = False
if (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus) or fa_utils.use_v3: if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus:
fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1) fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1)
elif fa_utils.v2_7_0_plus: elif fa_utils.v2_7_0_plus:
fa_forward_kwargs["window_size_left"] = -1 fa_forward_kwargs["window_size_left"] = -1
...@@ -856,14 +917,15 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -856,14 +917,15 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
attn_biases[i] = rest[0] if len(rest) > 0 else None attn_biases[i] = rest[0] if len(rest) > 0 else None
else: else:
fa_forward_args_thd = [] fa_forward_args_thd = get_fa_args(
if qkv_format == "thd": True,
fa_forward_args_thd = [ use_flash_attn_3,
cu_seqlens_q_per_step[i], qkv_format,
cu_seqlens_kv_per_step[i], cu_seqlens_q=cu_seqlens_q_per_step[i],
max_seqlen_q, cu_seqlens_kv=cu_seqlens_kv_per_step[i],
max_seqlen_kv, max_seqlen_q=max_seqlen_q,
] max_seqlen_kv=max_seqlen_kv,
)
fa_outputs = flash_attn_fwd( fa_outputs = flash_attn_fwd(
q_inputs[i % 2], q_inputs[i % 2],
( (
...@@ -883,12 +945,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -883,12 +945,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if not fa_utils.v2_7_0_plus: if not fa_utils.v2_7_0_plus:
out_per_step[i] = fa_outputs[4] out_per_step[i] = fa_outputs[4]
softmax_lse_per_step[i] = fa_outputs[5] softmax_lse_per_step[i] = fa_outputs[5]
if not fa_utils.use_v3: if not use_flash_attn_3:
rng_states[i] = fa_outputs[7] rng_states[i] = fa_outputs[7]
else: else:
out_per_step[i] = fa_outputs[0] out_per_step[i] = fa_outputs[0]
softmax_lse_per_step[i] = fa_outputs[1] softmax_lse_per_step[i] = fa_outputs[1]
if not fa_utils.use_v3: if not use_flash_attn_3:
rng_states[i] = fa_outputs[3] rng_states[i] = fa_outputs[3]
elif i <= rank: elif i <= rank:
if pad_between_seqs: if pad_between_seqs:
...@@ -986,15 +1048,16 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -986,15 +1048,16 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
attn_biases[i] = rest[0] if len(rest) > 0 else None attn_biases[i] = rest[0] if len(rest) > 0 else None
else: else:
fa_forward_args_thd = [] fa_forward_args_thd = get_fa_args(
if qkv_format == "thd": True,
fa_forward_args_thd = [ use_flash_attn_3,
cu_seqlens_q_per_step[i], qkv_format,
cu_seqlens_kv_per_step[i], cu_seqlens_q=cu_seqlens_q_per_step[i],
max_seqlen_q, cu_seqlens_kv=cu_seqlens_kv_per_step[i],
max_seqlen_kv // 2, max_seqlen_q=max_seqlen_q,
] max_seqlen_kv=max_seqlen_kv // 2,
if fa_utils.use_v3 or ( )
if use_flash_attn_3 or (
fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus
): ):
fa_forward_kwargs["window_size"] = (-1, -1) fa_forward_kwargs["window_size"] = (-1, -1)
...@@ -1020,12 +1083,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1020,12 +1083,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if not fa_utils.v2_7_0_plus: if not fa_utils.v2_7_0_plus:
out_per_step[i] = fa_outputs[4] out_per_step[i] = fa_outputs[4]
softmax_lse_per_step[i] = fa_outputs[5] softmax_lse_per_step[i] = fa_outputs[5]
if not fa_utils.use_v3: if not use_flash_attn_3:
rng_states[i] = fa_outputs[7] rng_states[i] = fa_outputs[7]
else: else:
out_per_step[i] = fa_outputs[0] out_per_step[i] = fa_outputs[0]
softmax_lse_per_step[i] = fa_outputs[1] softmax_lse_per_step[i] = fa_outputs[1]
if not fa_utils.use_v3: if not use_flash_attn_3:
rng_states[i] = fa_outputs[3] rng_states[i] = fa_outputs[3]
else: else:
if pad_between_seqs: if pad_between_seqs:
...@@ -1132,15 +1195,16 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1132,15 +1195,16 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
attn_biases[i] = rest[0] if len(rest) > 0 else None attn_biases[i] = rest[0] if len(rest) > 0 else None
else: else:
fa_forward_args_thd = [] fa_forward_args_thd = get_fa_args(
if qkv_format == "thd": True,
fa_forward_args_thd = [ use_flash_attn_3,
cu_seqlens_q_per_step[i], qkv_format,
cu_seqlens_kv_per_step[i], cu_seqlens_q=cu_seqlens_q_per_step[i],
max_seqlen_q // 2, cu_seqlens_kv=cu_seqlens_kv_per_step[i],
max_seqlen_kv, max_seqlen_q=max_seqlen_q // 2,
] max_seqlen_kv=max_seqlen_kv,
if fa_utils.use_v3 or ( )
if use_flash_attn_3 or (
fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus
): ):
fa_forward_kwargs["window_size"] = (-1, -1) fa_forward_kwargs["window_size"] = (-1, -1)
...@@ -1166,12 +1230,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1166,12 +1230,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if not fa_utils.v2_7_0_plus: if not fa_utils.v2_7_0_plus:
out_per_step[i] = fa_outputs[4] out_per_step[i] = fa_outputs[4]
softmax_lse_per_step[i] = fa_outputs[5] softmax_lse_per_step[i] = fa_outputs[5]
if not fa_utils.use_v3: if not use_flash_attn_3:
rng_states[i] = fa_outputs[7] rng_states[i] = fa_outputs[7]
else: else:
out_per_step[i] = fa_outputs[0] out_per_step[i] = fa_outputs[0]
softmax_lse_per_step[i] = fa_outputs[1] softmax_lse_per_step[i] = fa_outputs[1]
if not fa_utils.use_v3: if not use_flash_attn_3:
rng_states[i] = fa_outputs[3] rng_states[i] = fa_outputs[3]
else: else:
if pad_between_seqs: if pad_between_seqs:
...@@ -1254,14 +1318,15 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1254,14 +1318,15 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
attn_biases[i] = rest[0] if len(rest) > 0 else None attn_biases[i] = rest[0] if len(rest) > 0 else None
else: else:
fa_forward_args_thd = [] fa_forward_args_thd = get_fa_args(
if qkv_format == "thd": True,
fa_forward_args_thd = [ use_flash_attn_3,
cu_seqlens_q_per_step[i], qkv_format,
cu_seqlens_kv_per_step[i], cu_seqlens_q=cu_seqlens_q_per_step[i],
max_seqlen_q, cu_seqlens_kv=cu_seqlens_kv_per_step[i],
max_seqlen_kv, max_seqlen_q=max_seqlen_q,
] max_seqlen_kv=max_seqlen_kv,
)
fa_outputs = flash_attn_fwd( fa_outputs = flash_attn_fwd(
q, q,
( (
...@@ -1281,12 +1346,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1281,12 +1346,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if not fa_utils.v2_7_0_plus: if not fa_utils.v2_7_0_plus:
out_per_step[i] = fa_outputs[4] out_per_step[i] = fa_outputs[4]
softmax_lse_per_step[i] = fa_outputs[5] softmax_lse_per_step[i] = fa_outputs[5]
if not fa_utils.use_v3: if not use_flash_attn_3:
rng_states[i] = fa_outputs[7] rng_states[i] = fa_outputs[7]
else: else:
out_per_step[i] = fa_outputs[0] out_per_step[i] = fa_outputs[0]
softmax_lse_per_step[i] = fa_outputs[1] softmax_lse_per_step[i] = fa_outputs[1]
if not fa_utils.use_v3: if not use_flash_attn_3:
rng_states[i] = fa_outputs[3] rng_states[i] = fa_outputs[3]
if i > 0: if i > 0:
...@@ -1478,6 +1543,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1478,6 +1543,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
ctx.fp8_meta = fp8_meta ctx.fp8_meta = fp8_meta
ctx.is_input_fp8 = is_input_fp8 ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8 ctx.is_output_fp8 = is_output_fp8
ctx.use_flash_attn_3 = use_flash_attn_3
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.forward") nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.forward")
return out_ret return out_ret
...@@ -1650,11 +1716,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1650,11 +1716,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
flash_attn_bwd = None flash_attn_bwd = None
if not ctx.use_fused_attention: if not ctx.use_fused_attention:
fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
if fa_utils.use_v3: if ctx.use_flash_attn_3:
if ctx.qkv_format == "thd": flash_attn_bwd = (
flash_attn_bwd = _flash_attn_varlen_bwd_v3 _flash_attn_bwd_v3 # pylint: disable=possibly-used-before-assignment
else: )
flash_attn_bwd = _flash_attn_bwd_v3
fa_backward_kwargs["deterministic"] = ctx.deterministic fa_backward_kwargs["deterministic"] = ctx.deterministic
else: else:
if ctx.qkv_format == "thd": if ctx.qkv_format == "thd":
...@@ -1792,20 +1857,34 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1792,20 +1857,34 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
else: else:
dq_ = torch.empty_like(q_) dq_ = torch.empty_like(q_)
dkv_ = torch.empty_like(kv_) dkv_ = torch.empty_like(kv_)
fa_backward_args_thd = [] fa_backward_args_thd = get_fa_args(
if ctx.qkv_format == "thd": False,
fa_backward_args_thd = [ ctx.use_flash_attn_3,
cu_seqlens_q_per_step[cp_size - i - 1], ctx.qkv_format,
cu_seqlens_kv_per_step[cp_size - i - 1], cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1],
ctx.max_seqlen_q, cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1],
ctx.max_seqlen_kv, max_seqlen_q=ctx.max_seqlen_q,
] max_seqlen_kv=ctx.max_seqlen_kv,
if fa_utils.use_v3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): dq=dq_,
dk=(
dkv_[..., 0, :, :]
if ctx.qkv_format in ["bshd", "sbhd"]
else dkv_[0]
),
dv=(
dkv_[..., 1, :, :]
if ctx.qkv_format in ["bshd", "sbhd"]
else dkv_[1]
),
)
if ctx.use_flash_attn_3 or (
fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus
):
fa_backward_kwargs["window_size"] = (-1, 0) fa_backward_kwargs["window_size"] = (-1, 0)
elif fa_utils.v2_7_0_plus: elif fa_utils.v2_7_0_plus:
fa_backward_kwargs["window_size_left"] = -1 fa_backward_kwargs["window_size_left"] = -1
fa_backward_kwargs["window_size_right"] = 0 fa_backward_kwargs["window_size_right"] = 0
if not fa_utils.use_v3: if not ctx.use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
flash_attn_bwd( flash_attn_bwd(
dout_, dout_,
...@@ -1814,9 +1893,6 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1814,9 +1893,6 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
out_, out_,
softmax_lse, softmax_lse,
dq_,
dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0],
dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1],
*fa_backward_args_thd, *fa_backward_args_thd,
causal=True, causal=True,
**fa_backward_kwargs, **fa_backward_kwargs,
...@@ -1907,20 +1983,34 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1907,20 +1983,34 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
else: else:
dq_ = torch.empty_like(q_) dq_ = torch.empty_like(q_)
dkv_ = torch.empty_like(kv_) dkv_ = torch.empty_like(kv_)
fa_backward_args_thd = [] fa_backward_args_thd = get_fa_args(
if ctx.qkv_format == "thd": False,
fa_backward_args_thd = [ ctx.use_flash_attn_3,
cu_seqlens_q_per_step[cp_size - i - 1], ctx.qkv_format,
cu_seqlens_kv_per_step[cp_size - i - 1], cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1],
ctx.max_seqlen_q, cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1],
ctx.max_seqlen_kv // 2, max_seqlen_q=ctx.max_seqlen_q,
] max_seqlen_kv=ctx.max_seqlen_kv // 2,
if fa_utils.use_v3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): dq=dq_,
dk=(
dkv_[..., 0, :, :]
if ctx.qkv_format in ["bshd", "sbhd"]
else dkv_[0]
),
dv=(
dkv_[..., 1, :, :]
if ctx.qkv_format in ["bshd", "sbhd"]
else dkv_[1]
),
)
if ctx.use_flash_attn_3 or (
fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus
):
fa_backward_kwargs["window_size"] = (-1, -1) fa_backward_kwargs["window_size"] = (-1, -1)
if fa_utils.v2_7_0_plus: elif fa_utils.v2_7_0_plus:
fa_backward_kwargs["window_size_left"] = -1 fa_backward_kwargs["window_size_left"] = -1
fa_backward_kwargs["window_size_right"] = -1 fa_backward_kwargs["window_size_right"] = -1
if not fa_utils.use_v3: if not ctx.use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
flash_attn_bwd( flash_attn_bwd(
dout_, dout_,
...@@ -1929,9 +2019,6 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1929,9 +2019,6 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
out_, out_,
softmax_lse, softmax_lse,
dq_,
dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0],
dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1],
*fa_backward_args_thd, *fa_backward_args_thd,
causal=False, causal=False,
**fa_backward_kwargs, **fa_backward_kwargs,
...@@ -2024,20 +2111,34 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2024,20 +2111,34 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
else: else:
dq_ = torch.empty_like(q_) dq_ = torch.empty_like(q_)
dkv_ = torch.empty_like(kv_) dkv_ = torch.empty_like(kv_)
fa_backward_args_thd = [] fa_backward_args_thd = get_fa_args(
if ctx.qkv_format == "thd": False,
fa_backward_args_thd = [ ctx.use_flash_attn_3,
cu_seqlens_q_per_step[cp_size - i - 1], ctx.qkv_format,
cu_seqlens_kv_per_step[cp_size - i - 1], cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1],
ctx.max_seqlen_q // 2, cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1],
ctx.max_seqlen_kv, max_seqlen_q=ctx.max_seqlen_q // 2,
] max_seqlen_kv=ctx.max_seqlen_kv,
if fa_utils.use_v3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): dq=dq_,
dk=(
dkv_[..., 0, :, :]
if ctx.qkv_format in ["bshd", "sbhd"]
else dkv_[0]
),
dv=(
dkv_[..., 1, :, :]
if ctx.qkv_format in ["bshd", "sbhd"]
else dkv_[1]
),
)
if ctx.use_flash_attn_3 or (
fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus
):
fa_backward_kwargs["window_size"] = (-1, -1) fa_backward_kwargs["window_size"] = (-1, -1)
elif fa_utils.v2_7_0_plus: elif fa_utils.v2_7_0_plus:
fa_backward_kwargs["window_size_left"] = -1 fa_backward_kwargs["window_size_left"] = -1
fa_backward_kwargs["window_size_right"] = -1 fa_backward_kwargs["window_size_right"] = -1
if not fa_utils.use_v3: if not ctx.use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
flash_attn_bwd( flash_attn_bwd(
dout_, dout_,
...@@ -2046,9 +2147,6 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2046,9 +2147,6 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
out_, out_,
softmax_lse_, softmax_lse_,
dq_,
dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0],
dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1],
*fa_backward_args_thd, *fa_backward_args_thd,
causal=False, causal=False,
**fa_backward_kwargs, **fa_backward_kwargs,
...@@ -2118,20 +2216,24 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2118,20 +2216,24 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
else: else:
dq_ = torch.empty_like(q) dq_ = torch.empty_like(q)
dkv_ = torch.empty_like(kv) dkv_ = torch.empty_like(kv)
fa_backward_args_thd = [] fa_backward_args_thd = get_fa_args(
if ctx.qkv_format == "thd": False,
fa_backward_args_thd = [ ctx.use_flash_attn_3,
cu_seqlens_q_per_step[cp_size - i - 1], ctx.qkv_format,
cu_seqlens_kv_per_step[cp_size - i - 1], cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1],
ctx.max_seqlen_q, cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1],
ctx.max_seqlen_kv, max_seqlen_q=ctx.max_seqlen_q,
] max_seqlen_kv=ctx.max_seqlen_kv,
if fa_utils.use_v3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): dq=dq_,
dk=dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0],
dv=dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1],
)
if ctx.use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus):
fa_backward_kwargs["window_size"] = (-1, -1) fa_backward_kwargs["window_size"] = (-1, -1)
elif fa_utils.v2_7_0_plus: elif fa_utils.v2_7_0_plus:
fa_backward_kwargs["window_size_left"] = -1 fa_backward_kwargs["window_size_left"] = -1
fa_backward_kwargs["window_size_right"] = -1 fa_backward_kwargs["window_size_right"] = -1
if not fa_utils.use_v3: if not ctx.use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
flash_attn_bwd( flash_attn_bwd(
dout, dout,
...@@ -2140,9 +2242,6 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2140,9 +2242,6 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1], kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1],
out, out,
softmax_lse, softmax_lse,
dq_,
dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0],
dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1],
*fa_backward_args_thd, *fa_backward_args_thd,
causal=False, causal=False,
**fa_backward_kwargs, **fa_backward_kwargs,
...@@ -2382,6 +2481,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2382,6 +2481,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
...@@ -2440,6 +2540,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function): ...@@ -2440,6 +2540,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
window_size, window_size,
cp_group, cp_group,
cp_stream, cp_stream,
use_flash_attn_3,
): ):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.forward") nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.forward")
...@@ -2465,11 +2566,8 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function): ...@@ -2465,11 +2566,8 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
flash_attn_fwd = None flash_attn_fwd = None
if not use_fused_attention: if not use_fused_attention:
fa_forward_kwargs = {"softmax_scale": softmax_scale} fa_forward_kwargs = {"softmax_scale": softmax_scale}
if fa_utils.use_v3: if use_flash_attn_3:
if qkv_format == "thd": flash_attn_fwd = _flash_attn_fwd_v3
flash_attn_fwd = _flash_attn_varlen_fwd_v3
else:
flash_attn_fwd = _flash_attn_fwd_v3
else: else:
if qkv_format == "thd": if qkv_format == "thd":
flash_attn_fwd = _flash_attn_varlen_fwd flash_attn_fwd = _flash_attn_varlen_fwd
...@@ -2584,15 +2682,16 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function): ...@@ -2584,15 +2682,16 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
window_size=window_size_per_step[i], window_size=window_size_per_step[i],
) )
else: else:
fa_forward_args_thd = [] fa_forward_args_thd = get_fa_args(
if qkv_format == "thd": True,
fa_forward_args_thd = [ use_flash_attn_3,
cu_seqlens_q, qkv_format,
cu_seqlens_kv_per_step[i], cu_seqlens_q=cu_seqlens_q,
max_seqlen_q, cu_seqlens_kv=cu_seqlens_kv_per_step[i],
max_seqlen_kv_, max_seqlen_q=max_seqlen_q,
] max_seqlen_kv=max_seqlen_kv_,
if fa_utils.use_v3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): )
if use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus):
fa_forward_kwargs["window_size"] = window_size_per_step[i] fa_forward_kwargs["window_size"] = window_size_per_step[i]
elif fa_utils.v2_7_0_plus: elif fa_utils.v2_7_0_plus:
fa_forward_kwargs["window_size_left"] = window_size_per_step[i][0] fa_forward_kwargs["window_size_left"] = window_size_per_step[i][0]
...@@ -2608,12 +2707,12 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function): ...@@ -2608,12 +2707,12 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
if not fa_utils.v2_7_0_plus: if not fa_utils.v2_7_0_plus:
out_per_step[i] = fa_outputs[4] out_per_step[i] = fa_outputs[4]
softmax_lse_per_step[i] = fa_outputs[5] softmax_lse_per_step[i] = fa_outputs[5]
if not fa_utils.use_v3: if not use_flash_attn_3:
rng_states[i] = fa_outputs[7] rng_states[i] = fa_outputs[7]
else: else:
out_per_step[i] = fa_outputs[0] out_per_step[i] = fa_outputs[0]
softmax_lse_per_step[i] = fa_outputs[1] softmax_lse_per_step[i] = fa_outputs[1]
if not fa_utils.use_v3: if not use_flash_attn_3:
rng_states[i] = fa_outputs[3] rng_states[i] = fa_outputs[3]
if i > 0: if i > 0:
...@@ -2658,6 +2757,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function): ...@@ -2658,6 +2757,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
ctx.attn_mask_type = attn_mask_type ctx.attn_mask_type = attn_mask_type
ctx.deterministic = deterministic ctx.deterministic = deterministic
ctx.use_fused_attention = use_fused_attention ctx.use_fused_attention = use_fused_attention
ctx.use_flash_attn_3 = use_flash_attn_3
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.forward") nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.forward")
return out return out
...@@ -2713,11 +2813,8 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function): ...@@ -2713,11 +2813,8 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
flash_attn_bwd = None flash_attn_bwd = None
if not ctx.use_fused_attention: if not ctx.use_fused_attention:
fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
if fa_utils.use_v3: if ctx.use_flash_attn_3:
if ctx.qkv_format == "thd": flash_attn_bwd = _flash_attn_bwd_v3
flash_attn_bwd = _flash_attn_varlen_bwd_v3
else:
flash_attn_bwd = _flash_attn_bwd_v3
fa_backward_kwargs["deterministic"] = ctx.deterministic fa_backward_kwargs["deterministic"] = ctx.deterministic
else: else:
if ctx.qkv_format == "thd": if ctx.qkv_format == "thd":
...@@ -2778,19 +2875,25 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function): ...@@ -2778,19 +2875,25 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ dq_per_step[i], dk_per_step[i], dv_per_step[i] = [
torch.empty_like(x) for x in [q_, k_, v_] torch.empty_like(x) for x in [q_, k_, v_]
] ]
fa_backward_args_thd = [] fa_backward_args_thd = get_fa_args(
if ctx.qkv_format == "thd": False,
fa_backward_args_thd = [ ctx.use_flash_attn_3,
cu_seqlens_q, ctx.qkv_format,
cu_seqlens_kv_per_step[i], cu_seqlens_q=cu_seqlens_q,
ctx.max_seqlen_q, cu_seqlens_kv=cu_seqlens_kv_per_step[i],
max_seqlen_kv, max_seqlen_q=ctx.max_seqlen_q,
] max_seqlen_kv=max_seqlen_kv,
if not fa_utils.use_v3: dq=dq_per_step[i],
dk=dk_per_step[i],
dv=dv_per_step[i],
)
if not ctx.use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[i] fa_backward_kwargs["rng_state"] = rng_states[i]
if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus: if ctx.use_flash_attn_3 or (
fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus
):
fa_backward_kwargs["window_size"] = window_size_per_step[i] fa_backward_kwargs["window_size"] = window_size_per_step[i]
if fa_utils.v2_7_0_plus: elif fa_utils.v2_7_0_plus:
fa_backward_kwargs["window_size_left"] = window_size_per_step[i][0] fa_backward_kwargs["window_size_left"] = window_size_per_step[i][0]
fa_backward_kwargs["window_size_right"] = window_size_per_step[i][1] fa_backward_kwargs["window_size_right"] = window_size_per_step[i][1]
flash_attn_bwd( flash_attn_bwd(
...@@ -2800,9 +2903,6 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function): ...@@ -2800,9 +2903,6 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
v_, v_,
out_, out_,
softmax_lse_per_step[i], softmax_lse_per_step[i],
dq_per_step[i],
dk_per_step[i],
dv_per_step[i],
*fa_backward_args_thd, *fa_backward_args_thd,
causal="causal" in ctx.attn_mask_type, causal="causal" in ctx.attn_mask_type,
**fa_backward_kwargs, **fa_backward_kwargs,
...@@ -2870,6 +2970,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function): ...@@ -2870,6 +2970,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
...@@ -2906,6 +3007,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -2906,6 +3007,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
cp_group, cp_group,
cp_stream, cp_stream,
quantizers, quantizers,
use_flash_attn_3,
): ):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward")
...@@ -2930,11 +3032,8 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -2930,11 +3032,8 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
flash_attn_fwd = None flash_attn_fwd = None
if not use_fused_attention: if not use_fused_attention:
fa_forward_kwargs = {"softmax_scale": softmax_scale} fa_forward_kwargs = {"softmax_scale": softmax_scale}
if fa_utils.use_v3: if use_flash_attn_3:
if qkv_format == "thd": flash_attn_fwd = _flash_attn_fwd_v3
flash_attn_fwd = _flash_attn_varlen_fwd_v3
else:
flash_attn_fwd = _flash_attn_fwd_v3
fa_forward_kwargs["window_size"] = window_size fa_forward_kwargs["window_size"] = window_size
else: else:
if qkv_format == "thd": if qkv_format == "thd":
...@@ -2943,7 +3042,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -2943,7 +3042,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
flash_attn_fwd = _flash_attn_fwd flash_attn_fwd = _flash_attn_fwd
fa_forward_kwargs["dropout_p"] = dropout_p fa_forward_kwargs["dropout_p"] = dropout_p
fa_forward_kwargs["return_softmax"] = False fa_forward_kwargs["return_softmax"] = False
if fa_utils.use_v3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus:
fa_forward_kwargs["window_size"] = window_size fa_forward_kwargs["window_size"] = window_size
elif fa_utils.v2_7_0_plus: elif fa_utils.v2_7_0_plus:
fa_forward_kwargs["window_size_left"] = window_size[0] fa_forward_kwargs["window_size_left"] = window_size[0]
...@@ -3048,14 +3147,15 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3048,14 +3147,15 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
if fp8: if fp8:
out = out._data out = out._data
else: else:
fa_forward_args_thd = [] fa_forward_args_thd = get_fa_args(
if qkv_format == "thd": True,
fa_forward_args_thd = [ use_flash_attn_3,
cu_seqlens_q, qkv_format,
cu_seqlens_kv, cu_seqlens_q=cu_seqlens_q,
max_seqlen_q, cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_kv, max_seqlen_q=max_seqlen_q,
] max_seqlen_kv=max_seqlen_kv,
)
fa_outputs = flash_attn_fwd( fa_outputs = flash_attn_fwd(
q, q,
k, k,
...@@ -3066,10 +3166,10 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3066,10 +3166,10 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
) )
if not fa_utils.v2_7_0_plus: if not fa_utils.v2_7_0_plus:
out, softmax_lse = fa_outputs[4], fa_outputs[5] out, softmax_lse = fa_outputs[4], fa_outputs[5]
rng_state = fa_outputs[7] if not fa_utils.use_v3 else None rng_state = fa_outputs[7] if not use_flash_attn_3 else None
else: else:
out, softmax_lse = fa_outputs[0], fa_outputs[1] out, softmax_lse = fa_outputs[0], fa_outputs[1]
rng_state = fa_outputs[3] if not fa_utils.use_v3 else None rng_state = fa_outputs[3] if not use_flash_attn_3 else None
aux_ctx_tensors = [softmax_lse, rng_state] aux_ctx_tensors = [softmax_lse, rng_state]
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, out.device) chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, out.device)
...@@ -3152,6 +3252,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3152,6 +3252,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
ctx.fp8_meta = fp8_meta ctx.fp8_meta = fp8_meta
ctx.is_input_fp8 = is_input_fp8 ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8 ctx.is_output_fp8 = is_output_fp8
ctx.use_flash_attn_3 = use_flash_attn_3
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward")
return out_ret return out_ret
...@@ -3240,11 +3341,10 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3240,11 +3341,10 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
flash_attn_bwd = None flash_attn_bwd = None
if not ctx.use_fused_attention: if not ctx.use_fused_attention:
fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
if fa_utils.use_v3: if ctx.use_flash_attn_3:
if ctx.qkv_format == "thd": flash_attn_bwd = (
flash_attn_bwd = _flash_attn_varlen_bwd_v3 _flash_attn_bwd_v3 # pylint: disable=possibly-used-before-assignment
else: )
flash_attn_bwd = _flash_attn_bwd_v3
fa_backward_kwargs["window_size"] = ctx.window_size fa_backward_kwargs["window_size"] = ctx.window_size
fa_backward_kwargs["deterministic"] = ctx.deterministic fa_backward_kwargs["deterministic"] = ctx.deterministic
else: else:
...@@ -3253,7 +3353,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3253,7 +3353,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
else: else:
flash_attn_bwd = _flash_attn_bwd flash_attn_bwd = _flash_attn_bwd
fa_backward_kwargs["dropout_p"] = ctx.dropout_p fa_backward_kwargs["dropout_p"] = ctx.dropout_p
if fa_utils.use_v3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus:
fa_backward_kwargs["window_size"] = ctx.window_size fa_backward_kwargs["window_size"] = ctx.window_size
elif fa_utils.v2_7_0_plus: elif fa_utils.v2_7_0_plus:
fa_backward_kwargs["window_size_left"] = ctx.window_size[0] fa_backward_kwargs["window_size_left"] = ctx.window_size[0]
...@@ -3321,15 +3421,19 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3321,15 +3421,19 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
else: else:
softmax_lse, rng_state = aux_ctx_tensors softmax_lse, rng_state = aux_ctx_tensors
dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]] dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]]
fa_backward_args_thd = [] fa_backward_args_thd = get_fa_args(
if ctx.qkv_format == "thd": False,
fa_backward_args_thd = [ ctx.use_flash_attn_3,
cu_seqlens_q, ctx.qkv_format,
cu_seqlens_kv, cu_seqlens_q=cu_seqlens_q,
ctx.max_seqlen_q, cu_seqlens_kv=cu_seqlens_kv,
ctx.max_seqlen_kv, max_seqlen_q=ctx.max_seqlen_q,
] max_seqlen_kv=ctx.max_seqlen_kv,
if not fa_utils.use_v3: dq=dq,
dk=dk,
dv=dv,
)
if not ctx.use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_state fa_backward_kwargs["rng_state"] = rng_state
flash_attn_bwd( flash_attn_bwd(
dout, dout,
...@@ -3338,9 +3442,6 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3338,9 +3442,6 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
v, v,
out, out,
softmax_lse, softmax_lse,
dq,
dk,
dv,
*fa_backward_args_thd, *fa_backward_args_thd,
causal=causal, causal=causal,
**fa_backward_kwargs, **fa_backward_kwargs,
...@@ -3427,6 +3528,7 @@ def attn_forward_func_with_cp( ...@@ -3427,6 +3528,7 @@ def attn_forward_func_with_cp(
fp8_meta=None, fp8_meta=None,
quantizers=None, quantizers=None,
pad_between_seqs=False, pad_between_seqs=False,
use_flash_attn_3=False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Attention implementation with context parallelism. Attention implementation with context parallelism.
...@@ -3494,15 +3596,24 @@ def attn_forward_func_with_cp( ...@@ -3494,15 +3596,24 @@ def attn_forward_func_with_cp(
] ]
if cp_comm_type in ["p2p", "a2a+p2p"]: if cp_comm_type in ["p2p", "a2a+p2p"]:
args += [fp8, fp8_meta, cp_group, cp_global_ranks, cp_stream, quantizers, pad_between_seqs] args += [
fp8,
fp8_meta,
cp_group,
cp_global_ranks,
cp_stream,
quantizers,
pad_between_seqs,
use_flash_attn_3,
]
out = AttnFuncWithCPAndKVP2P.apply(*args) out = AttnFuncWithCPAndKVP2P.apply(*args)
elif cp_comm_type == "all_gather": elif cp_comm_type == "all_gather":
args.pop(5) args.pop(5)
args.pop(8) args.pop(8)
args += [window_size, cp_group, cp_stream] args += [window_size, cp_group, cp_stream, use_flash_attn_3]
out = AttnFuncWithCPAndKVAllGather.apply(*args) out = AttnFuncWithCPAndKVAllGather.apply(*args)
elif cp_comm_type == "a2a": elif cp_comm_type == "a2a":
args += [window_size, fp8, fp8_meta, cp_group, cp_stream, quantizers] args += [window_size, fp8, fp8_meta, cp_group, cp_stream, quantizers, use_flash_attn_3]
out = AttnFuncWithCPAndQKVOA2A.apply(*args) out = AttnFuncWithCPAndQKVOA2A.apply(*args)
else: else:
raise ValueError(f"Unsupported communication type: {cp_comm_type}!") raise ValueError(f"Unsupported communication type: {cp_comm_type}!")
...@@ -3694,23 +3805,48 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -3694,23 +3805,48 @@ class UnfusedDotProductAttention(torch.nn.Module):
core_attention_bias_type: str = "no_bias", core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None, core_attention_bias: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None,
inference_params: Optional[InferenceParams] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Unfused attention fprop""" """Unfused attention fprop"""
assert ( assert (
qkv_layout in QKVLayouts qkv_layout in QKVLayouts
), f"UnfusedDotProductAttention does not support qkv_layout = {qkv_layout}!" ), f"UnfusedDotProductAttention does not support qkv_layout = {qkv_layout}!"
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
# get q_format and kv_format for training and inference
qkv_format, q_format, _ = dpa_utils.get_qkv_format(qkv_layout, inference_params)
if inference_params is not None and inference_params.is_paged:
key_layer, value_layer = inference_params.convert_paged_to_nonpaged(self.layer_number)
if qkv_format == "bshd": if qkv_format == "bshd":
# convert to sbhd and use sbhd implementation for now # convert to sbhd and use sbhd implementation for now
query_layer, key_layer, value_layer = [ query_layer, key_layer, value_layer = [
x.transpose(0, 1) for x in [query_layer, key_layer, value_layer] x.transpose(0, 1) for x in [query_layer, key_layer, value_layer]
] ]
if qkv_format == "sbhd_2bshd":
key_layer, value_layer = [x.transpose(0, 1) for x in [key_layer, value_layer]]
total_tokens, batch_size = None, None
if qkv_format == "thd_2bshd":
total_tokens, batch_size = query_layer.shape[0], key_layer.shape[0]
query_layer = tex.convert_thd_to_bshd(
query_layer,
cu_seqlens_q,
batch_size,
inference_params.max_ctx_len,
)
query_layer, key_layer, value_layer = [
x.transpose(0, 1) for x in [query_layer, key_layer, value_layer]
]
batch_size, max_seqlen_q, max_seqlen_kv = ( batch_size, max_seqlen_q, max_seqlen_kv = (
query_layer.shape[1], query_layer.shape[1],
query_layer.shape[0], query_layer.shape[0],
key_layer.shape[0], key_layer.shape[0],
) )
if "padding" in attn_mask_type and attention_mask is None:
attention_mask = dpa_utils.get_padding_mask(
batch_size, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv
)
attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = ( attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = (
dpa_utils.get_full_mask( dpa_utils.get_full_mask(
max_seqlen_q, max_seqlen_q,
...@@ -3843,20 +3979,34 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -3843,20 +3979,34 @@ class UnfusedDotProductAttention(torch.nn.Module):
# change view [b, np, sq, hn] # change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size) context_layer = context_layer.view(*output_size)
if qkv_format == "sbhd": if q_format == "sbhd":
# [b, np, sq, hn] --> [sq, b, np, hn] # [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous() context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp] # [sq, b, np, hn] --> [sq, b, hp]
context_layer = context_layer.view(seqlen, batch_size, -1) context_layer = context_layer.view(seqlen, batch_size, -1)
if qkv_format == "bshd": if q_format == "bshd":
# [b, np, sq, hn] --> [b, sq, np, hn] # [b, np, sq, hn] --> [b, sq, np, hn]
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
# [b, sq, np, hn] --> [b, sq, hp] # [b, sq, np, hn] --> [b, sq, hp]
context_layer = context_layer.view(batch_size, seqlen, -1) context_layer = context_layer.view(batch_size, seqlen, -1)
if q_format == "thd":
# [b, np, sq, hn] --> [b, sq, np, hn]
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
# [b, sq, np, hn] --> [tq, np, hn]
context_layer = tex.convert_bshd_to_thd(
context_layer,
cu_seqlens_q,
total_tokens,
)
# [tq, np, hn] --> [tq, hp]
context_layer = context_layer.view(total_tokens, -1)
return context_layer return context_layer
...@@ -3951,6 +4101,8 @@ class FlashAttention(torch.nn.Module): ...@@ -3951,6 +4101,8 @@ class FlashAttention(torch.nn.Module):
fp8: bool = False, fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None, fp8_meta: Optional[Dict[str, Any]] = None,
quantizers=None, quantizers=None,
inference_params: Optional[InferenceParams] = None,
flash_attention_backend: Optional[PkgVersion] = PkgVersion("0"),
) -> torch.Tensor: ) -> torch.Tensor:
"""flash-attn fprop""" """flash-attn fprop"""
...@@ -3973,8 +4125,10 @@ class FlashAttention(torch.nn.Module): ...@@ -3973,8 +4125,10 @@ class FlashAttention(torch.nn.Module):
cp_size *= get_distributed_world_size(group) cp_size *= get_distributed_world_size(group)
context_parallel = cp_size > 1 context_parallel = cp_size > 1
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) # get q_format and kv_format for training and inference
qkv_format, q_format, kv_format = dpa_utils.get_qkv_format(qkv_layout, inference_params)
# convert q, k, v to bshd if they are in sbhd; qkv_format doesn't change
if all(not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]): if all(not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]):
if qkv_format == "sbhd": if qkv_format == "sbhd":
# For now just 128, will make it more general in the future # For now just 128, will make it more general in the future
...@@ -3988,8 +4142,11 @@ class FlashAttention(torch.nn.Module): ...@@ -3988,8 +4142,11 @@ class FlashAttention(torch.nn.Module):
) )
else: else:
query_layer, key_layer, value_layer = [ query_layer, key_layer, value_layer = [
x.transpose(0, 1) for x in (query_layer, key_layer, value_layer) x.transpose(0, 1).contiguous()
for x in (query_layer, key_layer, value_layer)
] ]
elif q_format == "sbhd" and kv_format == "bshd":
query_layer = query_layer.transpose(0, 1).contiguous()
if context_parallel: if context_parallel:
query_layer, key_layer, value_layer = [ query_layer, key_layer, value_layer = [
x.contiguous() for x in (query_layer, key_layer, value_layer) x.contiguous() for x in (query_layer, key_layer, value_layer)
...@@ -3997,93 +4154,129 @@ class FlashAttention(torch.nn.Module): ...@@ -3997,93 +4154,129 @@ class FlashAttention(torch.nn.Module):
else: else:
if qkv_format == "sbhd": if qkv_format == "sbhd":
query_layer._data, key_layer._data, value_layer._data = [ query_layer._data, key_layer._data, value_layer._data = [
x.transpose(0, 1) x.transpose(0, 1).contiguous()
for x in (query_layer._data, key_layer._data, value_layer._data) for x in (query_layer._data, key_layer._data, value_layer._data)
] ]
query_layer, key_layer, value_layer = [ query_layer, key_layer, value_layer = [
Float8Tensor.make_like(x, data=x._data, shape=x._data.shape) Float8Tensor.make_like(x, data=x._data, shape=x._data.shape)
for x in (query_layer, key_layer, value_layer) for x in (query_layer, key_layer, value_layer)
] ]
elif q_format == "sbhd" and kv_format == "bshd":
query_layer._data = query_layer._data.transpose(0, 1).contiguous()
query_layer = Float8Tensor.make_like(
query_layer, data=query_layer._data, shape=query_layer._data.shape
)
if context_parallel: if context_parallel:
query_layer._data, key_layer._data, value_layer._data = [ query_layer._data, key_layer._data, value_layer._data = [
x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data) x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data)
] ]
batch_size = query_layer.shape[0] # get batch_size, max_seqlen and cu_seqlens
batch_size, context_len = None, None
if inference_params is None:
if qkv_format in ["sbhd", "bshd"]:
batch_size = query_layer.shape[0]
max_seqlen_q, max_seqlen_kv = query_layer.shape[1], key_layer.shape[1]
max_seqlen_q *= cp_size
max_seqlen_kv *= cp_size
if qkv_format in ["sbhd", "bshd"]: if "padding" in attn_mask_type:
max_seqlen_q, max_seqlen_kv = query_layer.shape[1], key_layer.shape[1] assert (
max_seqlen_q *= cp_size not context_parallel
max_seqlen_kv *= cp_size ), "Padding mask not supported with context parallelism!"
if "padding" in attn_mask_type: # [b * s, h, d]
assert not context_parallel, "Padding mask not supported with context parallelism!" query_layer, key_layer, value_layer = [
# [b * s, h, d] x.reshape(x.shape[0] * x.shape[1], *x.shape[2:])
query_layer, key_layer, value_layer = [ for x in [query_layer, key_layer, value_layer]
x.reshape(x.shape[0] * x.shape[1], *x.shape[2:]) ]
for x in [query_layer, key_layer, value_layer]
]
if self.attention_type == "self": if self.attention_type == "self":
assert (
max_seqlen_q == max_seqlen_kv
), "Maximum sequence length for Q and KV should be the same."
if cu_seqlens_q is None:
assert ( assert (
attention_mask is not None max_seqlen_q == max_seqlen_kv
), "Please provide attention_mask for padding!" ), "Maximum sequence length for Q and KV should be the same."
cu_seqlens_q, indices_q = dpa_utils.get_cu_seqlens_and_indices( if cu_seqlens_q is None:
attention_mask assert (
attention_mask is not None
), "Please provide attention_mask for padding!"
cu_seqlens_q, indices_q = dpa_utils.get_cu_seqlens_and_indices(
attention_mask
)
else:
indices_q = dpa_utils.get_indices(max_seqlen_q, cu_seqlens_q)
cu_seqlens_kv = cu_seqlens_q
query_layer, key_layer, value_layer = dpa_utils.PackTensors.apply(
indices_q, query_layer, key_layer, value_layer
) )
else: else:
indices_q = dpa_utils.get_indices(max_seqlen_q, cu_seqlens_q) if cu_seqlens_q is None or cu_seqlens_kv is None:
cu_seqlens_kv = cu_seqlens_q assert (
query_layer, key_layer, value_layer = dpa_utils.PackTensors.apply( attention_mask is not None
indices_q, query_layer, key_layer, value_layer ), "Please provide attention_mask for padding!"
) cu_seqlens_q, indices_q = dpa_utils.get_cu_seqlens_and_indices(
attention_mask[0]
)
cu_seqlens_kv, indices_kv = dpa_utils.get_cu_seqlens_and_indices(
attention_mask[1]
)
else:
indices_q = dpa_utils.get_indices(max_seqlen_q, cu_seqlens_q)
indices_kv = dpa_utils.get_indices(max_seqlen_kv, cu_seqlens_kv)
query_layer = dpa_utils.PackTensors.apply(indices_q, query_layer)
key_layer, value_layer = dpa_utils.PackTensors.apply(
indices_kv, key_layer, value_layer
)
else: else:
if cu_seqlens_q is None or cu_seqlens_kv is None: # Cumulative sequence lengths for unpadded data
assert ( if cu_seqlens_q is None:
attention_mask is not None cu_seqlens_q = dpa_utils.get_full_cu_seqlens(
), "Please provide attention_mask for padding!" batch_size,
cu_seqlens_q, indices_q = dpa_utils.get_cu_seqlens_and_indices( max_seqlen_q,
attention_mask[0] query_layer.device,
) )
cu_seqlens_kv, indices_kv = dpa_utils.get_cu_seqlens_and_indices( if cu_seqlens_kv is None:
attention_mask[1] cu_seqlens_kv = dpa_utils.get_full_cu_seqlens(
batch_size,
max_seqlen_kv,
key_layer.device,
) )
else: elif qkv_format == "thd":
indices_q = dpa_utils.get_indices(max_seqlen_q, cu_seqlens_q) assert (
indices_kv = dpa_utils.get_indices(max_seqlen_kv, cu_seqlens_kv) cu_seqlens_q is not None and cu_seqlens_kv is not None
query_layer = dpa_utils.PackTensors.apply(indices_q, query_layer) ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
key_layer, value_layer = dpa_utils.PackTensors.apply( if max_seqlen_q is None:
indices_kv, key_layer, value_layer seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
max_seqlen_q = seqlens_q.max().item()
if max_seqlen_kv is None:
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
max_seqlen_kv = seqlens_kv.max().item()
else:
if qkv_format in ["sbhd_2bshd", "bshd"]:
# q is in bshd in both cases from conversion above or the original input
batch_size, context_len = query_layer.shape[:2]
cu_seqlens_q = cu_seqlens_q[: batch_size + 1]
cu_seqlens_kv = cu_seqlens_kv[: batch_size + 1]
# convert from bshd to thd_2bshd for flash_attn_varlen_func/_with_kvcache;
# kernel assumes tensor is contiguous
if isinstance(query_layer, Float8Tensor):
query_layer._data = tex.convert_bshd_to_thd(
query_layer._data,
cu_seqlens_q,
batch_size * context_len,
) )
else: query_layer = Float8Tensor.make_like(
# Cumulative sequence lengths for unpadded data query_layer, data=query_layer._data, shape=query_layer._data.shape
if cu_seqlens_q is None:
cu_seqlens_q = dpa_utils.get_full_cu_seqlens(
batch_size,
max_seqlen_q,
query_layer.device,
) )
if cu_seqlens_kv is None: else:
cu_seqlens_kv = dpa_utils.get_full_cu_seqlens( query_layer = tex.convert_bshd_to_thd(
batch_size, query_layer,
max_seqlen_kv, cu_seqlens_q,
key_layer.device, batch_size * context_len,
) )
elif qkv_format == "thd":
assert (
cu_seqlens_q is not None and cu_seqlens_kv is not None
), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
if max_seqlen_q is None:
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
max_seqlen_q = seqlens_q.max().item()
if max_seqlen_kv is None:
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
max_seqlen_kv = seqlens_kv.max().item()
use_flash_attn_3 = False
if flash_attention_backend is not None and flash_attention_backend > PkgVersion("3.0.0b"):
use_flash_attn_3 = True
if context_parallel and all( if context_parallel and all(
not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer] not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]
): ):
...@@ -4114,6 +4307,7 @@ class FlashAttention(torch.nn.Module): ...@@ -4114,6 +4307,7 @@ class FlashAttention(torch.nn.Module):
window_size=window_size, window_size=window_size,
quantizers=quantizers, quantizers=quantizers,
pad_between_seqs=False, pad_between_seqs=False,
use_flash_attn_3=use_flash_attn_3,
) )
else: else:
...@@ -4126,30 +4320,77 @@ class FlashAttention(torch.nn.Module): ...@@ -4126,30 +4320,77 @@ class FlashAttention(torch.nn.Module):
tensor.activation_offloading = True tensor.activation_offloading = True
with self.attention_dropout_ctx(): with self.attention_dropout_ctx():
fa_optional_forward_kwargs = {} # | API | use cases
if fa_utils.v2_3_plus: # ----------------------------------------------------------------------
fa_optional_forward_kwargs["window_size"] = window_size # FA v2 | flash_attn_func | bshd/sbhd + not padding
if fa_utils.v2_4_plus: # | flash_attn_varlen_func | bshd/sbhd + padding
fa_optional_forward_kwargs["alibi_slopes"] = alibi_slopes # | | thd + padding
if fa_utils.v2_4_1_plus: # | | KV cache (not-paged/paged), i.e.
fa_optional_forward_kwargs["deterministic"] = self.deterministic # | | bshd/sbhd/thd + padding
# FA v3 | flash_attn_func | bshd/sbhd + not padding
# | flash_attn_varlen_func | bshd/sbhd + padding
# | | thd + padding
# | flash_attn_with_kvcache | KV cache (not-paged/paged), i.e.
# | | bshd/sbhd/thd + padding
fa_optional_forward_args_thd = [] fa_optional_forward_args_thd = []
if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type: if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type:
func = flash_attn_func if not fa_utils.use_v3 else flash_attn_func_v3
else:
if fa_utils.v2_5_7_plus:
fa_optional_forward_kwargs["block_table"] = None
func = ( func = (
flash_attn_varlen_func if not fa_utils.use_v3 else flash_attn_varlen_func_v3 flash_attn_func if not use_flash_attn_3 else flash_attn_func_v3
) # pylint: disable=possibly-used-before-assignment
else:
if not use_flash_attn_3:
func = flash_attn_varlen_func
elif inference_params is None:
func = flash_attn_varlen_func_v3 # pylint: disable=possibly-used-before-assignment
else:
func = flash_attn_with_kvcache_v3 # pylint: disable=possibly-used-before-assignment
if not use_flash_attn_3 or inference_params is None:
fa_optional_forward_args_thd.append(cu_seqlens_q)
fa_optional_forward_args_thd.append(cu_seqlens_kv)
fa_optional_forward_args_thd.append(max_seqlen_q)
fa_optional_forward_args_thd.append(max_seqlen_kv)
if not use_flash_attn_3:
fa_optional_forward_kwargs = {}
if fa_utils.v2_3_plus:
fa_optional_forward_kwargs["window_size"] = window_size
if fa_utils.v2_4_plus:
fa_optional_forward_kwargs["alibi_slopes"] = alibi_slopes
if fa_utils.v2_4_1_plus:
fa_optional_forward_kwargs["deterministic"] = self.deterministic
if inference_params is not None:
# use block_table kwarg to support thd_2bshd for non-paged
fa_optional_forward_kwargs["block_table"] = (
inference_params.cache_manager.page_table[:batch_size]
if inference_params.is_paged
else inference_params.cache_manager.batch_indices_post_step.unsqueeze(
1
)[:batch_size]
)
output = func(
query_layer,
key_layer,
value_layer,
*fa_optional_forward_args_thd,
self.attention_dropout if self.training else 0.0,
softmax_scale=self.softmax_scale,
causal="causal" in attn_mask_type,
**fa_optional_forward_kwargs,
) )
fa_optional_forward_args_thd.append(cu_seqlens_q) else:
fa_optional_forward_args_thd.append(cu_seqlens_kv)
fa_optional_forward_args_thd.append(max_seqlen_q)
fa_optional_forward_args_thd.append(max_seqlen_kv)
if fa_utils.use_v3:
fa_3_optional_forward_kwargs = {} fa_3_optional_forward_kwargs = {}
fa_3_optional_forward_kwargs["window_size"] = window_size fa_3_optional_forward_kwargs["window_size"] = window_size
fa_3_optional_forward_kwargs["deterministic"] = self.deterministic if inference_params is None:
fa_3_optional_forward_kwargs["deterministic"] = self.deterministic
else:
fa_3_optional_forward_kwargs["cu_seqlens_q"] = cu_seqlens_q
fa_3_optional_forward_kwargs["max_seqlen_q"] = max_seqlen_q
cache_seqlens = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
fa_3_optional_forward_kwargs["cache_seqlens"] = cache_seqlens
# flash_attn_with_kvcache accepts thd_2bshd for non-paged
if inference_params.is_paged:
fa_3_optional_forward_kwargs["page_table"] = (
inference_params.cache_manager.page_table[:batch_size]
)
if fp8: if fp8:
QKV_quantizer = quantizers["scaling_fwd"][META_QKV] QKV_quantizer = quantizers["scaling_fwd"][META_QKV]
torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True) torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True)
...@@ -4174,21 +4415,23 @@ class FlashAttention(torch.nn.Module): ...@@ -4174,21 +4415,23 @@ class FlashAttention(torch.nn.Module):
query_layer, key_layer, value_layer = ( query_layer, key_layer, value_layer = (
QKV_quantizer(x) for x in [query_layer, key_layer, value_layer] QKV_quantizer(x) for x in [query_layer, key_layer, value_layer]
) )
fa_3_optional_forward_kwargs["descale_q"] = ( batch_size = cu_seqlens_q.shape[0] - 1
query_layer._scale_inv.unsqueeze(0) num_heads_k = key_layer.shape[-2]
fa_3_optional_forward_kwargs["q_descale"] = (
query_layer._scale_inv.unsqueeze(0).repeat(batch_size, num_heads_k)
) )
fa_3_optional_forward_kwargs["descale_k"] = key_layer._scale_inv.unsqueeze( fa_3_optional_forward_kwargs["k_descale"] = key_layer._scale_inv.unsqueeze(
0 0
) ).repeat(batch_size, num_heads_k)
fa_3_optional_forward_kwargs["descale_v"] = ( fa_3_optional_forward_kwargs["v_descale"] = (
value_layer._scale_inv.unsqueeze(0) value_layer._scale_inv.unsqueeze(0).repeat(batch_size, num_heads_k)
) )
query_layer, key_layer, value_layer = ( query_layer, key_layer, value_layer = (
convert_to_torch_float8(x, torch_dtype) convert_to_torch_float8(x, torch_dtype)
for x in [query_layer, key_layer, value_layer] for x in [query_layer, key_layer, value_layer]
) )
try: try:
output, _ = func( output = func(
query_layer, query_layer,
key_layer, key_layer,
value_layer, value_layer,
...@@ -4197,6 +4440,8 @@ class FlashAttention(torch.nn.Module): ...@@ -4197,6 +4440,8 @@ class FlashAttention(torch.nn.Module):
causal="causal" in attn_mask_type, causal="causal" in attn_mask_type,
**fa_3_optional_forward_kwargs, **fa_3_optional_forward_kwargs,
) )
if isinstance(output, (List, Tuple)):
output = output[0]
except TypeError as e: except TypeError as e:
if fa_utils.v3_0_0_beta: if fa_utils.v3_0_0_beta:
e.args = ( e.args = (
...@@ -4212,22 +4457,30 @@ class FlashAttention(torch.nn.Module): ...@@ -4212,22 +4457,30 @@ class FlashAttention(torch.nn.Module):
if fp8 and fp8_meta["recipe"].fp8_mha: if fp8 and fp8_meta["recipe"].fp8_mha:
O_quantizer = quantizers["scaling_fwd"][META_O] O_quantizer = quantizers["scaling_fwd"][META_O]
output = O_quantizer(output) output = O_quantizer(output)
else:
output = func(
query_layer,
key_layer,
value_layer,
*fa_optional_forward_args_thd,
self.attention_dropout if self.training else 0.0,
softmax_scale=self.softmax_scale,
causal="causal" in attn_mask_type,
**fa_optional_forward_kwargs,
)
if qkv_format in ["sbhd", "bshd"] and "padding" in attn_mask_type: if inference_params is None:
output = dpa_utils.UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output) if qkv_format in ["sbhd", "bshd"] and "padding" in attn_mask_type:
output = dpa_utils.UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output)
elif qkv_format in ["bshd", "sbhd_2bshd"]:
# all KV caching cases use thd_2bshd for calculation
# convert results back to bshd from thd_2bshd
if isinstance(query_layer, Float8Tensor):
output._data = tex.convert_thd_to_bshd(
output._data,
cu_seqlens_q,
batch_size,
context_len,
)
output = Float8Tensor.make_like(output, data=output._data, shape=output._data.shape)
else:
output = tex.convert_thd_to_bshd(
output,
cu_seqlens_q,
batch_size,
context_len,
)
if qkv_format == "sbhd": if q_format == "sbhd":
# (bs)hd -> bs(hd) -> sb(hd) # (bs)hd -> bs(hd) -> sb(hd)
if fp8 and fp8_meta["recipe"].fp8_mha: if fp8 and fp8_meta["recipe"].fp8_mha:
output_data = ( output_data = (
...@@ -4242,10 +4495,10 @@ class FlashAttention(torch.nn.Module): ...@@ -4242,10 +4495,10 @@ class FlashAttention(torch.nn.Module):
) )
else: else:
output = output.view(batch_size, max_seqlen_q // cp_size, -1).transpose(0, 1) output = output.view(batch_size, max_seqlen_q // cp_size, -1).transpose(0, 1)
elif qkv_format == "bshd": elif q_format == "bshd":
# (bs)hd -> bs(hd) # (bs)hd -> bs(hd)
output = output.reshape(batch_size, max_seqlen_q // cp_size, -1) output = output.reshape(batch_size, max_seqlen_q // cp_size, -1)
elif qkv_format == "thd": elif q_format == "thd":
# thd -> t(hd) # thd -> t(hd)
output = output.reshape(output.shape[0], -1) output = output.reshape(output.shape[0], -1)
...@@ -4296,6 +4549,8 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -4296,6 +4549,8 @@ class FusedAttnFunc(torch.autograd.Function):
cu_seqlens_kv, cu_seqlens_kv,
cu_seqlens_q_padded, cu_seqlens_q_padded,
cu_seqlens_kv_padded, cu_seqlens_kv_padded,
page_table_k,
page_table_v,
q, q,
k, k,
v, v,
...@@ -4340,7 +4595,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -4340,7 +4595,7 @@ class FusedAttnFunc(torch.autograd.Function):
q_fp8, k_fp8, v_fp8 = q, k, v q_fp8, k_fp8, v_fp8 = q, k, v
else: else:
# 1: qkv packed, 2: kv packed, 3: qkv separate # 1: qkv packed, 2: kv packed, 3: qkv separate
qkv_group = len(qkv_layout.split("_")) qkv_group = len(qkv_layout.replace("paged_kv_", "").split("_"))
match qkv_group: match qkv_group:
case 1: case 1:
dim = qkv_layout.find("3") dim = qkv_layout.find("3")
...@@ -4376,6 +4631,8 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -4376,6 +4631,8 @@ class FusedAttnFunc(torch.autograd.Function):
attn_bias, attn_bias,
cu_seqlens_q_padded, cu_seqlens_q_padded,
cu_seqlens_kv_padded, cu_seqlens_kv_padded,
None,
None,
S_quantizer, S_quantizer,
O_quantizer, O_quantizer,
attn_scale, attn_scale,
...@@ -4398,7 +4655,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -4398,7 +4655,7 @@ class FusedAttnFunc(torch.autograd.Function):
if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
# 1: qkv packed, 2: kv packed, 3: qkv separate # 1: qkv packed, 2: kv packed, 3: qkv separate
if is_input_fp8: if is_input_fp8:
qkv_group = len(qkv_layout.split("_")) qkv_group = len(qkv_layout.replace("paged_kv_", "").split("_"))
if qkv_group == 1: if qkv_group == 1:
dim = qkv_layout.find("3") dim = qkv_layout.find("3")
qkv = _combine_tensors([q, k, v], dim) qkv = _combine_tensors([q, k, v], dim)
...@@ -4407,7 +4664,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -4407,7 +4664,7 @@ class FusedAttnFunc(torch.autograd.Function):
q, k, v = _SplitAlongDim.apply(qkv_no_fp8, dim, [1, 1, 1], True) q, k, v = _SplitAlongDim.apply(qkv_no_fp8, dim, [1, 1, 1], True)
if qkv_group == 2: if qkv_group == 2:
q = q.dequantize() q = q.dequantize()
dim = qkv_layout.split("_")[1].find("2") dim = qkv_layout.replace("paged_kv_", "").split("_")[1].find("2")
kv = _combine_tensors([k, v], dim) kv = _combine_tensors([k, v], dim)
kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
kv_no_fp8 = kv.dequantize() kv_no_fp8 = kv.dequantize()
...@@ -4436,6 +4693,8 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -4436,6 +4693,8 @@ class FusedAttnFunc(torch.autograd.Function):
attn_bias, attn_bias,
cu_seqlens_q_padded, cu_seqlens_q_padded,
cu_seqlens_kv_padded, cu_seqlens_kv_padded,
page_table_k,
page_table_v,
None, # s_quantizer None, # s_quantizer
None, # o_quantizer None, # o_quantizer
attn_scale, attn_scale,
...@@ -4612,7 +4871,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -4612,7 +4871,7 @@ class FusedAttnFunc(torch.autograd.Function):
# is_input_fp8 = False: dq, dk, dv: torch.float16 or torch.bfloat16 # is_input_fp8 = False: dq, dk, dv: torch.float16 or torch.bfloat16
# is_input_fp8 = True: dq, dk, dv: torch.float8_e5m2 # is_input_fp8 = True: dq, dk, dv: torch.float8_e5m2
if not ctx.is_input_fp8: if not ctx.is_input_fp8:
qkv_group = len(ctx.qkv_layout.split("_")) qkv_group = len(ctx.qkv_layout.replace("paged_kv_", "").split("_"))
if qkv_group == 1: if qkv_group == 1:
dim = ctx.qkv_layout.find("3") dim = ctx.qkv_layout.find("3")
dqkv_fp8_data = _combine_tensors( dqkv_fp8_data = _combine_tensors(
...@@ -4682,6 +4941,8 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -4682,6 +4941,8 @@ class FusedAttnFunc(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
None,
dq, dq,
dk, dk,
dv, dv,
...@@ -4712,6 +4973,8 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -4712,6 +4973,8 @@ class FusedAttnFunc(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
None,
dq, dq,
dk, dk,
dv, dv,
...@@ -4833,6 +5096,7 @@ class FusedAttention(torch.nn.Module): ...@@ -4833,6 +5096,7 @@ class FusedAttention(torch.nn.Module):
fp8_meta: Optional[Dict[str, Any]] = None, fp8_meta: Optional[Dict[str, Any]] = None,
quantizers=None, quantizers=None,
pad_between_seqs: bool = False, pad_between_seqs: bool = False,
inference_params: Optional[InferenceParams] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""fused attention fprop""" """fused attention fprop"""
assert ( assert (
...@@ -4857,60 +5121,63 @@ class FusedAttention(torch.nn.Module): ...@@ -4857,60 +5121,63 @@ class FusedAttention(torch.nn.Module):
cp_size *= get_distributed_world_size(group) cp_size *= get_distributed_world_size(group)
context_parallel = cp_size > 1 context_parallel = cp_size > 1
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) # get q_format and kv_format for training and inference
qkv_format, q_format, kv_format = dpa_utils.get_qkv_format(qkv_layout, inference_params)
if qkv_format in ["sbhd", "bshd"]: page_table = None
if qkv_format == "sbhd": if inference_params is None:
batch_size, max_seqlen_q, max_seqlen_kv = ( if qkv_format in ["sbhd", "bshd"]:
query_layer.shape[1], if qkv_format == "sbhd":
query_layer.shape[0], batch_size = query_layer.shape[1]
key_layer.shape[0], max_seqlen_q = query_layer.shape[0]
) max_seqlen_kv = key_layer.shape[0]
if qkv_format == "bshd": if qkv_format == "bshd":
batch_size, max_seqlen_q, max_seqlen_kv = ( batch_size = query_layer.shape[0]
query_layer.shape[0], max_seqlen_q = query_layer.shape[1]
query_layer.shape[1], max_seqlen_kv = key_layer.shape[1]
key_layer.shape[1], max_seqlen_q *= cp_size
) max_seqlen_kv *= cp_size
max_seqlen_q *= cp_size if "padding" in attn_mask_type:
max_seqlen_kv *= cp_size assert (
if "padding" in attn_mask_type: not context_parallel
assert not context_parallel, "Padding mask not supported with context parallelism!" ), "Padding mask not supported with context parallelism!"
if cu_seqlens_q is None or cu_seqlens_kv is None:
if cu_seqlens_q is None or cu_seqlens_kv is None: if attention_mask is None:
if attention_mask is None: raise RuntimeError(
raise RuntimeError( "Please provide attention_mask or cu_seqlens for padding!"
"Please provide attention_mask or cu_seqlens for padding!" )
if self.attention_type == "self":
cu_seqlens_q = dpa_utils.get_cu_seqlens(attention_mask)
cu_seqlens_kv = cu_seqlens_q
else:
cu_seqlens_q = dpa_utils.get_cu_seqlens(attention_mask[0])
cu_seqlens_kv = dpa_utils.get_cu_seqlens(attention_mask[1])
else:
if cu_seqlens_q is None:
cu_seqlens_q = dpa_utils.get_full_cu_seqlens(
batch_size,
max_seqlen_q,
query_layer.device,
) )
if self.attention_type == "self": if cu_seqlens_kv is None:
cu_seqlens_q = dpa_utils.get_cu_seqlens(attention_mask) cu_seqlens_kv = dpa_utils.get_full_cu_seqlens(
cu_seqlens_kv = cu_seqlens_q batch_size,
else: max_seqlen_kv,
cu_seqlens_q = dpa_utils.get_cu_seqlens(attention_mask[0]) key_layer.device,
cu_seqlens_kv = dpa_utils.get_cu_seqlens(attention_mask[1]) )
else: if qkv_format == "thd":
if cu_seqlens_q is None: assert (
cu_seqlens_q = dpa_utils.get_full_cu_seqlens( max_seqlen_q is not None
batch_size, and max_seqlen_kv is not None
max_seqlen_q, and cu_seqlens_q is not None
query_layer.device, and cu_seqlens_kv is not None
) ), "max_seqlen_q/kv and cu_seqlens_q/kv can not be None when qkv_format is thd!"
if cu_seqlens_kv is None: elif inference_params.is_paged:
cu_seqlens_kv = dpa_utils.get_full_cu_seqlens( page_table = inference_params.cache_manager.page_table
batch_size,
max_seqlen_kv, if (q_format == "thd" or "padding" in attn_mask_type) and cu_seqlens_q_padded is None:
key_layer.device,
)
if qkv_format == "thd":
assert (
max_seqlen_q is not None
and max_seqlen_kv is not None
and cu_seqlens_q is not None
and cu_seqlens_kv is not None
), "max_seqlen_q/kv and cu_seqlens_q/kv can not be None when qkv_format is thd!"
if qkv_format == "thd" and (cu_seqlens_q_padded is None or cu_seqlens_kv_padded is None):
cu_seqlens_q_padded = cu_seqlens_q cu_seqlens_q_padded = cu_seqlens_q
if (kv_format == "thd" or "padding" in attn_mask_type) and cu_seqlens_kv_padded is None:
cu_seqlens_kv_padded = cu_seqlens_kv cu_seqlens_kv_padded = cu_seqlens_kv
use_FAv2_bwd = ( use_FAv2_bwd = (
...@@ -4981,6 +5248,8 @@ class FusedAttention(torch.nn.Module): ...@@ -4981,6 +5248,8 @@ class FusedAttention(torch.nn.Module):
cu_seqlens_kv, cu_seqlens_kv,
cu_seqlens_q_padded, cu_seqlens_q_padded,
cu_seqlens_kv_padded, cu_seqlens_kv_padded,
page_table,
page_table,
query_layer, query_layer,
key_layer, key_layer,
value_layer, value_layer,
...@@ -5369,14 +5638,14 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5369,14 +5638,14 @@ class DotProductAttention(TransformerEngineBaseModule):
query_layer: torch.Tensor, query_layer: torch.Tensor,
key_layer: torch.Tensor, key_layer: torch.Tensor,
value_layer: torch.Tensor, value_layer: torch.Tensor,
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None,
qkv_format: Optional[str] = None, qkv_format: str = None,
cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_q: torch.Tensor = None,
cu_seqlens_kv: Optional[torch.Tensor] = None, cu_seqlens_kv: torch.Tensor = None,
cu_seqlens_q_padded: Optional[torch.Tensor] = None, cu_seqlens_q_padded: torch.Tensor = None,
cu_seqlens_kv_padded: Optional[torch.Tensor] = None, cu_seqlens_kv_padded: torch.Tensor = None,
max_seqlen_q: Optional[int] = None, max_seqlen_q: int = None,
max_seqlen_kv: Optional[int] = None, max_seqlen_kv: int = None,
attn_mask_type: Optional[str] = None, attn_mask_type: Optional[str] = None,
window_size: Optional[Tuple[int, int]] = None, window_size: Optional[Tuple[int, int]] = None,
checkpoint_core_attention: bool = False, checkpoint_core_attention: bool = False,
...@@ -5565,6 +5834,16 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5565,6 +5834,16 @@ class DotProductAttention(TransformerEngineBaseModule):
num_gemms=3, num_gemms=3,
allow_non_contiguous=True, allow_non_contiguous=True,
) as query_layer: ) as query_layer:
# checks for RNG
if self.rng_states_tracker is not None and is_graph_capturing():
assert isinstance(
self.rng_states_tracker, CudaRNGStatesTracker
), "Unsupported RNG states tracker."
assert (
graph_safe_rng_available()
), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture."
# checks for FP8
if self.fp8: if self.fp8:
if self.fp8_meta["recipe"].fp8_mha: if self.fp8_meta["recipe"].fp8_mha:
if not self.fp8_meta["recipe"].fp8_dpa: if not self.fp8_meta["recipe"].fp8_dpa:
...@@ -5573,7 +5852,6 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5573,7 +5852,6 @@ class DotProductAttention(TransformerEngineBaseModule):
"""Forcing fp8_meta["recipe"].fp8_dpa=True due to """ """Forcing fp8_meta["recipe"].fp8_dpa=True due to """
"""fp8_meta["recipe"].fp8_mha=True""" """fp8_meta["recipe"].fp8_mha=True"""
) )
if self.fp8 and self.fp8_meta["recipe"].fp8_dpa: if self.fp8 and self.fp8_meta["recipe"].fp8_dpa:
forward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=True) forward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=True)
backward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=False) backward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=False)
...@@ -5585,6 +5863,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5585,6 +5863,7 @@ class DotProductAttention(TransformerEngineBaseModule):
tex.DType.kFloat8E5M2, tex.DType.kFloat8E5M2,
], """DotProductAttention only supports "E4M3" and "E5M2" FP8 data types.""" ], """DotProductAttention only supports "E4M3" and "E5M2" FP8 data types."""
# checks for q/k/v shapes
assert ( assert (
query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
), "DotProductAttention only supports CUDA tensors." ), "DotProductAttention only supports CUDA tensors."
...@@ -5594,18 +5873,26 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5594,18 +5873,26 @@ class DotProductAttention(TransformerEngineBaseModule):
assert ( assert (
key_layer.shape[:-1] == value_layer.shape[:-1] key_layer.shape[:-1] == value_layer.shape[:-1]
), "Keys and values must have the same batch size, sequence length and number of heads!" ), "Keys and values must have the same batch size, sequence length and number of heads!"
num_attention_heads = query_layer.shape[-2]
num_gqa_groups = key_layer.shape[-2]
assert ( assert (
key_layer.shape[-1] == self.hidden_size_per_attention_head_k query_layer.shape[-1] == key_layer.shape[-1]
), f"Keys have head_dim = {key_layer.shape[-1]}, " ), "Queries and keys must have the same head dimension!"
head_dim_qk, head_dim_v = query_layer.shape[-1], value_layer.shape[-1]
assert (
head_dim_qk == self.hidden_size_per_attention_head_k
), f"Keys have head_dim = {head_dim_qk}, "
"but expected head_dim = {self.hidden_size_per_attention_head_k}!" "but expected head_dim = {self.hidden_size_per_attention_head_k}!"
assert ( assert (
value_layer.shape[-1] == self.hidden_size_per_attention_head_v head_dim_v == self.hidden_size_per_attention_head_v
), f"Values have head_dim = {value_layer.shape[-1]}, " ), f"Values have head_dim = {head_dim_v}, "
"but expected head_dim = {self.hidden_size_per_attention_head_v}!" "but expected head_dim = {self.hidden_size_per_attention_head_v}!"
assert num_gqa_groups == self.num_gqa_groups_per_partition, (
"Keys and values must have num_gqa_group ="
f" {self.num_gqa_groups_per_partition} heads! Found {num_gqa_groups}."
)
if qkv_format is None: # checks for attention mask
qkv_format = self.qkv_format
if attn_mask_type is None: if attn_mask_type is None:
attn_mask_type = self.attn_mask_type attn_mask_type = self.attn_mask_type
else: else:
...@@ -5615,82 +5902,40 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5615,82 +5902,40 @@ class DotProductAttention(TransformerEngineBaseModule):
assert ( assert (
attn_mask_type in AttnMaskTypes attn_mask_type in AttnMaskTypes
), f"Attention mask type {attn_mask_type} is not supported!" ), f"Attention mask type {attn_mask_type} is not supported!"
if qkv_format == "thd":
assert (
"padding" in attn_mask_type
), "Attention mask type must be padding or padding_causal for qkv_format=thd!"
# checks for sliding window
if window_size is None: if window_size is None:
window_size = self.window_size window_size = self.window_size
window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size) window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size)
if self.rng_states_tracker is not None and is_graph_capturing(): # checks for qkv_format
assert isinstance( if qkv_format is None:
self.rng_states_tracker, CudaRNGStatesTracker qkv_format = self.qkv_format
), "Unsupported RNG states tracker."
assert (
graph_safe_rng_available()
), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture."
if inference_params is not None:
assert self.layer_number is not None, "Layer number must be set!"
# convert causal to causal_bottom_right in inference when KV-caching is in use
# so users can run with the same attn_mask_type for training and inference
if attn_mask_type in ["causal", "padding_causal"]:
attn_mask_type = attn_mask_type + "_bottom_right"
if qkv_format == "bshd":
key_layer = key_layer.transpose(0, 1)
value_layer = value_layer.transpose(0, 1)
(
inference_key_memory,
inference_value_memory,
) = inference_params.key_value_memory_dict[self.layer_number]
batch_start = inference_params.batch_size_offset
batch_end = batch_start + key_layer.size(1)
assert batch_end <= inference_key_memory.size(1)
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + key_layer.size(0)
assert sequence_end <= inference_key_memory.size(0)
# Copy keys and values into KV-cache
inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = (
key_layer
)
inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = (
value_layer
)
key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...]
if qkv_format == "bshd":
key_layer = key_layer.transpose(0, 1)
value_layer = value_layer.transpose(0, 1)
key_layer = key_layer.contiguous()
value_layer = value_layer.contiguous()
assert (
key_layer.shape[-2] == self.num_gqa_groups_per_partition
and value_layer.shape[-2] == self.num_gqa_groups_per_partition
), (
"Keys and values must have num_gqa_group ="
f" {self.num_gqa_groups_per_partition} heads!"
)
assert qkv_format in [ assert qkv_format in [
"sbhd", "sbhd",
"bshd", "bshd",
"thd", "thd",
], "DotProductAttention only supports qkv_format = {'sbhd', 'bshd', 'thd'}!" ], "DotProductAttention only supports qkv_format = {'sbhd', 'bshd', 'thd'}!"
batch_size = None
if qkv_format in ["sbhd", "bshd"]:
assert all(
len(x.shape) == 4 for x in (query_layer, key_layer, value_layer)
), f"Queries, keys and values must be 4D tensors when {qkv_format=}!"
if qkv_format == "sbhd":
batch_size = query_layer.shape[1]
max_seqlen_q = query_layer.shape[0] if max_seqlen_q is None else max_seqlen_q
max_seqlen_kv = key_layer.shape[0] if max_seqlen_kv is None else max_seqlen_kv
else:
batch_size = query_layer.shape[0]
max_seqlen_q = query_layer.shape[1] if max_seqlen_q is None else max_seqlen_q
max_seqlen_kv = key_layer.shape[1] if max_seqlen_kv is None else max_seqlen_kv
if qkv_format == "thd": if qkv_format == "thd":
assert all( assert all(
len(x.shape) == 3 for x in (query_layer, key_layer, value_layer) len(x.shape) == 3 for x in (query_layer, key_layer, value_layer)
), "Queries, keys and values must be 3D tensors when qkv_format = thd!" ), "Queries, keys and values must be 3D tensors when qkv_format = thd!"
assert (
"padding" in attn_mask_type
), "Attention mask type must be padding or padding_causal for qkv_format=thd!"
assert ( assert (
cu_seqlens_q is not None and cu_seqlens_kv is not None cu_seqlens_q is not None and cu_seqlens_kv is not None
), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!" ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
...@@ -5716,6 +5961,76 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5716,6 +5961,76 @@ class DotProductAttention(TransformerEngineBaseModule):
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
max_seqlen_kv = int((seqlens_kv.max().item() + 63) // 64 * 64) max_seqlen_kv = int((seqlens_kv.max().item() + 63) // 64 * 64)
# update KV cache and retrieve saved tokens from cache for inference
if inference_params is not None:
assert self.layer_number is not None, "Layer number must be set!"
# convert top-left causal to bottom-right causal due to KV caching
# users can still use the same attention mask for inference as for training
assert "padding" in attn_mask_type, "KV caching requires padding mask!"
if attn_mask_type == "padding_causal":
attn_mask_type = attn_mask_type + "_bottom_right"
self.attention_type = "cross"
self.flash_attention.attention_type = self.attention_type
self.fused_attention.attention_type = self.attention_type
self.unfused_attention.attention_type = self.attention_type
query_layer, key_layer, value_layer = [
x.contiguous() if not x.is_contiguous() else x
for x in [query_layer, key_layer, value_layer]
]
# get full K/V tensors from cache and adjust cu_seqlens, qkv_format based on the cache
(
key_layer,
value_layer,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_kv,
qkv_format,
) = inference_params.step(
self.layer_number,
key_layer,
value_layer,
qkv_format,
)
cu_seqlens_q_padded = None
cu_seqlens_kv_padded = None
# get qkv's memory layout
if all(isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]):
(
qkv_layout,
query_layer._data,
key_layer._data,
value_layer._data,
q_format,
kv_format,
) = dpa_utils.get_qkv_layout(
query_layer._data,
key_layer._data,
value_layer._data,
qkv_format=qkv_format,
inference_params=inference_params,
)
else:
(
qkv_layout,
query_layer,
key_layer,
value_layer,
q_format,
kv_format,
) = dpa_utils.get_qkv_layout(
query_layer,
key_layer,
value_layer,
qkv_format=qkv_format,
inference_params=inference_params,
)
# adjust max_seqlen and cu_seqlens for CP
cp_size = 1 cp_size = 1
if isinstance(self.cp_group, dist_group_type): if isinstance(self.cp_group, dist_group_type):
cp_size = get_distributed_world_size(self.cp_group) cp_size = get_distributed_world_size(self.cp_group)
...@@ -5723,71 +6038,42 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5723,71 +6038,42 @@ class DotProductAttention(TransformerEngineBaseModule):
for group in self.cp_group: for group in self.cp_group:
cp_size *= get_distributed_world_size(group) cp_size *= get_distributed_world_size(group)
context_parallel = cp_size > 1 context_parallel = cp_size > 1
if q_format in ["sbhd", "bshd"]:
if qkv_format in ["sbhd", "bshd"]:
assert all(
len(x.shape) == 4 for x in (query_layer, key_layer, value_layer)
), f"Queries, keys and values must be 4D tensors when qkv_format = {qkv_format}!"
if qkv_format == "sbhd":
max_seqlen_q = query_layer.shape[0] if max_seqlen_q is None else max_seqlen_q
max_seqlen_kv = key_layer.shape[0] if max_seqlen_kv is None else max_seqlen_kv
batch_size = query_layer.shape[1]
else:
max_seqlen_q = query_layer.shape[1] if max_seqlen_q is None else max_seqlen_q
max_seqlen_kv = key_layer.shape[1] if max_seqlen_kv is None else max_seqlen_kv
batch_size = query_layer.shape[0]
max_seqlen_q *= cp_size max_seqlen_q *= cp_size
max_seqlen_kv *= cp_size if cu_seqlens_q is None:
if cu_seqlens_q is not None:
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
assert all(
seqlens_q <= max_seqlen_q
), """Sequence lengths indicated by cu_seqlens_q must be no greater than
the sequence dimension in 'query_layer'!"""
if cu_seqlens_kv is not None:
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
assert all(
seqlens_kv <= max_seqlen_kv
), """Sequence lengths indicated by cu_seqlens_kv must be no greater than
the sequence dimension in 'key_layer' and 'value_layer'!"""
if cu_seqlens_q is None or cu_seqlens_kv is None:
if "padding" in attn_mask_type: if "padding" in attn_mask_type:
assert ( assert (
attention_mask is not None attention_mask is not None
), "Please provide attention_mask for padding!" ), "Please provide attention_mask for padding!"
if self.attention_type == "self": if self.attention_type == "self":
cu_seqlens_q = dpa_utils.get_cu_seqlens(attention_mask) cu_seqlens_q = dpa_utils.get_cu_seqlens(attention_mask)
cu_seqlens_kv = cu_seqlens_q
else: else:
cu_seqlens_q = dpa_utils.get_cu_seqlens(attention_mask[0]) cu_seqlens_q = dpa_utils.get_cu_seqlens(attention_mask[0])
cu_seqlens_kv = dpa_utils.get_cu_seqlens(attention_mask[1])
else: else:
cu_seqlens_q = dpa_utils.get_full_cu_seqlens( cu_seqlens_q = dpa_utils.get_full_cu_seqlens(
batch_size, batch_size,
max_seqlen_q, max_seqlen_q,
query_layer.device, query_layer.device,
) )
if kv_format in ["sbhd", "bshd"]:
max_seqlen_kv *= cp_size
if cu_seqlens_kv is None:
if "padding" in attn_mask_type:
assert (
attention_mask is not None
), "Please provide attention_mask for padding!"
if self.attention_type == "self":
cu_seqlens_kv = dpa_utils.get_cu_seqlens(attention_mask)
else:
cu_seqlens_kv = dpa_utils.get_cu_seqlens(attention_mask[1])
else:
cu_seqlens_kv = dpa_utils.get_full_cu_seqlens( cu_seqlens_kv = dpa_utils.get_full_cu_seqlens(
batch_size, batch_size,
max_seqlen_kv, max_seqlen_kv,
key_layer.device, key_layer.device,
) )
if ( # set ALiBi attributes
isinstance(query_layer, Float8Tensor)
and isinstance(key_layer, Float8Tensor)
and isinstance(value_layer, Float8Tensor)
):
qkv_layout, query_layer._data, key_layer._data, value_layer._data = (
dpa_utils.get_qkv_layout(
query_layer._data, key_layer._data, value_layer._data, qkv_format=qkv_format
)
)
else:
qkv_layout, query_layer, key_layer, value_layer = dpa_utils.get_qkv_layout(
query_layer, key_layer, value_layer, qkv_format=qkv_format
)
global _alibi_cache global _alibi_cache
if alibi_slopes is not None: if alibi_slopes is not None:
assert ( assert (
...@@ -5811,6 +6097,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5811,6 +6097,7 @@ class DotProductAttention(TransformerEngineBaseModule):
_alibi_cache["_alibi_slopes_require_update"] = True _alibi_cache["_alibi_slopes_require_update"] = True
_alibi_cache["_alibi_bias_require_update"] = True _alibi_cache["_alibi_bias_require_update"] = True
# detect bias shape
core_attention_bias_shape = None core_attention_bias_shape = None
if core_attention_bias is not None: if core_attention_bias is not None:
if ( if (
...@@ -5846,17 +6133,18 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5846,17 +6133,18 @@ class DotProductAttention(TransformerEngineBaseModule):
else: else:
pad_between_seqs = False pad_between_seqs = False
# gather attention params for get_attention_backend
attention_params = dpa_utils.AttentionParams( attention_params = dpa_utils.AttentionParams(
qkv_type=type(query_layer), qkv_type=type(query_layer),
qkv_dtype=query_layer.dtype, qkv_dtype=query_layer.dtype,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
batch_size=batch_size, batch_size=batch_size,
num_heads=query_layer.shape[-2], num_heads=num_attention_heads,
num_gqa_groups=key_layer.shape[-2], num_gqa_groups=num_gqa_groups,
max_seqlen_q=max_seqlen_q, max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv, max_seqlen_kv=max_seqlen_kv,
head_dim_qk=query_layer.shape[-1], head_dim_qk=head_dim_qk,
head_dim_v=value_layer.shape[-1], head_dim_v=head_dim_v,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
window_size=window_size, window_size=window_size,
alibi_slopes_shape=alibi_slopes.shape if alibi_slopes is not None else None, alibi_slopes_shape=alibi_slopes.shape if alibi_slopes is not None else None,
...@@ -5872,6 +6160,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5872,6 +6160,7 @@ class DotProductAttention(TransformerEngineBaseModule):
is_training=self.training, is_training=self.training,
fp8=self.fp8, fp8=self.fp8,
fp8_meta=self.fp8_meta, fp8_meta=self.fp8_meta,
inference_params=inference_params,
) )
global _attention_backends global _attention_backends
if ( if (
...@@ -5881,9 +6170,9 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5881,9 +6170,9 @@ class DotProductAttention(TransformerEngineBaseModule):
_attention_backends["attention_params"] = attention_params _attention_backends["attention_params"] = attention_params
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
if _attention_backends["backend_selection_requires_update"]: if _attention_backends["backend_selection_requires_update"]:
fa_utils.use_v3 = fa_utils.v3_is_installed
( (
use_flash_attention, use_flash_attention,
flash_attention_backend,
use_fused_attention, use_fused_attention,
fused_attention_backend, fused_attention_backend,
use_unfused_attention, use_unfused_attention,
...@@ -5892,6 +6181,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5892,6 +6181,7 @@ class DotProductAttention(TransformerEngineBaseModule):
# Set global _attention_backends var using return value # Set global _attention_backends var using return value
# from get_attention_backend() # from get_attention_backend()
_attention_backends["use_flash_attention"] = use_flash_attention _attention_backends["use_flash_attention"] = use_flash_attention
_attention_backends["flash_attention_backend"] = flash_attention_backend
_attention_backends["use_fused_attention"] = use_fused_attention _attention_backends["use_fused_attention"] = use_fused_attention
_attention_backends["fused_attention_backend"] = fused_attention_backend _attention_backends["fused_attention_backend"] = fused_attention_backend
_attention_backends["use_unfused_attention"] = use_unfused_attention _attention_backends["use_unfused_attention"] = use_unfused_attention
...@@ -5899,7 +6189,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5899,7 +6189,7 @@ class DotProductAttention(TransformerEngineBaseModule):
if use_flash_attention: if use_flash_attention:
self.logger.info( self.logger.info(
"Running with FlashAttention backend (version %s)", "Running with FlashAttention backend (version %s)",
fa_utils.version if not fa_utils.use_v3 else fa_utils.fa3_version, flash_attention_backend,
) )
elif use_fused_attention: elif use_fused_attention:
self.logger.info( self.logger.info(
...@@ -5910,10 +6200,16 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5910,10 +6200,16 @@ class DotProductAttention(TransformerEngineBaseModule):
self.logger.info("Running with UnfusedDotProductAttention backend") self.logger.info("Running with UnfusedDotProductAttention backend")
else: else:
use_flash_attention = _attention_backends["use_flash_attention"] use_flash_attention = _attention_backends["use_flash_attention"]
flash_attention_backend = _attention_backends["flash_attention_backend"]
use_fused_attention = _attention_backends["use_fused_attention"] use_fused_attention = _attention_backends["use_fused_attention"]
fused_attention_backend = _attention_backends["fused_attention_backend"] fused_attention_backend = _attention_backends["fused_attention_backend"]
use_unfused_attention = _attention_backends["use_unfused_attention"] use_unfused_attention = _attention_backends["use_unfused_attention"]
# raise exception if no backend is available
if sum([use_flash_attention, use_fused_attention, use_unfused_attention]) == 0:
raise ValueError("No dot product attention support for the provided inputs!")
# run attention
if use_flash_attention: if use_flash_attention:
if core_attention_bias_type == "alibi": if core_attention_bias_type == "alibi":
alibi_slopes, _ = dpa_utils.get_alibi( alibi_slopes, _ = dpa_utils.get_alibi(
...@@ -5943,6 +6239,8 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5943,6 +6239,8 @@ class DotProductAttention(TransformerEngineBaseModule):
fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa, fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
fp8_meta=self.fp8_meta, fp8_meta=self.fp8_meta,
quantizers=self.quantizers, quantizers=self.quantizers,
inference_params=inference_params,
flash_attention_backend=flash_attention_backend,
) )
if use_fused_attention: if use_fused_attention:
...@@ -5961,6 +6259,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5961,6 +6259,7 @@ class DotProductAttention(TransformerEngineBaseModule):
bias_dtype=query_layer.dtype, bias_dtype=query_layer.dtype,
bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"], bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
) )
# checkpoint_core_attention=False
if checkpoint_core_attention: if checkpoint_core_attention:
return self._checkpointed_attention_forward( return self._checkpointed_attention_forward(
self.fused_attention, self.fused_attention,
...@@ -5987,7 +6286,9 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5987,7 +6286,9 @@ class DotProductAttention(TransformerEngineBaseModule):
cp_comm_type=self.cp_comm_type, cp_comm_type=self.cp_comm_type,
fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa, fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
fp8_meta=self.fp8_meta, fp8_meta=self.fp8_meta,
quantizers=self.quantizers,
pad_between_seqs=pad_between_seqs, pad_between_seqs=pad_between_seqs,
inference_params=inference_params,
) )
return self.fused_attention( return self.fused_attention(
query_layer, query_layer,
...@@ -6015,6 +6316,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -6015,6 +6316,7 @@ class DotProductAttention(TransformerEngineBaseModule):
fp8_meta=self.fp8_meta, fp8_meta=self.fp8_meta,
quantizers=self.quantizers, quantizers=self.quantizers,
pad_between_seqs=pad_between_seqs, pad_between_seqs=pad_between_seqs,
inference_params=inference_params,
) )
from .cpu_offload import CPUOffloadEnabled from .cpu_offload import CPUOffloadEnabled
...@@ -6041,6 +6343,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -6041,6 +6343,7 @@ class DotProductAttention(TransformerEngineBaseModule):
core_attention_bias_type=core_attention_bias_type, core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias, core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
inference_params=inference_params,
) )
return self.unfused_attention( return self.unfused_attention(
query_layer, query_layer,
...@@ -6055,9 +6358,9 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -6055,9 +6358,9 @@ class DotProductAttention(TransformerEngineBaseModule):
core_attention_bias_type=core_attention_bias_type, core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias, core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
inference_params=inference_params,
) )
return None
raise ValueError("No dot product attention support for the provided inputs!")
class MultiheadAttention(torch.nn.Module): class MultiheadAttention(torch.nn.Module):
...@@ -6241,7 +6544,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -6241,7 +6544,7 @@ class MultiheadAttention(torch.nn.Module):
self.qkv_format = qkv_format self.qkv_format = qkv_format
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size) self.window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size)
self.layer_number = layer_number self.layer_number = 1 if layer_number is None else layer_number
self.input_layernorm = input_layernorm self.input_layernorm = input_layernorm
self.attention_type = attention_type self.attention_type = attention_type
self.get_rng_state_tracker = get_rng_state_tracker self.get_rng_state_tracker = get_rng_state_tracker
...@@ -6410,19 +6713,6 @@ class MultiheadAttention(torch.nn.Module): ...@@ -6410,19 +6713,6 @@ class MultiheadAttention(torch.nn.Module):
**common_gemm_kwargs, **common_gemm_kwargs,
) )
def _allocate_memory(
self, inference_max_sequence_len: int, batch_size: int, dtype: torch.dtype
) -> torch.Tensor:
"""Allocates memory for KV cache."""
return torch.empty(
inference_max_sequence_len,
batch_size,
self.num_gqa_groups_per_partition,
self.hidden_size_per_attention_head,
dtype=dtype,
device=torch.cuda.current_device(),
)
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
""" """
Set the tensor parallel group for the given Set the tensor parallel group for the given
...@@ -6611,31 +6901,14 @@ class MultiheadAttention(torch.nn.Module): ...@@ -6611,31 +6901,14 @@ class MultiheadAttention(torch.nn.Module):
), f"core_attention_bias_type {core_attention_bias_type} is not supported!" ), f"core_attention_bias_type {core_attention_bias_type} is not supported!"
# ================================================= # =================================================
# Pre-allocate memory for key-values for inference # Pre-allocate memory for key-value cache for inference
# ================================================= # =================================================
if inference_params and self.layer_number is not None: if (
assert ( inference_params is not None
self.qkv_format != "thd" and self.layer_number not in inference_params.cache_manager.cache
), "qkv_format == thd is not supported for an inference with KV-cache!" ):
if self.layer_number not in inference_params.key_value_memory_dict: inference_params.allocate_memory(self.layer_number)
inf_max_seq_len = inference_params.max_sequence_length
inf_max_batch_size = inference_params.max_batch_size
inference_key_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
)
inference_value_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
)
inference_params.key_value_memory_dict[self.layer_number] = (
inference_key_memory,
inference_value_memory,
)
else:
(
inference_key_memory,
inference_value_memory,
) = inference_params.key_value_memory_dict[self.layer_number]
# ====================== # ======================
# Query, Key, and Value # Query, Key, and Value
...@@ -6801,9 +7074,12 @@ class MultiheadAttention(torch.nn.Module): ...@@ -6801,9 +7074,12 @@ class MultiheadAttention(torch.nn.Module):
elif self.qkv_format == "bshd": elif self.qkv_format == "bshd":
sequence_length = key_layer.size(1) sequence_length = key_layer.size(1)
else: else:
raise ValueError(f"QKV format {self.qkv_format} not supported for KV caching.") raise ValueError(
f"qkv_format={self.qkv_format} not supported for KV caching and RoPE."
)
sequence_start = inference_params.sequence_len_offset sequence_start = inference_params.get_seqlens_pre_step()
# sequence_start = inference_params.seqlens[0]
sequence_end = sequence_start + sequence_length sequence_end = sequence_start + sequence_length
q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...] q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...]
......
...@@ -64,6 +64,16 @@ QKVLayouts = ( ...@@ -64,6 +64,16 @@ QKVLayouts = (
"thd_t2hd", "thd_t2hd",
"thd_th2d", "thd_th2d",
"thd_thd_thd", "thd_thd_thd",
"sbhd_bshd_bshd",
"bshd_sbhd_sbhd",
"thd_bshd_bshd",
"thd_sbhd_sbhd",
"paged_kv_bshd_bshd_bshd",
"paged_kv_bshd_sbhd_sbhd",
"paged_kv_sbhd_bshd_bshd",
"paged_kv_sbhd_sbhd_sbhd",
"paged_kv_thd_bshd_bshd",
"paged_kv_thd_sbhd_sbhd",
) )
LayerTypes = ("encoder", "decoder") LayerTypes = ("encoder", "decoder")
......
...@@ -9,6 +9,7 @@ import torch ...@@ -9,6 +9,7 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import ( from transformer_engine_torch import (
NVTE_QKV_Layout, NVTE_QKV_Layout,
NVTE_QKV_Format,
NVTE_Bias_Type, NVTE_Bias_Type,
NVTE_Mask_Type, NVTE_Mask_Type,
NVTE_Fused_Attn_Backend, NVTE_Fused_Attn_Backend,
...@@ -31,6 +32,16 @@ TORCH_DType = { ...@@ -31,6 +32,16 @@ TORCH_DType = {
tex.DType.kInt32: torch.int32, tex.DType.kInt32: torch.int32,
} }
QKVFormat = {
"bshd": NVTE_QKV_Format.NVTE_BSHD,
"sbhd": NVTE_QKV_Format.NVTE_SBHD,
"thd": NVTE_QKV_Format.NVTE_THD,
"sbhd_2bshd": NVTE_QKV_Format.NVTE_SBHD_2BSHD,
"bshd_2sbhd": NVTE_QKV_Format.NVTE_BSHD_2SBHD,
"thd_2bshd": NVTE_QKV_Format.NVTE_THD_2BSHD,
"thd_2sbhd": NVTE_QKV_Format.NVTE_THD_2SBHD,
}
QKVLayout = { QKVLayout = {
"sb3hd": NVTE_QKV_Layout.NVTE_SB3HD, "sb3hd": NVTE_QKV_Layout.NVTE_SB3HD,
"sbh3d": NVTE_QKV_Layout.NVTE_SBH3D, "sbh3d": NVTE_QKV_Layout.NVTE_SBH3D,
...@@ -47,6 +58,16 @@ QKVLayout = { ...@@ -47,6 +58,16 @@ QKVLayout = {
"thd_t2hd": NVTE_QKV_Layout.NVTE_THD_T2HD, "thd_t2hd": NVTE_QKV_Layout.NVTE_THD_T2HD,
"thd_th2d": NVTE_QKV_Layout.NVTE_THD_TH2D, "thd_th2d": NVTE_QKV_Layout.NVTE_THD_TH2D,
"thd_thd_thd": NVTE_QKV_Layout.NVTE_THD_THD_THD, "thd_thd_thd": NVTE_QKV_Layout.NVTE_THD_THD_THD,
"sbhd_bshd_bshd": NVTE_QKV_Layout.NVTE_SBHD_BSHD_BSHD,
"bshd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_BSHD_SBHD_SBHD,
"thd_bshd_bshd": NVTE_QKV_Layout.NVTE_THD_BSHD_BSHD,
"thd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_THD_SBHD_SBHD,
"paged_kv_bshd_bshd_bshd": NVTE_QKV_Layout.NVTE_Paged_KV_BSHD_BSHD_BSHD,
"paged_kv_bshd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_BSHD_SBHD_SBHD,
"paged_kv_sbhd_bshd_bshd": NVTE_QKV_Layout.NVTE_Paged_KV_SBHD_BSHD_BSHD,
"paged_kv_sbhd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_SBHD_SBHD_SBHD,
"paged_kv_thd_bshd_bshd": NVTE_QKV_Layout.NVTE_Paged_KV_THD_BSHD_BSHD,
"paged_kv_thd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_THD_SBHD_SBHD,
} }
AttnBiasType = { AttnBiasType = {
...@@ -100,6 +121,8 @@ def fused_attn_fwd( ...@@ -100,6 +121,8 @@ def fused_attn_fwd(
attn_bias: torch.Tensor = None, attn_bias: torch.Tensor = None,
cu_seqlens_q_padded: torch.Tensor = None, cu_seqlens_q_padded: torch.Tensor = None,
cu_seqlens_kv_padded: torch.Tensor = None, cu_seqlens_kv_padded: torch.Tensor = None,
page_table_k: torch.Tensor = None,
page_table_v: torch.Tensor = None,
s_quantizer: Quantizer = None, s_quantizer: Quantizer = None,
o_quantizer: Quantizer = None, o_quantizer: Quantizer = None,
attn_scale: float = None, attn_scale: float = None,
...@@ -148,6 +171,10 @@ def fused_attn_fwd( ...@@ -148,6 +171,10 @@ def fused_attn_fwd(
cumulative sequence offsets for Q; shape [batch_size + 1] cumulative sequence offsets for Q; shape [batch_size + 1]
cu_seqlens_kv_padded: torch.Tensor, default = None cu_seqlens_kv_padded: torch.Tensor, default = None
cumulative sequence offsets for KV; shape [batch_size + 1] cumulative sequence offsets for KV; shape [batch_size + 1]
page_table_k: torch.Tensor, default = None
page table for K cache; shape [batch_size, max_pages_per_seq_k]
page_table_v: torch.Tensor, default = None
page table for V cache; shape [batch_size, max_pages_per_seq_v]
s_quantizer: Quantizer, default = None s_quantizer: Quantizer, default = None
Quantizer object for the intermediate value S. Quantizer object for the intermediate value S.
o_quantizer: Quantizer, default = None o_quantizer: Quantizer, default = None
...@@ -268,6 +295,8 @@ def fused_attn_fwd( ...@@ -268,6 +295,8 @@ def fused_attn_fwd(
fake_dtype, fake_dtype,
cu_seqlens_q_padded, cu_seqlens_q_padded,
cu_seqlens_kv_padded, cu_seqlens_kv_padded,
page_table_k,
page_table_v,
s_quantizer, s_quantizer,
o_quantizer, o_quantizer,
attn_bias, attn_bias,
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#include <ATen/cudnn/Handle.h> #include <ATen/cudnn/Handle.h>
#include <ATen/native/DispatchStub.h> #include <ATen/native/DispatchStub.h>
#include <c10/macros/Macros.h> #include <c10/macros/Macros.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e5m2.h>
#include <cublasLt.h> #include <cublasLt.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_bf16.h> #include <cuda_bf16.h>
......
...@@ -51,8 +51,9 @@ std::vector<py::object> fused_attn_fwd( ...@@ -51,8 +51,9 @@ std::vector<py::object> fused_attn_fwd(
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q,
const py::handle K, const py::handle V, const at::ScalarType fake_dtype, const py::handle K, const py::handle V, const at::ScalarType fake_dtype,
const c10::optional<at::Tensor> cu_seqlens_q_padded, const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer, const c10::optional<at::Tensor> cu_seqlens_kv_padded,
py::handle o_quantizer, const c10::optional<at::Tensor> Bias, const c10::optional<at::Tensor> page_table_k, const c10::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread); const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread);
std::vector<py::object> fused_attn_bwd( std::vector<py::object> fused_attn_bwd(
...@@ -69,6 +70,13 @@ std::vector<py::object> fused_attn_bwd( ...@@ -69,6 +70,13 @@ std::vector<py::object> fused_attn_bwd(
at::Tensor fa_prepare_fwd(at::Tensor qkvi); at::Tensor fa_prepare_fwd(at::Tensor qkvi);
at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v);
at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len);
at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t);
void copy_to_kv_cache(torch::Tensor new_k, torch::Tensor new_v, torch::Tensor k_cache,
torch::Tensor v_cache, torch::Tensor page_table, torch::Tensor cu_new_lens,
torch::Tensor cu_cached_lens, NVTE_QKV_Format kv_format, int b,
int max_ctx_len, int max_seq_len, int max_pages_per_seq, bool is_non_paged);
/*************************************************************************************************** /***************************************************************************************************
* GEMM * GEMM
**************************************************************************************************/ **************************************************************************************************/
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "extensions.h" #include "extensions.h"
#include "kv_cache.cuh"
#include "thd_utils.cuh" #include "thd_utils.cuh"
constexpr int block_size = 512; constexpr int block_size = 512;
...@@ -90,8 +91,9 @@ std::vector<py::object> fused_attn_fwd( ...@@ -90,8 +91,9 @@ std::vector<py::object> fused_attn_fwd(
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q,
const py::handle K, const py::handle V, const at::ScalarType fake_dtype, const py::handle K, const py::handle V, const at::ScalarType fake_dtype,
const c10::optional<at::Tensor> cu_seqlens_q_padded, const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer, const c10::optional<at::Tensor> cu_seqlens_kv_padded,
py::handle o_quantizer, const c10::optional<at::Tensor> Bias, const c10::optional<at::Tensor> page_table_k, const c10::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread) { const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread) {
using namespace transformer_engine; using namespace transformer_engine;
using namespace transformer_engine::pytorch; using namespace transformer_engine::pytorch;
...@@ -126,6 +128,7 @@ std::vector<py::object> fused_attn_fwd( ...@@ -126,6 +128,7 @@ std::vector<py::object> fused_attn_fwd(
TensorWrapper te_Bias; TensorWrapper te_Bias;
TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv;
TensorWrapper te_cu_seqlens_q_padded, te_cu_seqlens_kv_padded; TensorWrapper te_cu_seqlens_q_padded, te_cu_seqlens_kv_padded;
TensorWrapper te_page_table_k, te_page_table_v;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8 // FP8
auto h = q_shape[q_shape.size() - 2]; auto h = q_shape[q_shape.size() - 2];
...@@ -170,6 +173,19 @@ std::vector<py::object> fused_attn_fwd( ...@@ -170,6 +173,19 @@ std::vector<py::object> fused_attn_fwd(
cu_seqlens_kv_padded.value().data_ptr(), cu_seqlens_kv_padded_shape, DType::kInt32); cu_seqlens_kv_padded.value().data_ptr(), cu_seqlens_kv_padded_shape, DType::kInt32);
} }
if ((page_table_k.has_value()) && (page_table_v.has_value())) {
auto page_table_k_sizes = page_table_k.value().sizes().vec();
std::vector<size_t> page_table_k_shape{page_table_k_sizes.begin(), page_table_k_sizes.end()};
auto page_table_v_sizes = page_table_v.value().sizes().vec();
std::vector<size_t> page_table_v_shape{page_table_v_sizes.begin(), page_table_v_sizes.end()};
te_page_table_k =
makeTransformerEngineTensor(page_table_k.value().data_ptr(), page_table_k_shape,
DType::kInt32, nullptr, nullptr, nullptr);
te_page_table_v =
makeTransformerEngineTensor(page_table_v.value().data_ptr(), page_table_v_shape,
DType::kInt32, nullptr, nullptr, nullptr);
}
// extract rng seed and offset // extract rng seed and offset
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>( auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); rng_gen, at::cuda::detail::getDefaultCUDAGenerator());
...@@ -187,13 +203,13 @@ std::vector<py::object> fused_attn_fwd( ...@@ -187,13 +203,13 @@ std::vector<py::object> fused_attn_fwd(
TensorWrapper workspace; TensorWrapper workspace;
// populate tensors with appropriate shapes and dtypes // populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), nvte_fused_attn_fwd(
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), te_O.data(),
te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q, te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(),
max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout, bias_type, te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
attn_mask_type, window_size[0], window_size[1], workspace.data(), attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1],
at::cuda::getCurrentCUDAStream()); workspace.data(), at::cuda::getCurrentCUDAStream());
// allocate memory for workspace and auxiliary output tensors // allocate memory for workspace and auxiliary output tensors
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
...@@ -241,13 +257,13 @@ std::vector<py::object> fused_attn_fwd( ...@@ -241,13 +257,13 @@ std::vector<py::object> fused_attn_fwd(
} }
// execute the kernel // execute the kernel
nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), nvte_fused_attn_fwd(
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), te_O.data(),
te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q, te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(),
max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout, bias_type, te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
attn_mask_type, window_size[0], window_size[1], workspace.data(), attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1],
at::cuda::getCurrentCUDAStream()); workspace.data(), at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers, but not allocated memory // destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
...@@ -1012,3 +1028,174 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t ...@@ -1012,3 +1028,174 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t
return output; return output;
} }
/***************************************************************************************************
* KV Cache: Convert a tensor from qkv_format = thd to qkv_format = bshd
**************************************************************************************************/
template <typename scalar_t>
void convert_thd_to_bshd_launcher(at::Tensor tensor, at::Tensor new_tensor, at::Tensor cu_seqlens,
int b, int max_seq_len, int h, int d) {
transformer_engine::fused_attn::
convert_thd_to_bshd_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<scalar_t *>(tensor.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(new_tensor.data_ptr<scalar_t>()), cu_seqlens.data_ptr<int>(),
b, max_seq_len, h, d);
}
at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len) {
int h = tensor.size(1);
int d = tensor.size(2);
std::vector<int64_t> shape = {b, max_seq_len, h, d};
at::Tensor new_tensor = at::zeros(shape, at::CUDA(tensor.scalar_type()));
if (new_tensor.scalar_type() == at::ScalarType::Half) {
using dtype = at::Half;
convert_thd_to_bshd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else if (new_tensor.scalar_type() == at::ScalarType::BFloat16) {
using dtype = at::BFloat16;
convert_thd_to_bshd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else if (new_tensor.scalar_type() == at::ScalarType::Float) {
using dtype = float;
convert_thd_to_bshd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else if (new_tensor.scalar_type() == at::ScalarType::Float8_e4m3fn) {
using dtype = at::Float8_e4m3fn;
convert_thd_to_bshd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else if (new_tensor.scalar_type() == at::ScalarType::Float8_e5m2) {
using dtype = at::Float8_e5m2;
convert_thd_to_bshd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else {
NVTE_ERROR("Unsupported dtype for KV cache.\n");
}
return new_tensor;
}
/***************************************************************************************************
* KV Cache: Convert a tensor from qkv_format = bshd to qkv_format = thd
**************************************************************************************************/
template <typename scalar_t>
void convert_bshd_to_thd_launcher(at::Tensor tensor, at::Tensor new_tensor, at::Tensor cu_seqlens,
int b, int max_seq_len, int h, int d) {
transformer_engine::fused_attn::
convert_bshd_to_thd_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<scalar_t *>(tensor.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(new_tensor.data_ptr<scalar_t>()), cu_seqlens.data_ptr<int>(),
b, max_seq_len, h, d);
}
at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t) {
int b = tensor.size(0);
int max_seq_len = tensor.size(1);
int h = tensor.size(2);
int d = tensor.size(3);
std::vector<int64_t> shape = {t, h, d};
at::Tensor new_tensor = at::zeros(shape, at::CUDA(tensor.scalar_type()));
if (tensor.scalar_type() == at::ScalarType::Half) {
using dtype = at::Half;
convert_bshd_to_thd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else if (tensor.scalar_type() == at::ScalarType::BFloat16) {
using dtype = at::BFloat16;
convert_bshd_to_thd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else if (tensor.scalar_type() == at::ScalarType::Float) {
using dtype = float;
convert_bshd_to_thd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else if (tensor.scalar_type() == at::ScalarType::Float8_e4m3fn) {
using dtype = at::Float8_e4m3fn;
convert_bshd_to_thd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else if (tensor.scalar_type() == at::ScalarType::Float8_e5m2) {
using dtype = at::Float8_e5m2;
convert_bshd_to_thd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else {
NVTE_ERROR("Unsupported dtype for KV cache.\n");
}
return new_tensor;
}
/***************************************************************************************************
* KV Cache: Copy new KV tokens to the KV cache
* 1. new_k and new_v are in qkv_format; k_cache and v_cache are in 'bshd' format
* 2. cu_new_lens and cu_cached_lens are in shape [b + 1]; cu_cached_lens include the added lens
* in current step
* 3. Non-paged KV cache is a special case of paged KV cache, with page_table = [b, 1] and
* max_pages_per_seq = 1. We use the same underlying kernel for both non-paged and paged.
* Set is_non_paged = True/False to indicate as such.
* 4. is_non_paged = True also re-indexes the KV cache, e.g. the initial batch indices [0, 3, 1, 2]
* becomes [0, 1, 1, 2]. The page_table = batch_indices.unsqueeze(1) is however unchanged.
* batch_indices_post can be used for monotonical indexing, i.e. [0, 1, 2, 3]. batch_indices is
* preserved for the next layer in the same iteration.
* 5. Only supports same page_table for k_cache and v_cache
* 6. Only pad_between_seqs = False when qkv_format = thd, i.e. there should be no pad tokens
* between sequences in new_k and new_v such as [a a a 0..0 b b 0..0 c 0..0].
**************************************************************************************************/
template <typename scalar_t>
void copy_to_kv_cache_launcher(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache,
at::Tensor v_cache, at::Tensor page_table, at::Tensor cu_new_lens,
at::Tensor cu_cached_lens, NVTE_QKV_Format qkv_format, int h_kv,
int d_k, int d_v, int b, int max_ctx_len, int max_seq_len,
int max_pages_per_seq, bool is_non_paged) {
if (new_k.data_ptr() != nullptr && new_v.data_ptr() != nullptr && k_cache.data_ptr() != nullptr &&
v_cache.data_ptr() != nullptr) {
if (is_non_paged) {
transformer_engine::fused_attn::
reindex_kv_cache_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<scalar_t *>(k_cache.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(v_cache.data_ptr<scalar_t>()),
page_table.data_ptr<int>(), cu_new_lens.data_ptr<int>(),
cu_cached_lens.data_ptr<int>(), h_kv, d_k, d_v, b, max_seq_len);
}
transformer_engine::fused_attn::
copy_to_kv_cache_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<scalar_t *>(new_k.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(new_v.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(k_cache.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(v_cache.data_ptr<scalar_t>()), page_table.data_ptr<int>(),
cu_new_lens.data_ptr<int>(), cu_cached_lens.data_ptr<int>(), qkv_format, h_kv, d_k, d_v,
b, max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged);
}
}
void copy_to_kv_cache(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache, at::Tensor v_cache,
at::Tensor page_table, at::Tensor cu_new_lens, at::Tensor cu_cached_lens,
NVTE_QKV_Format qkv_format, int b, int max_ctx_len, int max_seq_len,
int max_pages_per_seq, bool is_non_paged) {
int h_kv = new_k.size(-2);
int d_k = new_k.size(-1);
int d_v = new_v.size(-1);
NVTE_CHECK(k_cache.scalar_type() == v_cache.scalar_type() &&
new_k.scalar_type() == new_v.scalar_type() &&
new_k.scalar_type() == k_cache.scalar_type(),
"new_k, new_v, k_cache and v_cache must be of the same data type.");
NVTE_CHECK(qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD ||
qkv_format == NVTE_QKV_Format::NVTE_THD,
"qkv_format must be {BSHD, SBHD, THD}.");
if (k_cache.scalar_type() == at::ScalarType::Half) {
using dtype = at::Half;
copy_to_kv_cache_launcher<dtype>(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens,
cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len,
max_seq_len, max_pages_per_seq, is_non_paged);
} else if (k_cache.scalar_type() == at::ScalarType::BFloat16) {
using dtype = at::BFloat16;
copy_to_kv_cache_launcher<dtype>(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens,
cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len,
max_seq_len, max_pages_per_seq, is_non_paged);
} else if (k_cache.scalar_type() == at::ScalarType::Float) {
using dtype = float;
copy_to_kv_cache_launcher<dtype>(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens,
cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len,
max_seq_len, max_pages_per_seq, is_non_paged);
} else if (k_cache.scalar_type() == at::ScalarType::Float8_e4m3fn) {
using dtype = at::Float8_e4m3fn;
copy_to_kv_cache_launcher<dtype>(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens,
cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len,
max_seq_len, max_pages_per_seq, is_non_paged);
} else if (k_cache.scalar_type() == at::ScalarType::Float8_e5m2) {
using dtype = at::Float8_e5m2;
copy_to_kv_cache_launcher<dtype>(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens,
cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len,
max_seq_len, max_pages_per_seq, is_non_paged);
} else {
NVTE_ERROR("Unsupported dtype for KV cache.\n");
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment