Unverified Commit 4f33ece4 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

Add KV cache for paged/non-paged attention (#1355)



* add paged attention; test_kv_cache_accuray and test_paged_attn pass
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove unnecessary change from last commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* test_fused_attn pass
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remove unnecessary import in test_numerics
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add license for test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add to L0 test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update license for test_paged_attn
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update kv_cache_manager license
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix build issue from previous merge
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: minor fix/preparation for inference/cuda graph
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: non-paged
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: non-paged, bshd/sbhd
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: non-paged, thd, no CG
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: non-paged, thd, CG
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: non-paged, CG
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: non-paged, using paged kernel
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: restructure kernels
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: paged, CG
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: padding + BRCM
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: restructure IP, clean up
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: fix non-CG, fused
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: fix last commit
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: unfused, non-CG
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: flash-attn, non-CG
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: flash_attn_with_kvcache
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* commit two files missed by bcef6b34
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: thd_bshd_bshd
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: fix last commit
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: fix 1c31b68d
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: add bshd_2sbhd, sbhd_2bshd
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: some cleanup
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: all qkv_format combinations and merge CM files
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: some lint fixes
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: add docstring for IP
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix sequences_pre
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: minor fixes for multi-layer
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: initial multi-layer test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: minor clean up
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: clean up
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: switch to flash_attn_varlen_func
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: fix unfused for separate q/kv format
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: fix fused for separate q/kv formats
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: flash attn + TELayer + 2 layers
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: unfused + TL + 2layers
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: all modules/backend
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: minor cleanup
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: FlashAttention on Hopper with 2.7.3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: FlashAttention + v3 from 39e7179
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: FlashAttention + v3 + FP8 + WIP
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: add backend support table
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: clean up
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: separate use_flash_attention_2 and _3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* WIP: tweaks to paged attn script
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* WIP: enable/disable certain cases for fused attn
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: small fixes for lint and cg
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: minor fixes for attn/infer
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: fix CP
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* WIP: readd page info to FADescriptor_v1
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor tweak to test_numerics.py
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* fix 9.5/9.7 sq/skv + mask logic
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* clean up
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fix for FA3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* more minor fixes for FA3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* test page_size=1 for FA3
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix t3hd/th3d strides
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix ckpt recompute and fa3 k_scale
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* raise dynamo recompile limit for test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove thunder test from L0
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix FA selection logic
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix FA3 q_descale shape
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remove page_table from IP.step() returns
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix FP8 FlashAttn DPA fp8_dpa tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix CP
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor tweaks
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update FA3 note and L3 test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove redundant import in test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* adopt new FA3 APIs from FA2.7.3+/hopper for CP and non-CP
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix lint
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* relax tols for TransformerLayers
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix merge
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix merge 2
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix FA import comments
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* relax tols for Ampere
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix fa3 version and reduce messaging
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update FA3 to its latest commit on main
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add default values to IP and assertion to graph.py
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add more comments in attention
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* use custom_cache_manager instead of cache_manager
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 05f6a691
......@@ -38,7 +38,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || test_fail
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py"
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_paged_attn.py || test_fail "test_paged_attn.py"
if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES"
exit 1
......
......@@ -17,7 +17,7 @@ if [ $sm_arch -gt 90 ]
then
FA_versions=(2.7.3)
else
FA_versions=(2.1.1 2.3.0 2.4.1 2.5.7 2.7.3 3.0.0b1)
FA_versions=(2.3.0 2.4.1 2.5.7 2.7.3 3.0.0b1)
fi
for fa_version in "${FA_versions[@]}"
......@@ -28,10 +28,12 @@ do
then
pip3 install flash-attn==${fa_version}
else
pip3 install "git+https://github.com/Dao-AILab/flash-attention.git@v2.7.2#egg=flashattn-hopper&subdirectory=hopper"
python_path=`python3 -c "import site; print(site.getsitepackages()[0])"`
mkdir -p $python_path/flashattn_hopper
wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/v2.7.2/hopper/flash_attn_interface.py
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/ && git checkout 27f501d && cd hopper/ && python setup.py install
python_path=`python -c "import site; print(site.getsitepackages()[0])"`
mkdir -p $python_path/flash_attn_3
wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/27f501dbe011f4371bff938fe7e09311ab3002fa/hopper/flash_attn_interface.py
cd ../../
fi
# Run tests
......
......@@ -26,6 +26,7 @@ from transformer_engine.pytorch.dot_product_attention.utils import (
check_set_window_size,
AttentionParams,
)
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.dot_product_attention.rope import RotaryPositionEmbedding
from transformer_engine.pytorch.constants import TE_DType
import transformer_engine.pytorch.cpp_extensions as ext
......@@ -96,6 +97,8 @@ class ModelConfig:
num_layers: int = 1,
bias_shape: str = "1hss",
window_size: Tuple[int, int] = (-1, -1),
total_requests: int = None,
max_ctx_len: int = None,
):
self.batch_size = batch_size
self.num_heads = num_heads
......@@ -114,6 +117,8 @@ class ModelConfig:
self.num_layers = num_layers
self.bias_shape = bias_shape
self.window_size = window_size
self.total_requests = total_requests
self.max_ctx_len = max_ctx_len
@contextmanager
......@@ -136,6 +141,8 @@ def _get_attention_backends(
deterministic: bool = False,
fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None,
is_training: bool = True,
inference_params: Optional[InferenceParams] = None,
) -> Tuple[List, List]:
"""Check if what attention backends support a model configuration"""
......@@ -165,6 +172,7 @@ def _get_attention_backends(
fused_attn_backends = []
available_backends = None
flash_attention_backend = None
fused_attention_backend = None
def test():
......@@ -190,10 +198,13 @@ def _get_attention_backends(
deterministic=deterministic,
fp8=fp8,
fp8_meta=fp8_meta,
is_training=is_training,
inference_params=inference_params,
)
(
use_flash_attention,
use_fused_attention,
flash_attention_backend,
fused_attention_backend,
use_unfused_attention,
available_backends,
......@@ -202,20 +213,21 @@ def _get_attention_backends(
# from get_attention_backend()
_attention_backends["use_flash_attention"] = use_flash_attention
_attention_backends["use_fused_attention"] = use_fused_attention
_attention_backends["flash_attention_backend"] = flash_attention_backend
_attention_backends["fused_attention_backend"] = fused_attention_backend
_attention_backends["use_unfused_attention"] = use_unfused_attention
_attention_backends["backend_selection_requires_update"] = False
return available_backends, fused_attention_backend
return available_backends, flash_attention_backend, fused_attention_backend
backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"}
with logging_context():
for i in range(3):
os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
_attention_backends["backend_selection_requires_update"] = True
available_backends, fused_attention_backend = test()
available_backends, flash_attention_backend, fused_attention_backend = test()
if fused_attention_backend == FusedAttnBackend[backends[i]]:
fused_attn_backends.append(fused_attention_backend)
return available_backends, fused_attn_backends
return available_backends, flash_attention_backend, fused_attn_backends
model_configs_base = {
......@@ -267,7 +279,7 @@ def test_dot_product_attention(
if config.window_size == (-1, -1) and swa:
config.window_size = [2, 2]
config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
available_backends, fused_attn_backends = _get_attention_backends(
available_backends, _, fused_attn_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
......@@ -1131,7 +1143,7 @@ def test_transformer_layer(
workspace_opt = True
# Test backend availability
available_backends, fused_attn_backends = _get_attention_backends(
available_backends, _, fused_attn_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout="sbh3d" if fused_qkv_params else "sb3hd",
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from collections import OrderedDict
from typing import List
import os
import logging
import math
import pytest
import torch
from torch.distributions import Exponential
from transformer_engine.pytorch import make_graphed_callables
from transformer_engine.common import recipe
from transformer_engine.pytorch import fp8_autocast, fp8_model_init
from transformer_engine.pytorch.transformer import (
TransformerLayer,
)
from transformer_engine.pytorch.attention import DotProductAttention
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils as fa_utils
from transformer_engine.pytorch.utils import (
get_device_compute_capability,
init_method_normal,
scaled_init_method_normal,
is_bf16_compatible,
)
from test_fused_attn import (
ModelConfig,
reset_rng_states,
_get_attention_backends,
)
# Initialize RNG state
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
param_types = [torch.float16]
if is_bf16_compatible():
param_types.append(torch.bfloat16)
model_configs_infer = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"infer_0": ModelConfig(
4, 16, 16, 128, 64, 64, 0.0, "no_mask", "no_bias", total_requests=8, max_ctx_len=16
),
"infer_1": ModelConfig(
2, 16, 4, 64, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6, max_ctx_len=16
),
}
qkv_formats = ["bshd", "sbhd", "thd"]
def to_pretty_string(x: torch.Tensor):
return "[" + ",".join(["{:>3s}".format(str(i)) for i in x.tolist()]) + "]"
def round_up(a: int, b: int):
return b * math.ceil(a / b)
class Simulation:
def __init__(
self,
total_requests: int = 10,
max_seq_len: int = 1024,
max_ctx_len: int = 128,
max_batch_size: int = 5,
poisson_rate: float = 1,
):
self.total_requests = total_requests
self.max_seq_len = max_seq_len
self.max_batch_size = max_batch_size
self.poisson_rate = poisson_rate
# calculate maximum context/generation length
self.max_ctx_len = max_ctx_len
self.max_gen_len = max_seq_len - self.max_ctx_len
# simulate sequence ids in monotonically increasing fashion
self.seq_ids = torch.range(0, total_requests - 1, dtype=torch.int32, device="cpu")
# simulate context lengths in Uniform distribution
self.context_lens = torch.randint(
1, self.max_ctx_len, [total_requests], dtype=torch.int32, device="cpu"
)
# simulate gen lengths in Exponential distribution
gen_dist = Exponential(1 / self.max_gen_len)
gen_lens = gen_dist.sample((total_requests,))
gen_lens = torch.where(gen_lens > self.max_gen_len, self.max_gen_len, gen_lens).to(
dtype=torch.int32, device="cpu"
)
self.gen_lens = torch.where(gen_lens == 0, 1, gen_lens).to(dtype=torch.int32, device="cpu")
# simulate arrival times in Poisson distribution
if poisson_rate is None:
self.poisson_rate = torch.randint(1, max_batch_size, [1]).item()
interval_dist = Exponential(self.poisson_rate)
arrival_intervals = interval_dist.sample((total_requests,))
self.arrival_times = torch.cumsum(arrival_intervals, dim=0).to(
dtype=torch.int32, device="cpu"
)
self.last_arrival = self.arrival_times.max().item()
# initialize tensors
self.reset()
def reset(self):
self.t = 0
self.request_delays = torch.zeros([self.total_requests], dtype=torch.int32, device="cpu")
self.delayed_seq_ids = torch.Tensor().to(dtype=torch.int32, device="cpu")
self.serving_times = self.arrival_times
self.complete_times = self.arrival_times
# batch info at step t
self.t_seq_ids = torch.Tensor([]).to(dtype=torch.bool, device="cpu")
self.t_ctx_lens = torch.Tensor([]).to(dtype=torch.bool, device="cpu")
self.t_gen_lens = torch.Tensor([]).to(dtype=torch.bool, device="cpu")
self.t_total_lens = self.t_ctx_lens + self.t_gen_lens
self.t_batch_size = 0
# step info from step t-1 to t
self.step_lens = torch.Tensor([]).to(dtype=torch.int32, device="cpu")
def print_setup(self, logger):
logger.info("Simulation:")
logger.info(" {:<31s}: {}".format("total number of requests", self.total_requests))
logger.info(" {:<31s}: {}".format("max sequence length per request", self.max_seq_len))
logger.info(" {:<31s}: {}".format("max context length", self.max_ctx_len))
logger.info(" {:<31s}: {}".format("max generation length", self.max_gen_len))
logger.info(" {:<31s}: {}".format("max batch size per iteration", self.max_batch_size))
logger.info(" {:<31s}: {}".format("Poisson rate", self.poisson_rate))
logger.info(" {:<17s}: {}".format("sequence ids", to_pretty_string(self.seq_ids)))
logger.info(" {:<17s}: {}".format("arrival times", to_pretty_string(self.arrival_times)))
logger.info(" {:<17s}: {}".format("context lengths", to_pretty_string(self.context_lens)))
logger.info(" {:<17s}: {}".format("generation lengths", to_pretty_string(self.gen_lens)))
def print_step(self, logger):
logger.info(f"Step t = {self.t}:")
logger.info(" {:<15s}: {}".format("t_batch_size", self.t_batch_size))
logger.info(" {:<15s}: {}".format("t_seq_ids", self.t_seq_ids.tolist()))
logger.info(" {:<15s}: {}".format("t_ctx_lens", self.t_ctx_lens.tolist()))
logger.info(" {:<15s}: {}".format("t_gen_lens", self.t_gen_lens.tolist()))
logger.info(" {:<15s}: {}".format("t_total_lens", self.t_total_lens.tolist()))
logger.info(" {:<15s}: {}".format("step_lens", self.step_lens.tolist()))
def print_summary(self, logger):
logger.info("Summary:")
logger.info(" {:<18s}: {}".format("total steps taken", self.t))
logger.info(" {:<18s}: {}".format("arrival_times", to_pretty_string(self.arrival_times)))
logger.info(" {:<18s}: {}".format("serving_times", to_pretty_string(self.serving_times)))
logger.info(" {:<18s}: {}".format("total_gen_lens", to_pretty_string(self.gen_lens)))
logger.info(" {:<18s}: {}".format("complete_times", to_pretty_string(self.complete_times)))
def add_new_seqs(self, new_seq_ids):
# get ctx_lens for new seqs
self.t_seq_ids = torch.cat([self.t_seq_ids, new_seq_ids], dim=0)
self.t_ctx_lens = torch.cat([self.t_ctx_lens, self.context_lens[new_seq_ids]], dim=0)
gen_lens = torch.Tensor([0] * len(new_seq_ids)).to(dtype=torch.int32, device="cpu")
self.t_gen_lens = torch.cat([self.t_gen_lens, gen_lens], dim=0)
# append new seqs' ctx_lens to step_lens
self.step_lens = torch.cat([self.step_lens, self.context_lens[new_seq_ids]], dim=0)
def remove_finished(self):
# figure out which seqs have finished
finished = torch.where(self.t_gen_lens - self.gen_lens[self.t_seq_ids] < 0, False, True).to(
dtype=torch.bool, device="cpu"
)
self.t_seq_ids = self.t_seq_ids[~finished]
self.t_ctx_lens = self.t_ctx_lens[~finished]
self.t_gen_lens = self.t_gen_lens[~finished]
# add ones for unfinished seqs to step_lens
self.step_lens = torch.ones([len(self.t_seq_ids)], dtype=torch.int32, device="cpu")
def step(self, dynamic_fill: bool = True):
# remove finished seqs
if self.t != 0:
self.remove_finished()
# get allowed new seqs
arrived_seq_ids = torch.where(self.arrival_times == self.t, True, False).nonzero().view(-1)
queuing_seq_ids = torch.cat([self.delayed_seq_ids, arrived_seq_ids], dim=0)
if dynamic_fill:
allowed_num_new_seqs = self.max_batch_size - len(self.t_seq_ids)
else:
allowed_num_new_seqs = 0 if len(self.t_seq_ids) else self.max_batch_size
if len(queuing_seq_ids) > allowed_num_new_seqs:
new_seq_ids = queuing_seq_ids[:allowed_num_new_seqs]
self.delayed_seq_ids = queuing_seq_ids[allowed_num_new_seqs:]
self.request_delays[self.delayed_seq_ids.tolist()] += 1
else:
new_seq_ids = queuing_seq_ids
self.delayed_seq_ids = torch.Tensor().to(dtype=torch.int32)
# add new seqs to batch
self.add_new_seqs(new_seq_ids)
# update batch variables
self.t_batch_size = len(self.t_seq_ids)
self.t_total_lens = self.t_ctx_lens + self.t_gen_lens
def get_model(
module: torch.nn.Module,
config: ModelConfig,
dtype: torch.dtype,
backend: str = "FusedAttention",
qkv_format: str = "bshd",
num_layers: int = 1,
mode: str = "reference",
is_fp8: bool = False,
):
reset_rng_states()
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, num_layers)
if mode == "reference":
attn_mask_type = "causal"
qkv_format = "bshd"
if mode == "inference":
attn_mask_type = "padding_causal" if backend != "FusedAttention" else "padding"
fp8_recipe = recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
fp8_dpa=is_fp8,
fp8_mha=False,
)
if module == "TransformerLayer":
hidden_size = config.head_dim_qk * config.num_heads
with fp8_model_init(enabled=is_fp8, recipe=fp8_recipe):
model = [
TransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=4 * hidden_size,
num_attention_heads=config.num_heads,
num_gqa_groups=config.num_gqa_groups,
hidden_dropout=0.0,
attention_dropout=config.dropout_p,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
layer_number=layer_number,
kv_channels=config.head_dim_qk,
self_attn_mask_type=attn_mask_type,
fuse_qkv_params=False,
params_dtype=dtype,
attn_input_format=qkv_format,
)
.cuda()
.eval()
for layer_number in range(1, num_layers + 1)
]
if module == "DotProductAttention":
with fp8_model_init(enabled=is_fp8, recipe=fp8_recipe):
model = [
DotProductAttention(
kv_channels=config.head_dim_qk,
num_attention_heads=config.num_heads,
num_gqa_groups=config.num_gqa_groups,
layer_number=layer_number,
attention_dropout=config.dropout_p,
qkv_format=qkv_format,
attn_mask_type=attn_mask_type,
)
.cuda()
.eval()
for layer_number in range(1, num_layers + 1)
]
return model
def generate_args(
module: torch.nn.Module,
config: ModelConfig,
dtype: torch.dtype,
qkv_format: str = "bshd",
mode: str = "full_inputs",
):
# full inputs used as reference
if mode == "full_inputs":
warmup = False
shapes = []
if module == "TransformerLayer":
shapes.append(
[config.total_requests, config.max_seqlen_kv, config.num_heads * config.head_dim_qk]
)
if module == "DotProductAttention":
shapes.append(
[config.total_requests, config.max_seqlen_kv, config.num_heads, config.head_dim_qk]
)
shapes.append(
[
config.total_requests,
config.max_seqlen_kv,
config.num_gqa_groups,
config.head_dim_qk,
]
)
shapes.append(
[
config.total_requests,
config.max_seqlen_kv,
config.num_gqa_groups,
config.head_dim_v,
]
)
# sample args used for cuda graph warmup
elif mode == "sample_args":
warmup = True
shapes = []
if qkv_format == "bshd":
shape = [config.batch_size, config.max_ctx_len]
if qkv_format == "sbhd":
shape = [config.max_ctx_len, config.batch_size]
if qkv_format == "thd":
shape = [config.batch_size * config.max_ctx_len]
if module == "TransformerLayer":
shapes.append([*shape, config.num_heads * config.head_dim_qk])
if module == "DotProductAttention":
shapes.append([*shape, config.num_heads, config.head_dim_qk])
shapes.append([*shape, config.num_gqa_groups, config.head_dim_qk])
shapes.append([*shape, config.num_gqa_groups, config.head_dim_v])
num_tensors = len(shapes)
if warmup:
return [
torch.ones(
*shapes[i],
device="cuda",
dtype=dtype,
)
for i in range(num_tensors)
]
elif module == "TransformerLayer":
return [
0.01
* torch.randint(
-100,
100,
shapes[i],
device="cuda",
dtype=dtype,
)
for i in range(num_tensors)
]
elif module == "DotProductAttention":
return [
0.1
* torch.randn(
*shapes[i],
device="cuda",
dtype=dtype,
)
for i in range(num_tensors)
]
def get_tols(module, backend, dtype):
if module == "TransformerLayer":
tols = {
torch.half: (5e-3, 5e-3),
torch.bfloat16: (3.5e-2, 3.5e-2),
}
if module == "DotProductAttention":
tols = {
torch.half: (1e-3, 1e-3),
torch.bfloat16: (1e-2, 1e-3),
torch.float8_e4m3fn: (2e-2, 3e-2),
}
return tols[dtype]
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model", model_configs_infer.keys())
@pytest.mark.parametrize("qkv_format", qkv_formats)
@pytest.mark.parametrize("is_paged", [False, True])
@pytest.mark.parametrize("backend", ["FusedAttention", "FlashAttention", "UnfusedAttention"])
@pytest.mark.parametrize("module", ["TransformerLayer", "DotProductAttention"])
@pytest.mark.parametrize("is_cuda_graph", [False, True])
@pytest.mark.parametrize("is_fp8", [False, True])
def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda_graph, is_fp8):
reset_rng_states()
logger = logging.getLogger("test_paged_attn")
fp8_recipe = recipe.DelayedScaling(
margin=0,
fp8_format=recipe.Format.HYBRID,
amax_history_len=1,
amax_compute_algo="most_recent",
fp8_dpa=is_fp8,
fp8_mha=False,
)
fp8_meta = {}
fp8_meta["recipe"] = fp8_recipe
config = model_configs_infer[model]
num_layers = 2 if module == "TransformerLayer" and backend != "FusedAttention" else 1
# flash-attn v2 requires page_size >= 256
if backend == "FlashAttention" and not fa_utils.v3_is_installed:
config_max_seqlen_q = config.max_seqlen_q
config_max_seqlen_kv = config.max_seqlen_kv
config.max_seqlen_q = 256
config.max_seqlen_kv = 256
# create a real-life simulation
max_batch_size = config.batch_size
page_size = None
total_num_pages = None
if is_paged:
page_size = 256 if backend == "FlashAttention" and not fa_utils.v3_is_installed else 1
config.max_seqlen_kv = round_up(config.max_seqlen_kv, page_size)
total_num_pages = int(max_batch_size * config.max_seqlen_kv / page_size)
else:
config.max_seqlen_kv = round_up(config.max_seqlen_kv, 64)
sim = Simulation(
total_requests=config.total_requests,
max_seq_len=config.max_seqlen_kv,
max_ctx_len=config.max_ctx_len,
max_batch_size=max_batch_size,
poisson_rate=2,
)
sim.print_setup(logger)
# initialize inference_params
inference_params = InferenceParams(
max_batch_size=max_batch_size,
max_seqlen_kv=config.max_seqlen_kv,
num_heads_kv=config.num_gqa_groups,
head_dim_k=config.head_dim_qk,
head_dim_v=config.head_dim_v,
dtype=dtype,
is_paged=is_paged,
page_size=page_size,
total_num_pages=total_num_pages,
max_ctx_len=config.max_ctx_len,
qkv_format=qkv_format,
)
if module == "DotProductAttention":
for layer_number in range(1, num_layers + 1):
inference_params.allocate_memory(layer_number)
# figure out supported backends
inference_params_qkv_format = "bshd"
qkv_layout = qkv_format + "_" + "_".join([inference_params_qkv_format] * 2)
if is_paged:
qkv_layout = "paged_kv_" + qkv_layout
available_backends, _, fused_attn_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=False,
is_training=False,
fp8=is_fp8,
fp8_meta=fp8_meta,
inference_params=inference_params,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if backend == "FlashAttention" and not flash_attn_supported:
pytest.skip("FlashAttention backend is not supported")
if backend == "FusedAttention" and not fused_attn_supported:
pytest.skip("FusedAttention backend is not supported")
if backend == "UnfusedAttention" and not unfused_attn_supported:
pytest.skip("UnfusedAttention backend is not supported")
os.environ["NVTE_FLASH_ATTN"] = str(int(backend == "FlashAttention"))
os.environ["NVTE_FUSED_ATTN"] = str(int(backend == "FusedAttention"))
os.environ["NVTE_UNFUSED_ATTN"] = str(int(backend == "UnfusedAttention"))
if backend == "UnfusedAttention" and is_cuda_graph:
pytest.skip("CUDA graph is not supported for UnfusedAttention backend")
# TransformerLayer FP8 TN Gemm currently requires %8=0
if is_fp8 and not (qkv_format == "thd" and module == "DotProductAttention"):
pytest.skip("BSHD/SBHD <-> THD conversions for FP8 are not supported")
# create full model
logger.info("=== Generating all tokens at once ===")
model = get_model(module, config, dtype, backend, qkv_format, num_layers, mode="reference")
# generate data for all requests
full_inputs = generate_args(module, config, dtype, qkv_format="bshd", mode="full_inputs")
# generate reference results
if module == "DotProductAttention":
full_output = full_inputs
for m in model:
full_output = m(
*full_output if isinstance(full_output, List) else full_output,
)
if module == "TransformerLayer":
full_output = full_inputs
for m in model:
full_output = m(
full_output[0] if isinstance(full_output, List) else full_output,
)
# create inference model
logger.info("=== Generating one token at a time ===")
model = get_model(
module,
config,
dtype,
backend,
qkv_format,
num_layers,
mode="inference",
is_fp8=is_fp8,
)
# graph the model if necessary
if is_cuda_graph:
t_seq_ids = torch.range(0, max_batch_size, dtype=torch.int32, device="cpu")
step_lens = config.max_ctx_len * torch.ones(max_batch_size, dtype=torch.int32, device="cpu")
step_dict = OrderedDict(zip(t_seq_ids.tolist(), step_lens.tolist()))
inference_params.pre_step(step_dict)
sample_args = generate_args(
module, config, dtype, qkv_format=qkv_format, mode="sample_args"
)
sample_kwargs = {}
sample_kwargs["cu_seqlens_q"] = torch.linspace(
0,
config.batch_size * config.max_ctx_len,
steps=config.batch_size + 1,
device="cuda",
dtype=torch.int32,
)
sample_kwargs["cu_seqlens_kv"] = torch.linspace(
0,
config.batch_size * config.max_ctx_len,
steps=config.batch_size + 1,
device="cuda",
dtype=torch.int32,
)
sample_kwargs["inference_params"] = inference_params
sample_kwargs["max_seqlen_q"] = config.max_ctx_len
sample_kwargs["max_seqlen_kv"] = config.max_seqlen_kv
model = [
make_graphed_callables(
model[i],
sample_args,
num_warmup_iters=10,
fp8_enabled=is_fp8,
sample_kwargs=sample_kwargs,
fp8_recipe=fp8_recipe,
)
for i in range(num_layers)
]
sim.reset()
inference_params.reset()
step_dict = OrderedDict()
# simulate step by step
# t-1: ...
# compute for seq_ids = [0, 1, 2], ctx_lens = [5, 2, 3], gen_lens = [2, 9, 4],
# batch_size = 3, step_lens = [1, 1, 1]
# increase counter for gen_lens = [3, 10, 5]
# t: detect seq 1 is finished since expected_gen_lens = [12, 10, 15]
# add two new seqs 3 and 4, with ctx lens 10 and 11
# compute for seq_ids = [0, 2, 3, 4], ctx_lens = [5, 3, 10, 11], gen_lens = [3, 5, 0, 0],
# batch_size = 4, step_lens = [1, 1, 10, 11]
# increase counter for gen_lens = [3, 5, 1, 1]
max_tokens = config.batch_size * config.max_ctx_len
while True:
# prepare batch for the current step
dynamic_fill = True # inference_params.is_paged
sim.step(dynamic_fill=dynamic_fill)
sim.print_step(logger)
if sim.t_batch_size == 0:
# all sequences are finished
if sim.t > sim.last_arrival:
sim.serving_times = sim.arrival_times + sim.request_delays
sim.complete_times = sim.serving_times + sim.gen_lens
break
# not finished; run next iteration
else:
sim.t += 1
continue
# create incremental input
batch_size = max_batch_size if is_cuda_graph else sim.t_batch_size
max_seqlen_q = sim.max_ctx_len if is_cuda_graph else max(sim.step_lens).item()
num_tensors = len(full_inputs)
if qkv_format == "thd":
incremental_inputs = []
for i in range(num_tensors):
inp = full_inputs[i]
inc_inp = torch.Tensor().to(dtype=dtype, device="cuda")
for i, seq in enumerate(sim.t_seq_ids):
start = (sim.t_total_lens[i] - sim.step_lens[i]).item()
end = sim.t_total_lens[i].item()
inc_inp = torch.cat([inc_inp, inp[seq, start:end]], dim=0)
if is_cuda_graph:
inc_inp = torch.cat(
[
inc_inp,
torch.zeros(
max_tokens - sum(sim.step_lens),
*inp.shape[2:],
dtype=dtype,
device=inc_inp.device,
),
],
dim=0,
)
incremental_inputs.append(inc_inp)
else:
incremental_inputs = []
for i in range(num_tensors):
inp = full_inputs[i]
inc_inp = torch.zeros(
batch_size,
max_seqlen_q,
*inp.shape[2:],
dtype=dtype,
device="cuda",
)
for i, seq in enumerate(sim.t_seq_ids):
start = (sim.t_total_lens[i] - sim.step_lens[i]).item()
end = sim.t_total_lens[i].item()
inc_inp[i, : sim.step_lens[i], :] = inp[seq, start:end]
if qkv_format == "sbhd":
inc_inp = inc_inp.transpose(0, 1).contiguous()
incremental_inputs.append(inc_inp)
# run step
batch_size = max_batch_size if is_cuda_graph else sim.t_batch_size
cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_q[1 : sim.t_batch_size + 1] = torch.cumsum(sim.step_lens, dim=0)
cu_seqlens_kv = cu_seqlens_q.clone()
step_dict = OrderedDict(zip(sim.t_seq_ids.tolist(), sim.step_lens.tolist()))
inference_params.pre_step(step_dict)
if inference_params.is_paged:
inference_params.cache_manager.print_cache()
incremental_output = incremental_inputs
with fp8_autocast(enabled=is_fp8, fp8_recipe=fp8_recipe):
for m in model:
incremental_output = m(
*incremental_output,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
inference_params=inference_params,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
)
incremental_output = [incremental_output]
incremental_output = incremental_output[0]
# compare results
atol, rtol = get_tols(module, backend, dtype=dtype if not is_fp8 else torch.float8_e4m3fn)
for i, seq in enumerate(sim.t_seq_ids):
token_index = sim.step_lens[i] - 1
if qkv_format == "bshd":
torch.testing.assert_close(
full_output[seq, sim.t_total_lens[i] - 1, :],
incremental_output[i, sim.step_lens[i] - 1, :],
atol=atol,
rtol=rtol,
)
if qkv_format == "sbhd":
torch.testing.assert_close(
full_output[seq, sim.t_total_lens[i] - 1, :],
incremental_output[sim.step_lens[i] - 1, i, :],
atol=atol,
rtol=rtol,
)
if qkv_format == "thd":
torch.testing.assert_close(
full_output[seq, sim.t_total_lens[i] - 1, :],
incremental_output[cu_seqlens_q[i + 1] - 1, :],
atol=atol,
rtol=rtol,
)
sim.t += 1
sim.t_gen_lens = sim.t_gen_lens + 1
# last value in complete_times should be equal to sim.t
sim.serving_times = sim.arrival_times + sim.request_delays
sim.complete_times = sim.serving_times + sim.gen_lens
sim.print_summary(logger)
if backend == "FlashAttention" and not fa_utils.v3_is_installed:
config.max_seqlen_q = config_max_seqlen_q
config.max_seqlen_kv = config_max_seqlen_kv
......@@ -2,6 +2,7 @@
#
# See LICENSE for license information.
from collections import OrderedDict
import math
import os
from typing import Dict, List, Optional
......@@ -59,6 +60,8 @@ torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
torch._dynamo.config.recompile_limit = 16
class ModelConfig:
def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq_len):
......@@ -77,9 +80,9 @@ model_configs = {
model_configs_inference = {
# hidden_size, eps, num_attention_heads, embed, num_layers, seq_len
"126m": ModelConfig(768, 1e-5, 12, 64, 12, 16),
"126m": ModelConfig(768, 1e-5, 12, 64, 12, 256),
}
backends_inference = ["FlashAttention", "UnfusedAttention"]
backends_inference = ["FlashAttention", "UnfusedAttention", "FusedAttention"]
module_inference = ["TransformerLayer", "MultiheadAttention"]
input_formats_inference = ["sbhd", "bshd"]
......@@ -2037,14 +2040,25 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
@pytest.mark.parametrize("input_format", input_formats_inference)
@pytest.mark.parametrize("module", module_inference)
@pytest.mark.parametrize("backend", backends_inference)
def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, backend):
@pytest.mark.parametrize("is_paged", [False, True])
def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, backend, is_paged):
reset_rng_states()
if backend in ["FusedAttention", "FlashAttention"] and dtype == torch.float32:
pytest.skip("FusedAttention and FlashAttention do not support FP32")
if use_RoPE:
pytest.skip("KV cache does not support starting positions for RoPE")
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
elif backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
elif backend == "UnfusedAttention":
os.environ["NVTE_UNFUSED_ATTN"] = "1"
config = model_configs_inference[model_key]
......@@ -2057,7 +2071,7 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
# Limits the max size of KV-cache
B_max = B
S_max = S + 2
S_max = S
if module == "TransformerLayer":
model = TransformerLayer(
......@@ -2087,7 +2101,17 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
.eval()
)
inference_params = InferenceParams(max_batch_size=B_max, max_sequence_length=S_max)
inference_params = InferenceParams(
max_batch_size=B_max,
max_seqlen_kv=S_max,
num_heads_kv=H,
head_dim_k=head_size,
dtype=dtype,
is_paged=is_paged,
total_num_pages=int(B_max * S_max / 256),
page_size=256,
)
rotary_freqs = torch.randn((S_max, 1, 1, head_size), dtype=torch.float, device="cuda")
input = torch.randn((S, B, D), dtype=dtype, device="cuda")
......@@ -2100,22 +2124,39 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
full_output = model(hidden_states=input, rotary_pos_emb=rotary_freqs if use_RoPE else None)
# Incrementaly generate outputs using KV-cache
step_dict = OrderedDict(zip(list(range(B)), [1] * B))
for i in range(S):
inference_params.pre_step(step_dict)
if input_format == "sbhd":
incremental_input = input[i].view(1, B, D)
else:
incremental_input = input[:, i, :].view(B, 1, D)
seqlens_q = torch.ones(B, dtype=torch.int32, device="cuda")
cu_seqlens_q = torch.zeros(B + 1, dtype=torch.int32, device="cuda")
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
cu_seqlens_kv = cu_seqlens_q.clone()
mask_type = "padding"
kwargs = {}
if module == "TransformerLayer":
kwargs["self_attn_mask_type"] = mask_type
else:
kwargs["attn_mask_type"] = mask_type
line_output = model(
hidden_states=incremental_input,
inference_params=inference_params,
rotary_pos_emb=rotary_freqs if use_RoPE else None,
**kwargs,
max_seqlen_q=1,
max_seqlen_kv=S,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
)
inference_params.sequence_len_offset += 1
if input_format == "sbhd":
incremental_output[i] = line_output.view(B, D)
incremental_output[i, :, :] = line_output.view(B, D)
else:
incremental_output[:, i, :] = line_output.view(B, D)
......
......@@ -37,7 +37,18 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) {
case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_THD_THD_THD:
case NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD:
return NVTE_QKV_Layout_Group::NVTE_HD_HD_HD;
case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD:
return NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD;
default:
NVTE_ERROR("qkv_layout not supported!");
}
......@@ -51,12 +62,14 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) {
case NVTE_QKV_Layout::NVTE_SBHD_SB2HD:
case NVTE_QKV_Layout::NVTE_SBHD_SBH2D:
case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD:
return NVTE_QKV_Format::NVTE_SBHD;
case NVTE_QKV_Layout::NVTE_BS3HD:
case NVTE_QKV_Layout::NVTE_BSH3D:
case NVTE_QKV_Layout::NVTE_BSHD_BS2HD:
case NVTE_QKV_Layout::NVTE_BSHD_BSH2D:
case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD:
return NVTE_QKV_Format::NVTE_BSHD;
case NVTE_QKV_Layout::NVTE_T3HD:
case NVTE_QKV_Layout::NVTE_TH3D:
......@@ -64,6 +77,56 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) {
case NVTE_QKV_Layout::NVTE_THD_TH2D:
case NVTE_QKV_Layout::NVTE_THD_THD_THD:
return NVTE_QKV_Format::NVTE_THD;
case NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD:
return NVTE_QKV_Format::NVTE_SBHD_2BSHD;
case NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD:
return NVTE_QKV_Format::NVTE_BSHD_2SBHD;
case NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD:
return NVTE_QKV_Format::NVTE_THD_2BSHD;
case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD:
return NVTE_QKV_Format::NVTE_THD_2SBHD;
default:
NVTE_ERROR("qkv_layout not supported!");
}
}
// map NVTE_QKV_Layout to NVTE_QKV_Format for Q
NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout) {
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
switch (qkv_format) {
case NVTE_QKV_Format::NVTE_SBHD:
case NVTE_QKV_Format::NVTE_SBHD_2BSHD:
return NVTE_QKV_Format::NVTE_SBHD;
case NVTE_QKV_Format::NVTE_BSHD:
case NVTE_QKV_Format::NVTE_BSHD_2SBHD:
return NVTE_QKV_Format::NVTE_BSHD;
case NVTE_QKV_Format::NVTE_THD:
case NVTE_QKV_Format::NVTE_THD_2BSHD:
case NVTE_QKV_Format::NVTE_THD_2SBHD:
return NVTE_QKV_Format::NVTE_THD;
default:
NVTE_ERROR("qkv_layout not supported!");
}
}
// map NVTE_QKV_Layout to NVTE_QKV_Format for KV
NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) {
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
switch (qkv_format) {
case NVTE_QKV_Format::NVTE_SBHD:
case NVTE_QKV_Format::NVTE_BSHD_2SBHD:
case NVTE_QKV_Format::NVTE_THD_2SBHD:
return NVTE_QKV_Format::NVTE_SBHD;
case NVTE_QKV_Format::NVTE_BSHD:
case NVTE_QKV_Format::NVTE_SBHD_2BSHD:
case NVTE_QKV_Format::NVTE_THD_2BSHD:
return NVTE_QKV_Format::NVTE_BSHD;
case NVTE_QKV_Format::NVTE_THD:
return NVTE_QKV_Format::NVTE_THD;
default:
NVTE_ERROR("qkv_layout not supported!");
}
......@@ -81,6 +144,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
const int sm_arch_ = cuda::sm_arch(device_id);
NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type.");
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout);
NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto cudnn_runtime_version = cudnnGetVersion();
......@@ -202,11 +267,31 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK &&
max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv &&
bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) ||
// 9.5: adds {paged_kv_bshd, paged_kv_sbhd} + {padding, padding_causal, padding_causal_bottom_right}
(cudnn_runtime_version >= 90500 &&
layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD &&
(attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
(attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK &&
max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv)) &&
bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) ||
// 9.6: adds {bshd, sbhd, thd} + padding_causal_bottom_right + self/cross-attn (sq <= skv)
(cudnn_runtime_version >= 90600 &&
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK &&
max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv &&
bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0)) &&
bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) ||
// 9.7: removes s_q/s_kv % 64 = 0 for {causal_bottom_right, padding_causal_bottom_right}
// for any q_format/kv_format, and paged/non-paged
(cudnn_runtime_version >= 90700 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
((attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) &&
bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) ||
((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) &&
max_seqlen_q <= max_seqlen_kv)))) &&
// bias + mask combination
(!(cudnn_runtime_version >= 8906 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
......@@ -216,7 +301,12 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD ||
(qkv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90 &&
((cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups) ||
cudnn_runtime_version >= 90600))) &&
cudnn_runtime_version >= 90600)) ||
((q_format == NVTE_QKV_Format::NVTE_SBHD || q_format == NVTE_QKV_Format::NVTE_BSHD ||
(q_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90) ||
kv_format == NVTE_QKV_Format::NVTE_SBHD || kv_format == NVTE_QKV_Format::NVTE_BSHD ||
(kv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90)) &&
cudnn_runtime_version >= 90700)) &&
// sliding window
// pre-9.2: full attn, causal
((cudnn_runtime_version < 90200 && window_size_left == -1 &&
......@@ -465,22 +555,23 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
}
}
// NVTE fused attention FWD with packed KV
void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias,
NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream) {
void nvte_fused_attn_fwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv);
const Tensor *input_cu_seqlens_q_padded = reinterpret_cast<const Tensor *>(cu_seqlens_q_padded);
const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast<const Tensor *>(cu_seqlens_kv_padded);
const Tensor *input_page_table_k = reinterpret_cast<const Tensor *>(page_table_k);
const Tensor *input_page_table_v = reinterpret_cast<const Tensor *>(page_table_v);
const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(rng_state);
const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q);
const Tensor *input_KV = reinterpret_cast<const Tensor *>(KV);
......@@ -505,11 +596,40 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
}
size_t t_q = 0;
size_t t_kv = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout);
NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
if (q_format == NVTE_QKV_Format::NVTE_THD) {
t_q = input_Q->data.shape[0];
}
if (kv_format == NVTE_QKV_Format::NVTE_THD) {
t_kv = input_KV->data.shape[0];
}
int64_t num_pages_k = 0;
int64_t num_pages_v = 0;
int64_t page_size_k = 0;
int64_t page_size_v = 0;
int64_t max_pages_per_seq_k = 0;
int64_t max_pages_per_seq_v = 0;
if (input_page_table_k->data.dptr != nullptr) {
max_pages_per_seq_k = input_page_table_k->data.shape[1];
}
if (input_page_table_v->data.dptr != nullptr) {
max_pages_per_seq_v = input_page_table_v->data.shape[1];
}
if (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) {
NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
if (kv_format == NVTE_QKV_Format::NVTE_BSHD) {
num_pages_k = input_KV->data.shape[0];
page_size_k = input_KV->data.shape[1];
num_pages_v = num_pages_v;
page_size_v = page_size_v;
} else if (kv_format == NVTE_QKV_Format::NVTE_SBHD) {
num_pages_k = input_KV->data.shape[1];
page_size_k = input_KV->data.shape[0];
num_pages_v = num_pages_v;
page_size_v = page_size_v;
}
}
auto handle = cudnnExecutionPlanManager::Instance().GetHandle();
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
......@@ -531,11 +651,12 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8903)
fused_attn_arbitrary_seqlen_fwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, is_training, attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q,
input_KV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv,
input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream,
handle);
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, num_pages_k, num_pages_v,
page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale,
dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right,
input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q,
input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded,
input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
......@@ -596,9 +717,12 @@ void nvte_fused_attn_bwd_kvpacked(
}
size_t t_q = 0;
size_t t_kv = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout);
NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
if (q_format == NVTE_QKV_Format::NVTE_THD) {
t_q = input_Q->data.shape[0];
}
if (kv_format == NVTE_QKV_Format::NVTE_THD) {
t_kv = input_KV->data.shape[0];
}
......@@ -664,7 +788,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
......@@ -676,6 +801,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv);
const Tensor *input_cu_seqlens_q_padded = reinterpret_cast<const Tensor *>(cu_seqlens_q_padded);
const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast<const Tensor *>(cu_seqlens_kv_padded);
const Tensor *input_page_table_k = reinterpret_cast<const Tensor *>(page_table_k);
const Tensor *input_page_table_v = reinterpret_cast<const Tensor *>(page_table_v);
const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(rng_state);
const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q);
const Tensor *input_K = reinterpret_cast<const Tensor *>(K);
......@@ -686,18 +813,49 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
Tensor *wkspace = reinterpret_cast<Tensor *>(workspace);
auto ndim = input_Q->data.shape.size();
auto ndim_kv = input_K->data.shape.size();
size_t b = input_cu_seqlens_q->data.shape[0] - 1;
size_t h_q = input_Q->data.shape[ndim - 2];
size_t h_kv = input_K->data.shape[ndim - 2];
size_t h_kv = input_K->data.shape[ndim_kv - 2];
size_t d_qk = input_Q->data.shape[ndim - 1];
size_t d_v = input_V->data.shape[ndim - 1];
size_t d_v = input_V->data.shape[ndim_kv - 1];
size_t t_q = 0;
size_t t_kv = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout);
NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
if (q_format == NVTE_QKV_Format::NVTE_THD) {
t_q = input_Q->data.shape[0];
}
if (kv_format == NVTE_QKV_Format::NVTE_THD) {
t_kv = input_K->data.shape[0];
}
int64_t num_pages_k = 0;
int64_t num_pages_v = 0;
int64_t page_size_k = 0;
int64_t page_size_v = 0;
int64_t max_pages_per_seq_k = 0;
int64_t max_pages_per_seq_v = 0;
if (input_page_table_k->data.dptr != nullptr) {
max_pages_per_seq_k = input_page_table_k->data.shape[1];
}
if (input_page_table_v->data.dptr != nullptr) {
max_pages_per_seq_v = input_page_table_v->data.shape[1];
}
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) {
NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
if (kv_format == NVTE_QKV_Format::NVTE_BSHD) {
num_pages_k = input_K->data.shape[0];
page_size_k = input_K->data.shape[1];
num_pages_v = input_V->data.shape[0];
page_size_v = input_V->data.shape[1];
} else if (kv_format == NVTE_QKV_Format::NVTE_SBHD) {
num_pages_k = input_K->data.shape[1];
page_size_k = input_K->data.shape[0];
num_pages_v = input_V->data.shape[1];
page_size_v = input_V->data.shape[0];
}
}
auto handle = cudnnExecutionPlanManager::Instance().GetHandle();
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
......@@ -719,11 +877,12 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900)
fused_attn_arbitrary_seqlen_fwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, is_training, attn_scale,
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v,
page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale,
dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right,
input_Q, input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q,
input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state,
wkspace, stream, handle);
input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded,
input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
......@@ -773,16 +932,20 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
Tensor *wkspace = reinterpret_cast<Tensor *>(workspace);
auto ndim = input_Q->data.shape.size();
auto ndim_kv = input_K->data.shape.size();
size_t b = input_cu_seqlens_q->data.shape[0] - 1;
size_t h_q = input_Q->data.shape[ndim - 2];
size_t h_kv = input_K->data.shape[ndim - 2];
size_t h_kv = input_K->data.shape[ndim_kv - 2];
size_t d_qk = input_Q->data.shape[ndim - 1];
size_t d_v = input_V->data.shape[ndim - 1];
size_t d_v = input_V->data.shape[ndim_kv - 1];
size_t t_q = 0;
size_t t_kv = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout);
NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
if (q_format == NVTE_QKV_Format::NVTE_THD) {
t_q = input_Q->data.shape[0];
}
if (kv_format == NVTE_QKV_Format::NVTE_THD) {
t_kv = input_K->data.shape[0];
}
......
......@@ -50,14 +50,16 @@ namespace transformer_engine {
namespace fused_attn {
void fused_attn_arbitrary_seqlen_fwd_impl(
int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v,
int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h,
bool is_training, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t num_pages_k, int64_t num_pages_v,
int64_t page_size_k, int64_t page_size_v, int64_t max_pages_per_seq_k,
int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training,
float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
int64_t window_size_right, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias,
void *devPtrSoftmaxStats, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset,
void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ,
void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace,
size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) {
void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV,
void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType,
void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
......@@ -66,26 +68,35 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_bottom_right = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK));
if (is_bottom_right && s_q == s_kv) {
is_causal = true;
is_bottom_right = false;
}
bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK));
if (is_bottom_right && s_q == s_kv && !is_padding) {
is_causal = true;
is_bottom_right = false;
}
bool is_dropout = (is_training && dropout_probability != 0.0f);
bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD);
NVTE_QKV_Format q_format = nvte_get_q_format(layout);
NVTE_QKV_Format kv_format = nvte_get_kv_format(layout);
bool is_ragged_q = (q_format == NVTE_QKV_Format::NVTE_THD);
bool is_ragged_kv = (kv_format == NVTE_QKV_Format::NVTE_THD);
const auto cudnn_runtime_version = cudnnGetVersion();
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout);
bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD);
if (is_paged_kv) {
NVTE_CHECK(is_padding, "Paged attention requires padding mask!");
}
// keep original batch size because cu_seqlens are created with [b+1] shape
int64_t actual_b = b;
if (is_ragged && cudnn_runtime_version >= 90600) {
if ((is_ragged_q || is_ragged_kv) && cudnn_runtime_version >= 90600) {
NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!");
// replace batch size and maximum sequence lengths with maximum token counts
// for query and key/value so the graph is static within each quantization bucket
b = max_b;
s_q = max_t_q;
s_kv = max_t_kv;
s_q = is_ragged_q ? max_t_q : s_q;
s_kv = is_ragged_kv ? max_t_kv : s_kv;
}
const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32;
......@@ -97,6 +108,12 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
s_kv,
d_qk,
d_v,
num_pages_k,
num_pages_v,
page_size_k,
page_size_v,
max_pages_per_seq_k,
max_pages_per_seq_v,
bias_b,
bias_h,
scaling_factor,
......@@ -123,6 +140,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
std::shared_ptr<fe::graph::Tensor_attributes>, // bias
std::shared_ptr<fe::graph::Tensor_attributes>, // seq_q
std::shared_ptr<fe::graph::Tensor_attributes>, // seq_kv
std::shared_ptr<fe::graph::Tensor_attributes>, // page_table_k
std::shared_ptr<fe::graph::Tensor_attributes>, // page_table_v
std::shared_ptr<fe::graph::Tensor_attributes>, // offset_q
std::shared_ptr<fe::graph::Tensor_attributes>, // offset_k
std::shared_ptr<fe::graph::Tensor_attributes>, // offset_v
......@@ -151,6 +170,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
std::shared_ptr<fe::graph::Tensor_attributes> Q, K, V, attn_scale;
std::shared_ptr<fe::graph::Tensor_attributes> bias, seq_q, seq_kv;
std::shared_ptr<fe::graph::Tensor_attributes> page_table_k, page_table_v;
std::shared_ptr<fe::graph::Tensor_attributes> offset_q, offset_k, offset_v, offset_o,
offset_stats;
std::shared_ptr<fe::graph::Tensor_attributes> dropout_seed, dropout_offset;
......@@ -160,17 +180,36 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
std::vector<int64_t> v_stride(4);
generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), layout,
NVTE_QKV_Matrix::NVTE_Q_Matrix);
if (is_paged_kv) {
generateMatrixStrides(num_pages_k, hg, page_size_k, page_size_v, d_qk, k_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_K_Matrix);
generateMatrixStrides(num_pages_v, hg, page_size_k, page_size_v, d_v, v_stride.data(),
layout, NVTE_QKV_Matrix::NVTE_V_Matrix);
} else {
generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), layout,
NVTE_QKV_Matrix::NVTE_K_Matrix);
generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), layout,
NVTE_QKV_Matrix::NVTE_V_Matrix);
}
if (is_ragged) {
Q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim({b, h, s_q, d_qk})
.set_stride(q_stride));
if (is_ragged_q) {
offset_q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_q")
.set_dim({b + 1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(get_cudnn_fe_dtype(ragged_offset_type)));
Q->set_ragged_offset(offset_q);
}
K = mha_graph->tensor(fe::graph::Tensor_attributes().set_name("K").set_stride(k_stride));
V = mha_graph->tensor(fe::graph::Tensor_attributes().set_name("V").set_stride(v_stride));
if (is_paged_kv) {
K->set_dim({num_pages_k, hg, page_size_k, d_qk});
V->set_dim({num_pages_v, hg, page_size_v, d_v});
} else if (is_ragged_kv) {
offset_k = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_k")
.set_dim({b + 1, 1, 1, 1})
......@@ -181,34 +220,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
.set_dim({b + 1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(get_cudnn_fe_dtype(ragged_offset_type)));
Q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim({b, h, s_q, d_qk})
.set_stride(q_stride)
.set_ragged_offset(offset_q));
K = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("K")
.set_dim({b, hg, s_kv, d_qk})
.set_stride(k_stride)
.set_ragged_offset(offset_k));
V = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("V")
.set_dim({b, hg, s_kv, d_v})
.set_stride(v_stride)
.set_ragged_offset(offset_v));
K->set_dim({b, hg, s_kv, d_qk}).set_ragged_offset(offset_k);
V->set_dim({b, hg, s_kv, d_v}).set_ragged_offset(offset_v);
} else {
Q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim({b, h, s_q, d_qk})
.set_stride(q_stride));
K = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("K")
.set_dim({b, hg, s_kv, d_qk})
.set_stride(k_stride));
V = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("V")
.set_dim({b, hg, s_kv, d_v})
.set_stride(v_stride));
K->set_dim({b, hg, s_kv, d_qk});
V->set_dim({b, hg, s_kv, d_v});
}
attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes()
......@@ -254,6 +270,24 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
sdpa_options.set_padding_mask(is_padding).set_seq_len_q(seq_q).set_seq_len_kv(seq_kv);
}
if (is_paged_kv) {
page_table_k =
mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("page_table_k")
.set_dim({b, 1, max_pages_per_seq_k, 1})
.set_stride({{max_pages_per_seq_k, max_pages_per_seq_v, 1, 1}})
.set_data_type(fe::DataType_t::INT32));
page_table_v =
mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("page_table_v")
.set_dim({b, 1, max_pages_per_seq_v, 1})
.set_stride({{max_pages_per_seq_v, max_pages_per_seq_v, 1, 1}})
.set_data_type(fe::DataType_t::INT32));
sdpa_options.set_paged_attention_k_table(page_table_k);
sdpa_options.set_paged_attention_v_table(page_table_v);
sdpa_options.set_paged_attention_max_seq_len_kv(static_cast<int32_t>(s_kv));
}
if (is_dropout) {
dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Seed")
......@@ -273,37 +307,27 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
std::vector<int64_t> o_stride(4);
generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout,
NVTE_QKV_Matrix::NVTE_O_Matrix);
if (is_ragged) {
O->set_output(true).set_dim({b, h, s_q, d_v}).set_stride(o_stride);
if (is_ragged_q) {
offset_o = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_o")
.set_dim({b + 1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(get_cudnn_fe_dtype(ragged_offset_type)));
O->set_output(true)
.set_dim({b, h, s_q, d_v})
.set_stride(o_stride)
.set_ragged_offset(offset_o);
} else {
O->set_output(true).set_dim({b, h, s_q, d_v}).set_stride(o_stride);
O->set_ragged_offset(offset_o);
}
if (is_ragged && cudnn_runtime_version >= 90600) {
Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1});
if (is_ragged_q && cudnn_runtime_version >= 90600) {
offset_stats =
mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_stats")
.set_dim({b + 1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(get_cudnn_fe_dtype(ragged_offset_type)));
Stats->set_output(true)
.set_data_type(fe::DataType_t::FLOAT)
.set_dim({b, h, s_q, 1})
.set_stride({h * s_q, 1, h, 1})
.set_ragged_offset(offset_stats);
Stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats);
} else {
Stats->set_output(true)
.set_data_type(fe::DataType_t::FLOAT)
.set_dim({b, h, s_q, 1})
.set_stride({h * s_q, s_q, 1, 1});
Stats->set_stride({h * s_q, s_q, 1, 1});
}
std::tuple<std::shared_ptr<fe::graph::Tensor_attributes>, // Q
......@@ -316,9 +340,13 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr);
auto padding_tuple =
is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr);
auto offset_qkvo_tuple = is_ragged ? std::make_tuple(offset_q, offset_k, offset_v, offset_o)
: std::make_tuple(nullptr, nullptr, nullptr, nullptr);
auto offset_s_tuple = (is_ragged && cudnn_runtime_version >= 90600)
auto page_table_tuple = is_paged_kv ? std::make_tuple(page_table_k, page_table_v)
: std::make_tuple(nullptr, nullptr);
auto offset_qo_tuple =
is_ragged_q ? std::make_tuple(offset_q, offset_o) : std::make_tuple(nullptr, nullptr);
auto offset_kv_tuple =
is_ragged_kv ? std::make_tuple(offset_k, offset_v) : std::make_tuple(nullptr, nullptr);
auto offset_s_tuple = (is_ragged_q && cudnn_runtime_version >= 90600)
? std::make_tuple(offset_stats)
: std::make_tuple(nullptr);
auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset)
......@@ -330,16 +358,16 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle));
NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle));
auto return_tuple =
std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple,
padding_tuple, offset_qkvo_tuple, offset_s_tuple, dropout_tuple);
auto return_tuple = std::tuple_cat(
std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, padding_tuple,
page_table_tuple, offset_qo_tuple, offset_kv_tuple, offset_s_tuple, dropout_tuple);
cache.insert({descriptor, return_tuple});
return return_tuple;
};
auto [mha_graph, Q, K, V, attn_scale, O, Stats, bias, seq_q, seq_kv, offset_q, offset_k,
offset_v, offset_o, offset_stats, dropout_seed, dropout_offset] =
auto [mha_graph, Q, K, V, attn_scale, O, Stats, bias, seq_q, seq_kv, page_table_k, page_table_v,
offset_q, offset_o, offset_k, offset_v, offset_stats, dropout_seed, dropout_offset] =
get_graph(sdpa_f16_fprop_cache, descriptor);
// Exit to request upper level API to allocate memory if needed
......@@ -351,11 +379,12 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
const size_t num_bytes_per_ragged_offset =
alignTo<16>((b + 1) * typeToSize(ragged_offset_type));
size_t seqlen_offsets_workspace_size = 0;
if (is_ragged) {
if (cudnn_runtime_version >= 90600) {
seqlen_offsets_workspace_size = 5 * num_bytes_per_ragged_offset;
if (is_ragged_q || is_ragged_kv) {
size_t count = 2 * (static_cast<size_t>(is_ragged_q) + static_cast<size_t>(is_ragged_kv));
if (is_ragged_q && cudnn_runtime_version >= 90600) {
seqlen_offsets_workspace_size = (count + 1) * num_bytes_per_ragged_offset;
} else {
seqlen_offsets_workspace_size = 4 * num_bytes_per_ragged_offset;
seqlen_offsets_workspace_size = count * num_bytes_per_ragged_offset;
}
}
if (workspace == nullptr) {
......@@ -391,28 +420,49 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
variant_pack[seq_kv] = devActualSeqlenKV;
}
if (is_ragged) {
if (is_paged_kv) {
variant_pack[page_table_k] = devPtrPageTableK;
variant_pack[page_table_v] = devPtrPageTableV;
}
if (is_ragged_q || is_ragged_kv) {
constexpr size_t nthreads_per_block = 128;
const size_t grid = (b + nthreads_per_block) / nthreads_per_block;
void *devOffsetsQ =
void *devOffsets =
static_cast<int8_t *>(workspace) + plan_workspace_size + actual_seqlen_workspace_size;
void *devOffsetsK = static_cast<int8_t *>(devOffsetsQ) + num_bytes_per_ragged_offset;
void *devOffsetsV = static_cast<int8_t *>(devOffsetsK) + num_bytes_per_ragged_offset;
void *devOffsetsO = static_cast<int8_t *>(devOffsetsV) + num_bytes_per_ragged_offset;
void *devOffsetsQ = nullptr;
void *devOffsetsO = nullptr;
if (is_ragged_q) {
devOffsetsQ = devOffsets;
devOffsetsO = static_cast<int8_t *>(devOffsetsQ) + num_bytes_per_ragged_offset;
}
void *devOffsetsK = nullptr;
void *devOffsetsV = nullptr;
if (is_ragged_kv) {
devOffsetsK = static_cast<int8_t *>(devOffsets) +
static_cast<int>(is_ragged_q) * 2 * num_bytes_per_ragged_offset;
devOffsetsV = static_cast<int8_t *>(devOffsetsK) + num_bytes_per_ragged_offset;
}
void *devOffsetsS = nullptr;
if (cudnn_runtime_version >= 90600) {
devOffsetsS = static_cast<int8_t *>(devOffsetsO) + num_bytes_per_ragged_offset;
if (is_ragged_q && cudnn_runtime_version >= 90600) {
devOffsetsS = static_cast<int8_t *>(devOffsets) +
(static_cast<int>(is_ragged_q) + static_cast<int>(is_ragged_kv)) * 2 *
num_bytes_per_ragged_offset;
}
const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout);
cu_seqlens_padded_to_offsets<<<grid, nthreads_per_block, 0, stream>>>(
layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast<int32_t *>(devPtrSeqOffsetsQ),
static_cast<int32_t *>(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK,
devOffsetsV, devOffsetsO, devOffsetsS);
if (is_ragged_q) {
variant_pack[offset_q] = devOffsetsQ;
variant_pack[offset_o] = devOffsetsO;
}
if (is_ragged_kv) {
variant_pack[offset_k] = devOffsetsK;
variant_pack[offset_v] = devOffsetsV;
variant_pack[offset_o] = devOffsetsO;
if (cudnn_runtime_version >= 90600) {
}
if (is_ragged_q && cudnn_runtime_version >= 90600) {
variant_pack[offset_stats] = devOffsetsS;
}
}
......@@ -447,25 +497,37 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_bottom_right = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK));
if (is_bottom_right && s_q == s_kv) {
is_causal = true;
is_bottom_right = false;
}
bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) ||
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK));
if (is_bottom_right && s_q == s_kv && !is_padding) {
is_causal = true;
is_bottom_right = false;
}
bool is_dropout = (dropout_probability != 0.0f);
bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD);
NVTE_QKV_Format q_format = nvte_get_q_format(layout);
NVTE_QKV_Format kv_format = nvte_get_kv_format(layout);
bool is_ragged_q = (q_format == NVTE_QKV_Format::NVTE_THD);
bool is_ragged_kv = (kv_format == NVTE_QKV_Format::NVTE_THD);
const auto cudnn_runtime_version = cudnnGetVersion();
const int device_id = cuda::current_device();
const int sm_arch_ = cuda::sm_arch(device_id);
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout);
bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD);
if (is_paged_kv) {
NVTE_CHECK(is_padding, "Paged attention requires padding mask!");
}
// keep original batch size because cu_seqlens are created with [b+1] shape
int64_t actual_b = b;
if (is_ragged && cudnn_runtime_version >= 90600) {
if ((is_ragged_q || is_ragged_kv) && cudnn_runtime_version >= 90600) {
NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!");
// replace batch size and maximum sequence lengths with maximum token counts
// for query and key/value so the graph is static within each quantization bucket
b = max_b;
s_q = max_t_q;
s_kv = max_t_kv;
s_q = is_ragged_q ? max_t_q : s_q;
s_kv = is_ragged_kv ? max_t_kv : s_kv;
}
// We choose between 32-bit and 64-bit offsets depending on need.
......@@ -480,6 +542,12 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
s_kv,
d_qk,
d_v,
0,
0,
0,
0,
0,
0,
bias_b,
bias_h,
scaling_factor,
......@@ -556,53 +624,6 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout,
NVTE_QKV_Matrix::NVTE_O_Matrix);
if (is_ragged) {
offset_q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_q")
.set_dim({b + 1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(get_cudnn_fe_dtype(ragged_offset_type)));
offset_k = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_k")
.set_dim({b + 1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(get_cudnn_fe_dtype(ragged_offset_type)));
offset_v = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_v")
.set_dim({b + 1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(get_cudnn_fe_dtype(ragged_offset_type)));
offset_o = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_o")
.set_dim({b + 1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(get_cudnn_fe_dtype(ragged_offset_type)));
q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim({b, h, s_q, d_qk})
.set_stride(q_stride)
.set_ragged_offset(offset_q));
k = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("K")
.set_dim({b, hg, s_kv, d_qk})
.set_stride(k_stride)
.set_ragged_offset(offset_k));
v = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("V")
.set_dim({b, hg, s_kv, d_v})
.set_stride(v_stride)
.set_ragged_offset(offset_v));
o = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("O")
.set_dim({b, h, s_q, d_v})
.set_stride(o_stride)
.set_ragged_offset(offset_o));
dO = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("dO")
.set_dim({b, h, s_q, d_v})
.set_stride(o_stride)
.set_ragged_offset(offset_o));
} else {
q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim({b, h, s_q, d_qk})
......@@ -623,26 +644,50 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
.set_name("dO")
.set_dim({b, h, s_q, d_v})
.set_stride(o_stride));
if (is_ragged_q) {
offset_q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_q")
.set_dim({b + 1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(get_cudnn_fe_dtype(ragged_offset_type)));
offset_o = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_o")
.set_dim({b + 1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(get_cudnn_fe_dtype(ragged_offset_type)));
q->set_ragged_offset(offset_q);
o->set_ragged_offset(offset_o);
dO->set_ragged_offset(offset_o);
}
if (is_ragged && cudnn_runtime_version >= 90600) {
offset_stats =
mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_stats")
if (is_ragged_kv) {
offset_k = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_k")
.set_dim({b + 1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(get_cudnn_fe_dtype(ragged_offset_type)));
offset_v = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_v")
.set_dim({b + 1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(get_cudnn_fe_dtype(ragged_offset_type)));
k->set_ragged_offset(offset_k);
v->set_ragged_offset(offset_v);
}
stats = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("stats")
.set_dim({b, h, s_q, 1})
.set_stride({h * s_q, 1, h, 1})
.set_data_type(fe::DataType_t::FLOAT)
.set_ragged_offset(offset_stats));
} else {
stats = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("stats")
.set_dim({b, h, s_q, 1})
.set_stride({h * s_q, s_q, 1, 1})
.set_data_type(fe::DataType_t::FLOAT));
if (is_ragged_q && cudnn_runtime_version >= 90600) {
offset_stats =
mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_stats")
.set_dim({b + 1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(get_cudnn_fe_dtype(ragged_offset_type)));
stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats);
} else {
stats->set_stride({h * s_q, s_q, 1, 1});
}
attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes()
......@@ -659,8 +704,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
.set_causal_mask_bottom_right(is_bottom_right)
.set_attn_scale(attn_scale);
if (is_ragged && cudnn_runtime_version >= 90600) {
if (is_ragged_q && cudnn_runtime_version >= 90600) {
sdpa_backward_options.set_max_total_seq_len_q(s_q);
}
if (is_ragged_kv && cudnn_runtime_version >= 90600) {
sdpa_backward_options.set_max_total_seq_len_kv(s_kv);
}
......@@ -724,23 +771,15 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
auto [dQ, dK, dV] = mha_graph->sdpa_backward(q, k, v, o, dO, stats, sdpa_backward_options);
if (is_ragged) {
dQ->set_output(true)
.set_dim({b, h, s_q, d_qk})
.set_stride(q_stride)
.set_ragged_offset(offset_q);
dK->set_output(true)
.set_dim({b, hg, s_kv, d_qk})
.set_stride(k_stride)
.set_ragged_offset(offset_k);
dV->set_output(true)
.set_dim({b, hg, s_kv, d_v})
.set_stride(v_stride)
.set_ragged_offset(offset_v);
} else {
dQ->set_output(true).set_dim({b, h, s_q, d_qk}).set_stride(q_stride);
dK->set_output(true).set_dim({b, hg, s_kv, d_qk}).set_stride(k_stride);
dV->set_output(true).set_dim({b, hg, s_kv, d_v}).set_stride(v_stride);
if (is_ragged_q) {
dQ->set_ragged_offset(offset_q);
}
if (is_ragged_kv) {
dK->set_ragged_offset(offset_k);
dV->set_ragged_offset(offset_v);
}
std::tuple<std::shared_ptr<fe::graph::Tensor_attributes>, // q
......@@ -757,9 +796,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
auto bias_tuple = is_bias ? std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr);
auto padding_tuple =
is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr);
auto offset_qkvo_tuple = is_ragged ? std::make_tuple(offset_q, offset_k, offset_v, offset_o)
: std::make_tuple(nullptr, nullptr, nullptr, nullptr);
auto offset_s_tuple = (is_ragged && cudnn_runtime_version >= 90600)
auto offset_qo_tuple =
is_ragged_q ? std::make_tuple(offset_q, offset_o) : std::make_tuple(nullptr, nullptr);
auto offset_kv_tuple =
is_ragged_kv ? std::make_tuple(offset_k, offset_v) : std::make_tuple(nullptr, nullptr);
auto offset_s_tuple = (is_ragged_q && cudnn_runtime_version >= 90600)
? std::make_tuple(offset_stats)
: std::make_tuple(nullptr);
auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset)
......@@ -773,14 +814,14 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
auto return_tuple =
std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, padding_tuple,
offset_qkvo_tuple, offset_s_tuple, dropout_tuple);
offset_qo_tuple, offset_kv_tuple, offset_s_tuple, dropout_tuple);
cache.insert({descriptor, return_tuple});
return return_tuple;
};
auto [mha_graph, q, k, v, o, dO, stats, attn_scale, dQ, dK, dV, bias, dBias, seq_q, seq_kv,
offset_q, offset_k, offset_v, offset_o, offset_stats, dropout_seed, dropout_offset] =
offset_q, offset_o, offset_k, offset_v, offset_stats, dropout_seed, dropout_offset] =
get_graph(sdpa_f16_bprop_cache, descriptor);
// Exit to request upper level API to allocate memory if needed
......@@ -792,11 +833,12 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
const size_t num_bytes_per_ragged_offset =
alignTo<16>((b + 1) * typeToSize(ragged_offset_type));
size_t seqlen_offsets_workspace_size = 0;
if (is_ragged) {
if (cudnn_runtime_version >= 90600) {
seqlen_offsets_workspace_size = 5 * num_bytes_per_ragged_offset;
if (is_ragged_q || is_ragged_kv) {
size_t count = 2 * (static_cast<size_t>(is_ragged_q) + static_cast<size_t>(is_ragged_kv));
if (is_ragged_q && cudnn_runtime_version >= 90600) {
seqlen_offsets_workspace_size = (count + 1) * num_bytes_per_ragged_offset;
} else {
seqlen_offsets_workspace_size = 4 * num_bytes_per_ragged_offset;
seqlen_offsets_workspace_size = count * num_bytes_per_ragged_offset;
}
}
if (workspace == nullptr) {
......@@ -845,28 +887,44 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
variant_pack[seq_kv] = devActualSeqlenKV;
}
if (is_ragged) {
if (is_ragged_q || is_ragged_kv) {
constexpr size_t nthreads_per_block = 128;
const size_t grid = (b + nthreads_per_block) / nthreads_per_block;
void *devOffsetsQ =
void *devOffsets =
static_cast<int8_t *>(workspace) + plan_workspace_size + actual_seqlen_workspace_size;
void *devOffsetsK = static_cast<int8_t *>(devOffsetsQ) + num_bytes_per_ragged_offset;
void *devOffsetsV = static_cast<int8_t *>(devOffsetsK) + num_bytes_per_ragged_offset;
void *devOffsetsO = static_cast<int8_t *>(devOffsetsV) + num_bytes_per_ragged_offset;
void *devOffsetsQ = nullptr;
void *devOffsetsO = nullptr;
if (is_ragged_q) {
devOffsetsQ = devOffsets;
devOffsetsO = static_cast<int8_t *>(devOffsetsQ) + num_bytes_per_ragged_offset;
}
void *devOffsetsK = nullptr;
void *devOffsetsV = nullptr;
if (is_ragged_kv) {
devOffsetsK = static_cast<int8_t *>(devOffsets) +
static_cast<int>(is_ragged_q) * 2 * num_bytes_per_ragged_offset;
devOffsetsV = static_cast<int8_t *>(devOffsetsK) + num_bytes_per_ragged_offset;
}
void *devOffsetsS = nullptr;
if (cudnn_runtime_version >= 90600) {
devOffsetsS = static_cast<int8_t *>(devOffsetsO) + num_bytes_per_ragged_offset;
if (is_ragged_q && cudnn_runtime_version >= 90600) {
devOffsetsS = static_cast<int8_t *>(devOffsets) +
(static_cast<int>(is_ragged_q) + static_cast<int>(is_ragged_kv)) * 2 *
num_bytes_per_ragged_offset;
}
const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout);
cu_seqlens_padded_to_offsets<<<grid, nthreads_per_block, 0, stream>>>(
layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast<int32_t *>(devPtrSeqOffsetsQ),
static_cast<int32_t *>(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK,
devOffsetsV, devOffsetsO, devOffsetsS);
if (is_ragged_q) {
variant_pack[offset_q] = devOffsetsQ;
variant_pack[offset_o] = devOffsetsO;
}
if (is_ragged_kv) {
variant_pack[offset_k] = devOffsetsK;
variant_pack[offset_v] = devOffsetsV;
variant_pack[offset_o] = devOffsetsO;
if (cudnn_runtime_version >= 90600) {
}
if (is_ragged_q && cudnn_runtime_version >= 90600) {
variant_pack[offset_stats] = devOffsetsS;
}
}
......@@ -987,11 +1045,12 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
fused_attn_arbitrary_seqlen_fwd_impl(
batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim,
max_batch_size, max_tokens, max_tokens, bias_b, bias_h, is_training, attn_scale, p_dropout,
qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ, devPtrK,
devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
max_batch_size, max_tokens, max_tokens, 0, 0, 0, 0, 0, 0, bias_b, bias_h, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right,
devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed,
devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, nullptr, nullptr, devPtrSeqOffsets,
devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream,
handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......@@ -1095,20 +1154,23 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv,
bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle) {
size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v,
size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded,
const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype;
void *devPtrQ = input_Q->data.dptr;
void *devPtrKV = input_KV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout);
NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
stride = typeToSize(QKV_type) * num_gqa_groups * head_dim;
......@@ -1134,13 +1196,19 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr;
void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr;
void *devPtrPageTableK = page_table_k->data.dptr;
void *devPtrPageTableV = page_table_v->data.dptr;
size_t max_batch_size = 0;
size_t max_tokens_q = 0;
size_t max_tokens_kv = 0;
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
if (q_format == NVTE_QKV_Format::NVTE_THD || kv_format == NVTE_QKV_Format::NVTE_THD) {
max_batch_size = get_max_batch_size(batch);
}
if (q_format == NVTE_QKV_Format::NVTE_THD) {
max_tokens_q = get_max_tokens(num_tokens_q);
}
if (kv_format == NVTE_QKV_Format::NVTE_THD) {
max_tokens_kv = get_max_tokens(num_tokens_kv);
}
......@@ -1150,7 +1218,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
Aux_CTX_Tensors->size = 3;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
} else {
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
......@@ -1168,7 +1236,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
Aux_CTX_Tensors->size = 2;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
} else {
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
......@@ -1203,11 +1271,13 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
fused_attn_arbitrary_seqlen_fwd_impl(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim,
max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, is_training, attn_scale,
p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ,
devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k,
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right,
devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed,
devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV,
devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......@@ -1266,10 +1336,15 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t max_batch_size = 0;
size_t max_tokens_q = 0;
size_t max_tokens_kv = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout);
NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
if (q_format == NVTE_QKV_Format::NVTE_THD || kv_format == NVTE_QKV_Format::NVTE_THD) {
max_batch_size = get_max_batch_size(batch);
}
if (q_format == NVTE_QKV_Format::NVTE_THD) {
max_tokens_q = get_max_tokens(num_tokens_q);
}
if (kv_format == NVTE_QKV_Format::NVTE_THD) {
max_tokens_kv = get_max_tokens(num_tokens_kv);
}
......@@ -1319,17 +1394,20 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
void fused_attn_arbitrary_seqlen_fwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q,
size_t num_tokens_kv, bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k,
size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout);
NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
void *devPtrQ = input_Q->data.dptr;
void *devPtrK = input_K->data.dptr;
void *devPtrV = input_V->data.dptr;
......@@ -1348,13 +1426,19 @@ void fused_attn_arbitrary_seqlen_fwd(
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr;
void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr;
void *devPtrPageTableK = page_table_k->data.dptr;
void *devPtrPageTableV = page_table_v->data.dptr;
size_t max_batch_size = 0;
size_t max_tokens_q = 0;
size_t max_tokens_kv = 0;
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
if (q_format == NVTE_QKV_Format::NVTE_THD || kv_format == NVTE_QKV_Format::NVTE_THD) {
max_batch_size = get_max_batch_size(batch);
}
if (q_format == NVTE_QKV_Format::NVTE_THD) {
max_tokens_q = get_max_tokens(num_tokens_q);
}
if (kv_format == NVTE_QKV_Format::NVTE_THD) {
max_tokens_kv = get_max_tokens(num_tokens_kv);
}
......@@ -1364,7 +1448,7 @@ void fused_attn_arbitrary_seqlen_fwd(
Aux_CTX_Tensors->size = 3;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
} else {
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
......@@ -1382,7 +1466,7 @@ void fused_attn_arbitrary_seqlen_fwd(
Aux_CTX_Tensors->size = 2;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
} else {
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
......@@ -1417,11 +1501,13 @@ void fused_attn_arbitrary_seqlen_fwd(
fused_attn_arbitrary_seqlen_fwd_impl(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v,
max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, is_training, attn_scale,
p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ,
devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV,
get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k,
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right,
devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed,
devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV,
devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......@@ -1470,10 +1556,15 @@ void fused_attn_arbitrary_seqlen_bwd(
size_t max_batch_size = 0;
size_t max_tokens_q = 0;
size_t max_tokens_kv = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout);
NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
if (q_format == NVTE_QKV_Format::NVTE_THD || kv_format == NVTE_QKV_Format::NVTE_THD) {
max_batch_size = get_max_batch_size(batch);
}
if (q_format == NVTE_QKV_Format::NVTE_THD) {
max_tokens_q = get_max_tokens(num_tokens_q);
}
if (kv_format == NVTE_QKV_Format::NVTE_THD) {
max_tokens_kv = get_max_tokens(num_tokens_kv);
}
......
......@@ -38,13 +38,15 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv,
bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v,
size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded,
const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
......@@ -61,13 +63,15 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
void fused_attn_arbitrary_seqlen_fwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q,
size_t num_tokens_kv, bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k,
size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
......
......@@ -1679,6 +1679,12 @@ void fused_attn_fp8_fwd_impl_v1(
s_kv,
d,
d,
0,
0,
0,
0,
0,
0,
bias_b,
bias_h,
scaling_factor,
......@@ -1977,6 +1983,12 @@ void fused_attn_fp8_bwd_impl_v1(
s_kv,
d,
d,
0,
0,
0,
0,
0,
0,
bias_b,
bias_h,
scaling_factor,
......
......@@ -117,6 +117,7 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6
}
break;
case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD:
if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) ||
(matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) ||
(matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) ||
......@@ -223,6 +224,9 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6
break;
case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_THD_THD_THD:
case NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD:
if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) ||
(matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) {
strideA[batch_dim_idx] = s_q * h * d;
......@@ -243,6 +247,52 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6
strideA[hidden_transpose_dim_idx] = 1;
}
break;
case NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD:
if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) ||
(matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) {
strideA[batch_dim_idx] = s_kv * h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = h * d;
strideA[hidden_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) ||
(matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) {
strideA[batch_dim_idx] = s_kv * h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_transpose_dim_idx] = h * d;
strideA[hidden_transpose_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) ||
(matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) {
strideA[batch_dim_idx] = h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = b * h * d;
strideA[hidden_dim_idx] = 1;
}
break;
case NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD:
if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) ||
(matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) {
strideA[batch_dim_idx] = h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = b * h * d;
strideA[hidden_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) ||
(matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) {
strideA[batch_dim_idx] = h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_transpose_dim_idx] = b * h * d;
strideA[hidden_transpose_dim_idx] = 1;
} else if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) ||
(matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) {
strideA[batch_dim_idx] = s_q * h * d;
strideA[head_dim_idx] = d;
strideA[seqlen_dim_idx] = h * d;
strideA[hidden_dim_idx] = 1;
}
break;
}
if (matrix == NVTE_QKV_Matrix::NVTE_S_Matrix) {
......@@ -379,30 +429,46 @@ __device__ void cu_seqlens_padded_to_offsets_impl(
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
auto cu_seqlens_id = min(tid, actual_b);
if (tid <= max_b) {
offsets_o[tid] = h * d_v * cu_seqlens_q_padded[cu_seqlens_id];
if (offsets_s != nullptr) {
offsets_s[tid] = h * cu_seqlens_q_padded[cu_seqlens_id];
}
if (offsets_q != nullptr && offsets_o != nullptr) {
offsets_o[tid] = h * d_v * cu_seqlens_q_padded[cu_seqlens_id];
switch (layout_group) {
case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD:
case NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD:
offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[cu_seqlens_id];
offsets_k[tid] = hg * d_qk * cu_seqlens_kv_padded[cu_seqlens_id];
offsets_v[tid] = hg * d_v * cu_seqlens_kv_padded[cu_seqlens_id];
break;
case NVTE_QKV_Layout_Group::NVTE_3HD:
case NVTE_QKV_Layout_Group::NVTE_H3D:
offsets_q[tid] = 3 * h * d_qk * cu_seqlens_q_padded[cu_seqlens_id];
offsets_k[tid] = offsets_q[cu_seqlens_id];
offsets_v[tid] = offsets_q[cu_seqlens_id];
break;
case NVTE_QKV_Layout_Group::NVTE_HD_2HD:
case NVTE_QKV_Layout_Group::NVTE_HD_H2D:
offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[cu_seqlens_id];
break;
}
}
if (offsets_k != nullptr && offsets_v != nullptr) {
switch (layout_group) {
case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD:
case NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD:
offsets_k[tid] = hg * d_qk * cu_seqlens_kv_padded[cu_seqlens_id];
offsets_v[tid] = hg * d_v * cu_seqlens_kv_padded[cu_seqlens_id];
break;
case NVTE_QKV_Layout_Group::NVTE_3HD:
case NVTE_QKV_Layout_Group::NVTE_H3D:
offsets_k[tid] = 3 * h * d_qk * cu_seqlens_q_padded[cu_seqlens_id];
offsets_v[tid] = offsets_k[cu_seqlens_id];
break;
case NVTE_QKV_Layout_Group::NVTE_HD_2HD:
case NVTE_QKV_Layout_Group::NVTE_HD_H2D:
offsets_k[tid] = 2 * hg * d_qk * cu_seqlens_kv_padded[cu_seqlens_id];
offsets_v[tid] = offsets_k[cu_seqlens_id];
break;
}
}
}
}
__global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, int64_t actual_b,
......@@ -433,6 +499,7 @@ DType get_ragged_offset_dtype(NVTE_QKV_Layout_Group layout_group, int64_t num_at
std::array<int64_t, 4> offsets_qkvo{};
switch (layout_group) {
case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD:
case NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD:
offsets_qkvo[0] = num_attn_heads * head_dim_qk * max_seqlen_q;
offsets_qkvo[1] = num_gqa_groups * head_dim_qk * max_seqlen_kv;
offsets_qkvo[2] = num_gqa_groups * head_dim_v * max_seqlen_kv;
......
......@@ -93,6 +93,12 @@ struct FADescriptor_v1 {
std::int64_t s_kv;
std::int64_t d_qk;
std::int64_t d_v;
std::int64_t num_pages_k;
std::int64_t num_pages_v;
std::int64_t page_size_k;
std::int64_t page_size_v;
std::int64_t max_pages_per_seq_k;
std::int64_t max_pages_per_seq_v;
std::int64_t bias_b;
std::int64_t bias_h;
float attnScale;
......@@ -108,13 +114,16 @@ struct FADescriptor_v1 {
cudnn_frontend::DataType_t bwd_tensor_type;
bool operator<(const FADescriptor_v1 &rhs) const {
return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, bias_b, bias_h, attnScale, isTraining,
dropoutProbability, layout, mask_type, window_size_left, window_size_right,
deterministic, bias_type, fwd_tensor_type, bwd_tensor_type) <
std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.bias_b,
rhs.bias_h, rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout,
rhs.mask_type, rhs.window_size_left, rhs.window_size_right, rhs.deterministic,
rhs.bias_type, rhs.fwd_tensor_type, rhs.bwd_tensor_type);
return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k,
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h,
attnScale, isTraining, dropoutProbability, layout, mask_type, window_size_left,
window_size_right, deterministic, bias_type, fwd_tensor_type, bwd_tensor_type) <
std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k,
rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k,
rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining,
rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.window_size_left,
rhs.window_size_right, rhs.deterministic, rhs.bias_type, rhs.fwd_tensor_type,
rhs.bwd_tensor_type);
}
};
......
......@@ -25,7 +25,7 @@ extern "C" {
* head size, and the total number of tokens in a batch, i.e. `t = sum(s_i) for i = 0...b-1`.
* `SBHD` and `BSHD`-based layouts are used when sequences in a batch are of equal length
* or padded to the same length, and `THD`-based layouts are used when sequences have
* different lengths in a batch.
* different lengths in a batch. `Paged_KV`-based layouts are used for paged attention.
*/
enum NVTE_QKV_Layout {
NVTE_SB3HD = 0, /*!< SB3HD layout */
......@@ -43,6 +43,16 @@ enum NVTE_QKV_Layout {
NVTE_THD_T2HD = 12, /*!< THD_T2HD layout */
NVTE_THD_TH2D = 13, /*!< THD_TH2D layout */
NVTE_THD_THD_THD = 14, /*!< THD_THD_THD layout */
NVTE_SBHD_BSHD_BSHD = 15, /*!< SBHD_BSHD_BSHD layout */
NVTE_BSHD_SBHD_SBHD = 16, /*!< BSHD_SBHD_SBHD layout */
NVTE_THD_BSHD_BSHD = 17, /*!< THD_BSHD_BSHD layout */
NVTE_THD_SBHD_SBHD = 18, /*!< THD_SBHD_SBHD layout */
NVTE_Paged_KV_BSHD_BSHD_BSHD = 19, /*!< Paged_KV_BSHD_BSHD_BSHD layout */
NVTE_Paged_KV_BSHD_SBHD_SBHD = 20, /*!< Paged_KV_BSHD_SBHD_SBHD layout */
NVTE_Paged_KV_SBHD_BSHD_BSHD = 21, /*!< Paged_KV_SBHD_BSHD_BSHD layout */
NVTE_Paged_KV_SBHD_SBHD_SBHD = 22, /*!< Paged_KV_SBHD_SBHD_SBHD layout */
NVTE_Paged_KV_THD_BSHD_BSHD = 23, /*!< Paged_KV_THD_BSHD_BSHD layout */
NVTE_Paged_KV_THD_SBHD_SBHD = 24, /*!< Paged_KV_THD_SBHD_SBHD layout */
};
/*! \enum NVTE_QKV_Layout_Group
......@@ -59,18 +69,28 @@ enum NVTE_QKV_Layout_Group {
NVTE_HD_H2D = 3,
/*! HD_HD_HD QKV layouts, i.e. BSHD_BSHD_BSHD, SBHD_SBHD_SBHD, THD_THD_THD */
NVTE_HD_HD_HD = 4,
/*! Paged_KV_HD_HD_HD QKV layouts, e.g. Paged_KV_BSHD_BSHD_BSHD, Paged_KV_THD_SBHD_SBHD */
NVTE_Paged_KV_HD_HD_HD = 5,
};
/*! \enum NVTE_QKV_Format
* \brief QKV formats
*/
enum NVTE_QKV_Format {
/*! SBHD QKV format, i.e. SB3HD, SBH3D, SBHD_SB2HD, SBHD_SBH2D, SBHD_SBHD_SBHD */
/*! SBHD QKV format, i.e. SB3HD, SBH3D, SBHD_SB2HD, SBHD_SBH2D, SBHD_SBHD_SBHD, Paged_KV_SBHD_SBHD_SBHD */
NVTE_SBHD = 0,
/*! BSHD QKV format, i.e. BS3HD, BSH3D, BSHD_BS2HD, BSHD_BSH2D, BSHD_BSHD_BSHD */
/*! BSHD QKV format, i.e. BS3HD, BSH3D, BSHD_BS2HD, BSHD_BSH2D, BSHD_BSHD_BSHD, Paged_KV_BSHD_BSHD_BSHD */
NVTE_BSHD = 1,
/*! THD QKV format, i.e. T3HD, TH3D, THD_T2HD, THD_TH2D, THD_THD_THD */
NVTE_THD = 2,
/*! BSHD format for Q and SBHD format for KV, i.e. BSHD_SBHD_SBHD, Paged_KV_BSHD_SBHD_SBHD */
NVTE_BSHD_2SBHD = 3,
/*! SBHD format for Q and BSHD format for KV, i.e. SBHD_BSHD_BSHD, Paged_KV_SBHD_BSHD_BSHD */
NVTE_SBHD_2BSHD = 4,
/*! THD format for Q and BSHD format for KV, i.e. THD_BSHD_BSHD, Paged_KV_THD_BSHD_BSHD */
NVTE_THD_2BSHD = 5,
/*! THD format for Q and SBHD format for KV, i.e. THD_SBHD_SBHD, Paged_KV_THD_SBHD_SBHD */
NVTE_THD_2SBHD = 6,
};
/*! \enum NVTE_Bias_Type
......@@ -135,6 +155,22 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout);
*/
NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout);
/*! \brief Get Q format for a given QKV layout.
*
* \param[in] qkv_layout QKV layout, e.g. sbhd_bshd_bshd.
*
* \return q format, e.g. sbhd.
*/
NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout);
/*! \brief Get KV format for a given QKV layout.
*
* \param[in] qkv_layout QKV layout, e.g. sbhd_bshd_bshd.
*
* \return kv format, e.g. bshd.
*/
NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
/*! \brief Get fused attention backend based on input parameters.
*
* \param[in] q_dtype The data type of Tensor Q.
......@@ -312,6 +348,8 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1].
* \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv_padded Cumulative sequence offsets for KV, [batch_size + 1].
* \param[in] page_table_k Page table for K cache, [batch_size, max_pages_per_seq_k].
* \param[in] page_table_v Page table for V cache, [batch_size, max_pages_per_seq_v].
* \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen_q Max sequence length used for computing for Q.
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
......@@ -329,16 +367,14 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias,
NVTETensor S, NVTETensor O, NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream);
void nvte_fused_attn_fwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed KV input.
*
......@@ -445,6 +481,8 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1].
* \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv_padded Cumulative sequence offsets for KV, [batch_size + 1].
* \param[in] page_table_k Page table for K cache, [batch_size, max_pages_per_seq_k].
* \param[in] page_table_v Page table for V cache, [batch_size, max_pages_per_seq_v].
* \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen_q Max sequence length used for computing for Q.
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
......@@ -465,7 +503,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
......
......@@ -36,6 +36,14 @@
.value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \
.value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \
NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \
pybind11::enum_<NVTE_QKV_Format>(m, "NVTE_QKV_Format", pybind11::module_local()) \
.value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD) \
.value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD) \
.value("NVTE_THD", NVTE_QKV_Format::NVTE_THD) \
.value("NVTE_SBHD_2BSHD", NVTE_QKV_Format::NVTE_SBHD_2BSHD) \
.value("NVTE_BSHD_2SBHD", NVTE_QKV_Format::NVTE_BSHD_2SBHD) \
.value("NVTE_THD_2BSHD", NVTE_QKV_Format::NVTE_THD_2BSHD) \
.value("NVTE_THD_2SBHD", NVTE_QKV_Format::NVTE_THD_2SBHD); \
pybind11::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout", pybind11::module_local()) \
.value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \
.value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \
......@@ -51,7 +59,17 @@
.value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) \
.value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \
.value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \
.value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); \
.value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD) \
.value("NVTE_SBHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD) \
.value("NVTE_BSHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD) \
.value("NVTE_THD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD) \
.value("NVTE_THD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD) \
.value("NVTE_Paged_KV_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD) \
.value("NVTE_Paged_KV_BSHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD) \
.value("NVTE_Paged_KV_SBHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD) \
.value("NVTE_Paged_KV_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD) \
.value("NVTE_Paged_KV_THD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD) \
.value("NVTE_Paged_KV_THD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD); \
pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) \
.value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \
.value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \
......
......@@ -129,6 +129,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
auto o_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);
auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt32);
NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors);
......@@ -164,15 +165,16 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_rng_state_tensor.data(),
q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, window_size_left, window_size_right, query_workspace_tensor.data(),
nullptr);
ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_page_table_tensor.data(),
dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen,
kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type,
mask_type, window_size_left, window_size_right, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_fwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(),
dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, window_size_left,
window_size_right, query_workspace_tensor.data(), nullptr);
......@@ -256,6 +258,7 @@ static void FusedAttnForwardImpl(
backend, softmax_aux);
/* Call the underlying NVTE API */
auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt32);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype);
......@@ -273,9 +276,10 @@ static void FusedAttnForwardImpl(
nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), rng_state_tensor.data(),
q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream);
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(),
dummy_page_table_tensor.data(), rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen,
is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
window_size_left, window_size_right, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, head_dim};
......@@ -283,13 +287,13 @@ static void FusedAttnForwardImpl(
auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v_tensor = TensorWrapper(v, v_shape, dtype);
nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
window_size_left, window_size_right, workspace_tensor.data(), stream);
nvte_fused_attn_fwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(),
dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(),
q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
}
......
......@@ -68,6 +68,7 @@ from transformer_engine.pytorch.distributed import (
)
from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
from transformer_engine.pytorch.graph import is_graph_capturing
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.tensor.quantized_tensor import (
QuantizedTensor,
prepare_for_saving,
......@@ -76,7 +77,6 @@ from transformer_engine.pytorch.tensor.quantized_tensor import (
# Import attention utils
import transformer_engine.pytorch.dot_product_attention.utils as dpa_utils
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils as fa_utils
from transformer_engine.pytorch.dot_product_attention.utils import AttentionLogging as attn_log
from transformer_engine.pytorch.dot_product_attention.rope import apply_rotary_pos_emb
......@@ -85,7 +85,7 @@ from transformer_engine.pytorch.dot_product_attention.rope import apply_rotary_p
# Setup Attention Logging
attn_log.setup_logging()
# Global vars for flash attn imports
# Global vars for flash attn v2 and v3 imports
flash_attn_cuda_bwd = None
flash_attn_func = None
flash_attn_varlen_func = None
......@@ -96,15 +96,7 @@ _flash_attn_varlen_bwd = None
try:
fa_utils.version = PkgVersion(get_pkg_version("flash-attn"))
except PackageNotFoundError:
if (
torch.cuda.is_available()
and get_device_compute_capability() >= (8, 0)
and dpa_utils._NVTE_FLASH_ATTN
):
attn_log.fa_logger.debug(
"flash-attn v2 is not installed. To use, please install it by"
""" "pip3 install flash-attn".""",
)
pass # only print warning if use_flash_attention_2 = True in get_attention_backend
else:
if torch.cuda.is_available() and get_device_compute_capability() >= (10, 0):
if fa_utils.version_required_blackwell <= fa_utils.version <= fa_utils.max_version:
......@@ -143,35 +135,20 @@ else:
),
fa_utils.version,
)
# Detect flash-attn v3 in the environment
# This section will be removed when FA3 is released as a regular FA package,
# i.e. flashattn-hopper 3.0.0 as flash-attn 3.0.0
try:
fa_utils.fa3_version = PkgVersion(get_pkg_version("flashattn-hopper"))
fa_utils.fa3_version = PkgVersion(get_pkg_version("flash-attn-3"))
except PackageNotFoundError:
if (
torch.cuda.is_available()
and get_device_compute_capability() >= (9, 0)
and dpa_utils._NVTE_FLASH_ATTN
):
attn_log.fa_logger.debug(
"flash-attn v3 is not installed. To use, please install it by \n%s",
fa_utils.v3_installation_steps,
)
pass # only print warning if use_flash_attention_3 = True in get_attention_backend
else:
from flashattn_hopper.flash_attn_interface import flash_attn_func as flash_attn_func_v3
from flashattn_hopper.flash_attn_interface import (
from flash_attn_3.flash_attn_interface import flash_attn_func as flash_attn_func_v3
from flash_attn_3.flash_attn_interface import (
flash_attn_varlen_func as flash_attn_varlen_func_v3,
)
from flashattn_hopper.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3
from flashattn_hopper.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3
from flashattn_hopper.flash_attn_interface import (
_flash_attn_varlen_forward as _flash_attn_varlen_fwd_v3,
)
from flashattn_hopper.flash_attn_interface import (
_flash_attn_varlen_backward as _flash_attn_varlen_bwd_v3,
from flash_attn_3.flash_attn_interface import (
flash_attn_with_kvcache as flash_attn_with_kvcache_v3,
)
from flash_attn_3.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3
from flash_attn_3.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3
fa_utils.set_flash_attention_3_params()
......@@ -179,6 +156,7 @@ else:
_attention_backends = {
"attention_params": None,
"use_flash_attention": None,
"flash_attention_backend": None,
"use_fused_attention": None,
"fused_attention_backend": None,
"use_unfused_attention": None,
......@@ -487,6 +465,89 @@ def _get_cu_seqlens_info_with_cp(
return _cu_seqlens_info_with_cp_cache[(batch_size, max_seqlen, cp_size)]
def get_fa_args(
forward: bool,
use_flash_attn_3: bool,
qkv_format: str,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
dq=None,
dk=None,
dv=None,
):
"""Get forward/backward arguments for flash-attn v2 and v3."""
if use_flash_attn_3:
if forward:
if qkv_format == "thd":
return [
*[None] * 4, # k_new, v_new, qv, out
cu_seqlens_q,
cu_seqlens_kv,
*[None] * 3, # cu_seqlens_k_new, seqused_q, seqused_k
max_seqlen_q,
max_seqlen_kv,
*[None]
* 8, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale
]
return [
*[None]
* 9, # k_new, v_new, qv, out, cu_seqlens_q, cu_seqlens_kv, cu_seqlens_k_new, seqused_q, seqused_k
max_seqlen_q,
max_seqlen_kv,
*[None]
* 8, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale
]
if qkv_format == "thd":
return [
cu_seqlens_q,
cu_seqlens_kv,
None, # sequed_q
None, # sequed_k
max_seqlen_q,
max_seqlen_kv,
dq,
dk,
dv,
]
return [
None, # cu_seqlens_q
None, # cu_seqlens_kv
None, # sequed_q
None, # sequed_k
max_seqlen_q,
max_seqlen_kv,
dq,
dk,
dv,
]
if forward:
if qkv_format == "thd":
return [
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
]
return []
if qkv_format == "thd":
return [
dq,
dk,
dv,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
]
return [
dq,
dk,
dv,
]
class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
"""
Attention implementation with context parallelism. Exchange KV between CP ranks
......@@ -527,6 +588,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
cp_stream,
quantizers,
pad_between_seqs,
use_flash_attn_3,
):
# pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.forward")
......@@ -685,16 +747,15 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if use_fused_attention:
softmax_lse_in_packed_format = get_cudnn_version() >= (9, 6, 0)
else:
softmax_lse_in_packed_format = fa_utils.v2_6_0_plus or fa_utils.use_v3
softmax_lse_in_packed_format = fa_utils.v2_6_0_plus or use_flash_attn_3
flash_attn_fwd = None
if not use_fused_attention:
fa_forward_kwargs = {"softmax_scale": softmax_scale}
if fa_utils.use_v3:
if qkv_format == "thd":
flash_attn_fwd = _flash_attn_varlen_fwd_v3
else:
flash_attn_fwd = _flash_attn_fwd_v3
if use_flash_attn_3:
flash_attn_fwd = (
_flash_attn_fwd_v3 # pylint: disable=possibly-used-before-assignment
)
fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1)
else:
if qkv_format == "thd":
......@@ -703,7 +764,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
flash_attn_fwd = _flash_attn_fwd
fa_forward_kwargs["dropout_p"] = dropout_p
fa_forward_kwargs["return_softmax"] = False
if (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus) or fa_utils.use_v3:
if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus:
fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1)
elif fa_utils.v2_7_0_plus:
fa_forward_kwargs["window_size_left"] = -1
......@@ -856,14 +917,15 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
attn_biases[i] = rest[0] if len(rest) > 0 else None
else:
fa_forward_args_thd = []
if qkv_format == "thd":
fa_forward_args_thd = [
cu_seqlens_q_per_step[i],
cu_seqlens_kv_per_step[i],
max_seqlen_q,
max_seqlen_kv,
]
fa_forward_args_thd = get_fa_args(
True,
use_flash_attn_3,
qkv_format,
cu_seqlens_q=cu_seqlens_q_per_step[i],
cu_seqlens_kv=cu_seqlens_kv_per_step[i],
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
)
fa_outputs = flash_attn_fwd(
q_inputs[i % 2],
(
......@@ -883,12 +945,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if not fa_utils.v2_7_0_plus:
out_per_step[i] = fa_outputs[4]
softmax_lse_per_step[i] = fa_outputs[5]
if not fa_utils.use_v3:
if not use_flash_attn_3:
rng_states[i] = fa_outputs[7]
else:
out_per_step[i] = fa_outputs[0]
softmax_lse_per_step[i] = fa_outputs[1]
if not fa_utils.use_v3:
if not use_flash_attn_3:
rng_states[i] = fa_outputs[3]
elif i <= rank:
if pad_between_seqs:
......@@ -986,15 +1048,16 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
attn_biases[i] = rest[0] if len(rest) > 0 else None
else:
fa_forward_args_thd = []
if qkv_format == "thd":
fa_forward_args_thd = [
cu_seqlens_q_per_step[i],
cu_seqlens_kv_per_step[i],
max_seqlen_q,
max_seqlen_kv // 2,
]
if fa_utils.use_v3 or (
fa_forward_args_thd = get_fa_args(
True,
use_flash_attn_3,
qkv_format,
cu_seqlens_q=cu_seqlens_q_per_step[i],
cu_seqlens_kv=cu_seqlens_kv_per_step[i],
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv // 2,
)
if use_flash_attn_3 or (
fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus
):
fa_forward_kwargs["window_size"] = (-1, -1)
......@@ -1020,12 +1083,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if not fa_utils.v2_7_0_plus:
out_per_step[i] = fa_outputs[4]
softmax_lse_per_step[i] = fa_outputs[5]
if not fa_utils.use_v3:
if not use_flash_attn_3:
rng_states[i] = fa_outputs[7]
else:
out_per_step[i] = fa_outputs[0]
softmax_lse_per_step[i] = fa_outputs[1]
if not fa_utils.use_v3:
if not use_flash_attn_3:
rng_states[i] = fa_outputs[3]
else:
if pad_between_seqs:
......@@ -1132,15 +1195,16 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
attn_biases[i] = rest[0] if len(rest) > 0 else None
else:
fa_forward_args_thd = []
if qkv_format == "thd":
fa_forward_args_thd = [
cu_seqlens_q_per_step[i],
cu_seqlens_kv_per_step[i],
max_seqlen_q // 2,
max_seqlen_kv,
]
if fa_utils.use_v3 or (
fa_forward_args_thd = get_fa_args(
True,
use_flash_attn_3,
qkv_format,
cu_seqlens_q=cu_seqlens_q_per_step[i],
cu_seqlens_kv=cu_seqlens_kv_per_step[i],
max_seqlen_q=max_seqlen_q // 2,
max_seqlen_kv=max_seqlen_kv,
)
if use_flash_attn_3 or (
fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus
):
fa_forward_kwargs["window_size"] = (-1, -1)
......@@ -1166,12 +1230,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if not fa_utils.v2_7_0_plus:
out_per_step[i] = fa_outputs[4]
softmax_lse_per_step[i] = fa_outputs[5]
if not fa_utils.use_v3:
if not use_flash_attn_3:
rng_states[i] = fa_outputs[7]
else:
out_per_step[i] = fa_outputs[0]
softmax_lse_per_step[i] = fa_outputs[1]
if not fa_utils.use_v3:
if not use_flash_attn_3:
rng_states[i] = fa_outputs[3]
else:
if pad_between_seqs:
......@@ -1254,14 +1318,15 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors
attn_biases[i] = rest[0] if len(rest) > 0 else None
else:
fa_forward_args_thd = []
if qkv_format == "thd":
fa_forward_args_thd = [
cu_seqlens_q_per_step[i],
cu_seqlens_kv_per_step[i],
max_seqlen_q,
max_seqlen_kv,
]
fa_forward_args_thd = get_fa_args(
True,
use_flash_attn_3,
qkv_format,
cu_seqlens_q=cu_seqlens_q_per_step[i],
cu_seqlens_kv=cu_seqlens_kv_per_step[i],
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
)
fa_outputs = flash_attn_fwd(
q,
(
......@@ -1281,12 +1346,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if not fa_utils.v2_7_0_plus:
out_per_step[i] = fa_outputs[4]
softmax_lse_per_step[i] = fa_outputs[5]
if not fa_utils.use_v3:
if not use_flash_attn_3:
rng_states[i] = fa_outputs[7]
else:
out_per_step[i] = fa_outputs[0]
softmax_lse_per_step[i] = fa_outputs[1]
if not fa_utils.use_v3:
if not use_flash_attn_3:
rng_states[i] = fa_outputs[3]
if i > 0:
......@@ -1478,6 +1543,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
ctx.fp8_meta = fp8_meta
ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8
ctx.use_flash_attn_3 = use_flash_attn_3
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.forward")
return out_ret
......@@ -1650,11 +1716,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
flash_attn_bwd = None
if not ctx.use_fused_attention:
fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
if fa_utils.use_v3:
if ctx.qkv_format == "thd":
flash_attn_bwd = _flash_attn_varlen_bwd_v3
else:
flash_attn_bwd = _flash_attn_bwd_v3
if ctx.use_flash_attn_3:
flash_attn_bwd = (
_flash_attn_bwd_v3 # pylint: disable=possibly-used-before-assignment
)
fa_backward_kwargs["deterministic"] = ctx.deterministic
else:
if ctx.qkv_format == "thd":
......@@ -1792,20 +1857,34 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
else:
dq_ = torch.empty_like(q_)
dkv_ = torch.empty_like(kv_)
fa_backward_args_thd = []
if ctx.qkv_format == "thd":
fa_backward_args_thd = [
cu_seqlens_q_per_step[cp_size - i - 1],
cu_seqlens_kv_per_step[cp_size - i - 1],
ctx.max_seqlen_q,
ctx.max_seqlen_kv,
]
if fa_utils.use_v3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus):
fa_backward_args_thd = get_fa_args(
False,
ctx.use_flash_attn_3,
ctx.qkv_format,
cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1],
cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1],
max_seqlen_q=ctx.max_seqlen_q,
max_seqlen_kv=ctx.max_seqlen_kv,
dq=dq_,
dk=(
dkv_[..., 0, :, :]
if ctx.qkv_format in ["bshd", "sbhd"]
else dkv_[0]
),
dv=(
dkv_[..., 1, :, :]
if ctx.qkv_format in ["bshd", "sbhd"]
else dkv_[1]
),
)
if ctx.use_flash_attn_3 or (
fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus
):
fa_backward_kwargs["window_size"] = (-1, 0)
elif fa_utils.v2_7_0_plus:
fa_backward_kwargs["window_size_left"] = -1
fa_backward_kwargs["window_size_right"] = 0
if not fa_utils.use_v3:
if not ctx.use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
flash_attn_bwd(
dout_,
......@@ -1814,9 +1893,6 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
out_,
softmax_lse,
dq_,
dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0],
dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1],
*fa_backward_args_thd,
causal=True,
**fa_backward_kwargs,
......@@ -1907,20 +1983,34 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
else:
dq_ = torch.empty_like(q_)
dkv_ = torch.empty_like(kv_)
fa_backward_args_thd = []
if ctx.qkv_format == "thd":
fa_backward_args_thd = [
cu_seqlens_q_per_step[cp_size - i - 1],
cu_seqlens_kv_per_step[cp_size - i - 1],
ctx.max_seqlen_q,
ctx.max_seqlen_kv // 2,
]
if fa_utils.use_v3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus):
fa_backward_args_thd = get_fa_args(
False,
ctx.use_flash_attn_3,
ctx.qkv_format,
cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1],
cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1],
max_seqlen_q=ctx.max_seqlen_q,
max_seqlen_kv=ctx.max_seqlen_kv // 2,
dq=dq_,
dk=(
dkv_[..., 0, :, :]
if ctx.qkv_format in ["bshd", "sbhd"]
else dkv_[0]
),
dv=(
dkv_[..., 1, :, :]
if ctx.qkv_format in ["bshd", "sbhd"]
else dkv_[1]
),
)
if ctx.use_flash_attn_3 or (
fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus
):
fa_backward_kwargs["window_size"] = (-1, -1)
if fa_utils.v2_7_0_plus:
elif fa_utils.v2_7_0_plus:
fa_backward_kwargs["window_size_left"] = -1
fa_backward_kwargs["window_size_right"] = -1
if not fa_utils.use_v3:
if not ctx.use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
flash_attn_bwd(
dout_,
......@@ -1929,9 +2019,6 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
out_,
softmax_lse,
dq_,
dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0],
dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1],
*fa_backward_args_thd,
causal=False,
**fa_backward_kwargs,
......@@ -2024,20 +2111,34 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
else:
dq_ = torch.empty_like(q_)
dkv_ = torch.empty_like(kv_)
fa_backward_args_thd = []
if ctx.qkv_format == "thd":
fa_backward_args_thd = [
cu_seqlens_q_per_step[cp_size - i - 1],
cu_seqlens_kv_per_step[cp_size - i - 1],
ctx.max_seqlen_q // 2,
ctx.max_seqlen_kv,
]
if fa_utils.use_v3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus):
fa_backward_args_thd = get_fa_args(
False,
ctx.use_flash_attn_3,
ctx.qkv_format,
cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1],
cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1],
max_seqlen_q=ctx.max_seqlen_q // 2,
max_seqlen_kv=ctx.max_seqlen_kv,
dq=dq_,
dk=(
dkv_[..., 0, :, :]
if ctx.qkv_format in ["bshd", "sbhd"]
else dkv_[0]
),
dv=(
dkv_[..., 1, :, :]
if ctx.qkv_format in ["bshd", "sbhd"]
else dkv_[1]
),
)
if ctx.use_flash_attn_3 or (
fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus
):
fa_backward_kwargs["window_size"] = (-1, -1)
elif fa_utils.v2_7_0_plus:
fa_backward_kwargs["window_size_left"] = -1
fa_backward_kwargs["window_size_right"] = -1
if not fa_utils.use_v3:
if not ctx.use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
flash_attn_bwd(
dout_,
......@@ -2046,9 +2147,6 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1],
out_,
softmax_lse_,
dq_,
dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0],
dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1],
*fa_backward_args_thd,
causal=False,
**fa_backward_kwargs,
......@@ -2118,20 +2216,24 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
else:
dq_ = torch.empty_like(q)
dkv_ = torch.empty_like(kv)
fa_backward_args_thd = []
if ctx.qkv_format == "thd":
fa_backward_args_thd = [
cu_seqlens_q_per_step[cp_size - i - 1],
cu_seqlens_kv_per_step[cp_size - i - 1],
ctx.max_seqlen_q,
ctx.max_seqlen_kv,
]
if fa_utils.use_v3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus):
fa_backward_args_thd = get_fa_args(
False,
ctx.use_flash_attn_3,
ctx.qkv_format,
cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1],
cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1],
max_seqlen_q=ctx.max_seqlen_q,
max_seqlen_kv=ctx.max_seqlen_kv,
dq=dq_,
dk=dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0],
dv=dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1],
)
if ctx.use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus):
fa_backward_kwargs["window_size"] = (-1, -1)
elif fa_utils.v2_7_0_plus:
fa_backward_kwargs["window_size_left"] = -1
fa_backward_kwargs["window_size_right"] = -1
if not fa_utils.use_v3:
if not ctx.use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
flash_attn_bwd(
dout,
......@@ -2140,9 +2242,6 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1],
out,
softmax_lse,
dq_,
dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0],
dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1],
*fa_backward_args_thd,
causal=False,
**fa_backward_kwargs,
......@@ -2382,6 +2481,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -2440,6 +2540,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
window_size,
cp_group,
cp_stream,
use_flash_attn_3,
):
# pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.forward")
......@@ -2465,10 +2566,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
flash_attn_fwd = None
if not use_fused_attention:
fa_forward_kwargs = {"softmax_scale": softmax_scale}
if fa_utils.use_v3:
if qkv_format == "thd":
flash_attn_fwd = _flash_attn_varlen_fwd_v3
else:
if use_flash_attn_3:
flash_attn_fwd = _flash_attn_fwd_v3
else:
if qkv_format == "thd":
......@@ -2584,15 +2682,16 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
window_size=window_size_per_step[i],
)
else:
fa_forward_args_thd = []
if qkv_format == "thd":
fa_forward_args_thd = [
cu_seqlens_q,
cu_seqlens_kv_per_step[i],
max_seqlen_q,
max_seqlen_kv_,
]
if fa_utils.use_v3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus):
fa_forward_args_thd = get_fa_args(
True,
use_flash_attn_3,
qkv_format,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv_per_step[i],
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv_,
)
if use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus):
fa_forward_kwargs["window_size"] = window_size_per_step[i]
elif fa_utils.v2_7_0_plus:
fa_forward_kwargs["window_size_left"] = window_size_per_step[i][0]
......@@ -2608,12 +2707,12 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
if not fa_utils.v2_7_0_plus:
out_per_step[i] = fa_outputs[4]
softmax_lse_per_step[i] = fa_outputs[5]
if not fa_utils.use_v3:
if not use_flash_attn_3:
rng_states[i] = fa_outputs[7]
else:
out_per_step[i] = fa_outputs[0]
softmax_lse_per_step[i] = fa_outputs[1]
if not fa_utils.use_v3:
if not use_flash_attn_3:
rng_states[i] = fa_outputs[3]
if i > 0:
......@@ -2658,6 +2757,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
ctx.attn_mask_type = attn_mask_type
ctx.deterministic = deterministic
ctx.use_fused_attention = use_fused_attention
ctx.use_flash_attn_3 = use_flash_attn_3
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.forward")
return out
......@@ -2713,10 +2813,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
flash_attn_bwd = None
if not ctx.use_fused_attention:
fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
if fa_utils.use_v3:
if ctx.qkv_format == "thd":
flash_attn_bwd = _flash_attn_varlen_bwd_v3
else:
if ctx.use_flash_attn_3:
flash_attn_bwd = _flash_attn_bwd_v3
fa_backward_kwargs["deterministic"] = ctx.deterministic
else:
......@@ -2778,19 +2875,25 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
dq_per_step[i], dk_per_step[i], dv_per_step[i] = [
torch.empty_like(x) for x in [q_, k_, v_]
]
fa_backward_args_thd = []
if ctx.qkv_format == "thd":
fa_backward_args_thd = [
cu_seqlens_q,
cu_seqlens_kv_per_step[i],
ctx.max_seqlen_q,
max_seqlen_kv,
]
if not fa_utils.use_v3:
fa_backward_args_thd = get_fa_args(
False,
ctx.use_flash_attn_3,
ctx.qkv_format,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv_per_step[i],
max_seqlen_q=ctx.max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
dq=dq_per_step[i],
dk=dk_per_step[i],
dv=dv_per_step[i],
)
if not ctx.use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[i]
if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus:
if ctx.use_flash_attn_3 or (
fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus
):
fa_backward_kwargs["window_size"] = window_size_per_step[i]
if fa_utils.v2_7_0_plus:
elif fa_utils.v2_7_0_plus:
fa_backward_kwargs["window_size_left"] = window_size_per_step[i][0]
fa_backward_kwargs["window_size_right"] = window_size_per_step[i][1]
flash_attn_bwd(
......@@ -2800,9 +2903,6 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
v_,
out_,
softmax_lse_per_step[i],
dq_per_step[i],
dk_per_step[i],
dv_per_step[i],
*fa_backward_args_thd,
causal="causal" in ctx.attn_mask_type,
**fa_backward_kwargs,
......@@ -2870,6 +2970,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -2906,6 +3007,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
cp_group,
cp_stream,
quantizers,
use_flash_attn_3,
):
# pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward")
......@@ -2930,10 +3032,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
flash_attn_fwd = None
if not use_fused_attention:
fa_forward_kwargs = {"softmax_scale": softmax_scale}
if fa_utils.use_v3:
if qkv_format == "thd":
flash_attn_fwd = _flash_attn_varlen_fwd_v3
else:
if use_flash_attn_3:
flash_attn_fwd = _flash_attn_fwd_v3
fa_forward_kwargs["window_size"] = window_size
else:
......@@ -2943,7 +3042,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
flash_attn_fwd = _flash_attn_fwd
fa_forward_kwargs["dropout_p"] = dropout_p
fa_forward_kwargs["return_softmax"] = False
if fa_utils.use_v3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus):
if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus:
fa_forward_kwargs["window_size"] = window_size
elif fa_utils.v2_7_0_plus:
fa_forward_kwargs["window_size_left"] = window_size[0]
......@@ -3048,14 +3147,15 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
if fp8:
out = out._data
else:
fa_forward_args_thd = []
if qkv_format == "thd":
fa_forward_args_thd = [
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
]
fa_forward_args_thd = get_fa_args(
True,
use_flash_attn_3,
qkv_format,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
)
fa_outputs = flash_attn_fwd(
q,
k,
......@@ -3066,10 +3166,10 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
)
if not fa_utils.v2_7_0_plus:
out, softmax_lse = fa_outputs[4], fa_outputs[5]
rng_state = fa_outputs[7] if not fa_utils.use_v3 else None
rng_state = fa_outputs[7] if not use_flash_attn_3 else None
else:
out, softmax_lse = fa_outputs[0], fa_outputs[1]
rng_state = fa_outputs[3] if not fa_utils.use_v3 else None
rng_state = fa_outputs[3] if not use_flash_attn_3 else None
aux_ctx_tensors = [softmax_lse, rng_state]
chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, out.device)
......@@ -3152,6 +3252,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
ctx.fp8_meta = fp8_meta
ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8
ctx.use_flash_attn_3 = use_flash_attn_3
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward")
return out_ret
......@@ -3240,11 +3341,10 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
flash_attn_bwd = None
if not ctx.use_fused_attention:
fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale}
if fa_utils.use_v3:
if ctx.qkv_format == "thd":
flash_attn_bwd = _flash_attn_varlen_bwd_v3
else:
flash_attn_bwd = _flash_attn_bwd_v3
if ctx.use_flash_attn_3:
flash_attn_bwd = (
_flash_attn_bwd_v3 # pylint: disable=possibly-used-before-assignment
)
fa_backward_kwargs["window_size"] = ctx.window_size
fa_backward_kwargs["deterministic"] = ctx.deterministic
else:
......@@ -3253,7 +3353,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
else:
flash_attn_bwd = _flash_attn_bwd
fa_backward_kwargs["dropout_p"] = ctx.dropout_p
if fa_utils.use_v3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus):
if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus:
fa_backward_kwargs["window_size"] = ctx.window_size
elif fa_utils.v2_7_0_plus:
fa_backward_kwargs["window_size_left"] = ctx.window_size[0]
......@@ -3321,15 +3421,19 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
else:
softmax_lse, rng_state = aux_ctx_tensors
dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]]
fa_backward_args_thd = []
if ctx.qkv_format == "thd":
fa_backward_args_thd = [
cu_seqlens_q,
cu_seqlens_kv,
ctx.max_seqlen_q,
ctx.max_seqlen_kv,
]
if not fa_utils.use_v3:
fa_backward_args_thd = get_fa_args(
False,
ctx.use_flash_attn_3,
ctx.qkv_format,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=ctx.max_seqlen_q,
max_seqlen_kv=ctx.max_seqlen_kv,
dq=dq,
dk=dk,
dv=dv,
)
if not ctx.use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_state
flash_attn_bwd(
dout,
......@@ -3338,9 +3442,6 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
v,
out,
softmax_lse,
dq,
dk,
dv,
*fa_backward_args_thd,
causal=causal,
**fa_backward_kwargs,
......@@ -3427,6 +3528,7 @@ def attn_forward_func_with_cp(
fp8_meta=None,
quantizers=None,
pad_between_seqs=False,
use_flash_attn_3=False,
) -> torch.Tensor:
"""
Attention implementation with context parallelism.
......@@ -3494,15 +3596,24 @@ def attn_forward_func_with_cp(
]
if cp_comm_type in ["p2p", "a2a+p2p"]:
args += [fp8, fp8_meta, cp_group, cp_global_ranks, cp_stream, quantizers, pad_between_seqs]
args += [
fp8,
fp8_meta,
cp_group,
cp_global_ranks,
cp_stream,
quantizers,
pad_between_seqs,
use_flash_attn_3,
]
out = AttnFuncWithCPAndKVP2P.apply(*args)
elif cp_comm_type == "all_gather":
args.pop(5)
args.pop(8)
args += [window_size, cp_group, cp_stream]
args += [window_size, cp_group, cp_stream, use_flash_attn_3]
out = AttnFuncWithCPAndKVAllGather.apply(*args)
elif cp_comm_type == "a2a":
args += [window_size, fp8, fp8_meta, cp_group, cp_stream, quantizers]
args += [window_size, fp8, fp8_meta, cp_group, cp_stream, quantizers, use_flash_attn_3]
out = AttnFuncWithCPAndQKVOA2A.apply(*args)
else:
raise ValueError(f"Unsupported communication type: {cp_comm_type}!")
......@@ -3694,23 +3805,48 @@ class UnfusedDotProductAttention(torch.nn.Module):
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
inference_params: Optional[InferenceParams] = None,
) -> torch.Tensor:
"""Unfused attention fprop"""
assert (
qkv_layout in QKVLayouts
), f"UnfusedDotProductAttention does not support qkv_layout = {qkv_layout}!"
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
# get q_format and kv_format for training and inference
qkv_format, q_format, _ = dpa_utils.get_qkv_format(qkv_layout, inference_params)
if inference_params is not None and inference_params.is_paged:
key_layer, value_layer = inference_params.convert_paged_to_nonpaged(self.layer_number)
if qkv_format == "bshd":
# convert to sbhd and use sbhd implementation for now
query_layer, key_layer, value_layer = [
x.transpose(0, 1) for x in [query_layer, key_layer, value_layer]
]
if qkv_format == "sbhd_2bshd":
key_layer, value_layer = [x.transpose(0, 1) for x in [key_layer, value_layer]]
total_tokens, batch_size = None, None
if qkv_format == "thd_2bshd":
total_tokens, batch_size = query_layer.shape[0], key_layer.shape[0]
query_layer = tex.convert_thd_to_bshd(
query_layer,
cu_seqlens_q,
batch_size,
inference_params.max_ctx_len,
)
query_layer, key_layer, value_layer = [
x.transpose(0, 1) for x in [query_layer, key_layer, value_layer]
]
batch_size, max_seqlen_q, max_seqlen_kv = (
query_layer.shape[1],
query_layer.shape[0],
key_layer.shape[0],
)
if "padding" in attn_mask_type and attention_mask is None:
attention_mask = dpa_utils.get_padding_mask(
batch_size, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv
)
attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = (
dpa_utils.get_full_mask(
max_seqlen_q,
......@@ -3843,20 +3979,34 @@ class UnfusedDotProductAttention(torch.nn.Module):
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
if qkv_format == "sbhd":
if q_format == "sbhd":
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
context_layer = context_layer.view(seqlen, batch_size, -1)
if qkv_format == "bshd":
if q_format == "bshd":
# [b, np, sq, hn] --> [b, sq, np, hn]
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
# [b, sq, np, hn] --> [b, sq, hp]
context_layer = context_layer.view(batch_size, seqlen, -1)
if q_format == "thd":
# [b, np, sq, hn] --> [b, sq, np, hn]
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
# [b, sq, np, hn] --> [tq, np, hn]
context_layer = tex.convert_bshd_to_thd(
context_layer,
cu_seqlens_q,
total_tokens,
)
# [tq, np, hn] --> [tq, hp]
context_layer = context_layer.view(total_tokens, -1)
return context_layer
......@@ -3951,6 +4101,8 @@ class FlashAttention(torch.nn.Module):
fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None,
quantizers=None,
inference_params: Optional[InferenceParams] = None,
flash_attention_backend: Optional[PkgVersion] = PkgVersion("0"),
) -> torch.Tensor:
"""flash-attn fprop"""
......@@ -3973,8 +4125,10 @@ class FlashAttention(torch.nn.Module):
cp_size *= get_distributed_world_size(group)
context_parallel = cp_size > 1
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
# get q_format and kv_format for training and inference
qkv_format, q_format, kv_format = dpa_utils.get_qkv_format(qkv_layout, inference_params)
# convert q, k, v to bshd if they are in sbhd; qkv_format doesn't change
if all(not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]):
if qkv_format == "sbhd":
# For now just 128, will make it more general in the future
......@@ -3988,8 +4142,11 @@ class FlashAttention(torch.nn.Module):
)
else:
query_layer, key_layer, value_layer = [
x.transpose(0, 1) for x in (query_layer, key_layer, value_layer)
x.transpose(0, 1).contiguous()
for x in (query_layer, key_layer, value_layer)
]
elif q_format == "sbhd" and kv_format == "bshd":
query_layer = query_layer.transpose(0, 1).contiguous()
if context_parallel:
query_layer, key_layer, value_layer = [
x.contiguous() for x in (query_layer, key_layer, value_layer)
......@@ -3997,27 +4154,37 @@ class FlashAttention(torch.nn.Module):
else:
if qkv_format == "sbhd":
query_layer._data, key_layer._data, value_layer._data = [
x.transpose(0, 1)
x.transpose(0, 1).contiguous()
for x in (query_layer._data, key_layer._data, value_layer._data)
]
query_layer, key_layer, value_layer = [
Float8Tensor.make_like(x, data=x._data, shape=x._data.shape)
for x in (query_layer, key_layer, value_layer)
]
elif q_format == "sbhd" and kv_format == "bshd":
query_layer._data = query_layer._data.transpose(0, 1).contiguous()
query_layer = Float8Tensor.make_like(
query_layer, data=query_layer._data, shape=query_layer._data.shape
)
if context_parallel:
query_layer._data, key_layer._data, value_layer._data = [
x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data)
]
batch_size = query_layer.shape[0]
# get batch_size, max_seqlen and cu_seqlens
batch_size, context_len = None, None
if inference_params is None:
if qkv_format in ["sbhd", "bshd"]:
batch_size = query_layer.shape[0]
max_seqlen_q, max_seqlen_kv = query_layer.shape[1], key_layer.shape[1]
max_seqlen_q *= cp_size
max_seqlen_kv *= cp_size
if "padding" in attn_mask_type:
assert not context_parallel, "Padding mask not supported with context parallelism!"
assert (
not context_parallel
), "Padding mask not supported with context parallelism!"
# [b * s, h, d]
query_layer, key_layer, value_layer = [
x.reshape(x.shape[0] * x.shape[1], *x.shape[2:])
......@@ -4083,7 +4250,33 @@ class FlashAttention(torch.nn.Module):
if max_seqlen_kv is None:
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
max_seqlen_kv = seqlens_kv.max().item()
else:
if qkv_format in ["sbhd_2bshd", "bshd"]:
# q is in bshd in both cases from conversion above or the original input
batch_size, context_len = query_layer.shape[:2]
cu_seqlens_q = cu_seqlens_q[: batch_size + 1]
cu_seqlens_kv = cu_seqlens_kv[: batch_size + 1]
# convert from bshd to thd_2bshd for flash_attn_varlen_func/_with_kvcache;
# kernel assumes tensor is contiguous
if isinstance(query_layer, Float8Tensor):
query_layer._data = tex.convert_bshd_to_thd(
query_layer._data,
cu_seqlens_q,
batch_size * context_len,
)
query_layer = Float8Tensor.make_like(
query_layer, data=query_layer._data, shape=query_layer._data.shape
)
else:
query_layer = tex.convert_bshd_to_thd(
query_layer,
cu_seqlens_q,
batch_size * context_len,
)
use_flash_attn_3 = False
if flash_attention_backend is not None and flash_attention_backend > PkgVersion("3.0.0b"):
use_flash_attn_3 = True
if context_parallel and all(
not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]
):
......@@ -4114,6 +4307,7 @@ class FlashAttention(torch.nn.Module):
window_size=window_size,
quantizers=quantizers,
pad_between_seqs=False,
use_flash_attn_3=use_flash_attn_3,
)
else:
......@@ -4126,6 +4320,36 @@ class FlashAttention(torch.nn.Module):
tensor.activation_offloading = True
with self.attention_dropout_ctx():
# | API | use cases
# ----------------------------------------------------------------------
# FA v2 | flash_attn_func | bshd/sbhd + not padding
# | flash_attn_varlen_func | bshd/sbhd + padding
# | | thd + padding
# | | KV cache (not-paged/paged), i.e.
# | | bshd/sbhd/thd + padding
# FA v3 | flash_attn_func | bshd/sbhd + not padding
# | flash_attn_varlen_func | bshd/sbhd + padding
# | | thd + padding
# | flash_attn_with_kvcache | KV cache (not-paged/paged), i.e.
# | | bshd/sbhd/thd + padding
fa_optional_forward_args_thd = []
if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type:
func = (
flash_attn_func if not use_flash_attn_3 else flash_attn_func_v3
) # pylint: disable=possibly-used-before-assignment
else:
if not use_flash_attn_3:
func = flash_attn_varlen_func
elif inference_params is None:
func = flash_attn_varlen_func_v3 # pylint: disable=possibly-used-before-assignment
else:
func = flash_attn_with_kvcache_v3 # pylint: disable=possibly-used-before-assignment
if not use_flash_attn_3 or inference_params is None:
fa_optional_forward_args_thd.append(cu_seqlens_q)
fa_optional_forward_args_thd.append(cu_seqlens_kv)
fa_optional_forward_args_thd.append(max_seqlen_q)
fa_optional_forward_args_thd.append(max_seqlen_kv)
if not use_flash_attn_3:
fa_optional_forward_kwargs = {}
if fa_utils.v2_3_plus:
fa_optional_forward_kwargs["window_size"] = window_size
......@@ -4133,23 +4357,40 @@ class FlashAttention(torch.nn.Module):
fa_optional_forward_kwargs["alibi_slopes"] = alibi_slopes
if fa_utils.v2_4_1_plus:
fa_optional_forward_kwargs["deterministic"] = self.deterministic
fa_optional_forward_args_thd = []
if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type:
func = flash_attn_func if not fa_utils.use_v3 else flash_attn_func_v3
else:
if fa_utils.v2_5_7_plus:
fa_optional_forward_kwargs["block_table"] = None
func = (
flash_attn_varlen_func if not fa_utils.use_v3 else flash_attn_varlen_func_v3
if inference_params is not None:
# use block_table kwarg to support thd_2bshd for non-paged
fa_optional_forward_kwargs["block_table"] = (
inference_params.cache_manager.page_table[:batch_size]
if inference_params.is_paged
else inference_params.cache_manager.batch_indices_post_step.unsqueeze(
1
)[:batch_size]
)
fa_optional_forward_args_thd.append(cu_seqlens_q)
fa_optional_forward_args_thd.append(cu_seqlens_kv)
fa_optional_forward_args_thd.append(max_seqlen_q)
fa_optional_forward_args_thd.append(max_seqlen_kv)
if fa_utils.use_v3:
output = func(
query_layer,
key_layer,
value_layer,
*fa_optional_forward_args_thd,
self.attention_dropout if self.training else 0.0,
softmax_scale=self.softmax_scale,
causal="causal" in attn_mask_type,
**fa_optional_forward_kwargs,
)
else:
fa_3_optional_forward_kwargs = {}
fa_3_optional_forward_kwargs["window_size"] = window_size
if inference_params is None:
fa_3_optional_forward_kwargs["deterministic"] = self.deterministic
else:
fa_3_optional_forward_kwargs["cu_seqlens_q"] = cu_seqlens_q
fa_3_optional_forward_kwargs["max_seqlen_q"] = max_seqlen_q
cache_seqlens = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
fa_3_optional_forward_kwargs["cache_seqlens"] = cache_seqlens
# flash_attn_with_kvcache accepts thd_2bshd for non-paged
if inference_params.is_paged:
fa_3_optional_forward_kwargs["page_table"] = (
inference_params.cache_manager.page_table[:batch_size]
)
if fp8:
QKV_quantizer = quantizers["scaling_fwd"][META_QKV]
torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True)
......@@ -4174,21 +4415,23 @@ class FlashAttention(torch.nn.Module):
query_layer, key_layer, value_layer = (
QKV_quantizer(x) for x in [query_layer, key_layer, value_layer]
)
fa_3_optional_forward_kwargs["descale_q"] = (
query_layer._scale_inv.unsqueeze(0)
batch_size = cu_seqlens_q.shape[0] - 1
num_heads_k = key_layer.shape[-2]
fa_3_optional_forward_kwargs["q_descale"] = (
query_layer._scale_inv.unsqueeze(0).repeat(batch_size, num_heads_k)
)
fa_3_optional_forward_kwargs["descale_k"] = key_layer._scale_inv.unsqueeze(
fa_3_optional_forward_kwargs["k_descale"] = key_layer._scale_inv.unsqueeze(
0
)
fa_3_optional_forward_kwargs["descale_v"] = (
value_layer._scale_inv.unsqueeze(0)
).repeat(batch_size, num_heads_k)
fa_3_optional_forward_kwargs["v_descale"] = (
value_layer._scale_inv.unsqueeze(0).repeat(batch_size, num_heads_k)
)
query_layer, key_layer, value_layer = (
convert_to_torch_float8(x, torch_dtype)
for x in [query_layer, key_layer, value_layer]
)
try:
output, _ = func(
output = func(
query_layer,
key_layer,
value_layer,
......@@ -4197,6 +4440,8 @@ class FlashAttention(torch.nn.Module):
causal="causal" in attn_mask_type,
**fa_3_optional_forward_kwargs,
)
if isinstance(output, (List, Tuple)):
output = output[0]
except TypeError as e:
if fa_utils.v3_0_0_beta:
e.args = (
......@@ -4212,22 +4457,30 @@ class FlashAttention(torch.nn.Module):
if fp8 and fp8_meta["recipe"].fp8_mha:
O_quantizer = quantizers["scaling_fwd"][META_O]
output = O_quantizer(output)
else:
output = func(
query_layer,
key_layer,
value_layer,
*fa_optional_forward_args_thd,
self.attention_dropout if self.training else 0.0,
softmax_scale=self.softmax_scale,
causal="causal" in attn_mask_type,
**fa_optional_forward_kwargs,
)
if inference_params is None:
if qkv_format in ["sbhd", "bshd"] and "padding" in attn_mask_type:
output = dpa_utils.UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output)
elif qkv_format in ["bshd", "sbhd_2bshd"]:
# all KV caching cases use thd_2bshd for calculation
# convert results back to bshd from thd_2bshd
if isinstance(query_layer, Float8Tensor):
output._data = tex.convert_thd_to_bshd(
output._data,
cu_seqlens_q,
batch_size,
context_len,
)
output = Float8Tensor.make_like(output, data=output._data, shape=output._data.shape)
else:
output = tex.convert_thd_to_bshd(
output,
cu_seqlens_q,
batch_size,
context_len,
)
if qkv_format == "sbhd":
if q_format == "sbhd":
# (bs)hd -> bs(hd) -> sb(hd)
if fp8 and fp8_meta["recipe"].fp8_mha:
output_data = (
......@@ -4242,10 +4495,10 @@ class FlashAttention(torch.nn.Module):
)
else:
output = output.view(batch_size, max_seqlen_q // cp_size, -1).transpose(0, 1)
elif qkv_format == "bshd":
elif q_format == "bshd":
# (bs)hd -> bs(hd)
output = output.reshape(batch_size, max_seqlen_q // cp_size, -1)
elif qkv_format == "thd":
elif q_format == "thd":
# thd -> t(hd)
output = output.reshape(output.shape[0], -1)
......@@ -4296,6 +4549,8 @@ class FusedAttnFunc(torch.autograd.Function):
cu_seqlens_kv,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
page_table_k,
page_table_v,
q,
k,
v,
......@@ -4340,7 +4595,7 @@ class FusedAttnFunc(torch.autograd.Function):
q_fp8, k_fp8, v_fp8 = q, k, v
else:
# 1: qkv packed, 2: kv packed, 3: qkv separate
qkv_group = len(qkv_layout.split("_"))
qkv_group = len(qkv_layout.replace("paged_kv_", "").split("_"))
match qkv_group:
case 1:
dim = qkv_layout.find("3")
......@@ -4376,6 +4631,8 @@ class FusedAttnFunc(torch.autograd.Function):
attn_bias,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
None,
None,
S_quantizer,
O_quantizer,
attn_scale,
......@@ -4398,7 +4655,7 @@ class FusedAttnFunc(torch.autograd.Function):
if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")):
# 1: qkv packed, 2: kv packed, 3: qkv separate
if is_input_fp8:
qkv_group = len(qkv_layout.split("_"))
qkv_group = len(qkv_layout.replace("paged_kv_", "").split("_"))
if qkv_group == 1:
dim = qkv_layout.find("3")
qkv = _combine_tensors([q, k, v], dim)
......@@ -4407,7 +4664,7 @@ class FusedAttnFunc(torch.autograd.Function):
q, k, v = _SplitAlongDim.apply(qkv_no_fp8, dim, [1, 1, 1], True)
if qkv_group == 2:
q = q.dequantize()
dim = qkv_layout.split("_")[1].find("2")
dim = qkv_layout.replace("paged_kv_", "").split("_")[1].find("2")
kv = _combine_tensors([k, v], dim)
kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1])
kv_no_fp8 = kv.dequantize()
......@@ -4436,6 +4693,8 @@ class FusedAttnFunc(torch.autograd.Function):
attn_bias,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
page_table_k,
page_table_v,
None, # s_quantizer
None, # o_quantizer
attn_scale,
......@@ -4612,7 +4871,7 @@ class FusedAttnFunc(torch.autograd.Function):
# is_input_fp8 = False: dq, dk, dv: torch.float16 or torch.bfloat16
# is_input_fp8 = True: dq, dk, dv: torch.float8_e5m2
if not ctx.is_input_fp8:
qkv_group = len(ctx.qkv_layout.split("_"))
qkv_group = len(ctx.qkv_layout.replace("paged_kv_", "").split("_"))
if qkv_group == 1:
dim = ctx.qkv_layout.find("3")
dqkv_fp8_data = _combine_tensors(
......@@ -4682,6 +4941,8 @@ class FusedAttnFunc(torch.autograd.Function):
None,
None,
None,
None,
None,
dq,
dk,
dv,
......@@ -4712,6 +4973,8 @@ class FusedAttnFunc(torch.autograd.Function):
None,
None,
None,
None,
None,
dq,
dk,
dv,
......@@ -4833,6 +5096,7 @@ class FusedAttention(torch.nn.Module):
fp8_meta: Optional[Dict[str, Any]] = None,
quantizers=None,
pad_between_seqs: bool = False,
inference_params: Optional[InferenceParams] = None,
) -> torch.Tensor:
"""fused attention fprop"""
assert (
......@@ -4857,26 +5121,26 @@ class FusedAttention(torch.nn.Module):
cp_size *= get_distributed_world_size(group)
context_parallel = cp_size > 1
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
# get q_format and kv_format for training and inference
qkv_format, q_format, kv_format = dpa_utils.get_qkv_format(qkv_layout, inference_params)
page_table = None
if inference_params is None:
if qkv_format in ["sbhd", "bshd"]:
if qkv_format == "sbhd":
batch_size, max_seqlen_q, max_seqlen_kv = (
query_layer.shape[1],
query_layer.shape[0],
key_layer.shape[0],
)
batch_size = query_layer.shape[1]
max_seqlen_q = query_layer.shape[0]
max_seqlen_kv = key_layer.shape[0]
if qkv_format == "bshd":
batch_size, max_seqlen_q, max_seqlen_kv = (
query_layer.shape[0],
query_layer.shape[1],
key_layer.shape[1],
)
batch_size = query_layer.shape[0]
max_seqlen_q = query_layer.shape[1]
max_seqlen_kv = key_layer.shape[1]
max_seqlen_q *= cp_size
max_seqlen_kv *= cp_size
if "padding" in attn_mask_type:
assert not context_parallel, "Padding mask not supported with context parallelism!"
assert (
not context_parallel
), "Padding mask not supported with context parallelism!"
if cu_seqlens_q is None or cu_seqlens_kv is None:
if attention_mask is None:
raise RuntimeError(
......@@ -4908,9 +5172,12 @@ class FusedAttention(torch.nn.Module):
and cu_seqlens_q is not None
and cu_seqlens_kv is not None
), "max_seqlen_q/kv and cu_seqlens_q/kv can not be None when qkv_format is thd!"
elif inference_params.is_paged:
page_table = inference_params.cache_manager.page_table
if qkv_format == "thd" and (cu_seqlens_q_padded is None or cu_seqlens_kv_padded is None):
if (q_format == "thd" or "padding" in attn_mask_type) and cu_seqlens_q_padded is None:
cu_seqlens_q_padded = cu_seqlens_q
if (kv_format == "thd" or "padding" in attn_mask_type) and cu_seqlens_kv_padded is None:
cu_seqlens_kv_padded = cu_seqlens_kv
use_FAv2_bwd = (
......@@ -4981,6 +5248,8 @@ class FusedAttention(torch.nn.Module):
cu_seqlens_kv,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
page_table,
page_table,
query_layer,
key_layer,
value_layer,
......@@ -5369,14 +5638,14 @@ class DotProductAttention(TransformerEngineBaseModule):
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
qkv_format: Optional[str] = None,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
cu_seqlens_q_padded: Optional[torch.Tensor] = None,
cu_seqlens_kv_padded: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None,
qkv_format: str = None,
cu_seqlens_q: torch.Tensor = None,
cu_seqlens_kv: torch.Tensor = None,
cu_seqlens_q_padded: torch.Tensor = None,
cu_seqlens_kv_padded: torch.Tensor = None,
max_seqlen_q: int = None,
max_seqlen_kv: int = None,
attn_mask_type: Optional[str] = None,
window_size: Optional[Tuple[int, int]] = None,
checkpoint_core_attention: bool = False,
......@@ -5565,6 +5834,16 @@ class DotProductAttention(TransformerEngineBaseModule):
num_gemms=3,
allow_non_contiguous=True,
) as query_layer:
# checks for RNG
if self.rng_states_tracker is not None and is_graph_capturing():
assert isinstance(
self.rng_states_tracker, CudaRNGStatesTracker
), "Unsupported RNG states tracker."
assert (
graph_safe_rng_available()
), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture."
# checks for FP8
if self.fp8:
if self.fp8_meta["recipe"].fp8_mha:
if not self.fp8_meta["recipe"].fp8_dpa:
......@@ -5573,7 +5852,6 @@ class DotProductAttention(TransformerEngineBaseModule):
"""Forcing fp8_meta["recipe"].fp8_dpa=True due to """
"""fp8_meta["recipe"].fp8_mha=True"""
)
if self.fp8 and self.fp8_meta["recipe"].fp8_dpa:
forward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=True)
backward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=False)
......@@ -5585,6 +5863,7 @@ class DotProductAttention(TransformerEngineBaseModule):
tex.DType.kFloat8E5M2,
], """DotProductAttention only supports "E4M3" and "E5M2" FP8 data types."""
# checks for q/k/v shapes
assert (
query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda
), "DotProductAttention only supports CUDA tensors."
......@@ -5594,18 +5873,26 @@ class DotProductAttention(TransformerEngineBaseModule):
assert (
key_layer.shape[:-1] == value_layer.shape[:-1]
), "Keys and values must have the same batch size, sequence length and number of heads!"
num_attention_heads = query_layer.shape[-2]
num_gqa_groups = key_layer.shape[-2]
assert (
query_layer.shape[-1] == key_layer.shape[-1]
), "Queries and keys must have the same head dimension!"
head_dim_qk, head_dim_v = query_layer.shape[-1], value_layer.shape[-1]
assert (
key_layer.shape[-1] == self.hidden_size_per_attention_head_k
), f"Keys have head_dim = {key_layer.shape[-1]}, "
head_dim_qk == self.hidden_size_per_attention_head_k
), f"Keys have head_dim = {head_dim_qk}, "
"but expected head_dim = {self.hidden_size_per_attention_head_k}!"
assert (
value_layer.shape[-1] == self.hidden_size_per_attention_head_v
), f"Values have head_dim = {value_layer.shape[-1]}, "
head_dim_v == self.hidden_size_per_attention_head_v
), f"Values have head_dim = {head_dim_v}, "
"but expected head_dim = {self.hidden_size_per_attention_head_v}!"
assert num_gqa_groups == self.num_gqa_groups_per_partition, (
"Keys and values must have num_gqa_group ="
f" {self.num_gqa_groups_per_partition} heads! Found {num_gqa_groups}."
)
if qkv_format is None:
qkv_format = self.qkv_format
# checks for attention mask
if attn_mask_type is None:
attn_mask_type = self.attn_mask_type
else:
......@@ -5615,82 +5902,40 @@ class DotProductAttention(TransformerEngineBaseModule):
assert (
attn_mask_type in AttnMaskTypes
), f"Attention mask type {attn_mask_type} is not supported!"
if qkv_format == "thd":
assert (
"padding" in attn_mask_type
), "Attention mask type must be padding or padding_causal for qkv_format=thd!"
# checks for sliding window
if window_size is None:
window_size = self.window_size
window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size)
if self.rng_states_tracker is not None and is_graph_capturing():
assert isinstance(
self.rng_states_tracker, CudaRNGStatesTracker
), "Unsupported RNG states tracker."
assert (
graph_safe_rng_available()
), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture."
if inference_params is not None:
assert self.layer_number is not None, "Layer number must be set!"
# convert causal to causal_bottom_right in inference when KV-caching is in use
# so users can run with the same attn_mask_type for training and inference
if attn_mask_type in ["causal", "padding_causal"]:
attn_mask_type = attn_mask_type + "_bottom_right"
if qkv_format == "bshd":
key_layer = key_layer.transpose(0, 1)
value_layer = value_layer.transpose(0, 1)
(
inference_key_memory,
inference_value_memory,
) = inference_params.key_value_memory_dict[self.layer_number]
batch_start = inference_params.batch_size_offset
batch_end = batch_start + key_layer.size(1)
assert batch_end <= inference_key_memory.size(1)
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + key_layer.size(0)
assert sequence_end <= inference_key_memory.size(0)
# Copy keys and values into KV-cache
inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = (
key_layer
)
inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = (
value_layer
)
key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...]
if qkv_format == "bshd":
key_layer = key_layer.transpose(0, 1)
value_layer = value_layer.transpose(0, 1)
key_layer = key_layer.contiguous()
value_layer = value_layer.contiguous()
assert (
key_layer.shape[-2] == self.num_gqa_groups_per_partition
and value_layer.shape[-2] == self.num_gqa_groups_per_partition
), (
"Keys and values must have num_gqa_group ="
f" {self.num_gqa_groups_per_partition} heads!"
)
# checks for qkv_format
if qkv_format is None:
qkv_format = self.qkv_format
assert qkv_format in [
"sbhd",
"bshd",
"thd",
], "DotProductAttention only supports qkv_format = {'sbhd', 'bshd', 'thd'}!"
batch_size = None
if qkv_format in ["sbhd", "bshd"]:
assert all(
len(x.shape) == 4 for x in (query_layer, key_layer, value_layer)
), f"Queries, keys and values must be 4D tensors when {qkv_format=}!"
if qkv_format == "sbhd":
batch_size = query_layer.shape[1]
max_seqlen_q = query_layer.shape[0] if max_seqlen_q is None else max_seqlen_q
max_seqlen_kv = key_layer.shape[0] if max_seqlen_kv is None else max_seqlen_kv
else:
batch_size = query_layer.shape[0]
max_seqlen_q = query_layer.shape[1] if max_seqlen_q is None else max_seqlen_q
max_seqlen_kv = key_layer.shape[1] if max_seqlen_kv is None else max_seqlen_kv
if qkv_format == "thd":
assert all(
len(x.shape) == 3 for x in (query_layer, key_layer, value_layer)
), "Queries, keys and values must be 3D tensors when qkv_format = thd!"
assert (
"padding" in attn_mask_type
), "Attention mask type must be padding or padding_causal for qkv_format=thd!"
assert (
cu_seqlens_q is not None and cu_seqlens_kv is not None
), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!"
......@@ -5716,6 +5961,76 @@ class DotProductAttention(TransformerEngineBaseModule):
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
max_seqlen_kv = int((seqlens_kv.max().item() + 63) // 64 * 64)
# update KV cache and retrieve saved tokens from cache for inference
if inference_params is not None:
assert self.layer_number is not None, "Layer number must be set!"
# convert top-left causal to bottom-right causal due to KV caching
# users can still use the same attention mask for inference as for training
assert "padding" in attn_mask_type, "KV caching requires padding mask!"
if attn_mask_type == "padding_causal":
attn_mask_type = attn_mask_type + "_bottom_right"
self.attention_type = "cross"
self.flash_attention.attention_type = self.attention_type
self.fused_attention.attention_type = self.attention_type
self.unfused_attention.attention_type = self.attention_type
query_layer, key_layer, value_layer = [
x.contiguous() if not x.is_contiguous() else x
for x in [query_layer, key_layer, value_layer]
]
# get full K/V tensors from cache and adjust cu_seqlens, qkv_format based on the cache
(
key_layer,
value_layer,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_kv,
qkv_format,
) = inference_params.step(
self.layer_number,
key_layer,
value_layer,
qkv_format,
)
cu_seqlens_q_padded = None
cu_seqlens_kv_padded = None
# get qkv's memory layout
if all(isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]):
(
qkv_layout,
query_layer._data,
key_layer._data,
value_layer._data,
q_format,
kv_format,
) = dpa_utils.get_qkv_layout(
query_layer._data,
key_layer._data,
value_layer._data,
qkv_format=qkv_format,
inference_params=inference_params,
)
else:
(
qkv_layout,
query_layer,
key_layer,
value_layer,
q_format,
kv_format,
) = dpa_utils.get_qkv_layout(
query_layer,
key_layer,
value_layer,
qkv_format=qkv_format,
inference_params=inference_params,
)
# adjust max_seqlen and cu_seqlens for CP
cp_size = 1
if isinstance(self.cp_group, dist_group_type):
cp_size = get_distributed_world_size(self.cp_group)
......@@ -5723,71 +6038,42 @@ class DotProductAttention(TransformerEngineBaseModule):
for group in self.cp_group:
cp_size *= get_distributed_world_size(group)
context_parallel = cp_size > 1
if qkv_format in ["sbhd", "bshd"]:
assert all(
len(x.shape) == 4 for x in (query_layer, key_layer, value_layer)
), f"Queries, keys and values must be 4D tensors when qkv_format = {qkv_format}!"
if qkv_format == "sbhd":
max_seqlen_q = query_layer.shape[0] if max_seqlen_q is None else max_seqlen_q
max_seqlen_kv = key_layer.shape[0] if max_seqlen_kv is None else max_seqlen_kv
batch_size = query_layer.shape[1]
else:
max_seqlen_q = query_layer.shape[1] if max_seqlen_q is None else max_seqlen_q
max_seqlen_kv = key_layer.shape[1] if max_seqlen_kv is None else max_seqlen_kv
batch_size = query_layer.shape[0]
if q_format in ["sbhd", "bshd"]:
max_seqlen_q *= cp_size
max_seqlen_kv *= cp_size
if cu_seqlens_q is not None:
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
assert all(
seqlens_q <= max_seqlen_q
), """Sequence lengths indicated by cu_seqlens_q must be no greater than
the sequence dimension in 'query_layer'!"""
if cu_seqlens_kv is not None:
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
assert all(
seqlens_kv <= max_seqlen_kv
), """Sequence lengths indicated by cu_seqlens_kv must be no greater than
the sequence dimension in 'key_layer' and 'value_layer'!"""
if cu_seqlens_q is None or cu_seqlens_kv is None:
if cu_seqlens_q is None:
if "padding" in attn_mask_type:
assert (
attention_mask is not None
), "Please provide attention_mask for padding!"
if self.attention_type == "self":
cu_seqlens_q = dpa_utils.get_cu_seqlens(attention_mask)
cu_seqlens_kv = cu_seqlens_q
else:
cu_seqlens_q = dpa_utils.get_cu_seqlens(attention_mask[0])
cu_seqlens_kv = dpa_utils.get_cu_seqlens(attention_mask[1])
else:
cu_seqlens_q = dpa_utils.get_full_cu_seqlens(
batch_size,
max_seqlen_q,
query_layer.device,
)
if kv_format in ["sbhd", "bshd"]:
max_seqlen_kv *= cp_size
if cu_seqlens_kv is None:
if "padding" in attn_mask_type:
assert (
attention_mask is not None
), "Please provide attention_mask for padding!"
if self.attention_type == "self":
cu_seqlens_kv = dpa_utils.get_cu_seqlens(attention_mask)
else:
cu_seqlens_kv = dpa_utils.get_cu_seqlens(attention_mask[1])
else:
cu_seqlens_kv = dpa_utils.get_full_cu_seqlens(
batch_size,
max_seqlen_kv,
key_layer.device,
)
if (
isinstance(query_layer, Float8Tensor)
and isinstance(key_layer, Float8Tensor)
and isinstance(value_layer, Float8Tensor)
):
qkv_layout, query_layer._data, key_layer._data, value_layer._data = (
dpa_utils.get_qkv_layout(
query_layer._data, key_layer._data, value_layer._data, qkv_format=qkv_format
)
)
else:
qkv_layout, query_layer, key_layer, value_layer = dpa_utils.get_qkv_layout(
query_layer, key_layer, value_layer, qkv_format=qkv_format
)
# set ALiBi attributes
global _alibi_cache
if alibi_slopes is not None:
assert (
......@@ -5811,6 +6097,7 @@ class DotProductAttention(TransformerEngineBaseModule):
_alibi_cache["_alibi_slopes_require_update"] = True
_alibi_cache["_alibi_bias_require_update"] = True
# detect bias shape
core_attention_bias_shape = None
if core_attention_bias is not None:
if (
......@@ -5846,17 +6133,18 @@ class DotProductAttention(TransformerEngineBaseModule):
else:
pad_between_seqs = False
# gather attention params for get_attention_backend
attention_params = dpa_utils.AttentionParams(
qkv_type=type(query_layer),
qkv_dtype=query_layer.dtype,
qkv_layout=qkv_layout,
batch_size=batch_size,
num_heads=query_layer.shape[-2],
num_gqa_groups=key_layer.shape[-2],
num_heads=num_attention_heads,
num_gqa_groups=num_gqa_groups,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
head_dim_qk=query_layer.shape[-1],
head_dim_v=value_layer.shape[-1],
head_dim_qk=head_dim_qk,
head_dim_v=head_dim_v,
attn_mask_type=attn_mask_type,
window_size=window_size,
alibi_slopes_shape=alibi_slopes.shape if alibi_slopes is not None else None,
......@@ -5872,6 +6160,7 @@ class DotProductAttention(TransformerEngineBaseModule):
is_training=self.training,
fp8=self.fp8,
fp8_meta=self.fp8_meta,
inference_params=inference_params,
)
global _attention_backends
if (
......@@ -5881,9 +6170,9 @@ class DotProductAttention(TransformerEngineBaseModule):
_attention_backends["attention_params"] = attention_params
_attention_backends["backend_selection_requires_update"] = True
if _attention_backends["backend_selection_requires_update"]:
fa_utils.use_v3 = fa_utils.v3_is_installed
(
use_flash_attention,
flash_attention_backend,
use_fused_attention,
fused_attention_backend,
use_unfused_attention,
......@@ -5892,6 +6181,7 @@ class DotProductAttention(TransformerEngineBaseModule):
# Set global _attention_backends var using return value
# from get_attention_backend()
_attention_backends["use_flash_attention"] = use_flash_attention
_attention_backends["flash_attention_backend"] = flash_attention_backend
_attention_backends["use_fused_attention"] = use_fused_attention
_attention_backends["fused_attention_backend"] = fused_attention_backend
_attention_backends["use_unfused_attention"] = use_unfused_attention
......@@ -5899,7 +6189,7 @@ class DotProductAttention(TransformerEngineBaseModule):
if use_flash_attention:
self.logger.info(
"Running with FlashAttention backend (version %s)",
fa_utils.version if not fa_utils.use_v3 else fa_utils.fa3_version,
flash_attention_backend,
)
elif use_fused_attention:
self.logger.info(
......@@ -5910,10 +6200,16 @@ class DotProductAttention(TransformerEngineBaseModule):
self.logger.info("Running with UnfusedDotProductAttention backend")
else:
use_flash_attention = _attention_backends["use_flash_attention"]
flash_attention_backend = _attention_backends["flash_attention_backend"]
use_fused_attention = _attention_backends["use_fused_attention"]
fused_attention_backend = _attention_backends["fused_attention_backend"]
use_unfused_attention = _attention_backends["use_unfused_attention"]
# raise exception if no backend is available
if sum([use_flash_attention, use_fused_attention, use_unfused_attention]) == 0:
raise ValueError("No dot product attention support for the provided inputs!")
# run attention
if use_flash_attention:
if core_attention_bias_type == "alibi":
alibi_slopes, _ = dpa_utils.get_alibi(
......@@ -5943,6 +6239,8 @@ class DotProductAttention(TransformerEngineBaseModule):
fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
fp8_meta=self.fp8_meta,
quantizers=self.quantizers,
inference_params=inference_params,
flash_attention_backend=flash_attention_backend,
)
if use_fused_attention:
......@@ -5961,6 +6259,7 @@ class DotProductAttention(TransformerEngineBaseModule):
bias_dtype=query_layer.dtype,
bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
)
# checkpoint_core_attention=False
if checkpoint_core_attention:
return self._checkpointed_attention_forward(
self.fused_attention,
......@@ -5987,7 +6286,9 @@ class DotProductAttention(TransformerEngineBaseModule):
cp_comm_type=self.cp_comm_type,
fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
fp8_meta=self.fp8_meta,
quantizers=self.quantizers,
pad_between_seqs=pad_between_seqs,
inference_params=inference_params,
)
return self.fused_attention(
query_layer,
......@@ -6015,6 +6316,7 @@ class DotProductAttention(TransformerEngineBaseModule):
fp8_meta=self.fp8_meta,
quantizers=self.quantizers,
pad_between_seqs=pad_between_seqs,
inference_params=inference_params,
)
from .cpu_offload import CPUOffloadEnabled
......@@ -6041,6 +6343,7 @@ class DotProductAttention(TransformerEngineBaseModule):
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes,
inference_params=inference_params,
)
return self.unfused_attention(
query_layer,
......@@ -6055,9 +6358,9 @@ class DotProductAttention(TransformerEngineBaseModule):
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes,
inference_params=inference_params,
)
raise ValueError("No dot product attention support for the provided inputs!")
return None
class MultiheadAttention(torch.nn.Module):
......@@ -6241,7 +6544,7 @@ class MultiheadAttention(torch.nn.Module):
self.qkv_format = qkv_format
self.attn_mask_type = attn_mask_type
self.window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size)
self.layer_number = layer_number
self.layer_number = 1 if layer_number is None else layer_number
self.input_layernorm = input_layernorm
self.attention_type = attention_type
self.get_rng_state_tracker = get_rng_state_tracker
......@@ -6410,19 +6713,6 @@ class MultiheadAttention(torch.nn.Module):
**common_gemm_kwargs,
)
def _allocate_memory(
self, inference_max_sequence_len: int, batch_size: int, dtype: torch.dtype
) -> torch.Tensor:
"""Allocates memory for KV cache."""
return torch.empty(
inference_max_sequence_len,
batch_size,
self.num_gqa_groups_per_partition,
self.hidden_size_per_attention_head,
dtype=dtype,
device=torch.cuda.current_device(),
)
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
"""
Set the tensor parallel group for the given
......@@ -6611,31 +6901,14 @@ class MultiheadAttention(torch.nn.Module):
), f"core_attention_bias_type {core_attention_bias_type} is not supported!"
# =================================================
# Pre-allocate memory for key-values for inference
# Pre-allocate memory for key-value cache for inference
# =================================================
if inference_params and self.layer_number is not None:
assert (
self.qkv_format != "thd"
), "qkv_format == thd is not supported for an inference with KV-cache!"
if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_length
inf_max_batch_size = inference_params.max_batch_size
inference_key_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
)
inference_value_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
)
inference_params.key_value_memory_dict[self.layer_number] = (
inference_key_memory,
inference_value_memory,
)
else:
(
inference_key_memory,
inference_value_memory,
) = inference_params.key_value_memory_dict[self.layer_number]
if (
inference_params is not None
and self.layer_number not in inference_params.cache_manager.cache
):
inference_params.allocate_memory(self.layer_number)
# ======================
# Query, Key, and Value
......@@ -6801,9 +7074,12 @@ class MultiheadAttention(torch.nn.Module):
elif self.qkv_format == "bshd":
sequence_length = key_layer.size(1)
else:
raise ValueError(f"QKV format {self.qkv_format} not supported for KV caching.")
raise ValueError(
f"qkv_format={self.qkv_format} not supported for KV caching and RoPE."
)
sequence_start = inference_params.sequence_len_offset
sequence_start = inference_params.get_seqlens_pre_step()
# sequence_start = inference_params.seqlens[0]
sequence_end = sequence_start + sequence_length
q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...]
......
......@@ -64,6 +64,16 @@ QKVLayouts = (
"thd_t2hd",
"thd_th2d",
"thd_thd_thd",
"sbhd_bshd_bshd",
"bshd_sbhd_sbhd",
"thd_bshd_bshd",
"thd_sbhd_sbhd",
"paged_kv_bshd_bshd_bshd",
"paged_kv_bshd_sbhd_sbhd",
"paged_kv_sbhd_bshd_bshd",
"paged_kv_sbhd_sbhd_sbhd",
"paged_kv_thd_bshd_bshd",
"paged_kv_thd_sbhd_sbhd",
)
LayerTypes = ("encoder", "decoder")
......
......@@ -9,6 +9,7 @@ import torch
import transformer_engine_torch as tex
from transformer_engine_torch import (
NVTE_QKV_Layout,
NVTE_QKV_Format,
NVTE_Bias_Type,
NVTE_Mask_Type,
NVTE_Fused_Attn_Backend,
......@@ -31,6 +32,16 @@ TORCH_DType = {
tex.DType.kInt32: torch.int32,
}
QKVFormat = {
"bshd": NVTE_QKV_Format.NVTE_BSHD,
"sbhd": NVTE_QKV_Format.NVTE_SBHD,
"thd": NVTE_QKV_Format.NVTE_THD,
"sbhd_2bshd": NVTE_QKV_Format.NVTE_SBHD_2BSHD,
"bshd_2sbhd": NVTE_QKV_Format.NVTE_BSHD_2SBHD,
"thd_2bshd": NVTE_QKV_Format.NVTE_THD_2BSHD,
"thd_2sbhd": NVTE_QKV_Format.NVTE_THD_2SBHD,
}
QKVLayout = {
"sb3hd": NVTE_QKV_Layout.NVTE_SB3HD,
"sbh3d": NVTE_QKV_Layout.NVTE_SBH3D,
......@@ -47,6 +58,16 @@ QKVLayout = {
"thd_t2hd": NVTE_QKV_Layout.NVTE_THD_T2HD,
"thd_th2d": NVTE_QKV_Layout.NVTE_THD_TH2D,
"thd_thd_thd": NVTE_QKV_Layout.NVTE_THD_THD_THD,
"sbhd_bshd_bshd": NVTE_QKV_Layout.NVTE_SBHD_BSHD_BSHD,
"bshd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_BSHD_SBHD_SBHD,
"thd_bshd_bshd": NVTE_QKV_Layout.NVTE_THD_BSHD_BSHD,
"thd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_THD_SBHD_SBHD,
"paged_kv_bshd_bshd_bshd": NVTE_QKV_Layout.NVTE_Paged_KV_BSHD_BSHD_BSHD,
"paged_kv_bshd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_BSHD_SBHD_SBHD,
"paged_kv_sbhd_bshd_bshd": NVTE_QKV_Layout.NVTE_Paged_KV_SBHD_BSHD_BSHD,
"paged_kv_sbhd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_SBHD_SBHD_SBHD,
"paged_kv_thd_bshd_bshd": NVTE_QKV_Layout.NVTE_Paged_KV_THD_BSHD_BSHD,
"paged_kv_thd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_THD_SBHD_SBHD,
}
AttnBiasType = {
......@@ -100,6 +121,8 @@ def fused_attn_fwd(
attn_bias: torch.Tensor = None,
cu_seqlens_q_padded: torch.Tensor = None,
cu_seqlens_kv_padded: torch.Tensor = None,
page_table_k: torch.Tensor = None,
page_table_v: torch.Tensor = None,
s_quantizer: Quantizer = None,
o_quantizer: Quantizer = None,
attn_scale: float = None,
......@@ -148,6 +171,10 @@ def fused_attn_fwd(
cumulative sequence offsets for Q; shape [batch_size + 1]
cu_seqlens_kv_padded: torch.Tensor, default = None
cumulative sequence offsets for KV; shape [batch_size + 1]
page_table_k: torch.Tensor, default = None
page table for K cache; shape [batch_size, max_pages_per_seq_k]
page_table_v: torch.Tensor, default = None
page table for V cache; shape [batch_size, max_pages_per_seq_v]
s_quantizer: Quantizer, default = None
Quantizer object for the intermediate value S.
o_quantizer: Quantizer, default = None
......@@ -268,6 +295,8 @@ def fused_attn_fwd(
fake_dtype,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
page_table_k,
page_table_v,
s_quantizer,
o_quantizer,
attn_bias,
......
......@@ -14,6 +14,8 @@
#include <ATen/cudnn/Handle.h>
#include <ATen/native/DispatchStub.h>
#include <c10/macros/Macros.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e5m2.h>
#include <cublasLt.h>
#include <cuda.h>
#include <cuda_bf16.h>
......
......@@ -51,8 +51,9 @@ std::vector<py::object> fused_attn_fwd(
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q,
const py::handle K, const py::handle V, const at::ScalarType fake_dtype,
const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
py::handle o_quantizer, const c10::optional<at::Tensor> Bias,
const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> page_table_k, const c10::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread);
std::vector<py::object> fused_attn_bwd(
......@@ -69,6 +70,13 @@ std::vector<py::object> fused_attn_bwd(
at::Tensor fa_prepare_fwd(at::Tensor qkvi);
at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v);
at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len);
at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t);
void copy_to_kv_cache(torch::Tensor new_k, torch::Tensor new_v, torch::Tensor k_cache,
torch::Tensor v_cache, torch::Tensor page_table, torch::Tensor cu_new_lens,
torch::Tensor cu_cached_lens, NVTE_QKV_Format kv_format, int b,
int max_ctx_len, int max_seq_len, int max_pages_per_seq, bool is_non_paged);
/***************************************************************************************************
* GEMM
**************************************************************************************************/
......
......@@ -4,6 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "kv_cache.cuh"
#include "thd_utils.cuh"
constexpr int block_size = 512;
......@@ -90,8 +91,9 @@ std::vector<py::object> fused_attn_fwd(
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q,
const py::handle K, const py::handle V, const at::ScalarType fake_dtype,
const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
py::handle o_quantizer, const c10::optional<at::Tensor> Bias,
const c10::optional<at::Tensor> cu_seqlens_kv_padded,
const c10::optional<at::Tensor> page_table_k, const c10::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
......@@ -126,6 +128,7 @@ std::vector<py::object> fused_attn_fwd(
TensorWrapper te_Bias;
TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv;
TensorWrapper te_cu_seqlens_q_padded, te_cu_seqlens_kv_padded;
TensorWrapper te_page_table_k, te_page_table_v;
if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
// FP8
auto h = q_shape[q_shape.size() - 2];
......@@ -170,6 +173,19 @@ std::vector<py::object> fused_attn_fwd(
cu_seqlens_kv_padded.value().data_ptr(), cu_seqlens_kv_padded_shape, DType::kInt32);
}
if ((page_table_k.has_value()) && (page_table_v.has_value())) {
auto page_table_k_sizes = page_table_k.value().sizes().vec();
std::vector<size_t> page_table_k_shape{page_table_k_sizes.begin(), page_table_k_sizes.end()};
auto page_table_v_sizes = page_table_v.value().sizes().vec();
std::vector<size_t> page_table_v_shape{page_table_v_sizes.begin(), page_table_v_sizes.end()};
te_page_table_k =
makeTransformerEngineTensor(page_table_k.value().data_ptr(), page_table_k_shape,
DType::kInt32, nullptr, nullptr, nullptr);
te_page_table_v =
makeTransformerEngineTensor(page_table_v.value().data_ptr(), page_table_v_shape,
DType::kInt32, nullptr, nullptr, nullptr);
}
// extract rng seed and offset
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
rng_gen, at::cuda::detail::getDefaultCUDAGenerator());
......@@ -187,13 +203,13 @@ std::vector<py::object> fused_attn_fwd(
TensorWrapper workspace;
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(),
te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q,
max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout, bias_type,
attn_mask_type, window_size[0], window_size[1], workspace.data(),
at::cuda::getCurrentCUDAStream());
nvte_fused_attn_fwd(
te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), te_O.data(),
&nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(),
te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1],
workspace.data(), at::cuda::getCurrentCUDAStream());
// allocate memory for workspace and auxiliary output tensors
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
......@@ -241,13 +257,13 @@ std::vector<py::object> fused_attn_fwd(
}
// execute the kernel
nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(),
te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q,
max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout, bias_type,
attn_mask_type, window_size[0], window_size[1], workspace.data(),
at::cuda::getCurrentCUDAStream());
nvte_fused_attn_fwd(
te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), te_O.data(),
&nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(),
te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1],
workspace.data(), at::cuda::getCurrentCUDAStream());
// destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
......@@ -1012,3 +1028,174 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t
return output;
}
/***************************************************************************************************
* KV Cache: Convert a tensor from qkv_format = thd to qkv_format = bshd
**************************************************************************************************/
template <typename scalar_t>
void convert_thd_to_bshd_launcher(at::Tensor tensor, at::Tensor new_tensor, at::Tensor cu_seqlens,
int b, int max_seq_len, int h, int d) {
transformer_engine::fused_attn::
convert_thd_to_bshd_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<scalar_t *>(tensor.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(new_tensor.data_ptr<scalar_t>()), cu_seqlens.data_ptr<int>(),
b, max_seq_len, h, d);
}
at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len) {
int h = tensor.size(1);
int d = tensor.size(2);
std::vector<int64_t> shape = {b, max_seq_len, h, d};
at::Tensor new_tensor = at::zeros(shape, at::CUDA(tensor.scalar_type()));
if (new_tensor.scalar_type() == at::ScalarType::Half) {
using dtype = at::Half;
convert_thd_to_bshd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else if (new_tensor.scalar_type() == at::ScalarType::BFloat16) {
using dtype = at::BFloat16;
convert_thd_to_bshd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else if (new_tensor.scalar_type() == at::ScalarType::Float) {
using dtype = float;
convert_thd_to_bshd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else if (new_tensor.scalar_type() == at::ScalarType::Float8_e4m3fn) {
using dtype = at::Float8_e4m3fn;
convert_thd_to_bshd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else if (new_tensor.scalar_type() == at::ScalarType::Float8_e5m2) {
using dtype = at::Float8_e5m2;
convert_thd_to_bshd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else {
NVTE_ERROR("Unsupported dtype for KV cache.\n");
}
return new_tensor;
}
/***************************************************************************************************
* KV Cache: Convert a tensor from qkv_format = bshd to qkv_format = thd
**************************************************************************************************/
template <typename scalar_t>
void convert_bshd_to_thd_launcher(at::Tensor tensor, at::Tensor new_tensor, at::Tensor cu_seqlens,
int b, int max_seq_len, int h, int d) {
transformer_engine::fused_attn::
convert_bshd_to_thd_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<scalar_t *>(tensor.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(new_tensor.data_ptr<scalar_t>()), cu_seqlens.data_ptr<int>(),
b, max_seq_len, h, d);
}
at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t) {
int b = tensor.size(0);
int max_seq_len = tensor.size(1);
int h = tensor.size(2);
int d = tensor.size(3);
std::vector<int64_t> shape = {t, h, d};
at::Tensor new_tensor = at::zeros(shape, at::CUDA(tensor.scalar_type()));
if (tensor.scalar_type() == at::ScalarType::Half) {
using dtype = at::Half;
convert_bshd_to_thd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else if (tensor.scalar_type() == at::ScalarType::BFloat16) {
using dtype = at::BFloat16;
convert_bshd_to_thd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else if (tensor.scalar_type() == at::ScalarType::Float) {
using dtype = float;
convert_bshd_to_thd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else if (tensor.scalar_type() == at::ScalarType::Float8_e4m3fn) {
using dtype = at::Float8_e4m3fn;
convert_bshd_to_thd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else if (tensor.scalar_type() == at::ScalarType::Float8_e5m2) {
using dtype = at::Float8_e5m2;
convert_bshd_to_thd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else {
NVTE_ERROR("Unsupported dtype for KV cache.\n");
}
return new_tensor;
}
/***************************************************************************************************
* KV Cache: Copy new KV tokens to the KV cache
* 1. new_k and new_v are in qkv_format; k_cache and v_cache are in 'bshd' format
* 2. cu_new_lens and cu_cached_lens are in shape [b + 1]; cu_cached_lens include the added lens
* in current step
* 3. Non-paged KV cache is a special case of paged KV cache, with page_table = [b, 1] and
* max_pages_per_seq = 1. We use the same underlying kernel for both non-paged and paged.
* Set is_non_paged = True/False to indicate as such.
* 4. is_non_paged = True also re-indexes the KV cache, e.g. the initial batch indices [0, 3, 1, 2]
* becomes [0, 1, 1, 2]. The page_table = batch_indices.unsqueeze(1) is however unchanged.
* batch_indices_post can be used for monotonical indexing, i.e. [0, 1, 2, 3]. batch_indices is
* preserved for the next layer in the same iteration.
* 5. Only supports same page_table for k_cache and v_cache
* 6. Only pad_between_seqs = False when qkv_format = thd, i.e. there should be no pad tokens
* between sequences in new_k and new_v such as [a a a 0..0 b b 0..0 c 0..0].
**************************************************************************************************/
template <typename scalar_t>
void copy_to_kv_cache_launcher(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache,
at::Tensor v_cache, at::Tensor page_table, at::Tensor cu_new_lens,
at::Tensor cu_cached_lens, NVTE_QKV_Format qkv_format, int h_kv,
int d_k, int d_v, int b, int max_ctx_len, int max_seq_len,
int max_pages_per_seq, bool is_non_paged) {
if (new_k.data_ptr() != nullptr && new_v.data_ptr() != nullptr && k_cache.data_ptr() != nullptr &&
v_cache.data_ptr() != nullptr) {
if (is_non_paged) {
transformer_engine::fused_attn::
reindex_kv_cache_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<scalar_t *>(k_cache.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(v_cache.data_ptr<scalar_t>()),
page_table.data_ptr<int>(), cu_new_lens.data_ptr<int>(),
cu_cached_lens.data_ptr<int>(), h_kv, d_k, d_v, b, max_seq_len);
}
transformer_engine::fused_attn::
copy_to_kv_cache_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<scalar_t *>(new_k.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(new_v.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(k_cache.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(v_cache.data_ptr<scalar_t>()), page_table.data_ptr<int>(),
cu_new_lens.data_ptr<int>(), cu_cached_lens.data_ptr<int>(), qkv_format, h_kv, d_k, d_v,
b, max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged);
}
}
void copy_to_kv_cache(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache, at::Tensor v_cache,
at::Tensor page_table, at::Tensor cu_new_lens, at::Tensor cu_cached_lens,
NVTE_QKV_Format qkv_format, int b, int max_ctx_len, int max_seq_len,
int max_pages_per_seq, bool is_non_paged) {
int h_kv = new_k.size(-2);
int d_k = new_k.size(-1);
int d_v = new_v.size(-1);
NVTE_CHECK(k_cache.scalar_type() == v_cache.scalar_type() &&
new_k.scalar_type() == new_v.scalar_type() &&
new_k.scalar_type() == k_cache.scalar_type(),
"new_k, new_v, k_cache and v_cache must be of the same data type.");
NVTE_CHECK(qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD ||
qkv_format == NVTE_QKV_Format::NVTE_THD,
"qkv_format must be {BSHD, SBHD, THD}.");
if (k_cache.scalar_type() == at::ScalarType::Half) {
using dtype = at::Half;
copy_to_kv_cache_launcher<dtype>(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens,
cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len,
max_seq_len, max_pages_per_seq, is_non_paged);
} else if (k_cache.scalar_type() == at::ScalarType::BFloat16) {
using dtype = at::BFloat16;
copy_to_kv_cache_launcher<dtype>(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens,
cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len,
max_seq_len, max_pages_per_seq, is_non_paged);
} else if (k_cache.scalar_type() == at::ScalarType::Float) {
using dtype = float;
copy_to_kv_cache_launcher<dtype>(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens,
cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len,
max_seq_len, max_pages_per_seq, is_non_paged);
} else if (k_cache.scalar_type() == at::ScalarType::Float8_e4m3fn) {
using dtype = at::Float8_e4m3fn;
copy_to_kv_cache_launcher<dtype>(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens,
cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len,
max_seq_len, max_pages_per_seq, is_non_paged);
} else if (k_cache.scalar_type() == at::ScalarType::Float8_e5m2) {
using dtype = at::Float8_e5m2;
copy_to_kv_cache_launcher<dtype>(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens,
cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len,
max_seq_len, max_pages_per_seq, is_non_paged);
} else {
NVTE_ERROR("Unsupported dtype for KV cache.\n");
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment