Unverified Commit 7f829be7 authored by Li, Jiang's avatar Li, Jiang Committed by GitHub
Browse files

[CPU] Refactor CPU attention backend (#27954)


Signed-off-by: default avatarjiang1.li <jiang1.li@intel.com>
parent e1710393
......@@ -35,7 +35,7 @@ DEVICE_MLA_BACKENDS = {
DEVICE_REGULAR_ATTN_BACKENDS = {
"cuda": ["XFORMERS", "FLASHINFER", "FLASH_ATTN"],
"hip": ["ROCM_ATTN"],
"cpu": ["TORCH_SDPA"],
"cpu": ["CPU_ATTN"],
}
DEVICE_MLA_BLOCK_SIZES = {
......@@ -86,7 +86,7 @@ def test_env(
if device == "cpu":
with patch("vllm.platforms.current_platform", CpuPlatform()):
backend = get_attn_backend(16, torch.float16, None, block_size)
assert backend.get_name() == "TORCH_SDPA"
assert backend.get_name() == "CPU_ATTN"
elif device == "hip":
with patch("vllm.platforms.current_platform", RocmPlatform()):
......@@ -224,7 +224,7 @@ def test_fp32_fallback(device: str):
if device == "cpu":
with patch("vllm.platforms.current_platform", CpuPlatform()):
backend = get_attn_backend(16, torch.float32, None, 16)
assert backend.get_name() == "TORCH_SDPA"
assert backend.get_name() == "CPU_ATTN"
elif device == "cuda":
with patch("vllm.platforms.current_platform", CudaPlatform()):
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import math
import pytest
import torch
from vllm.platforms import current_platform
if not current_platform.is_cpu():
pytest.skip("skipping CPU-only tests", allow_module_level=True)
from vllm._custom_ops import (
cpu_attention_with_kv_cache,
cpu_attn_get_scheduler_metadata,
cpu_attn_reshape_and_cache,
)
NUM_HEADS = [
(4, 4),
(8, 2),
(9, 3),
]
HEAD_SIZES = [96, 128]
QTYPES = [torch.bfloat16, torch.half, torch.float32]
SLIDING_WINDOWS = [None, 256]
NUM_BLOCKS = [
1024,
]
SEQ_LENS = [ # (q_len, kv_len)
[(1, 213), (1, 1), (1, 312), (1, 7), (1, 7812)], # decode batch
[(2345, 2345), (5, 5), (3, 16), (134, 5131)], # prefill batch
[(992, 2456), (1, 1234), (98, 1145), (1, 4162), (2345, 2345)], # mixed batch
]
# rand number generation takes too much time, cache rand tensors
@functools.lru_cache(maxsize=128, typed=False)
def tensor_cache(
elem_num: int,
dtype: torch.dtype,
) -> torch.Tensor:
tensor = torch.randn(elem_num, dtype=dtype)
return tensor
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
closest_power_of_2 = 2 ** math.floor(math.log2(total_num_heads))
base = torch.tensor(
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))),
dtype=torch.float32,
)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != total_num_heads:
extra_base = torch.tensor(
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),
dtype=torch.float32,
)
num_remaining_heads = min(
closest_power_of_2, total_num_heads - closest_power_of_2
)
extra_powers = torch.arange(
start=1, end=1 + 2 * num_remaining_heads, step=2, dtype=torch.int32
)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes.float()
def ref_paged_attn(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
query_lens: list[int],
kv_lens: list[int],
block_tables: torch.Tensor,
scale: float,
sliding_window: int | None = None,
soft_cap: float | None = None,
alibi_slopes: torch.Tensor | None = None,
s_aux: torch.Tensor | None = None,
) -> torch.Tensor:
num_seqs = len(query_lens)
block_tables = block_tables.cpu().numpy()
_, block_size, num_kv_heads, head_size = key_cache.shape
dtype = query.dtype
outputs: list[torch.Tensor] = []
start_idx = 0
if alibi_slopes is not None:
alibi_slopes = alibi_slopes[:, None, None]
if s_aux is not None:
s_aux = s_aux.float()
s_aux = s_aux[:, None, None]
for i in range(num_seqs):
query_len = query_lens[i]
kv_len = kv_lens[i]
q = query[start_idx : start_idx + query_len].float()
q *= scale
num_kv_blocks = (kv_len + block_size - 1) // block_size
block_indices = block_tables[i, :num_kv_blocks]
k = key_cache[block_indices].view(-1, num_kv_heads, head_size)
k = k[:kv_len].float()
v = value_cache[block_indices].view(-1, num_kv_heads, head_size)
v = v[:kv_len].float()
if q.shape[1] != k.shape[1]:
k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1)
v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1)
attn = torch.einsum("qhd,khd->hqk", q, k).float()
empty_mask = torch.ones(query_len, kv_len)
mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
if sliding_window is not None:
sliding_window_mask = (
torch.triu(
empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1
)
.bool()
.logical_not()
)
mask |= sliding_window_mask
if soft_cap is not None:
attn = soft_cap * torch.tanh(attn / soft_cap)
if alibi_slopes is not None:
q_start_pos = kv_len - query_len
q_pos = q_start_pos + torch.arange(0, query_len)[None, :, None]
kv_pos = torch.arange(0, kv_len)[None, None, :]
dist = q_pos - kv_pos
alibi_bias = -alibi_slopes * dist
attn += alibi_bias
attn.masked_fill_(mask, float("-inf"))
if s_aux is not None:
s_aux_ext = s_aux.repeat(1, query_len, 1)
attn = torch.cat((s_aux_ext, attn), dim=-1)
attn = torch.softmax(attn, dim=-1)
if s_aux is not None:
attn = attn[:, :, 1:]
out = torch.einsum("hqk,khd->qhd", attn, v).to(dtype=dtype)
outputs.append(out)
start_idx += query_len
return torch.cat(outputs, dim=0)
@torch.inference_mode()
def varlen_with_paged_kv(
seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],
head_size: int,
sliding_window: int | None,
dtype: torch.dtype,
block_size: int,
soft_cap: float | None,
num_blocks: int,
use_alibi: bool,
use_sink: bool,
isa: str,
) -> None:
current_platform.seed_everything(0)
num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens]
kv_lens = [x[1] for x in seq_lens]
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
max_kv_len = max(kv_lens)
window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1)
scale = head_size**-0.5
token_num = sum(query_lens)
# for n heads the set of slopes is the geometric sequence that starts
# 2^(-8/n)
alibi_slopes = _get_alibi_slopes(num_query_heads) if use_alibi else None
s_aux = (
15 * torch.rand((num_query_heads,), dtype=torch.bfloat16) if use_sink else None
)
query = tensor_cache(
elem_num=token_num * num_query_heads * head_size,
dtype=dtype,
)
query = query.view(
token_num,
num_query_heads,
head_size,
)
key_value = tensor_cache(
elem_num=2 * num_blocks * num_kv_heads * block_size * head_size,
dtype=dtype,
)
key_value = key_value.view(
2,
num_blocks,
block_size,
num_kv_heads,
head_size,
)
key_cache, value_cache = key_value.unbind(0)
# KV cache for CPU attention
packed_key_cache = torch.empty(
num_blocks, num_kv_heads, block_size, head_size, dtype=dtype
)
packed_value_cache = torch.empty_like(packed_key_cache)
cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum(
dim=0, dtype=torch.int32
)
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(
0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
)
# use reshape_and_cache to pack key_cache and value_cache
slot_mapping = torch.arange(0, num_blocks * block_size, dtype=torch.int64)
cpu_attn_reshape_and_cache(
key=key_cache.view(-1, num_kv_heads, head_size),
value=value_cache.view(-1, num_kv_heads, head_size),
key_cache=packed_key_cache,
value_cache=packed_value_cache,
slot_mapping=slot_mapping,
isa=isa,
)
metadata = cpu_attn_get_scheduler_metadata(
num_reqs=num_seqs,
num_heads=num_query_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
seq_lens=kv_lens_tensor,
dtype=dtype,
query_start_loc=cu_query_lens,
causal=True,
sliding_window_size=sliding_window if sliding_window is not None else -1,
isa=isa,
enable_kv_split=False,
)
out_without_split = torch.empty_like(query)
cpu_attention_with_kv_cache(
query=query,
key_cache=packed_key_cache,
value_cache=packed_value_cache,
output=out_without_split,
query_start_loc=cu_query_lens,
seq_lens=kv_lens_tensor,
scale=scale,
causal=True,
alibi_slopes=alibi_slopes,
sliding_window=window_size,
block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0,
scheduler_metadata=metadata,
s_aux=s_aux,
)
metadata = cpu_attn_get_scheduler_metadata(
num_reqs=num_seqs,
num_heads=num_query_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
seq_lens=kv_lens_tensor,
dtype=dtype,
query_start_loc=cu_query_lens,
causal=True,
sliding_window_size=sliding_window if sliding_window is not None else -1,
isa=isa,
enable_kv_split=True,
)
out_with_split = torch.empty_like(query)
cpu_attention_with_kv_cache(
query=query,
key_cache=packed_key_cache,
value_cache=packed_value_cache,
output=out_with_split,
query_start_loc=cu_query_lens,
seq_lens=kv_lens_tensor,
scale=scale,
causal=True,
alibi_slopes=alibi_slopes,
sliding_window=window_size,
block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0,
scheduler_metadata=metadata,
s_aux=s_aux,
)
ref_output = ref_paged_attn(
query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=query_lens,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
sliding_window=sliding_window,
soft_cap=soft_cap,
alibi_slopes=alibi_slopes,
s_aux=s_aux,
)
atol, rtol = 1.5e-2, 1e-2
(
torch.testing.assert_close(out_with_split, ref_output, atol=atol, rtol=rtol),
f"{torch.max(torch.abs(out_with_split - ref_output))}",
)
(
torch.testing.assert_close(out_without_split, ref_output, atol=atol, rtol=rtol),
f"{torch.max(torch.abs(out_without_split - ref_output))}",
)
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", [96, 128])
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
@pytest.mark.parametrize("dtype", QTYPES)
@pytest.mark.parametrize("soft_cap", [None])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("use_alibi", [False])
@pytest.mark.parametrize("use_sink", [False])
@pytest.mark.parametrize("isa", ["vec"])
def test_varlen_with_paged_kv_normal_vec(
seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],
head_size: int,
sliding_window: int | None,
dtype: torch.dtype,
block_size: int,
soft_cap: float | None,
num_blocks: int,
use_alibi: bool,
use_sink: bool,
isa: str,
) -> None:
varlen_with_paged_kv(
seq_lens=seq_lens,
num_heads=num_heads,
head_size=head_size,
sliding_window=sliding_window,
dtype=dtype,
block_size=block_size,
soft_cap=soft_cap,
num_blocks=num_blocks,
use_alibi=use_alibi,
use_sink=use_sink,
isa=isa,
)
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", [96, 128])
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("soft_cap", [None])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("use_alibi", [False])
@pytest.mark.parametrize("use_sink", [False])
@pytest.mark.parametrize("isa", ["amx"])
@pytest.mark.skipif(
not torch._C._cpu._is_amx_tile_supported(), reason="no AMX support."
)
def test_varlen_with_paged_kv_normal_amx(
seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],
head_size: int,
sliding_window: int | None,
dtype: torch.dtype,
block_size: int,
soft_cap: float | None,
num_blocks: int,
use_alibi: bool,
use_sink: bool,
isa: str,
) -> None:
varlen_with_paged_kv(
seq_lens=seq_lens,
num_heads=num_heads,
head_size=head_size,
sliding_window=sliding_window,
dtype=dtype,
block_size=block_size,
soft_cap=soft_cap,
num_blocks=num_blocks,
use_alibi=use_alibi,
use_sink=use_sink,
isa=isa,
)
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", [48])
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("soft_cap", [None])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("use_alibi", [False])
@pytest.mark.parametrize("use_sink", [False])
@pytest.mark.parametrize("isa", ["vec16"])
def test_varlen_with_paged_kv_normal_vec16(
seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],
head_size: int,
sliding_window: int | None,
dtype: torch.dtype,
block_size: int,
soft_cap: float | None,
num_blocks: int,
use_alibi: bool,
use_sink: bool,
isa: str,
) -> None:
varlen_with_paged_kv(
seq_lens=seq_lens,
num_heads=num_heads,
head_size=head_size,
sliding_window=sliding_window,
dtype=dtype,
block_size=block_size,
soft_cap=soft_cap,
num_blocks=num_blocks,
use_alibi=use_alibi,
use_sink=use_sink,
isa=isa,
)
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", [96])
@pytest.mark.parametrize("block_size", [128])
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("soft_cap", [50])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("use_alibi", [False])
@pytest.mark.parametrize("use_sink", [False])
@pytest.mark.parametrize(
"isa", ["amx"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
)
def test_varlen_with_paged_kv_softcap(
seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],
head_size: int,
sliding_window: int | None,
dtype: torch.dtype,
block_size: int,
soft_cap: float | None,
num_blocks: int,
use_alibi: bool,
use_sink: bool,
isa: str,
) -> None:
varlen_with_paged_kv(
seq_lens=seq_lens,
num_heads=num_heads,
head_size=head_size,
sliding_window=sliding_window,
dtype=dtype,
block_size=block_size,
soft_cap=soft_cap,
num_blocks=num_blocks,
use_alibi=use_alibi,
use_sink=use_sink,
isa=isa,
)
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", [96])
@pytest.mark.parametrize("block_size", [128])
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("soft_cap", [None])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("use_alibi", [True])
@pytest.mark.parametrize("use_sink", [False])
@pytest.mark.parametrize(
"isa", ["amx"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
)
def test_varlen_with_paged_kv_alibi(
seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],
head_size: int,
sliding_window: int | None,
dtype: torch.dtype,
block_size: int,
soft_cap: float | None,
num_blocks: int,
use_alibi: bool,
use_sink: bool,
isa: str,
) -> None:
varlen_with_paged_kv(
seq_lens=seq_lens,
num_heads=num_heads,
head_size=head_size,
sliding_window=sliding_window,
dtype=dtype,
block_size=block_size,
soft_cap=soft_cap,
num_blocks=num_blocks,
use_alibi=use_alibi,
use_sink=use_sink,
isa=isa,
)
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", [96])
@pytest.mark.parametrize("block_size", [128])
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("soft_cap", [None])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("use_alibi", [False])
@pytest.mark.parametrize("use_sink", [True])
@pytest.mark.parametrize(
"isa", ["amx"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
)
def test_varlen_with_paged_kv_sink(
seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],
head_size: int,
sliding_window: int | None,
dtype: torch.dtype,
block_size: int,
soft_cap: float | None,
num_blocks: int,
use_alibi: bool,
use_sink: bool,
isa: str,
) -> None:
varlen_with_paged_kv(
seq_lens=seq_lens,
num_heads=num_heads,
head_size=head_size,
sliding_window=sliding_window,
dtype=dtype,
block_size=block_size,
soft_cap=soft_cap,
num_blocks=num_blocks,
use_alibi=use_alibi,
use_sink=use_sink,
isa=isa,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Integration tests for FlexAttention backend vs default backend"""
import pytest
import torch
......
......@@ -38,7 +38,11 @@ AITER_MODEL_LIST = [
[
pytest.param(
"bigscience/bloom-560m", # bloom - testing alibi slopes
marks=[pytest.mark.core_model, pytest.mark.slow_test],
marks=[
pytest.mark.core_model,
pytest.mark.slow_test,
pytest.mark.cpu_model,
],
),
pytest.param(
"openai-community/gpt2", # gpt2
......@@ -55,6 +59,10 @@ AITER_MODEL_LIST = [
pytest.mark.slow_test,
],
),
pytest.param(
"google/gemma-2-2b-it", # test hybrid attention
marks=[pytest.mark.cpu_model],
),
pytest.param(
"zai-org/chatglm3-6b", # chatglm (text-only)
),
......@@ -64,7 +72,6 @@ AITER_MODEL_LIST = [
),
pytest.param(
"openbmb/MiniCPM3-4B",
# fused_moe not supported on CPU
marks=[pytest.mark.core_model, large_gpu_mark(min_gb=32)],
),
pytest.param(
......@@ -93,11 +100,7 @@ AITER_MODEL_LIST = [
pytest.param("bigcode/starcoder2-3b"), # starcoder2
pytest.param(
"TitanML/tiny-mixtral", # mixtral
marks=[pytest.mark.core_model],
),
pytest.param(
"allenai/OLMoE-1B-7B-0924-Instruct",
marks=[pytest.mark.cpu_model],
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
),
pytest.param("swiss-ai/Apertus-8B-Instruct-2509"), # apertus
],
......
......@@ -23,8 +23,7 @@ from ...utils import check_embeddings_close
),
pytest.param(
"intfloat/e5-mistral-7b-instruct",
# CPU v1 doesn't support sliding window
marks=[pytest.mark.core_model],
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
),
pytest.param(
"ssmits/Qwen2-7B-Instruct-embed-base", marks=[pytest.mark.cpu_model]
......
......@@ -243,7 +243,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"FalconH1ForCausalLM": _HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base"),
"FlexOlmoForCausalLM": _HfExamplesInfo("allenai/Flex-reddit-2x7B-1T"),
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"),
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
"Gemma2ForCausalLM": _HfExamplesInfo(
"google/gemma-2-9b", extras={"tiny": "google/gemma-2-2b-it"}
),
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),
"Gemma3nForCausalLM": _HfExamplesInfo("google/gemma-3n-E2B-it"),
"GlmForCausalLM": _HfExamplesInfo("zai-org/glm-4-9b-chat-hf"),
......
......@@ -2583,6 +2583,88 @@ def onednn_scaled_mm(
return output
def cpu_attn_get_scheduler_metadata(
num_reqs: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
seq_lens: torch.Tensor,
dtype: torch.dtype,
query_start_loc: torch.Tensor,
causal: bool,
sliding_window_size: int,
isa: str,
enable_kv_split: bool,
) -> torch.Tensor:
sheduler_metadata = torch.ops._C.get_scheduler_metadata(
num_reqs,
num_heads,
num_kv_heads,
head_dim,
seq_lens,
dtype,
query_start_loc,
causal,
sliding_window_size,
isa,
enable_kv_split,
)
return sheduler_metadata
def cpu_attn_reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
isa: str,
) -> None:
torch.ops._C.cpu_attn_reshape_and_cache(
key,
value,
key_cache,
value_cache,
slot_mapping,
isa,
)
def cpu_attention_with_kv_cache(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
output: torch.Tensor,
query_start_loc: torch.Tensor,
seq_lens: torch.Tensor,
scale: float,
causal: bool,
alibi_slopes: torch.Tensor | None,
sliding_window: tuple[int, int],
block_table: torch.Tensor,
softcap: float,
scheduler_metadata: torch.Tensor,
s_aux: torch.Tensor | None,
) -> None:
torch.ops._C.cpu_attention_with_kv_cache(
query,
key_cache,
value_cache,
output,
query_start_loc,
seq_lens,
scale,
causal,
alibi_slopes,
sliding_window[0],
sliding_window[1],
block_table,
softcap,
scheduler_metadata,
s_aux,
)
if hasattr(torch.ops._qutlass_C, "matmul_mxf4_bf16_tn"):
@register_fake("_qutlass_C::matmul_mxf4_bf16_tn")
......
......@@ -49,7 +49,7 @@ class AttentionBackendEnum(enum.Enum, metaclass=_AttentionBackendEnumMeta):
ROCM_AITER_FA = (
"vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
)
TORCH_SDPA = "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend"
TORCH_SDPA = "" # this tag is only used for ViT
FLASHINFER = "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
FLASHINFER_MLA = (
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
......@@ -70,6 +70,7 @@ class AttentionBackendEnum(enum.Enum, metaclass=_AttentionBackendEnumMeta):
"vllm.v1.attention.backends.rocm_aiter_unified_attn."
"RocmAiterUnifiedAttentionBackend"
)
CPU_ATTN = "vllm.v1.attention.backends.cpu_attn.CPUAttentionBackend"
# Placeholder for third-party/custom backends - must be registered before use
CUSTOM = ""
......
......@@ -1726,9 +1726,6 @@ class EngineArgs:
)
_raise_unsupported_error(feature_name=name)
if current_platform.is_cpu() and model_config.get_sliding_window() is not None:
_raise_unsupported_error(feature_name="sliding window (CPU backend)")
def _set_default_args(
self, usage_context: UsageContext, model_config: ModelConfig
) -> None:
......
......@@ -8,7 +8,6 @@ import platform
import subprocess
import sys
from dataclasses import dataclass
from importlib.util import find_spec
from typing import TYPE_CHECKING
import regex as re
......@@ -139,16 +138,15 @@ class CpuPlatform(Platform):
) -> str:
from vllm.attention.backends.registry import AttentionBackendEnum
if selected_backend and selected_backend != AttentionBackendEnum.TORCH_SDPA:
if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN:
logger.info("Cannot use %s backend on CPU.", selected_backend)
if use_mla:
raise NotImplementedError("MLA is not supported on CPU.")
if use_sparse:
raise NotImplementedError("Sparse Attention is not supported on CPU.")
logger.info("Using Torch SDPA backend.")
if not use_v1:
raise ValueError("CPU backend only supports V1.")
return AttentionBackendEnum.TORCH_SDPA.get_path()
return AttentionBackendEnum.CPU_ATTN.get_path()
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
......@@ -186,15 +184,13 @@ class CpuPlatform(Platform):
cache_config = vllm_config.cache_config
ipex_available = find_spec("intel_extension_for_pytorch") is not None
if cache_config.block_size is None:
cache_config.block_size = 128
if cache_config and cache_config.block_size is None:
cache_config.block_size = 128 if ipex_available else 16
if not ipex_available and cache_config.block_size != 16:
raise RuntimeError(
f"--block-size={cache_config.block_size} requires"
" intel_extension_for_pytorch"
if cache_config.block_size % 32 != 0:
logger.warning(
"CPU backend prefers block_size is multiples of 32, "
"otherwise the performance is not optimized."
)
scheduler_config = vllm_config.scheduler_config
......@@ -207,22 +203,11 @@ class CpuPlatform(Platform):
"backend is not compatible with FP8 KV cache."
)
if cache_config.cache_dtype == "fp8_e4m3":
cache_config.cache_dtype = "fp8_e5m2"
logger.warning(
"CPU backend doesn't support fp8_e4m3 KV cache type, cast to fp8_e5m2."
)
if (
cache_config.cache_dtype != "auto"
and model_config is not None
and model_config.dtype == torch.half
):
if cache_config.cache_dtype != "auto":
logger.warning(
"FP8 KV cache on the CPU backend only does not"
" support fp16 for now, cast to bf16."
"CPU backend doesn't support KV cache quantization fallback to auto."
)
model_config.dtype = torch.bfloat16
cache_config.cache_dtype = "auto"
cache_config.cpu_kvcache_space_bytes = CpuPlatform.get_device_total_memory()
......
......@@ -57,7 +57,6 @@ STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
# Possible string values of STR_BACKEND_ENV_VAR
# register, corresponding to possible backends
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID"
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar, Optional
from typing import ClassVar
import numpy as np
import torch
from torch.nn.functional import scaled_dot_product_attention
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
AttentionLayer,
AttentionMetadata,
AttentionType,
is_quantized_kv_cache,
)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
......@@ -24,44 +23,38 @@ from vllm.v1.attention.backends.utils import (
)
from vllm.v1.kv_cache_interface import AttentionSpec
try:
import intel_extension_for_pytorch.llm.modules as ipex_modules
_use_ipex = True
# AttributeError is to handle a bug in ipex
# https://github.com/intel/intel-extension-for-pytorch/pull/813
except (ImportError, AttributeError):
_use_ipex = False
from vllm import _custom_ops as ops
logger = init_logger(__name__)
_CPU_ARCH_PREFER_MIXED_BATCH = (CpuArchEnum.X86,)
class TorchSDPABackend(AttentionBackend):
accept_output_buffer: bool = False
class CPUAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [
torch.float16,
torch.bfloat16,
torch.float32,
]
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16, torch.float32]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
attn_impl = _get_paged_attn_impl()
return attn_impl.get_supported_head_sizes()
return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
def get_name() -> str:
return "TORCH_SDPA"
return "CPU_ATTN"
@staticmethod
def get_impl_cls() -> type["TorchSDPABackendImpl"]:
return TorchSDPABackendImpl
def get_impl_cls() -> type["CPUAttentionBackendImpl"]:
return CPUAttentionBackendImpl
@staticmethod
def get_builder_cls() -> type["TorchSDPAMetadataBuilderV1"]:
return TorchSDPAMetadataBuilderV1
def get_builder_cls() -> type["CPUAttentionMetadataBuilder"]:
return CPUAttentionMetadataBuilder
@staticmethod
def get_kv_cache_shape(
......@@ -71,9 +64,7 @@ class TorchSDPABackend(AttentionBackend):
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
return _get_paged_attn_impl().get_kv_cache_shape(
num_blocks, block_size, num_kv_heads, head_size
)
return 2, num_blocks, num_kv_heads, block_size, head_size
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
......@@ -81,264 +72,26 @@ class TorchSDPABackend(AttentionBackend):
@dataclass
class TorchSDPAMetadata(AttentionMetadata):
"""Attention metadata for prefill and decode batched together."""
# Total number of prefill requests.
num_prefills: int
# Number of prefill tokens.
num_prefill_tokens: int
# Number of decode tokens. Note that it is equivalent to the number of
# decode requests.
num_decode_tokens: int
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
class CPUAttentionMetadata:
isa: str
num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
"""Metadata for PagedAttention."""
# (batch_size,). The length of sequences (entire tokens seen so far) per
# sequence.
decode_seq_lens_tensor: torch.Tensor | None
# Maximum sequence length in the batch. 0 if it is prefill-only batch.
decode_max_seq_len: int
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
decode_block_tables: torch.Tensor | None
"""Metadata for TorchSDPABackend.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
chunked_prefill: bool
seq_lens: list[int] | None = None # For non-chunked prefill
# For chunked prefill only
max_query_len: int | None = None
prefill_max_seq_len: int | None = None
prefill_query_start_loc: torch.Tensor | None = None
prefill_seq_start_loc: torch.Tensor | None = None
prefill_block_tables: torch.Tensor | None = None
# For V1 logits index only
query_start_loc: torch.Tensor | None = None
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
encoder_seq_lens: list[int] | None = None
encoder_seq_lens_tensor: torch.Tensor | None = None
# Maximum sequence length among encoder sequences
max_encoder_seq_len: int | None = None
# Number of tokens input to encoder
num_encoder_tokens: int | None = None
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping: torch.Tensor | None = None
cross_block_tables: torch.Tensor | None = None
def __post_init__(self):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# when alibi slopes is used. It is because of the limitation
# from xformer API.
# will not appear in the __repr__ and __init__
self.attn_bias: list[torch.Tensor] | None = None
self.encoder_attn_bias: list[torch.Tensor] | None = None
self.cross_attn_bias: list[torch.Tensor] | None = None
@property
def is_all_encoder_attn_metadata_set(self):
"""
All attention metadata required for encoder attention is set.
"""
return (
(self.encoder_seq_lens is not None)
and (self.encoder_seq_lens_tensor is not None)
and (self.max_encoder_seq_len is not None)
)
@property
def is_all_cross_attn_metadata_set(self):
"""
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
"""
return (
self.is_all_encoder_attn_metadata_set
and (self.cross_slot_mapping is not None)
and (self.cross_block_tables is not None)
)
@property
def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]:
if self.num_prefill_tokens == 0:
return None
return self
@property
def decode_metadata(self) -> Optional["TorchSDPAMetadata"]:
if self.num_decode_tokens == 0:
return None
return self
def get_seq_lens(
self,
attn_type: str,
):
"""
Extract appropriate sequence lengths from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence lengths tensor for query
* Appropriate sequence lengths tensor for key & value
"""
if (
attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY
):
seq_lens_q = self.seq_lens
seq_lens_kv = self.seq_lens
elif attn_type == AttentionType.ENCODER:
seq_lens_q = self.encoder_seq_lens
seq_lens_kv = self.encoder_seq_lens
elif attn_type == AttentionType.ENCODER_DECODER:
seq_lens_q = self.seq_lens
seq_lens_kv = self.encoder_seq_lens
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
return seq_lens_q, seq_lens_kv
def get_attn_bias(
self,
attn_type: str,
) -> list[torch.Tensor] | None:
"""
Extract appropriate attention bias from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate attention bias value given the attention type
"""
if (
attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY
):
return self.attn_bias
elif attn_type == AttentionType.ENCODER:
return self.encoder_attn_bias
elif attn_type == AttentionType.ENCODER_DECODER:
return self.cross_attn_bias
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
def set_attn_bias(
self,
attn_bias: list[torch.Tensor],
attn_type: str,
) -> None:
"""
Update appropriate attention bias field of attention metadata,
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_bias: The desired attention bias value
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
"""
if (
attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY
):
self.attn_bias = attn_bias
elif attn_type == AttentionType.ENCODER:
self.encoder_attn_bias = attn_bias
elif attn_type == AttentionType.ENCODER_DECODER:
self.cross_attn_bias = attn_bias
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
def get_seq_len_block_table_args(
self,
attn_type: str,
) -> tuple:
"""
The particular choice of sequence-length- and block-table-related
attributes which should be extracted from attn_metadata is dependent
on the type of attention operation.
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths &
cross-attn block-tables fields
Encoder attn -> select encoder sequence lengths fields & no block tables
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* is_prompt: True if prefill, False otherwise
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence-lengths tensor
* Appropriate max sequence-length scalar
* Appropriate block tables (or None)
"""
if (
attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY
):
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
return (
self.decode_seq_lens_tensor,
self.decode_max_seq_len,
self.decode_block_tables,
)
elif attn_type == AttentionType.ENCODER_DECODER:
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
return (
self.encoder_seq_lens_tensor,
self.max_encoder_seq_len,
self.cross_block_tables,
)
elif attn_type == AttentionType.ENCODER:
# No block tables associated with encoder attention
return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len, None)
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
scheduler_metadata: torch.Tensor | None
causal: bool = True
# can be removed after deprecate sdpa
use_sdpa_prefill: bool = False
num_decode_tokens: int = 0
sdpa_attn_masks: list[torch.Tensor | None] | None = None
sdpa_start_loc: torch.Tensor | None = None
class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
reorder_batch_threshold: int = 1
class CPUAttentionMetadataBuilder(AttentionMetadataBuilder[CPUAttentionMetadata]):
def __init__(
self,
kv_cache_spec: AttentionSpec,
......@@ -348,80 +101,104 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
) -> None:
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.scheduler_config = vllm_config.scheduler_config
self._init_reorder_batch_threshold(1, False)
self.use_sdpa_prefill = False
reorder_batch_threshold = None
if current_platform.get_cpu_architecture() not in _CPU_ARCH_PREFER_MIXED_BATCH:
# in this case, decode seqs are reordered to the front of prefill seqs
# to split decode and prefill. Then use SDPA for prefill and
# cpu_attention_with_kv_cache for decode
reorder_batch_threshold = 1
self.use_sdpa_prefill = True
self.seq_start_loc_cpu = torch.zeros(
vllm_config.scheduler_config.max_num_seqs + 1,
dtype=torch.int32,
device="cpu",
self._init_reorder_batch_threshold(reorder_batch_threshold, False)
self.kv_cache_spec = kv_cache_spec
self.vllm_config = vllm_config
parallel_config = vllm_config.parallel_config
self.num_kv_heads = vllm_config.model_config.get_num_kv_heads(parallel_config)
self.num_heads = vllm_config.model_config.get_num_attention_heads(
parallel_config
)
self.seq_start_loc_np = self.seq_start_loc_cpu.numpy()
self.head_dim = kv_cache_spec.head_size
self.dtype = vllm_config.model_config.dtype
self.window_size = getattr(kv_cache_spec, "sliding_window", -1)
if self.window_size is None:
self.window_size = -1
self.block_size = vllm_config.cache_config.block_size
self.isa = _get_attn_isa(self.dtype, self.block_size)
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> TorchSDPAMetadata:
) -> CPUAttentionMetadata:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
seq_lens_np = seq_lens_cpu.numpy()
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
query_start_loc_np = query_start_loc_cpu.numpy()
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold,
require_uniform=True,
)
)
max_prefill_seq_len = (
seq_lens_np[num_decodes:num_reqs].max().item() if num_prefills > 0 else 0
)
max_decode_seq_len = (
seq_lens_np[:num_decodes].max().item() if num_prefills < num_reqs else 0
)
self.seq_start_loc_np[0] = 0
np.cumsum(seq_lens_np, out=self.seq_start_loc_np[1 : num_reqs + 1])
slot_mapping = common_attn_metadata.slot_mapping.long()
max_seq_len = common_attn_metadata.max_seq_len
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
block_table_tensor = common_attn_metadata.block_table_tensor
query_start_loc_np = query_start_loc_cpu.numpy()
query_start_loc_np[num_decodes : num_reqs + 1] -= num_decode_tokens
slot_mapping = common_attn_metadata.slot_mapping
causal = common_attn_metadata.causal
sdpa_start_loc = query_start_loc
num_decode_tokens = 0
if self.use_sdpa_prefill and causal:
# Decoder, need reorder and truncate
assert self.reorder_batch_threshold
(num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = (
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold,
require_uniform=True,
)
)
num_reqs = num_decodes
sdpa_start_loc = sdpa_start_loc[num_decodes:] - num_decode_tokens
seq_lens = seq_lens[:num_decodes]
query_start_loc = query_start_loc[: num_decodes + 1]
block_table_tensor = block_table_tensor[:num_decodes]
sheduler_metadata = None
if causal:
# for decode batch, use the custom kernel
sheduler_metadata = ops.cpu_attn_get_scheduler_metadata(
num_reqs=num_reqs,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
head_dim=self.head_dim,
seq_lens=seq_lens,
dtype=self.dtype,
query_start_loc=query_start_loc,
causal=causal,
sliding_window_size=self.window_size,
isa=self.isa,
enable_kv_split=True,
)
attn_metadata = TorchSDPAMetadata(
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
slot_mapping=slot_mapping,
# to ensure inference when chunked_prefill is disabled
seq_lens=seq_lens_cpu.tolist()[num_decodes:], # prefill
decode_seq_lens_tensor=seq_lens_cpu[:num_decodes], # decode
decode_max_seq_len=max_decode_seq_len, # decode
decode_block_tables=block_table_tensor[:num_decodes], # decode
chunked_prefill=self.scheduler_config.chunked_prefill_enabled,
attn_metadata = CPUAttentionMetadata(
isa=self.isa,
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
prefill_max_seq_len=max_prefill_seq_len,
prefill_query_start_loc=query_start_loc_cpu[
num_decodes : num_reqs + 1
], # prefill
prefill_seq_start_loc=self.seq_start_loc_cpu[
num_decodes : num_reqs + 1
], # prefill
prefill_block_tables=block_table_tensor[num_decodes:num_reqs], # prefill
query_start_loc=query_start_loc_cpu[: num_reqs + 1], # for logits index
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table_tensor,
slot_mapping=slot_mapping,
scheduler_metadata=sheduler_metadata,
causal=causal,
use_sdpa_prefill=self.use_sdpa_prefill,
num_decode_tokens=num_decode_tokens,
sdpa_start_loc=sdpa_start_loc,
)
return attn_metadata
class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
class CPUAttentionBackendImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
......@@ -434,37 +211,48 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
logits_soft_cap: float | None = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: str | None = None,
sinks: torch.Tensor | None = None,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.")
if logits_soft_cap is not None:
logger.warning_once(
"Torch SPDA does not support logits soft cap. "
"Outputs may be slightly off."
)
self.paged_attn_impl = _get_paged_attn_impl()
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
if logits_soft_cap is not None and attn_type in (
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
):
logger.warning_once(
"CPU_ATTN does not support logits softcap for"
" ENCODER and ENCODER_ONLY, outputs may be slightly off"
)
if logits_soft_cap is None:
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.sliding_window = sliding_window
if sliding_window is None:
self.sliding_window = (-1, -1)
elif attn_type == AttentionType.ENCODER_ONLY:
self.sliding_window = (sliding_window - 1, sliding_window - 1)
else:
self.sliding_window = (sliding_window - 1, 0)
self.kv_cache_dtype = kv_cache_dtype
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.need_mask = (
self.alibi_slopes is not None or self.sliding_window is not None
)
if is_quantized_kv_cache(kv_cache_dtype) and not _use_ipex:
raise NotImplementedError(
"Torch SDPA backend FP8 KV cache requires "
"intel_extension_for_pytorch support."
)
if is_quantized_kv_cache(kv_cache_dtype):
raise NotImplementedError("FP8 KV cache is unsupported in CPU_ATTN")
self.attn_type = attn_type
self.sinks = sinks
if self.sinks is not None:
assert self.sinks.shape[0] == num_heads, (
"Sinks must have the same number of heads as the number of "
"heads in the layer"
)
def forward(
self,
layer: AttentionLayer,
......@@ -472,196 +260,130 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: TorchSDPAMetadata, # type: ignore
attn_metadata: CPUAttentionMetadata | None,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.
"""Forward pass for CPU attention backend.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape =
[2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
[2, num_blocks, num_kv_heads, block_size, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for TorchSDPABackendImpl"
" for CPUAttentionBackendImpl"
)
# For warming-up
if attn_metadata is None:
return query
attn_type = self.attn_type
if attn_type == AttentionType.ENCODER and (
not attn_metadata.is_all_encoder_attn_metadata_set
):
raise AttributeError(
"Encoder attention requires setting encoder metadata attributes."
)
elif attn_type == AttentionType.ENCODER_DECODER and (
not attn_metadata.is_all_cross_attn_metadata_set
):
raise AttributeError(
"Encoder/decoder cross-attention "
"requires setting cross-attention "
"metadata attributes."
return output
num_actual_tokens = attn_metadata.num_actual_tokens
# Handle encoder attention differently - no KV cache needed
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
# For encoder attention,
return self._run_sdpa_forward(
query[:num_actual_tokens],
key[:num_actual_tokens],
value[:num_actual_tokens],
output[:num_actual_tokens],
attn_metadata,
self.attn_type,
)
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
if key is not None:
assert value is not None
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
else:
assert value is None
if attn_type != AttentionType.ENCODER and kv_cache.numel() > 0:
# KV-cache during decoder-self- or
# encoder-decoder-cross-attention, but not
# during encoder attention.
#
# Even if there are no new key/value pairs to cache,
# we still need to break out key_cache and value_cache
# i.e. for later use by paged attention
key_cache, value_cache = self.paged_attn_impl.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size
)
# For decoder and cross-attention, use KV cache, size are
# [num_blocks, num_kv_heads, block_size, head_size]
key_cache, value_cache = kv_cache.unbind(0)
if (key is not None) and (value is not None):
if attn_type == AttentionType.ENCODER_DECODER:
# Update cross-attention KV cache (prefill-only)
# During cross-attention decode, key & value will be None,
# preventing this IF-statement branch from running
updated_slot_mapping = attn_metadata.cross_slot_mapping
else:
# Update self-attention KV cache (prefill/decode)
updated_slot_mapping = attn_metadata.slot_mapping
self.paged_attn_impl.write_to_paged_cache(
key,
value,
key_cache,
value_cache,
updated_slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
ops.cpu_attn_reshape_and_cache(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
attn_metadata.isa,
)
if attn_type != AttentionType.ENCODER:
# Decoder self-attention supports chunked prefill.
# Encoder/decoder cross-attention requires no chunked
# prefill (100% prefill or 100% decode tokens, no mix)
num_prefill_tokens = attn_metadata.num_prefill_tokens
if attn_metadata.use_sdpa_prefill:
assert self.sinks is None, "Attention sink is unsupported in SDPA prefill"
num_decode_tokens = attn_metadata.num_decode_tokens
else:
# Encoder attention - chunked prefill is not applicable;
# derive token-count from query shape & and treat them
# as 100% prefill tokens
assert attn_metadata.num_encoder_tokens is not None
num_prefill_tokens = attn_metadata.num_encoder_tokens
num_decode_tokens = 0
if attn_type == AttentionType.DECODER:
# Only enforce this shape-constraint for decoder
# self-attention
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
output = torch.empty_like(query)
if prefill_meta := attn_metadata.prefill_metadata:
if not prefill_meta.prefill_metadata.chunked_prefill: # type: ignore
assert attn_metadata.seq_lens is not None
self._run_sdpa_forward(
output, query, key, value, prefill_meta, attn_type=attn_type
)
else:
# prefix-enabled attention
assert not self.need_mask
import intel_extension_for_pytorch.llm.modules as ipex_modules
output = torch.empty_like(query)
ipex_modules.PagedAttention.flash_attn_varlen_func(
output[prefill_meta.num_decode_tokens :, :, :],
query[prefill_meta.num_decode_tokens :, :, :],
key_cache,
value_cache,
prefill_meta.prefill_query_start_loc,
prefill_meta.prefill_seq_start_loc,
prefill_meta.max_query_len,
prefill_meta.prefill_max_seq_len,
self.scale,
True,
prefill_meta.prefill_block_tables,
self.alibi_slopes,
)
if decode_meta := attn_metadata.decode_metadata:
assert attn_type != AttentionType.ENCODER_ONLY, (
"Encoder-only models should not have decode metadata."
self._run_sdpa_forward(
query[num_decode_tokens:num_actual_tokens],
key[num_decode_tokens:num_actual_tokens],
value[num_decode_tokens:num_actual_tokens],
output[num_decode_tokens:num_actual_tokens],
attn_metadata,
self.attn_type,
)
# Decoding run.
(
seq_lens_arg,
max_seq_len_arg,
block_tables_arg,
) = decode_meta.get_seq_len_block_table_args(attn_type)
self.paged_attn_impl.forward_decode(
output[: attn_metadata.num_decode_tokens, :, :],
query[: attn_metadata.num_decode_tokens, :, :],
key_cache,
value_cache,
block_tables_arg,
seq_lens_arg,
max_seq_len_arg,
self.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
layer._k_scale,
layer._v_scale,
num_actual_tokens = num_decode_tokens
if num_actual_tokens > 0:
ops.cpu_attention_with_kv_cache(
query=query[:num_actual_tokens],
key_cache=key_cache,
value_cache=value_cache,
output=output[:num_actual_tokens], # type: ignore
query_start_loc=attn_metadata.query_start_loc,
seq_lens=attn_metadata.seq_lens,
scale=self.scale,
causal=attn_metadata.causal,
alibi_slopes=self.alibi_slopes, # type: ignore
sliding_window=self.sliding_window,
block_table=attn_metadata.block_table,
softcap=self.logits_soft_cap,
scheduler_metadata=attn_metadata.scheduler_metadata,
s_aux=self.sinks,
)
# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)
return output
def _run_sdpa_forward(
self,
output: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: TorchSDPAMetadata,
attn_type: str = AttentionType.DECODER,
) -> None:
attn_masks = attn_metadata.get_attn_bias(attn_type)
output: torch.Tensor,
attn_metadata: CPUAttentionMetadata,
attn_type: str,
) -> torch.Tensor:
attn_masks = attn_metadata.sdpa_attn_masks
if attn_masks is None:
if self.alibi_slopes is not None:
attn_masks = _make_alibi_bias(
self.alibi_slopes,
query.dtype,
attn_metadata.seq_lens, # type: ignore
attn_metadata.sdpa_start_loc,
)
elif self.sliding_window is not None:
elif self.sliding_window[0] != -1 or self.sliding_window[1] != -1:
assert attn_metadata.seq_lens is not None
attn_masks = _make_sliding_window_bias(
attn_metadata.seq_lens, self.sliding_window, query.dtype
attn_metadata.sdpa_start_loc,
self.sliding_window[0],
self.sliding_window[1],
query.dtype,
)
else:
seq_lens, _ = attn_metadata.get_seq_lens(attn_type)
attn_masks = [None] * len(seq_lens)
attn_metadata.set_attn_bias(attn_masks, attn_type)
attn_masks = [None] * (attn_metadata.sdpa_start_loc.size(0) - 1) # type: ignore
attn_metadata.sdpa_attn_masks = attn_masks
query = query.movedim(0, query.dim() - 2)
key = key.movedim(0, key.dim() - 2)
......@@ -673,21 +395,16 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
causal_attn = attn_type == AttentionType.DECODER
seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type)
# Incoming Q and KV contain decoded tokens as well, hence start at an offset
# equal to num_decode_tokens since decode requests appear first
start_q, start_kv = (
attn_metadata.num_decode_tokens,
attn_metadata.num_decode_tokens,
)
for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv, attn_masks):
end_q = start_q + seq_len_q
end_kv = start_kv + seq_len_kv
sdpa_start_loc = attn_metadata.sdpa_start_loc.numpy() # type: ignore
for i in range(len(attn_masks)):
mask = attn_masks[i]
start_q = sdpa_start_loc[i]
end_q = sdpa_start_loc[i + 1]
sub_out = (
scaled_dot_product_attention(
torch.nn.functional.scaled_dot_product_attention(
query[None, :, start_q:end_q, :],
key[None, :, start_kv:end_kv, :],
value[None, :, start_kv:end_kv, :],
key[None, :, start_q:end_q, :],
value[None, :, start_q:end_q, :],
attn_mask=mask,
dropout_p=0.0,
is_causal=causal_attn and mask is None,
......@@ -697,17 +414,20 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
.movedim(query.dim() - 2, 0)
)
output[start_q:end_q, :, :] = sub_out
start_q, start_kv = end_q, end_kv
return output
def _make_alibi_bias(
alibi_slopes: torch.Tensor,
dtype: torch.dtype,
seq_lens: list[int],
sdpa_start_loc: torch.Tensor,
) -> list[torch.Tensor]:
attn_biases: list[torch.Tensor] = []
for seq_len in seq_lens:
bias = torch.arange(seq_len, dtype=dtype)
seq_num = sdpa_start_loc.size(0) - 1
sdpa_start_loc = sdpa_start_loc.numpy() # type: ignore
for i in range(seq_num):
seq_len = sdpa_start_loc[i + 1] - sdpa_start_loc[i]
bias = torch.arange(seq_len, dtype=dtype) # type: ignore
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
......@@ -719,7 +439,7 @@ def _make_alibi_bias(
bias = bias[None, :].repeat((num_heads, 1, 1))
bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0)
inf_mask = (
torch.empty((1, seq_len, seq_len), dtype=bias.dtype)
torch.empty((1, seq_len, seq_len), dtype=bias.dtype) # type: ignore
.fill_(-torch.inf)
.triu_(diagonal=1)
)
......@@ -729,210 +449,37 @@ def _make_alibi_bias(
def _make_sliding_window_bias(
seq_lens: list[int],
window_size: int | None,
sdpa_start_loc: torch.Tensor,
left_window_size: int,
right_window_size: int,
dtype: torch.dtype,
) -> list[torch.Tensor]:
attn_biases: list[torch.Tensor] = []
for seq_len in seq_lens:
tensor = torch.full(
(1, seq_len, seq_len),
dtype=dtype,
seq_num = sdpa_start_loc.size(0) - 1
sdpa_start_loc = sdpa_start_loc.numpy() # type: ignore
for i in range(seq_num):
seq_len = sdpa_start_loc[i + 1] - sdpa_start_loc[i]
mask = torch.full( # type: ignore
(1, seq_len, seq_len), # type: ignore
fill_value=1,
dtype=dtype,
)
shift = 0
mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore
if window_size is not None:
mask = torch.triu(mask, diagonal=shift - window_size + 1)
if right_window_size != -1:
mask = torch.tril(mask, diagonal=right_window_size)
if left_window_size != -1:
mask = torch.triu(mask, diagonal=-left_window_size)
mask = torch.log(mask)
attn_biases.append(mask.to(dtype))
attn_biases.append(mask)
return attn_biases
class _PagedAttention:
@staticmethod
def get_supported_head_sizes() -> list[int]:
return [32, 64, 80, 96, 112, 128, 192, 256]
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
*args,
) -> tuple[int, ...]:
return 2, num_blocks, block_size * num_kv_heads * head_size
@staticmethod
def split_kv_cache(
kv_cache: torch.Tensor,
num_kv_heads: int,
head_size: int,
*args,
) -> tuple[torch.Tensor, torch.Tensor]:
x = 16 // kv_cache.element_size()
num_blocks = kv_cache.shape[1]
key_cache = kv_cache[0]
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, -1, x)
value_cache = kv_cache[1]
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
return key_cache, value_cache
@staticmethod
def write_to_paged_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
*args,
) -> None:
ops.reshape_and_cache(
key,
value,
key_cache,
value_cache,
slot_mapping.flatten(),
kv_cache_dtype,
k_scale,
v_scale,
)
@staticmethod
def forward_decode(
output: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
max_context_len: int,
kv_cache_dtype: str,
num_kv_heads: int,
scale: float,
alibi_slopes: torch.Tensor | None,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
*args,
) -> None:
tp_rank: int = 0
blocksparse_local_blocks: int = 0
blocksparse_vert_stride: int = 0
blocksparse_block_size: int = 64
blocksparse_head_sliding_step: int = 0
block_size = value_cache.shape[3]
ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
)
class _IPEXPagedAttention(_PagedAttention):
@staticmethod
def get_supported_head_sizes() -> list[int]:
return []
@staticmethod
def split_kv_cache(
kv_cache: torch.Tensor,
num_kv_heads: int,
head_size: int,
*args,
) -> tuple[torch.Tensor, torch.Tensor]:
num_blocks = kv_cache.shape[1]
key_cache = kv_cache[0]
key_cache = key_cache.view(num_blocks, num_kv_heads, -1, head_size)
value_cache = kv_cache[1]
value_cache = value_cache.view(num_blocks, num_kv_heads, -1, head_size)
return key_cache, value_cache
@staticmethod
def write_to_paged_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
*args,
) -> None:
ipex_modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slot_mapping.flatten().int()
)
@staticmethod
def forward_decode(
output: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
max_context_len: int,
kv_cache_dtype: str,
num_kv_heads: int,
scale: float,
alibi_slopes: torch.Tensor | None,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
*args,
) -> None:
block_size = value_cache.shape[2]
head_mapping = (
torch.arange(
0,
num_kv_heads,
device="cpu",
dtype=torch.int32,
)
.view(num_kv_heads, 1)
.repeat_interleave(query.size(1) // num_kv_heads)
.flatten()
)
ipex_modules.PagedAttention.single_query_cached_kv_attention(
output,
query.contiguous(),
key_cache,
value_cache,
head_mapping,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
)
def _get_paged_attn_impl():
if _use_ipex:
return _IPEXPagedAttention
def _get_attn_isa(dtype: torch.dtype, block_size: int) -> str:
supports_amx = torch._C._cpu._is_amx_tile_supported()
if supports_amx and dtype in (torch.bfloat16,) and block_size % 32 == 0:
return "amx"
elif block_size % 32 == 0:
return "vec"
else:
return _PagedAttention
return "vec16"
......@@ -265,7 +265,7 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
def _init_reorder_batch_threshold(
self,
reorder_batch_threshold: int = 1,
reorder_batch_threshold: int | None = 1,
supports_spec_as_decode: bool = False,
supports_dcp_with_varlen: bool = False,
) -> None:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any
from typing import Any
import torch
import torch.nn as nn
......@@ -12,9 +12,6 @@ from vllm.model_executor.model_loader import get_model
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
logger = init_logger(__name__)
......@@ -31,15 +28,6 @@ class CPUModelRunner(GPUModelRunner):
self._postprocess_tensors()
# Note: Remove the override after new attention backend finished
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
if len(self.kv_cache_config.kv_cache_groups) > 1:
raise ValueError(
"Multiple KVCacheGroups is not"
"currently supported with CPU model runner."
)
super()._may_reorder_batch(scheduler_output)
def _postprocess_tensors(self) -> None:
# Note: replace device tensors with cpu tensors
def replace_tensor(obj: Any, cpu_attr_name: str, device_attr_name) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment