Unverified Commit 2ce3d0ce authored by JartX's avatar JartX Committed by GitHub
Browse files

[Feature] KV cache per-token-head INT8/FP8 quantization (#38378)


Signed-off-by: default avatarJartX <sagformas@epdcenter.es>
Signed-off-by: default avatarWentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: default avataryangyang4991 <yangyang4991@gmail.com>
Co-authored-by: default avatarWentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: default avatarIsotr0py <2037008807@qq.com>
parent 4eefbf96
...@@ -177,7 +177,7 @@ Priority is **1 = highest** (tried first). ...@@ -177,7 +177,7 @@ Priority is **1 = highest** (tried first).
| `ROCM_AITER_UNIFIED_ATTN` | | fp16, bf16 | `auto` | %16 | Any | ✅ | ✅ | ❌ | All | N/A | | `ROCM_AITER_UNIFIED_ATTN` | | fp16, bf16 | `auto` | %16 | Any | ✅ | ✅ | ❌ | All | N/A |
| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | 32, 64, 80, 96, 128, 160, 192, 224, 256 | ❌ | ✅ | ❌ | Decoder, Encoder, Encoder Only | N/A | | `ROCM_ATTN` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | 32, 64, 80, 96, 128, 160, 192, 224, 256 | ❌ | ✅ | ❌ | Decoder, Encoder, Encoder Only | N/A |
| `TREE_ATTN` | | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | Any | | `TREE_ATTN` | | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | Any |
| `TRITON_ATTN` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ✅ | ❌ | All | Any | | `TRITON_ATTN` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2`, `int8_per_token_head`, `fp8_per_token_head` | %16 | Any | ✅ | ✅ | ❌ | All | Any |
> **†** FlashInfer uses TRTLLM attention on Blackwell (SM100), which supports sinks. Disable via `--attention-config.use_trtllm_attention=0`. > **†** FlashInfer uses TRTLLM attention on Blackwell (SM100), which supports sinks. Disable via `--attention-config.use_trtllm_attention=0`.
> >
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""End-to-end accuracy tests for per-token-head KV cache quantization.
Compares logprobs between a baseline bf16 model and the same model with
per-token-head quantized KV cache (int8 or fp8) using the Triton attention
backend.
Run: pytest tests/models/quantization/test_per_token_kv_cache.py -v -s
"""
import pytest
from vllm.platforms import current_platform
from ..utils import check_logprobs_close
@pytest.mark.skipif(
not current_platform.is_cuda_alike(),
reason="Per-token-head KV cache requires CUDA or ROCm GPU.",
)
@pytest.mark.parametrize(
"base_model,test_model",
[
(
"meta-llama/Llama-3.2-1B-Instruct",
"meta-llama/Llama-3.2-1B-Instruct",
),
],
)
@pytest.mark.parametrize(
"kv_cache_dtype", ["int8_per_token_head", "fp8_per_token_head"]
)
@pytest.mark.parametrize("max_tokens", [4])
@pytest.mark.parametrize("enforce_eager", [True])
@pytest.mark.parametrize("backend", ["TRITON_ATTN"])
@pytest.mark.parametrize("tensor_parallel_size", [1])
def test_per_token_head_kv_cache_accuracy(
vllm_runner,
example_prompts,
base_model: str,
test_model: str,
kv_cache_dtype: str,
max_tokens: int,
enforce_eager: bool,
backend: str,
tensor_parallel_size: int,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Compare logprobs between bf16 baseline and per-token-head quantized KV
cache.
Uses calculate_kv_scales (dynamic scale computation) since there are
no per-token-head calibrated checkpoints available yet.
"""
with monkeypatch.context() as m:
m.setenv("TOKENIZERS_PARALLELISM", "true")
MAX_MODEL_LEN = 1024
NUM_LOG_PROBS = 8
with vllm_runner(
base_model,
max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=tensor_parallel_size,
enforce_eager=enforce_eager,
kv_cache_dtype="auto",
attention_config={"backend": backend},
) as vllm_model:
baseline_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, NUM_LOG_PROBS
)
with vllm_runner(
test_model,
max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=tensor_parallel_size,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
calculate_kv_scales=True,
attention_config={"backend": backend},
) as vllm_model:
test_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, NUM_LOG_PROBS
)
check_logprobs_close(
outputs_0_lst=baseline_outputs,
outputs_1_lst=test_outputs,
name_0="bf16_kv_cache",
name_1=f"{kv_cache_dtype}_kv_cache",
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for per-token-head KV cache quantization (INT8 and FP8).
Covers:
- Per-token-head Triton reshape-and-cache kernel
- Round-trip quantize/dequantize accuracy
- process_weights_after_loading early-return path
- End-to-end integration with Triton unified attention kernel
Run: pytest tests/quantization/test_per_token_kv_cache.py -v -s
"""
import random
from dataclasses import dataclass
from unittest.mock import MagicMock
import pytest
import torch
from vllm.model_executor.layers.quantization.utils.quant_utils import (
get_fp8_min_max,
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.kv_cache_interface import KVQuantMode, is_quantized_kv_cache
# Skip entire module if no CUDA/ROCm GPU available
pytestmark = [
pytest.mark.skipif(
not current_platform.is_cuda_alike(),
reason="Per-token-head KV cache tests require CUDA or ROCm GPU.",
),
]
# ---------------------------------------------------------------------------
# Test parameters
# ---------------------------------------------------------------------------
NUM_TOKENS = [1, 7, 42]
NUM_KV_HEADS = [1, 4, 8]
HEAD_SIZES = [64, 128]
BLOCK_SIZES = [16]
SEEDS = [0]
# Platform-dependent FP8 dtype and range
FP8_DTYPE = current_platform.fp8_dtype()
FP8_MIN, FP8_MAX = get_fp8_min_max()
# ---------------------------------------------------------------------------
# Per-dtype quantization config
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class QuantConfig:
"""Quantization parameters for a given cache dtype."""
cache_dtype: torch.dtype # torch.int8 or FP8_DTYPE
kv_cache_dtype_str: str # "int8_per_token_head" or "fp8_per_token_head"
quant_max: float
quant_min: float
kv_quant_mode: KVQuantMode
# INT8 Triton stores truncate; FP8 hardware casts round.
uses_trunc: bool
INT8_CONFIG = QuantConfig(
cache_dtype=torch.int8,
kv_cache_dtype_str="int8_per_token_head",
quant_max=127.0,
quant_min=-128.0,
kv_quant_mode=KVQuantMode.INT8_PER_TOKEN_HEAD,
uses_trunc=True,
)
FP8_CONFIG = QuantConfig(
cache_dtype=FP8_DTYPE,
kv_cache_dtype_str="fp8_per_token_head",
quant_max=FP8_MAX,
quant_min=FP8_MIN,
kv_quant_mode=KVQuantMode.FP8_PER_TOKEN_HEAD,
uses_trunc=False,
)
QUANT_CONFIGS = [INT8_CONFIG, FP8_CONFIG]
@pytest.fixture(params=QUANT_CONFIGS, ids=["int8", "fp8"])
def qcfg(request) -> QuantConfig:
return request.param
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _quantize_per_token_head_ref(
data: torch.Tensor, # [num_tokens, num_heads, head_size]
cfg: QuantConfig,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Reference per-token-head quantization (one scale per token per head).
Returns (quantized, scales) where scales is [num_tokens, num_heads].
"""
absmax = data.float().abs().amax(dim=2) # [num_tokens, num_heads]
scales = (absmax / cfg.quant_max).clamp(min=1e-6)
scaled = data.float() * (1.0 / scales[:, :, None])
if cfg.uses_trunc:
q = scaled.round().clamp(cfg.quant_min, cfg.quant_max).to(cfg.cache_dtype)
else:
q = scaled.clamp(cfg.quant_min, cfg.quant_max).to(cfg.cache_dtype)
return q, scales
# ===========================================================================
# 1. is_quantized_kv_cache / get_kv_quant_mode
# ===========================================================================
class TestIsQuantizedKvCache:
def test_fp8_variants(self):
assert is_quantized_kv_cache("fp8")
assert is_quantized_kv_cache("fp8_e4m3")
assert is_quantized_kv_cache("fp8_e5m2")
def test_int8_per_token_head(self):
assert is_quantized_kv_cache("int8_per_token_head")
def test_fp8_per_token_head(self):
assert is_quantized_kv_cache("fp8_per_token_head")
def test_auto(self):
assert not is_quantized_kv_cache("auto")
def test_bfloat16(self):
assert not is_quantized_kv_cache("bfloat16")
def test_kv_quant_mode_int8(self):
from vllm.v1.kv_cache_interface import get_kv_quant_mode
assert (
get_kv_quant_mode("int8_per_token_head") == KVQuantMode.INT8_PER_TOKEN_HEAD
)
def test_kv_quant_mode_fp8(self):
from vllm.v1.kv_cache_interface import get_kv_quant_mode
assert get_kv_quant_mode("fp8_per_token_head") == KVQuantMode.FP8_PER_TOKEN_HEAD
# ===========================================================================
# 2. Triton per-token-head kernel (reshape-and-cache)
# ===========================================================================
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_KV_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_reshape_and_cache_per_token_head(
qcfg: QuantConfig,
num_tokens: int,
num_heads: int,
head_size: int,
block_size: int,
seed: int,
):
"""Test triton_reshape_and_cache_flash_per_token_head_quant kernel."""
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash_per_token_head_quant,
)
set_random_seed(seed)
torch.set_default_device("cuda")
num_blocks = (num_tokens + block_size - 1) // block_size + 4
key = torch.randn(num_tokens, num_heads, head_size, dtype=torch.bfloat16)
value = torch.randn(num_tokens, num_heads, head_size, dtype=torch.bfloat16)
key_cache = torch.zeros(
num_blocks, block_size, num_heads, head_size, dtype=qcfg.cache_dtype
)
value_cache = torch.zeros(
num_blocks, block_size, num_heads, head_size, dtype=qcfg.cache_dtype
)
k_scale_cache = torch.ones(num_blocks, block_size, num_heads, dtype=torch.float32)
v_scale_cache = torch.ones(num_blocks, block_size, num_heads, dtype=torch.float32)
num_slots = block_size * num_blocks
slot_mapping = torch.tensor(
random.sample(range(num_slots), num_tokens), dtype=torch.long
)
triton_reshape_and_cache_flash_per_token_head_quant(
key,
value,
key_cache,
value_cache,
k_scale_cache,
v_scale_cache,
slot_mapping,
)
# Reference
ref_k_quant, ref_k_scales = _quantize_per_token_head_ref(key, qcfg)
ref_v_quant, ref_v_scales = _quantize_per_token_head_ref(value, qcfg)
# Compare dequantized values rather than raw quantized values.
# Triton and PyTorch reductions can differ at FP8 rounding boundaries
# (up to 32 in quantized domain for fp8_e4m3), but the dequantized
# error is bounded by the scale.
for i, slot in enumerate(slot_mapping.tolist()):
blk = slot // block_size
off = slot % block_size
actual_k_scale = k_scale_cache[blk, off] # [num_heads]
k_deq = key_cache[blk, off].float() * actual_k_scale[:, None]
k_ref_deq = key[i].float()
torch.testing.assert_close(
k_deq,
k_ref_deq,
atol=0.1,
rtol=0.1,
)
actual_v_scale = v_scale_cache[blk, off] # [num_heads]
v_deq = value_cache[blk, off].float() * actual_v_scale[:, None]
v_ref_deq = value[i].float()
torch.testing.assert_close(
v_deq,
v_ref_deq,
atol=0.1,
rtol=0.1,
)
# Per-head scales: [num_heads]
torch.testing.assert_close(
k_scale_cache[blk, off], ref_k_scales[i], atol=1e-4, rtol=1e-3
)
torch.testing.assert_close(
v_scale_cache[blk, off], ref_v_scales[i], atol=1e-4, rtol=1e-3
)
# ===========================================================================
# 3. Per-token-head round-trip accuracy (quantize -> dequantize)
# ===========================================================================
@pytest.mark.parametrize("num_tokens", [1, 16])
@pytest.mark.parametrize("num_heads", [4])
@pytest.mark.parametrize("head_size", [128])
@pytest.mark.parametrize("block_size", [16])
@torch.inference_mode()
def test_per_token_head_round_trip_accuracy(
qcfg: QuantConfig,
num_tokens: int,
num_heads: int,
head_size: int,
block_size: int,
):
"""Verify per-token-head round-trip: kernel dequant matches reference.
INT8: Triton truncates on float->int8 store.
FP8: hardware cast (clamp then cast).
"""
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash_per_token_head_quant,
)
torch.set_default_device("cuda")
set_random_seed(42)
num_blocks = (num_tokens + block_size - 1) // block_size + 2
key = torch.randn(num_tokens, num_heads, head_size, dtype=torch.bfloat16) * 0.5
value = torch.randn(num_tokens, num_heads, head_size, dtype=torch.bfloat16) * 0.5
key_cache = torch.zeros(
num_blocks, block_size, num_heads, head_size, dtype=qcfg.cache_dtype
)
value_cache = torch.zeros(
num_blocks, block_size, num_heads, head_size, dtype=qcfg.cache_dtype
)
k_scale_cache = torch.ones(num_blocks, block_size, num_heads, dtype=torch.float32)
v_scale_cache = torch.ones(num_blocks, block_size, num_heads, dtype=torch.float32)
slot_mapping = torch.arange(num_tokens, dtype=torch.long)
triton_reshape_and_cache_flash_per_token_head_quant(
key,
value,
key_cache,
value_cache,
k_scale_cache,
v_scale_cache,
slot_mapping,
)
for i in range(num_tokens):
blk = i // block_size
off = i % block_size
for label, data, cache, sc in [
("key", key, key_cache, k_scale_cache),
("val", value, value_cache, v_scale_cache),
]:
for h in range(num_heads):
orig = data[i, h].float() # [head_size]
actual_q = cache[blk, off, h]
actual_sc = sc[blk, off, h]
actual_deq = actual_q.float() * actual_sc
# Round-trip: dequantized should be close to original
torch.testing.assert_close(
actual_deq,
orig,
atol=0.1,
rtol=0.1,
)
# ===========================================================================
# 4. Negative slot mapping (padding tokens should be skipped)
# ===========================================================================
@torch.inference_mode()
def test_per_token_head_negative_slot_skipped(qcfg: QuantConfig):
"""Tokens with slot_mapping=-1 should leave the cache unchanged."""
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash_per_token_head_quant,
)
torch.set_default_device("cuda")
num_tokens = 4
num_heads = 2
head_size = 64
block_size = 16
num_blocks = 2
key = torch.randn(num_tokens, num_heads, head_size, dtype=torch.bfloat16)
value = torch.randn(num_tokens, num_heads, head_size, dtype=torch.bfloat16)
key_cache = torch.zeros(
num_blocks, block_size, num_heads, head_size, dtype=qcfg.cache_dtype
)
value_cache = torch.zeros(
num_blocks, block_size, num_heads, head_size, dtype=qcfg.cache_dtype
)
k_scale_cache = torch.ones(num_blocks, block_size, num_heads, dtype=torch.float32)
v_scale_cache = torch.ones(num_blocks, block_size, num_heads, dtype=torch.float32)
slot_mapping = torch.tensor([0, -1, 1, -1], dtype=torch.long)
key_cache_before = key_cache.clone()
val_cache_before = value_cache.clone()
triton_reshape_and_cache_flash_per_token_head_quant(
key,
value,
key_cache,
value_cache,
k_scale_cache,
v_scale_cache,
slot_mapping,
)
# Slots 0 and 1 should have been written (tokens 0 and 2)
assert not torch.equal(key_cache[0, 0], key_cache_before[0, 0])
assert not torch.equal(key_cache[0, 1], key_cache_before[0, 1])
assert not torch.equal(value_cache[0, 0], val_cache_before[0, 0])
# All other slots should be unchanged
assert torch.equal(key_cache[0, 2:], key_cache_before[0, 2:])
assert torch.equal(key_cache[1], key_cache_before[1])
assert torch.equal(value_cache[0, 2:], val_cache_before[0, 2:])
# ===========================================================================
# 5. process_weights_after_loading -- per-token-head early return
# ===========================================================================
@pytest.mark.parametrize(
"kv_cache_dtype", ["int8_per_token_head", "fp8_per_token_head"]
)
def test_process_weights_sets_placeholder_scales(kv_cache_dtype: str):
"""Per-token-head should set _k_scale=1.0, _v_scale=1.0
and delete checkpoint attrs."""
from vllm.model_executor.layers.quantization.kv_cache import (
BaseKVCacheMethod,
)
layer = MagicMock()
layer.kv_cache_dtype = kv_cache_dtype
layer.calculate_kv_scales = False
layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
layer.q_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
layer.prob_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
layer._k_scale = torch.tensor(0.0)
layer._v_scale = torch.tensor(0.0)
layer._k_scale_float = 0.0
layer._v_scale_float = 0.0
method = BaseKVCacheMethod.__new__(BaseKVCacheMethod)
method.quant_config = MagicMock()
method.process_weights_after_loading(layer)
assert layer._k_scale_float == 1.0
assert layer._v_scale_float == 1.0
assert not hasattr(layer, "k_scale")
assert not hasattr(layer, "v_scale")
assert not hasattr(layer, "q_scale")
assert not hasattr(layer, "prob_scale")
# ===========================================================================
# 6. Triton unified_attention -- per-token-head scale cache (INT8 and FP8)
# ===========================================================================
@pytest.mark.parametrize(
"seq_lens",
[
[(1, 128)],
[(1, 64), (1, 32)],
],
)
@pytest.mark.parametrize("num_heads", [(4, 4)])
@pytest.mark.parametrize("head_size", [128])
@pytest.mark.parametrize("block_size", [16])
@torch.inference_mode()
def test_triton_unified_attention_per_token_head_scale(
qcfg: QuantConfig,
seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],
head_size: int,
block_size: int,
):
"""End-to-end: quantized KV with per-token-head scale caches."""
from vllm.utils.math_utils import next_power_of_2
from vllm.v1.attention.ops.triton_unified_attention import unified_attention
torch.set_default_device("cuda")
set_random_seed(0)
num_seqs = len(seq_lens)
query_lens = [s[0] for s in seq_lens]
kv_lens = [s[1] for s in seq_lens]
num_query_heads, num_kv_heads = num_heads
max_query_len = max(query_lens)
max_kv_len = max(kv_lens)
scale = head_size**-0.5
num_blocks = 2048
query = torch.randn(
sum(query_lens), num_query_heads, head_size, dtype=torch.bfloat16
)
key_cache_bf16 = torch.randn(
num_blocks, block_size, num_kv_heads, head_size, dtype=torch.bfloat16
)
value_cache_bf16 = torch.randn_like(key_cache_bf16)
# Per-token-head quantization: one scale per (block, slot, head)
k_absmax = key_cache_bf16.float().abs().amax(dim=-1) # [..., num_kv_heads]
v_absmax = value_cache_bf16.float().abs().amax(dim=-1)
k_scale_cache = (k_absmax / qcfg.quant_max).clamp(min=1e-6).to(torch.float32)
v_scale_cache = (v_absmax / qcfg.quant_max).clamp(min=1e-6).to(torch.float32)
scaled_k = key_cache_bf16.float() / k_scale_cache[:, :, :, None]
scaled_v = value_cache_bf16.float() / v_scale_cache[:, :, :, None]
if qcfg.uses_trunc:
key_cache_q = (
scaled_k.round().clamp(qcfg.quant_min, qcfg.quant_max).to(qcfg.cache_dtype)
)
value_cache_q = (
scaled_v.round().clamp(qcfg.quant_min, qcfg.quant_max).to(qcfg.cache_dtype)
)
else:
key_cache_q = scaled_k.clamp(qcfg.quant_min, qcfg.quant_max).to(
qcfg.cache_dtype
)
value_cache_q = scaled_v.clamp(qcfg.quant_min, qcfg.quant_max).to(
qcfg.cache_dtype
)
# Dequantized reference
key_cache_deq = key_cache_q.float() * k_scale_cache[:, :, :, None]
value_cache_deq = value_cache_q.float() * v_scale_cache[:, :, :, None]
cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum(
dim=0, dtype=torch.int32
)
kv_lens_t = torch.tensor(kv_lens, dtype=torch.int32)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(
0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
)
head_size_padded = next_power_of_2(head_size)
seq_threshold_3D = 0
num_par_softmax_segments = 16
softmax_segm_output = torch.empty(
(seq_threshold_3D, num_query_heads, num_par_softmax_segments, head_size_padded),
dtype=torch.float32,
)
softmax_segm_max = torch.empty(
(seq_threshold_3D, num_query_heads, num_par_softmax_segments),
dtype=torch.float32,
)
softmax_segm_expsum = torch.empty(
(seq_threshold_3D, num_query_heads, num_par_softmax_segments),
dtype=torch.float32,
)
output_q = torch.empty_like(query)
unified_attention(
q=query,
k=key_cache_q,
v=value_cache_q,
out=output_q,
cu_seqlens_q=cu_query_lens,
seqused_k=kv_lens_t,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len,
softmax_scale=scale,
causal=True,
window_size=(-1, -1),
block_table=block_tables,
softcap=0,
q_descale=None,
k_descale=None,
v_descale=None,
seq_threshold_3D=seq_threshold_3D,
num_par_softmax_segments=num_par_softmax_segments,
softmax_segm_output=softmax_segm_output,
softmax_segm_max=softmax_segm_max,
softmax_segm_expsum=softmax_segm_expsum,
kv_quant_mode=qcfg.kv_quant_mode,
k_scale_cache=k_scale_cache,
v_scale_cache=v_scale_cache,
)
output_ref = torch.empty_like(query)
unified_attention(
q=query,
k=key_cache_deq.to(torch.bfloat16),
v=value_cache_deq.to(torch.bfloat16),
out=output_ref,
cu_seqlens_q=cu_query_lens,
seqused_k=kv_lens_t,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len,
softmax_scale=scale,
causal=True,
window_size=(-1, -1),
block_table=block_tables,
softcap=0,
q_descale=None,
k_descale=None,
v_descale=None,
seq_threshold_3D=seq_threshold_3D,
num_par_softmax_segments=num_par_softmax_segments,
softmax_segm_output=softmax_segm_output,
softmax_segm_max=softmax_segm_max,
softmax_segm_expsum=softmax_segm_expsum,
)
torch.testing.assert_close(output_q, output_ref, atol=5e-2, rtol=5e-2)
...@@ -8,7 +8,10 @@ from pydantic import Field, SkipValidation, field_validator, model_validator ...@@ -8,7 +8,10 @@ from pydantic import Field, SkipValidation, field_validator, model_validator
from vllm.config.utils import config from vllm.config.utils import config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.torch_utils import is_quantized_kv_cache from vllm.utils.torch_utils import (
is_quantized_kv_cache,
kv_cache_uses_per_token_head_scales,
)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -21,6 +24,8 @@ CacheDType = Literal[ ...@@ -21,6 +24,8 @@ CacheDType = Literal[
"fp8_e5m2", "fp8_e5m2",
"fp8_inc", "fp8_inc",
"fp8_ds_mla", "fp8_ds_mla",
"int8_per_token_head",
"fp8_per_token_head",
] ]
MambaDType = Literal["auto", "float32", "float16"] MambaDType = Literal["auto", "float32", "float16"]
MambaCacheMode = Literal["all", "align", "none"] MambaCacheMode = Literal["all", "align", "none"]
...@@ -237,12 +242,20 @@ class CacheConfig: ...@@ -237,12 +242,20 @@ class CacheConfig:
@field_validator("cache_dtype", mode="after") @field_validator("cache_dtype", mode="after")
@classmethod @classmethod
def _validate_cache_dtype(cls, cache_dtype: CacheDType) -> CacheDType: def _validate_cache_dtype(cls, cache_dtype: CacheDType) -> CacheDType:
if is_quantized_kv_cache(cache_dtype): if kv_cache_uses_per_token_head_scales(cache_dtype):
logger.info( logger.info(
"Using fp8 data type to store kv cache. It reduces the GPU " "Using %s data type to store kv cache. It reduces the GPU "
"memory footprint and boosts the performance. "
"Dynamic per-token-head scales will be computed at runtime.",
str(cache_dtype),
)
elif is_quantized_kv_cache(cache_dtype):
logger.info(
"Using %s data type to store kv cache. It reduces the GPU "
"memory footprint and boosts the performance. " "memory footprint and boosts the performance. "
"Meanwhile, it may cause accuracy drop without a proper " "Meanwhile, it may cause accuracy drop without a proper "
"scaling factor." "scaling factor",
str(cache_dtype),
) )
return cache_dtype return cache_dtype
......
...@@ -38,6 +38,7 @@ from vllm.v1.kv_cache_interface import ( ...@@ -38,6 +38,7 @@ from vllm.v1.kv_cache_interface import (
FullAttentionSpec, FullAttentionSpec,
KVCacheSpec, KVCacheSpec,
SlidingWindowSpec, SlidingWindowSpec,
get_kv_quant_mode,
) )
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -381,8 +382,10 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -381,8 +382,10 @@ class Attention(nn.Module, AttentionLayerBase):
# for attn backends supporting query quantization # for attn backends supporting query quantization
self.query_quant = None self.query_quant = None
if self.impl.supports_quant_query_input and self.kv_cache_dtype.startswith( if (
"fp8" self.impl.supports_quant_query_input
and self.kv_cache_dtype.startswith("fp8")
and not self.kv_cache_dtype.endswith("per_token_head")
): ):
is_per_head = ( is_per_head = (
hasattr(self, "q_scale") and self.q_scale.numel() == self.num_kv_heads hasattr(self, "q_scale") and self.q_scale.numel() == self.num_kv_heads
...@@ -539,6 +542,7 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -539,6 +542,7 @@ class Attention(nn.Module, AttentionLayerBase):
block_size = vllm_config.cache_config.block_size block_size = vllm_config.cache_config.block_size
# Should not be called for enc-dec or encoder-only attention. # Should not be called for enc-dec or encoder-only attention.
assert self.attn_type == AttentionType.DECODER assert self.attn_type == AttentionType.DECODER
quant_mode = get_kv_quant_mode(self.kv_cache_dtype)
if self.sliding_window is not None: if self.sliding_window is not None:
assert not vllm_config.model_config.use_mla, ( assert not vllm_config.model_config.use_mla, (
"MLA is not supported for slidingwindow" "MLA is not supported for slidingwindow"
...@@ -548,6 +552,7 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -548,6 +552,7 @@ class Attention(nn.Module, AttentionLayerBase):
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
head_size=self.head_size, head_size=self.head_size,
dtype=self.kv_cache_torch_dtype, dtype=self.kv_cache_torch_dtype,
kv_quant_mode=quant_mode,
sliding_window=self.sliding_window, sliding_window=self.sliding_window,
) )
else: else:
...@@ -557,6 +562,7 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -557,6 +562,7 @@ class Attention(nn.Module, AttentionLayerBase):
head_size=self.head_size, head_size=self.head_size,
head_size_v=self.head_size_v, head_size_v=self.head_size_v,
dtype=self.kv_cache_torch_dtype, dtype=self.kv_cache_torch_dtype,
kv_quant_mode=quant_mode,
) )
......
...@@ -23,6 +23,7 @@ from vllm.v1.kv_cache_interface import ( ...@@ -23,6 +23,7 @@ from vllm.v1.kv_cache_interface import (
AttentionSpec, AttentionSpec,
ChunkedLocalAttentionSpec, ChunkedLocalAttentionSpec,
KVCacheSpec, KVCacheSpec,
get_kv_quant_mode,
) )
...@@ -123,5 +124,6 @@ class ChunkedLocalAttention(Attention): ...@@ -123,5 +124,6 @@ class ChunkedLocalAttention(Attention):
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
head_size=self.head_size, head_size=self.head_size,
dtype=self.kv_cache_torch_dtype, dtype=self.kv_cache_torch_dtype,
kv_quant_mode=get_kv_quant_mode(self.kv_cache_dtype),
attention_chunk_size=self.attention_chunk_size, attention_chunk_size=self.attention_chunk_size,
) )
...@@ -18,7 +18,11 @@ from vllm.v1.attention.backend import ( ...@@ -18,7 +18,11 @@ from vllm.v1.attention.backend import (
subclass_attention_backend_with_overrides, subclass_attention_backend_with_overrides,
) )
from vllm.v1.attention.selector import get_attn_backend from vllm.v1.attention.selector import get_attn_backend
from vllm.v1.kv_cache_interface import CrossAttentionSpec, KVCacheSpec from vllm.v1.kv_cache_interface import (
CrossAttentionSpec,
KVCacheSpec,
get_kv_quant_mode,
)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -220,4 +224,5 @@ class CrossAttention(Attention): ...@@ -220,4 +224,5 @@ class CrossAttention(Attention):
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
head_size=self.head_size, head_size=self.head_size,
dtype=self.kv_cache_torch_dtype, dtype=self.kv_cache_torch_dtype,
kv_quant_mode=get_kv_quant_mode(self.kv_cache_dtype),
) )
...@@ -26,6 +26,7 @@ from vllm.v1.kv_cache_interface import ( ...@@ -26,6 +26,7 @@ from vllm.v1.kv_cache_interface import (
AttentionSpec, AttentionSpec,
KVCacheSpec, KVCacheSpec,
SinkFullAttentionSpec, SinkFullAttentionSpec,
get_kv_quant_mode,
) )
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -217,6 +218,7 @@ class StaticSinkAttention(Attention, CustomOp): ...@@ -217,6 +218,7 @@ class StaticSinkAttention(Attention, CustomOp):
head_size_v=self.head_size_v, head_size_v=self.head_size_v,
sink_len=self.sink_len, sink_len=self.sink_len,
dtype=self.kv_cache_torch_dtype, dtype=self.kv_cache_torch_dtype,
kv_quant_mode=get_kv_quant_mode(self.kv_cache_dtype),
) )
......
...@@ -10,6 +10,7 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -10,6 +10,7 @@ from vllm.model_executor.layers.quantization.base_config import (
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import is_quantized_kv_cache from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.kv_cache_interface import kv_cache_uses_per_token_head_scales
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -53,6 +54,20 @@ class BaseKVCacheMethod(QuantizeMethodBase): ...@@ -53,6 +54,20 @@ class BaseKVCacheMethod(QuantizeMethodBase):
assert not hasattr(layer, "prob_scale") assert not hasattr(layer, "prob_scale")
return return
# Per-token-head quantized KV cache: scales are computed dynamically
# per (token, head) in the kernel at cache-write time. Checkpoint
# scales are never used regardless of calculate_kv_scales.
if kv_cache_uses_per_token_head_scales(layer.kv_cache_dtype):
layer._k_scale.copy_(1.0)
layer._v_scale.copy_(1.0)
layer._k_scale_float = 1.0
layer._v_scale_float = 1.0
del layer.k_scale
del layer.v_scale
del layer.q_scale
del layer.prob_scale
return
# If the kv-cache is not quantized, we enforce the k/v_scale to be 1.0 # If the kv-cache is not quantized, we enforce the k/v_scale to be 1.0
# regardless whether the kv-scale is available in the checkpoint. # regardless whether the kv-scale is available in the checkpoint.
# No need to process kv scales after loading if we are going to # No need to process kv scales after loading if we are going to
......
...@@ -505,6 +505,7 @@ class Platform: ...@@ -505,6 +505,7 @@ class Platform:
FullAttentionSpec, FullAttentionSpec,
MambaSpec, MambaSpec,
MLAAttentionSpec, MLAAttentionSpec,
get_kv_quant_mode,
) )
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
...@@ -516,6 +517,8 @@ class Platform: ...@@ -516,6 +517,8 @@ class Platform:
else: else:
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
kv_quant_mode = get_kv_quant_mode(cache_config.cache_dtype)
# Compute attention page size for 1 token # Compute attention page size for 1 token
if model_config.use_mla: if model_config.use_mla:
attn_page_size_1_token = MLAAttentionSpec( attn_page_size_1_token = MLAAttentionSpec(
...@@ -523,6 +526,7 @@ class Platform: ...@@ -523,6 +526,7 @@ class Platform:
num_kv_heads=model_config.get_num_kv_heads(parallel_config), num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(), head_size=model_config.get_head_size(),
dtype=kv_cache_dtype, dtype=kv_cache_dtype,
kv_quant_mode=kv_quant_mode,
).page_size_bytes ).page_size_bytes
else: else:
attn_page_size_1_token = FullAttentionSpec( attn_page_size_1_token = FullAttentionSpec(
...@@ -530,6 +534,7 @@ class Platform: ...@@ -530,6 +534,7 @@ class Platform:
num_kv_heads=model_config.get_num_kv_heads(parallel_config), num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(), head_size=model_config.get_head_size(),
dtype=kv_cache_dtype, dtype=kv_cache_dtype,
kv_quant_mode=kv_quant_mode,
).page_size_bytes ).page_size_bytes
# Compute mamba page size # Compute mamba page size
......
...@@ -37,6 +37,8 @@ STR_DTYPE_TO_TORCH_DTYPE = { ...@@ -37,6 +37,8 @@ STR_DTYPE_TO_TORCH_DTYPE = {
"fp8_e4m3": torch.uint8, "fp8_e4m3": torch.uint8,
"fp8_e5m2": torch.uint8, "fp8_e5m2": torch.uint8,
"int8": torch.int8, "int8": torch.int8,
"int8_per_token_head": torch.int8,
"fp8_per_token_head": torch.uint8,
"fp8_inc": torch.float8_e4m3fn, "fp8_inc": torch.float8_e4m3fn,
"fp8_ds_mla": torch.uint8, "fp8_ds_mla": torch.uint8,
} }
...@@ -62,7 +64,12 @@ T = TypeVar("T") ...@@ -62,7 +64,12 @@ T = TypeVar("T")
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool: def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
return kv_cache_dtype.startswith("fp8") return kv_cache_dtype.startswith("fp8") or kv_cache_dtype.endswith("per_token_head")
def kv_cache_uses_per_token_head_scales(kv_cache_dtype: str) -> bool:
"""Return True if *kv_cache_dtype* needs per-token-head scales."""
return kv_cache_dtype.endswith("per_token_head")
def is_strictly_contiguous(t: torch.Tensor) -> bool: def is_strictly_contiguous(t: torch.Tensor) -> bool:
......
...@@ -17,7 +17,9 @@ if TYPE_CHECKING: ...@@ -17,7 +17,9 @@ if TYPE_CHECKING:
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.utils import KVCacheLayoutType from vllm.v1.attention.backends.utils import KVCacheLayoutType
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec, KVQuantMode
from vllm.v1.kv_cache_interface import get_kv_quant_mode
class AttentionType(str, Enum): class AttentionType(str, Enum):
...@@ -740,6 +742,13 @@ class AttentionImplBase(ABC, Generic[T]): ...@@ -740,6 +742,13 @@ class AttentionImplBase(ABC, Generic[T]):
class AttentionImpl(AttentionImplBase[T], Generic[T]): class AttentionImpl(AttentionImplBase[T], Generic[T]):
"""Standard attention implementation with forward method.""" """Standard attention implementation with forward method."""
kv_cache_dtype: str
@property
def kv_quant_mode(self) -> "KVQuantMode":
"""Return the KV cache quantization mode for this layer."""
return get_kv_quant_mode(self.kv_cache_dtype)
@abstractmethod @abstractmethod
def __init__( def __init__(
self, self,
......
...@@ -33,9 +33,14 @@ from vllm.v1.attention.backends.utils import get_kv_cache_layout ...@@ -33,9 +33,14 @@ from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import ( from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash, triton_reshape_and_cache_flash,
triton_reshape_and_cache_flash_per_token_head_quant,
) )
from vllm.v1.attention.ops.triton_unified_attention import unified_attention from vllm.v1.attention.ops.triton_unified_attention import unified_attention
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import (
AttentionSpec,
get_kv_quant_mode,
kv_cache_uses_per_token_head_scales,
)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -270,6 +275,8 @@ class TritonAttentionBackend(AttentionBackend): ...@@ -270,6 +275,8 @@ class TritonAttentionBackend(AttentionBackend):
"fp8", "fp8",
"fp8_e4m3", "fp8_e4m3",
"fp8_e5m2", "fp8_e5m2",
"int8_per_token_head",
"fp8_per_token_head",
] ]
@staticmethod @staticmethod
...@@ -302,6 +309,18 @@ class TritonAttentionBackend(AttentionBackend): ...@@ -302,6 +309,18 @@ class TritonAttentionBackend(AttentionBackend):
) -> tuple[int, ...]: ) -> tuple[int, ...]:
if block_size % 16 != 0: if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.") raise ValueError("Block size must be a multiple of 16.")
if kv_cache_uses_per_token_head_scales(cache_dtype_str):
# Pad head_size by sizeof(float32)/sizeof(cache_dtype) so
# the per-head scale fits inline. The backend extracts
# data[:head_size] and scale[head_size:] via typed views.
from vllm.utils.torch_utils import (
STR_DTYPE_TO_TORCH_DTYPE,
get_dtype_size,
)
cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype_str]
scale_pad = get_dtype_size(torch.float32) // get_dtype_size(cache_dtype)
return (num_blocks, 2, block_size, num_kv_heads, head_size + scale_pad)
return (num_blocks, 2, block_size, num_kv_heads, head_size) return (num_blocks, 2, block_size, num_kv_heads, head_size)
@staticmethod @staticmethod
...@@ -365,6 +384,62 @@ class TritonAttentionBackend(AttentionBackend): ...@@ -365,6 +384,62 @@ class TritonAttentionBackend(AttentionBackend):
class TritonAttentionImpl(AttentionImpl): class TritonAttentionImpl(AttentionImpl):
# Per-token-head quant: scale views carved from inline head padding.
_k_scale_cache: torch.Tensor | None = None
_v_scale_cache: torch.Tensor | None = None
def _ensure_scale_caches(self, kv_cache: torch.Tensor) -> None:
"""Extract per-head scale views from the padded head dimension.
The KV cache shape is ``(num_blocks, 2, block_size, nkv, hs+pad)``
where ``pad = sizeof(float32) / sizeof(cache_dtype)``. The last
``pad`` elements of each head hold one float32 scale. We create
strided float32 views over those bytes.
Scale shape: ``(num_blocks, block_size, num_kv_heads)``
"""
if self._k_scale_cache is not None:
return
from vllm.utils.torch_utils import get_dtype_size
num_blocks, _, block_size, nkv, padded_hs = kv_cache.shape
dtype_sz = kv_cache.element_size()
scale_pad = get_dtype_size(torch.float32) // dtype_sz # e.g. 4
hs = padded_hs - scale_pad
raw = kv_cache.untyped_storage()
base_f32 = torch.tensor([], dtype=torch.float32, device=kv_cache.device).set_(
raw
)
# In the raw bytes, each (block, kv_half, slot, head) occupies
# padded_hs * dtype_sz bytes. The scale float32 sits at byte
# offset hs * dtype_sz within that region.
kv_half_bytes = block_size * nkv * padded_hs * dtype_sz
full_block_f32 = 2 * kv_half_bytes // 4 # stride between blocks
slot_f32 = nkv * padded_hs * dtype_sz // 4 # stride between slots
head_f32 = padded_hs * dtype_sz // 4 # stride between heads
scale_off_f32 = hs * dtype_sz // 4 # offset to scale within head
# K scales: kv_half=0
self._k_scale_cache = torch.as_strided(
base_f32,
size=(num_blocks, block_size, nkv),
stride=(full_block_f32, slot_f32, head_f32),
storage_offset=scale_off_f32,
)
self._k_scale_cache.fill_(1.0)
# V scales: kv_half=1, offset by kv_half_bytes
v_base_f32 = kv_half_bytes // 4
self._v_scale_cache = torch.as_strided(
base_f32,
size=(num_blocks, block_size, nkv),
stride=(full_block_f32, slot_f32, head_f32),
storage_offset=v_base_f32 + scale_off_f32,
)
self._v_scale_cache.fill_(1.0)
def fused_output_quant_supported(self, quant_key: QuantKey): def fused_output_quant_supported(self, quant_key: QuantKey):
return quant_key == kFp8StaticTensorSym return quant_key == kFp8StaticTensorSym
...@@ -418,6 +493,9 @@ class TritonAttentionImpl(AttentionImpl): ...@@ -418,6 +493,9 @@ class TritonAttentionImpl(AttentionImpl):
self.use_alibi_sqrt = use_alibi_sqrt self.use_alibi_sqrt = use_alibi_sqrt
self.supports_quant_query_input = current_platform.is_cuda() self.supports_quant_query_input = current_platform.is_cuda()
self._kv_quant_mode = get_kv_quant_mode(kv_cache_dtype)
self._is_per_token_head_quant = self._kv_quant_mode.is_per_token_head
def forward( def forward(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -480,15 +558,35 @@ class TritonAttentionImpl(AttentionImpl): ...@@ -480,15 +558,35 @@ class TritonAttentionImpl(AttentionImpl):
layer, layer,
) )
# For decoder and cross-attention, use KV cache as before # Per-token-head quantized KV cache: use separate scale caches.
key_cache, value_cache = kv_cache.unbind(1) if self._is_per_token_head_quant:
if is_quantized_kv_cache(self.kv_cache_dtype): self._ensure_scale_caches(kv_cache)
if key_cache.dtype != self.fp8_dtype: key_cache, value_cache = kv_cache.unbind(1)
if key_cache.dtype == torch.uint8:
key_cache = key_cache.view(self.fp8_dtype) key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype)
assert layer._q_scale_float == 1.0, ( k_descale = None
"A non 1.0 q_scale is not currently supported." v_descale = None
k_scale_cache = self._k_scale_cache
v_scale_cache = self._v_scale_cache
# FP8 per-tensor / auto path (original flow).
else:
key_cache, value_cache = kv_cache.unbind(1)
if is_quantized_kv_cache(self.kv_cache_dtype):
if key_cache.dtype != self.fp8_dtype:
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
assert layer._q_scale_float == 1.0, (
"A non 1.0 q_scale is not currently supported."
)
descale_shape = (
attn_metadata.query_start_loc.shape[0] - 1,
key_cache.shape[2],
) )
k_descale = layer._k_scale.expand(descale_shape)
v_descale = layer._v_scale.expand(descale_shape)
k_scale_cache = None
v_scale_cache = None
cu_seqlens_q = attn_metadata.query_start_loc cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens seqused_k = attn_metadata.seq_lens
...@@ -502,7 +600,6 @@ class TritonAttentionImpl(AttentionImpl): ...@@ -502,7 +600,6 @@ class TritonAttentionImpl(AttentionImpl):
softmax_segm_max = attn_metadata.softmax_segm_max softmax_segm_max = attn_metadata.softmax_segm_max
softmax_segm_expsum = attn_metadata.softmax_segm_expsum softmax_segm_expsum = attn_metadata.softmax_segm_expsum
descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2])
mm_prefix_range_tensor = attn_metadata.mm_prefix_range_tensor mm_prefix_range_tensor = attn_metadata.mm_prefix_range_tensor
unified_attention( unified_attention(
...@@ -522,8 +619,8 @@ class TritonAttentionImpl(AttentionImpl): ...@@ -522,8 +619,8 @@ class TritonAttentionImpl(AttentionImpl):
block_table=block_table, block_table=block_table,
softcap=self.logits_soft_cap, softcap=self.logits_soft_cap,
q_descale=None, # Not supported q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape), k_descale=k_descale,
v_descale=layer._v_scale.expand(descale_shape), v_descale=v_descale,
seq_threshold_3D=seq_threshold_3D, seq_threshold_3D=seq_threshold_3D,
num_par_softmax_segments=num_par_softmax_segments, num_par_softmax_segments=num_par_softmax_segments,
softmax_segm_output=softmax_segm_output, softmax_segm_output=softmax_segm_output,
...@@ -532,6 +629,9 @@ class TritonAttentionImpl(AttentionImpl): ...@@ -532,6 +629,9 @@ class TritonAttentionImpl(AttentionImpl):
sinks=self.sinks, sinks=self.sinks,
output_scale=output_scale, output_scale=output_scale,
mm_prefix_range=mm_prefix_range_tensor, mm_prefix_range=mm_prefix_range_tensor,
kv_quant_mode=self._kv_quant_mode,
k_scale_cache=k_scale_cache,
v_scale_cache=v_scale_cache,
) )
return output return output
...@@ -555,10 +655,10 @@ class TritonAttentionImpl(AttentionImpl): ...@@ -555,10 +655,10 @@ class TritonAttentionImpl(AttentionImpl):
attn_metadata: Encoder attention metadata attn_metadata: Encoder attention metadata
layer: The attention layer layer: The attention layer
""" """
# For encoder attention, process FP8 quantization if needed # Quantized KV cache is not supported for encoder attention.
if is_quantized_kv_cache(self.kv_cache_dtype): if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError( raise NotImplementedError(
"quantization is not supported for encoder attention" "quantized KV cache is not supported for encoder attention"
) )
# Use encoder-specific metadata for sequence information # Use encoder-specific metadata for sequence information
...@@ -594,16 +694,28 @@ class TritonAttentionImpl(AttentionImpl): ...@@ -594,16 +694,28 @@ class TritonAttentionImpl(AttentionImpl):
# For encoder attention, # For encoder attention,
# we use direct Q, K, V tensors without caching # we use direct Q, K, V tensors without caching
return return
# For decoder and cross-attention, use KV cache as before
key_cache, value_cache = kv_cache.unbind(1)
# Reshape the input keys and values and store them in the cache. # Reshape the input keys and values and store them in the cache.
if self._is_per_token_head_quant:
self._ensure_scale_caches(kv_cache)
key_cache, value_cache = kv_cache.unbind(1)
if key_cache.dtype == torch.uint8:
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
triton_reshape_and_cache_flash_per_token_head_quant(
key,
value,
key_cache,
value_cache,
self._k_scale_cache,
self._v_scale_cache,
slot_mapping,
)
return
# For decoder and cross-attention, use KV cache as before.
key_cache, value_cache = kv_cache.unbind(1)
if is_quantized_kv_cache(self.kv_cache_dtype): if is_quantized_kv_cache(self.kv_cache_dtype):
key_cache = key_cache.view(self.fp8_dtype) key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype)
# triton kernel does not support uint8 kv_cache
# (because some explicit casts (e.g. float8_e4m3fnuz)
# are not supported)
triton_reshape_and_cache_flash( triton_reshape_and_cache_flash(
key, key,
value, value,
...@@ -616,6 +728,8 @@ class TritonAttentionImpl(AttentionImpl): ...@@ -616,6 +728,8 @@ class TritonAttentionImpl(AttentionImpl):
) )
def fused_rope_kvcache_supported(self): def fused_rope_kvcache_supported(self):
if self._is_per_token_head_quant:
return False
return rocm_aiter_ops.is_enabled() return rocm_aiter_ops.is_enabled()
def do_rope_and_kv_cache_update( def do_rope_and_kv_cache_update(
......
...@@ -3,10 +3,16 @@ ...@@ -3,10 +3,16 @@
import torch import torch
from vllm.model_executor.layers.quantization.utils.quant_utils import (
FP8_DTYPE,
get_fp8_min_max,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import is_quantized_kv_cache from vllm.utils.torch_utils import is_quantized_kv_cache
FP8_MIN, FP8_MAX = get_fp8_min_max()
@triton.jit @triton.jit
def reshape_and_cache_kernel_flash( def reshape_and_cache_kernel_flash(
...@@ -118,6 +124,198 @@ def reshape_and_cache_kernel_flash( ...@@ -118,6 +124,198 @@ def reshape_and_cache_kernel_flash(
return return
# ---------------------------------------------------------------------------
# Per-token-head dynamic quantization kernel
# Grid: (num_tokens, NUM_KV_HEADS)
# Each program handles one (token, head) pair:
# 1. Loads K (or V) for that single head
# 2. Computes absmax across head_size → scale = absmax / QUANT_MAX
# 3. Quantizes and stores the data + per-head scale
#
# Parametrised by QUANT_MAX / QUANT_MIN so the same code path works
# for int8 (±127/128), fp8_e4m3 (±448), and other formats.
# ---------------------------------------------------------------------------
@triton.jit
def _reshape_cache_per_token_head(
key_ptr, # [num_tokens, num_kv_heads, head_size]
value_ptr, # [num_tokens, num_kv_heads, head_size_v]
key_cache_ptr, # [num_blocks, block_size, num_kv_heads, head_size]
value_cache_ptr, # [num_blocks, block_size, num_kv_heads, head_size_v]
k_scale_cache_ptr, # [num_blocks, block_size, num_kv_heads] float32
v_scale_cache_ptr, # [num_blocks, block_size, num_kv_heads] float32
slot_mapping_ptr, # [num_tokens]
stride_key_tok: tl.int64,
stride_key_head: tl.int64,
stride_val_tok: tl.int64,
stride_val_head: tl.int64,
stride_kc_blk: tl.int64, # key_cache stride over blocks
stride_kc_slot: tl.int64, # key_cache stride over slots
stride_kc_head: tl.int64, # key_cache stride over heads
stride_vc_blk: tl.int64,
stride_vc_slot: tl.int64,
stride_vc_head: tl.int64,
stride_ks_blk: tl.int64, # k_scale_cache stride[0] (blocks)
stride_ks_slot: tl.int64, # k_scale_cache stride[1] (slots)
stride_ks_head: tl.int64, # k_scale_cache stride[2] (heads)
stride_vs_blk: tl.int64, # v_scale_cache stride[0] (blocks)
stride_vs_slot: tl.int64, # v_scale_cache stride[1] (slots)
stride_vs_head: tl.int64, # v_scale_cache stride[2] (heads)
block_size: tl.constexpr,
head_size: tl.constexpr,
head_size_v: tl.constexpr,
HEAD_SIZE_PADDED: tl.constexpr, # next_power_of_2(max(head_size, head_size_v))
QUANT_MAX: tl.constexpr = 127.0,
QUANT_MIN: tl.constexpr = -128.0,
):
tok = tl.program_id(0)
head = tl.program_id(1)
slot = tl.load(slot_mapping_ptr + tok).to(tl.int64)
if slot < 0:
return
blk = slot // block_size
slot_in_blk = slot % block_size
dim_offs = tl.arange(0, HEAD_SIZE_PADDED)
# ---- Key: load one head → absmax → quantize → store -------------------
k_mask = dim_offs < head_size
k_h = tl.load(
key_ptr + tok * stride_key_tok + head * stride_key_head + dim_offs,
mask=k_mask,
other=0.0,
).to(tl.float32)
k_scale = tl.maximum(tl.max(tl.abs(k_h)) / QUANT_MAX, 1e-6)
tl.store(
k_scale_cache_ptr
+ blk * stride_ks_blk
+ slot_in_blk * stride_ks_slot
+ head * stride_ks_head,
k_scale,
)
k_q = tl.clamp(k_h * (1.0 / k_scale), QUANT_MIN, QUANT_MAX)
tl.store(
key_cache_ptr
+ blk * stride_kc_blk
+ slot_in_blk * stride_kc_slot
+ head * stride_kc_head
+ dim_offs,
k_q,
mask=k_mask,
)
# ---- Value: same per-head approach ------------------------------------
v_mask = dim_offs < head_size_v
v_h = tl.load(
value_ptr + tok * stride_val_tok + head * stride_val_head + dim_offs,
mask=v_mask,
other=0.0,
).to(tl.float32)
v_scale = tl.maximum(tl.max(tl.abs(v_h)) / QUANT_MAX, 1e-6)
tl.store(
v_scale_cache_ptr
+ blk * stride_vs_blk
+ slot_in_blk * stride_vs_slot
+ head * stride_vs_head,
v_scale,
)
v_q = tl.clamp(v_h * (1.0 / v_scale), QUANT_MIN, QUANT_MAX)
tl.store(
value_cache_ptr
+ blk * stride_vc_blk
+ slot_in_blk * stride_vc_slot
+ head * stride_vc_head
+ dim_offs,
v_q,
mask=v_mask,
)
# Mapping from cache torch dtype to (QUANT_MAX, QUANT_MIN) for the
# per-token-head quantization kernel.
_PER_TOKEN_HEAD_QUANT_PARAMS: dict[torch.dtype, tuple[float, float]] = {
torch.int8: (127.0, -128.0),
FP8_DTYPE: (FP8_MAX, FP8_MIN),
}
def triton_reshape_and_cache_flash_per_token_head_quant(
key: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
value: torch.Tensor, # [num_tokens, num_kv_heads, head_size_v]
key_cache: torch.Tensor, # [num_blocks, block_size, num_kv_heads, head_size]
value_cache: torch.Tensor, # [num_blocks, block_size, num_kv_heads, head_size_v]
k_scale_cache: torch.Tensor, # [num_blocks, block_size, num_kv_heads] float32
v_scale_cache: torch.Tensor, # [num_blocks, block_size, num_kv_heads] float32
slot_mapping: torch.Tensor, # [num_tokens]
):
"""Quantize key/value per (token, head) and write to paged cache.
Computes one scale = absmax / QUANT_MAX per (token, head), stores
quantized data in key_cache/value_cache, and stores the float32
scale in k_scale_cache/v_scale_cache.
The quantization range (QUANT_MAX, QUANT_MIN) is derived from the
cache tensor dtype so the same code path works for int8 and fp8.
"""
cache_dtype = key_cache.dtype
quant_params = _PER_TOKEN_HEAD_QUANT_PARAMS.get(cache_dtype)
if quant_params is None:
raise ValueError(
f"Per-token-head quantization not supported for cache dtype "
f"{cache_dtype}. Supported: {list(_PER_TOKEN_HEAD_QUANT_PARAMS)}"
)
quant_max, quant_min = quant_params
num_tokens, num_kv_heads, head_size = key.shape
head_size_v = value.shape[2]
head_size_padded = triton.next_power_of_2(max(head_size, head_size_v))
block_size = key_cache.shape[1]
if current_platform.is_rocm() or current_platform.is_xpu():
num_warps = 4
else:
num_warps = min(16, max(1, head_size_padded // 32))
_reshape_cache_per_token_head[(num_tokens, num_kv_heads)](
key_ptr=key,
value_ptr=value,
key_cache_ptr=key_cache,
value_cache_ptr=value_cache,
k_scale_cache_ptr=k_scale_cache,
v_scale_cache_ptr=v_scale_cache,
slot_mapping_ptr=slot_mapping,
stride_key_tok=key.stride(0),
stride_key_head=key.stride(1),
stride_val_tok=value.stride(0),
stride_val_head=value.stride(1),
stride_kc_blk=key_cache.stride(0),
stride_kc_slot=key_cache.stride(1),
stride_kc_head=key_cache.stride(2),
stride_vc_blk=value_cache.stride(0),
stride_vc_slot=value_cache.stride(1),
stride_vc_head=value_cache.stride(2),
stride_ks_blk=k_scale_cache.stride(0),
stride_ks_slot=k_scale_cache.stride(1),
stride_ks_head=k_scale_cache.stride(2),
stride_vs_blk=v_scale_cache.stride(0),
stride_vs_slot=v_scale_cache.stride(1),
stride_vs_head=v_scale_cache.stride(2),
block_size=block_size,
head_size=head_size,
head_size_v=head_size_v,
HEAD_SIZE_PADDED=head_size_padded,
QUANT_MAX=quant_max,
QUANT_MIN=quant_min,
num_warps=num_warps,
)
def triton_reshape_and_cache_flash( def triton_reshape_and_cache_flash(
key: torch.Tensor, # [num_tokens, num_heads, head_size] key: torch.Tensor, # [num_tokens, num_heads, head_size]
value: torch.Tensor, # [num_tokens, num_heads, head_size] value: torch.Tensor, # [num_tokens, num_heads, head_size]
...@@ -224,7 +422,6 @@ def triton_reshape_and_cache_flash( ...@@ -224,7 +422,6 @@ def triton_reshape_and_cache_flash(
block_size=block_size, block_size=block_size,
x=x, x=x,
USE_HEAD_MAJOR_LAYOUT=use_head_major_layout, USE_HEAD_MAJOR_LAYOUT=use_head_major_layout,
# FP8 flags
FP8_KV_CACHE=FP8_KV_CACHE, FP8_KV_CACHE=FP8_KV_CACHE,
# autotune parameters # autotune parameters
TILE_SIZE=TILE_SIZE, TILE_SIZE=TILE_SIZE,
......
...@@ -13,6 +13,7 @@ import vllm.envs as envs ...@@ -13,6 +13,7 @@ import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.v1.kv_cache_interface import KVQuantMode
logger = init_logger(__name__) logger = init_logger(__name__)
is_batch_invariant = envs.VLLM_BATCH_INVARIANT is_batch_invariant = envs.VLLM_BATCH_INVARIANT
...@@ -32,6 +33,63 @@ def apply_softcap(S, x): ...@@ -32,6 +33,63 @@ def apply_softcap(S, x):
return x * (p1 - p2) / (p1 + p2) return x * (p1 - p2) / (p1 + p2)
@triton.jit
def _prepare_kv_tile(
data,
Q,
tensor_scale,
scale_cache_ptr,
physical_block_idx,
seq_offset,
kv_head_idx,
stride_s_blk,
stride_s_slot,
stride_s_head,
tile_mask,
BLOCK_SIZE: tl.constexpr,
KV_QUANT_MODE: tl.constexpr,
):
"""Prepare a loaded KV tile for attention computation.
Casts the raw KV data to Q's dtype and loads per-token-head scales
when applicable:
- ``KV_QUANT_MODE == 0``: cast only (no-op for bf16/fp16).
- ``KV_QUANT_MODE == 1`` (FP8 per-tensor): dequantize inline
using the tensor-wide scale.
- ``KV_QUANT_MODE >= 2`` (per-token-head int8/fp8): cast to Q's
dtype and return per-head scales separately — the caller applies
them after the dot product for better numerical efficiency.
Returns ``(data, token_head_scales)``. *token_head_scales* is only
meaningful when ``KV_QUANT_MODE >= 2``; callers gate its use on
the same constexpr so the compiler eliminates dead code.
"""
# KV_QUANT_MODE values: 0=none, 1=fp8 per-tensor,
# 2=int8 per-token-head, 3=fp8 per-token-head
# Placeholder scales (float32) — never read when KV_QUANT_MODE < 2.
unused_scales = tile_mask.to(tl.float32)
if KV_QUANT_MODE == 1: # FP8 per-tensor
if Q.dtype.is_fp8():
return data.to(Q.dtype), unused_scales
return (data.to(tl.float32) * tl.load(tensor_scale)).to(Q.dtype), unused_scales
if KV_QUANT_MODE >= 2: # per-token-head (int8 or fp8)
scale_idx = (
physical_block_idx * stride_s_blk
+ (seq_offset % BLOCK_SIZE) * stride_s_slot
+ kv_head_idx * stride_s_head
)
token_head_scales = tl.load(
scale_cache_ptr + scale_idx, mask=tile_mask, other=1.0
)
return data.to(Q.dtype), token_head_scales
# .to(Q.dtype) is a no-op when data is already Q's type (bf16/fp16),
# but required so Triton sees consistent return types across branches.
return data.to(Q.dtype), unused_scales
@triton.jit @triton.jit
def find_seq_idx( def find_seq_idx(
query_start_len_ptr, query_start_len_ptr,
...@@ -105,8 +163,20 @@ def kernel_unified_attention_2d( ...@@ -105,8 +163,20 @@ def kernel_unified_attention_2d(
num_seqs: tl.int32, num_seqs: tl.int32,
BLOCK_M: tl.constexpr, # int BLOCK_M: tl.constexpr, # int
USE_FP8: tl.constexpr, # bool USE_FP8: tl.constexpr, # bool
# KV cache quantization: 0=none, 1=fp8, 2=per-token-head
KV_QUANT_MODE: tl.constexpr = 0,
FP8_MIN: tl.constexpr = float8_info.min, FP8_MIN: tl.constexpr = float8_info.min,
FP8_MAX: tl.constexpr = float8_info.max, FP8_MAX: tl.constexpr = float8_info.max,
# Per-token-head scale caches (KV_QUANT_MODE >= 2)
# Shape: [num_blocks, block_size, num_kv_heads]
k_scale_cache_ptr=None,
v_scale_cache_ptr=None,
stride_ks_blk=0,
stride_ks_slot=0,
stride_ks_head=0,
stride_vs_blk=0,
stride_vs_slot=0,
stride_vs_head=0,
): ):
q_block_global_idx = tl.program_id(0) q_block_global_idx = tl.program_id(0)
kv_head_idx = tl.program_id(1) kv_head_idx = tl.program_id(1)
...@@ -258,14 +328,21 @@ def kernel_unified_attention_2d( ...@@ -258,14 +328,21 @@ def kernel_unified_attention_2d(
mask=dim_mask[:, None] & tile_mask[None, :], mask=dim_mask[:, None] & tile_mask[None, :],
other=0.0, other=0.0,
) )
K, k_token_head_scales = _prepare_kv_tile(
if K_load.dtype.is_fp8(): K_load,
if Q.dtype.is_fp8(): Q,
K = K_load k_scale,
else: k_scale_cache_ptr,
K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) physical_block_idx,
else: seq_offset,
K = K_load kv_head_idx,
stride_ks_blk,
stride_ks_slot,
stride_ks_head,
tile_mask,
BLOCK_SIZE,
KV_QUANT_MODE,
)
# V : (TILE_SIZE, HEAD_SIZE) # V : (TILE_SIZE, HEAD_SIZE)
V_load = tl.load( V_load = tl.load(
...@@ -273,14 +350,21 @@ def kernel_unified_attention_2d( ...@@ -273,14 +350,21 @@ def kernel_unified_attention_2d(
mask=dim_mask[None, :] & tile_mask[:, None], mask=dim_mask[None, :] & tile_mask[:, None],
other=0.0, other=0.0,
) )
V, v_token_head_scales = _prepare_kv_tile(
if V_load.dtype.is_fp8(): V_load,
if Q.dtype.is_fp8(): Q,
V = V_load v_scale,
else: v_scale_cache_ptr,
V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) physical_block_idx,
else: seq_offset,
V = V_load kv_head_idx,
stride_vs_blk,
stride_vs_slot,
stride_vs_head,
tile_mask,
BLOCK_SIZE,
KV_QUANT_MODE,
)
# Compute attention mask: causal by default (key <= query) # Compute attention mask: causal by default (key <= query)
query_abs_pos = context_len + query_pos[:, None] query_abs_pos = context_len + query_pos[:, None]
...@@ -318,7 +402,12 @@ def kernel_unified_attention_2d( ...@@ -318,7 +402,12 @@ def kernel_unified_attention_2d(
# S : (BLOCK_M, TILE_SIZE) # S : (BLOCK_M, TILE_SIZE)
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32) S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
S += scale * tl.dot(Q, K) # Per-token-head quant: fuse softmax_scale with per-head k_scale
# to avoid a separate BLOCK_M × TILE_SIZE multiply on S.
if KV_QUANT_MODE >= 2:
S += tl.dot(Q, K) * (scale * k_token_head_scales[None, :])
else:
S += scale * tl.dot(Q, K)
if USE_SOFTCAP: if USE_SOFTCAP:
S = apply_softcap(S, softcap) S = apply_softcap(S, softcap)
...@@ -382,7 +471,12 @@ def kernel_unified_attention_2d( ...@@ -382,7 +471,12 @@ def kernel_unified_attention_2d(
) )
# acc : (BLOCK_M, HEAD_SIZE_PADDED) # acc : (BLOCK_M, HEAD_SIZE_PADDED)
acc += tl.dot(P.to(V.dtype), V) # Per-token-head quant: apply v_scale to P instead of V.
if KV_QUANT_MODE >= 2:
P_v = (P * v_token_head_scales[None, :]).to(V.dtype)
acc += tl.dot(P_v, V)
else:
acc += tl.dot(P.to(V.dtype), V)
# epilogue # epilogue
acc = acc / L[:, None] acc = acc / L[:, None]
...@@ -453,6 +547,18 @@ def kernel_unified_attention_3d( ...@@ -453,6 +547,18 @@ def kernel_unified_attention_3d(
USE_MM_PREFIX: tl.constexpr, # bool USE_MM_PREFIX: tl.constexpr, # bool
MAX_MM_RANGES: tl.constexpr, # int MAX_MM_RANGES: tl.constexpr, # int
mm_prefix_range_ptr, # [num_seqs] - prefix length for each sequence mm_prefix_range_ptr, # [num_seqs] - prefix length for each sequence
# KV cache quantization: 0=none, 1=fp8, 2=per-token-head
KV_QUANT_MODE: tl.constexpr = 0,
# Per-token-head scale caches (KV_QUANT_MODE >= 2)
# Shape: [num_blocks, block_size, num_kv_heads]
k_scale_cache_ptr=None,
v_scale_cache_ptr=None,
stride_ks_blk=0,
stride_ks_slot=0,
stride_ks_head=0,
stride_vs_blk=0,
stride_vs_slot=0,
stride_vs_head=0,
): ):
q_block_global_idx = tl.program_id(0) q_block_global_idx = tl.program_id(0)
kv_head_idx = tl.program_id(1) kv_head_idx = tl.program_id(1)
...@@ -613,14 +719,21 @@ def kernel_unified_attention_3d( ...@@ -613,14 +719,21 @@ def kernel_unified_attention_3d(
mask=dim_mask[:, None] & tile_mask[None, :], mask=dim_mask[:, None] & tile_mask[None, :],
other=0.0, other=0.0,
) )
K, k_token_head_scales = _prepare_kv_tile(
if K_load.dtype.is_fp8(): K_load,
if Q.dtype.is_fp8(): Q,
K = K_load k_scale,
else: k_scale_cache_ptr,
K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) physical_block_idx,
else: seq_offset,
K = K_load kv_head_idx,
stride_ks_blk,
stride_ks_slot,
stride_ks_head,
tile_mask,
BLOCK_SIZE,
KV_QUANT_MODE,
)
# V : (TILE_SIZE, HEAD_SIZE) # V : (TILE_SIZE, HEAD_SIZE)
V_load = tl.load( V_load = tl.load(
...@@ -628,14 +741,21 @@ def kernel_unified_attention_3d( ...@@ -628,14 +741,21 @@ def kernel_unified_attention_3d(
mask=dim_mask[None, :] & tile_mask[:, None], mask=dim_mask[None, :] & tile_mask[:, None],
other=0.0, other=0.0,
) )
V, v_token_head_scales = _prepare_kv_tile(
if V_load.dtype.is_fp8(): V_load,
if Q.dtype.is_fp8(): Q,
V = V_load v_scale,
else: v_scale_cache_ptr,
V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) physical_block_idx,
else: seq_offset,
V = V_load kv_head_idx,
stride_vs_blk,
stride_vs_slot,
stride_vs_head,
tile_mask,
BLOCK_SIZE,
KV_QUANT_MODE,
)
# Compute attention mask: causal by default (key <= query) # Compute attention mask: causal by default (key <= query)
query_abs_pos = context_len + query_pos[:, None] query_abs_pos = context_len + query_pos[:, None]
...@@ -672,7 +792,13 @@ def kernel_unified_attention_3d( ...@@ -672,7 +792,13 @@ def kernel_unified_attention_3d(
# S : (BLOCK_M, TILE_SIZE) # S : (BLOCK_M, TILE_SIZE)
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32) S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
S += scale * tl.dot(Q, K)
# Per-token-head quant: fuse softmax_scale with per-head k_scale
# to avoid a separate BLOCK_M × TILE_SIZE multiply on S.
if KV_QUANT_MODE >= 2:
S += tl.dot(Q, K) * (scale * k_token_head_scales[None, :])
else:
S += scale * tl.dot(Q, K)
if USE_SOFTCAP: if USE_SOFTCAP:
S = apply_softcap(S, softcap) S = apply_softcap(S, softcap)
...@@ -736,7 +862,12 @@ def kernel_unified_attention_3d( ...@@ -736,7 +862,12 @@ def kernel_unified_attention_3d(
) )
# acc : (BLOCK_M, HEAD_SIZE_PADDED) # acc : (BLOCK_M, HEAD_SIZE_PADDED)
acc += tl.dot(P.to(V.dtype), V) # Per-token-head quant: apply v_scale to P instead of V.
if KV_QUANT_MODE >= 2:
P_v = (P * v_token_head_scales[None, :]).to(V.dtype)
acc += tl.dot(P_v, V)
else:
acc += tl.dot(P.to(V.dtype), V)
segm_output_offset = ( segm_output_offset = (
query_offset_0[:, None].to(tl.int64) query_offset_0[:, None].to(tl.int64)
...@@ -911,6 +1042,10 @@ def unified_attention( ...@@ -911,6 +1042,10 @@ def unified_attention(
# Optional tensor for prefix lengths (PrefixLM support) # Optional tensor for prefix lengths (PrefixLM support)
mm_prefix_range=None, mm_prefix_range=None,
use_alibi_sqrt=False, use_alibi_sqrt=False,
# KV cache quantization mode and per-token-head scale caches.
kv_quant_mode: KVQuantMode = KVQuantMode.NONE,
k_scale_cache=None, # [num_blocks, block_size, num_kv_heads] float32
v_scale_cache=None, # [num_blocks, block_size, num_kv_heads] float32
): ):
assert causal, "Only causal attention is supported" assert causal, "Only causal attention is supported"
assert q_descale is None, "Q scales not supported" assert q_descale is None, "Q scales not supported"
...@@ -1040,6 +1175,15 @@ def unified_attention( ...@@ -1040,6 +1175,15 @@ def unified_attention(
num_seqs=num_seqs, num_seqs=num_seqs,
BLOCK_M=BLOCK_M, BLOCK_M=BLOCK_M,
USE_FP8=output_scale is not None, USE_FP8=output_scale is not None,
KV_QUANT_MODE=kv_quant_mode,
k_scale_cache_ptr=k_scale_cache,
v_scale_cache_ptr=v_scale_cache,
stride_ks_blk=k_scale_cache.stride(0) if k_scale_cache is not None else 0,
stride_ks_slot=k_scale_cache.stride(1) if k_scale_cache is not None else 0,
stride_ks_head=k_scale_cache.stride(2) if k_scale_cache is not None else 0,
stride_vs_blk=v_scale_cache.stride(0) if v_scale_cache is not None else 0,
stride_vs_slot=v_scale_cache.stride(1) if v_scale_cache is not None else 0,
stride_vs_head=v_scale_cache.stride(2) if v_scale_cache is not None else 0,
) )
else: else:
kernel_unified_attention_3d[ kernel_unified_attention_3d[
...@@ -1092,6 +1236,15 @@ def unified_attention( ...@@ -1092,6 +1236,15 @@ def unified_attention(
num_seqs=num_seqs, num_seqs=num_seqs,
BLOCK_M=BLOCK_M, BLOCK_M=BLOCK_M,
NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments, NUM_SEGMENTS_PER_SEQ=num_par_softmax_segments,
KV_QUANT_MODE=kv_quant_mode,
k_scale_cache_ptr=k_scale_cache,
v_scale_cache_ptr=v_scale_cache,
stride_ks_blk=k_scale_cache.stride(0) if k_scale_cache is not None else 0,
stride_ks_slot=k_scale_cache.stride(1) if k_scale_cache is not None else 0,
stride_ks_head=k_scale_cache.stride(2) if k_scale_cache is not None else 0,
stride_vs_blk=v_scale_cache.stride(0) if v_scale_cache is not None else 0,
stride_vs_slot=v_scale_cache.stride(1) if v_scale_cache is not None else 0,
stride_vs_head=v_scale_cache.stride(2) if v_scale_cache is not None else 0,
) )
reduce_segments[(q.shape[0], num_query_heads)]( reduce_segments[(q.shape[0], num_query_heads)](
output_ptr=out, output_ptr=out,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import copy import copy
from dataclasses import dataclass, fields, replace from dataclasses import dataclass, fields, replace
from enum import IntEnum
from math import prod from math import prod
from typing import TYPE_CHECKING
import torch import torch
from typing_extensions import Self from typing_extensions import Self
from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import get_dtype_size from vllm.utils.torch_utils import get_dtype_size
logger = init_logger(__name__) logger = init_logger(__name__)
# ---------------------------------------------------------------------------
# KV cache quantization mode
# ---------------------------------------------------------------------------
class KVQuantMode(IntEnum):
"""KV cache quantization mode.
Used by attention backends and kernels to dispatch quantization logic
without string matching on ``kv_cache_dtype``.
"""
NONE = 0
FP8_PER_TENSOR = 1 # per-tensor scales (current fp8 path)
INT8_PER_TOKEN_HEAD = 2 # per-token-head dynamic scales for int8
FP8_PER_TOKEN_HEAD = 3 # per-token-head dynamic scales for fp8
@property
def is_per_token_head(self) -> bool:
"""True for any per-token-head quantization mode."""
return self >= 2
def get_kv_quant_mode(kv_cache_dtype: str) -> KVQuantMode:
"""Map a ``kv_cache_dtype`` string to a :class:`KVQuantMode`."""
if kv_cache_dtype == "int8_per_token_head":
return KVQuantMode.INT8_PER_TOKEN_HEAD
if kv_cache_dtype == "fp8_per_token_head":
return KVQuantMode.FP8_PER_TOKEN_HEAD
if kv_cache_dtype.startswith("fp8"):
return KVQuantMode.FP8_PER_TENSOR
return KVQuantMode.NONE
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
return get_kv_quant_mode(kv_cache_dtype) != KVQuantMode.NONE
def kv_cache_uses_per_token_head_scales(kv_cache_dtype: str) -> bool:
"""Return True if *kv_cache_dtype* needs per-token-head scales."""
return get_kv_quant_mode(kv_cache_dtype).is_per_token_head
@dataclass(frozen=True) @dataclass(frozen=True)
class KVCacheSpec: class KVCacheSpec:
""" """
...@@ -66,11 +115,19 @@ class AttentionSpec(KVCacheSpec): ...@@ -66,11 +115,19 @@ class AttentionSpec(KVCacheSpec):
num_kv_heads: int num_kv_heads: int
head_size: int head_size: int
dtype: torch.dtype dtype: torch.dtype
kv_quant_mode: KVQuantMode = KVQuantMode.NONE
page_size_padded: int | None = None page_size_padded: int | None = None
@property @property
def page_size_bytes(self) -> int: def page_size_bytes(self) -> int:
real_page_size = self.real_page_size_bytes real_page_size = self.real_page_size_bytes
# Per-token-head scales are stored in separate tensors managed
# by the attention backend, but the memory is carved from the
# raw KV cache allocation so it must be budgeted here.
if self.kv_quant_mode.is_per_token_head:
real_page_size += (
2 * self.block_size * self.num_kv_heads * get_dtype_size(torch.float32)
)
if self.page_size_padded is not None: if self.page_size_padded is not None:
assert self.page_size_padded >= real_page_size assert self.page_size_padded >= real_page_size
return self.page_size_padded return self.page_size_padded
...@@ -159,6 +216,7 @@ class FullAttentionSpec(AttentionSpec): ...@@ -159,6 +216,7 @@ class FullAttentionSpec(AttentionSpec):
head_size=specs[0].head_size, head_size=specs[0].head_size,
head_size_v=specs[0].head_size_v, head_size_v=specs[0].head_size_v,
dtype=specs[0].dtype, dtype=specs[0].dtype,
kv_quant_mode=specs[0].kv_quant_mode,
page_size_padded=specs[0].page_size_padded, page_size_padded=specs[0].page_size_padded,
sliding_window=cls.merge_window_sizes(sliding_window), sliding_window=cls.merge_window_sizes(sliding_window),
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size), attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
...@@ -220,6 +278,7 @@ class MLAAttentionSpec(FullAttentionSpec): ...@@ -220,6 +278,7 @@ class MLAAttentionSpec(FullAttentionSpec):
num_kv_heads=specs[0].num_kv_heads, num_kv_heads=specs[0].num_kv_heads,
head_size=specs[0].head_size, head_size=specs[0].head_size,
dtype=specs[0].dtype, dtype=specs[0].dtype,
kv_quant_mode=specs[0].kv_quant_mode,
page_size_padded=specs[0].page_size_padded, page_size_padded=specs[0].page_size_padded,
cache_dtype_str=cache_dtype_str_set.pop(), cache_dtype_str=cache_dtype_str_set.pop(),
) )
...@@ -352,6 +411,7 @@ class SinkFullAttentionSpec(FullAttentionSpec): ...@@ -352,6 +411,7 @@ class SinkFullAttentionSpec(FullAttentionSpec):
head_size_v=specs[0].head_size_v, head_size_v=specs[0].head_size_v,
sink_len=specs[0].sink_len, sink_len=specs[0].sink_len,
dtype=specs[0].dtype, dtype=specs[0].dtype,
kv_quant_mode=specs[0].kv_quant_mode,
page_size_padded=specs[0].page_size_padded, page_size_padded=specs[0].page_size_padded,
sliding_window=cls.merge_window_sizes(sliding_window), sliding_window=cls.merge_window_sizes(sliding_window),
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size), attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment