Unverified Commit 7f829be7 authored by Li, Jiang's avatar Li, Jiang Committed by GitHub
Browse files

[CPU] Refactor CPU attention backend (#27954)


Signed-off-by: default avatarjiang1.li <jiang1.li@intel.com>
parent e1710393
...@@ -35,7 +35,7 @@ DEVICE_MLA_BACKENDS = { ...@@ -35,7 +35,7 @@ DEVICE_MLA_BACKENDS = {
DEVICE_REGULAR_ATTN_BACKENDS = { DEVICE_REGULAR_ATTN_BACKENDS = {
"cuda": ["XFORMERS", "FLASHINFER", "FLASH_ATTN"], "cuda": ["XFORMERS", "FLASHINFER", "FLASH_ATTN"],
"hip": ["ROCM_ATTN"], "hip": ["ROCM_ATTN"],
"cpu": ["TORCH_SDPA"], "cpu": ["CPU_ATTN"],
} }
DEVICE_MLA_BLOCK_SIZES = { DEVICE_MLA_BLOCK_SIZES = {
...@@ -86,7 +86,7 @@ def test_env( ...@@ -86,7 +86,7 @@ def test_env(
if device == "cpu": if device == "cpu":
with patch("vllm.platforms.current_platform", CpuPlatform()): with patch("vllm.platforms.current_platform", CpuPlatform()):
backend = get_attn_backend(16, torch.float16, None, block_size) backend = get_attn_backend(16, torch.float16, None, block_size)
assert backend.get_name() == "TORCH_SDPA" assert backend.get_name() == "CPU_ATTN"
elif device == "hip": elif device == "hip":
with patch("vllm.platforms.current_platform", RocmPlatform()): with patch("vllm.platforms.current_platform", RocmPlatform()):
...@@ -224,7 +224,7 @@ def test_fp32_fallback(device: str): ...@@ -224,7 +224,7 @@ def test_fp32_fallback(device: str):
if device == "cpu": if device == "cpu":
with patch("vllm.platforms.current_platform", CpuPlatform()): with patch("vllm.platforms.current_platform", CpuPlatform()):
backend = get_attn_backend(16, torch.float32, None, 16) backend = get_attn_backend(16, torch.float32, None, 16)
assert backend.get_name() == "TORCH_SDPA" assert backend.get_name() == "CPU_ATTN"
elif device == "cuda": elif device == "cuda":
with patch("vllm.platforms.current_platform", CudaPlatform()): with patch("vllm.platforms.current_platform", CudaPlatform()):
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import math
import pytest
import torch
from vllm.platforms import current_platform
if not current_platform.is_cpu():
pytest.skip("skipping CPU-only tests", allow_module_level=True)
from vllm._custom_ops import (
cpu_attention_with_kv_cache,
cpu_attn_get_scheduler_metadata,
cpu_attn_reshape_and_cache,
)
NUM_HEADS = [
(4, 4),
(8, 2),
(9, 3),
]
HEAD_SIZES = [96, 128]
QTYPES = [torch.bfloat16, torch.half, torch.float32]
SLIDING_WINDOWS = [None, 256]
NUM_BLOCKS = [
1024,
]
SEQ_LENS = [ # (q_len, kv_len)
[(1, 213), (1, 1), (1, 312), (1, 7), (1, 7812)], # decode batch
[(2345, 2345), (5, 5), (3, 16), (134, 5131)], # prefill batch
[(992, 2456), (1, 1234), (98, 1145), (1, 4162), (2345, 2345)], # mixed batch
]
# rand number generation takes too much time, cache rand tensors
@functools.lru_cache(maxsize=128, typed=False)
def tensor_cache(
elem_num: int,
dtype: torch.dtype,
) -> torch.Tensor:
tensor = torch.randn(elem_num, dtype=dtype)
return tensor
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
closest_power_of_2 = 2 ** math.floor(math.log2(total_num_heads))
base = torch.tensor(
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))),
dtype=torch.float32,
)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != total_num_heads:
extra_base = torch.tensor(
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),
dtype=torch.float32,
)
num_remaining_heads = min(
closest_power_of_2, total_num_heads - closest_power_of_2
)
extra_powers = torch.arange(
start=1, end=1 + 2 * num_remaining_heads, step=2, dtype=torch.int32
)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes.float()
def ref_paged_attn(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
query_lens: list[int],
kv_lens: list[int],
block_tables: torch.Tensor,
scale: float,
sliding_window: int | None = None,
soft_cap: float | None = None,
alibi_slopes: torch.Tensor | None = None,
s_aux: torch.Tensor | None = None,
) -> torch.Tensor:
num_seqs = len(query_lens)
block_tables = block_tables.cpu().numpy()
_, block_size, num_kv_heads, head_size = key_cache.shape
dtype = query.dtype
outputs: list[torch.Tensor] = []
start_idx = 0
if alibi_slopes is not None:
alibi_slopes = alibi_slopes[:, None, None]
if s_aux is not None:
s_aux = s_aux.float()
s_aux = s_aux[:, None, None]
for i in range(num_seqs):
query_len = query_lens[i]
kv_len = kv_lens[i]
q = query[start_idx : start_idx + query_len].float()
q *= scale
num_kv_blocks = (kv_len + block_size - 1) // block_size
block_indices = block_tables[i, :num_kv_blocks]
k = key_cache[block_indices].view(-1, num_kv_heads, head_size)
k = k[:kv_len].float()
v = value_cache[block_indices].view(-1, num_kv_heads, head_size)
v = v[:kv_len].float()
if q.shape[1] != k.shape[1]:
k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1)
v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1)
attn = torch.einsum("qhd,khd->hqk", q, k).float()
empty_mask = torch.ones(query_len, kv_len)
mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
if sliding_window is not None:
sliding_window_mask = (
torch.triu(
empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1
)
.bool()
.logical_not()
)
mask |= sliding_window_mask
if soft_cap is not None:
attn = soft_cap * torch.tanh(attn / soft_cap)
if alibi_slopes is not None:
q_start_pos = kv_len - query_len
q_pos = q_start_pos + torch.arange(0, query_len)[None, :, None]
kv_pos = torch.arange(0, kv_len)[None, None, :]
dist = q_pos - kv_pos
alibi_bias = -alibi_slopes * dist
attn += alibi_bias
attn.masked_fill_(mask, float("-inf"))
if s_aux is not None:
s_aux_ext = s_aux.repeat(1, query_len, 1)
attn = torch.cat((s_aux_ext, attn), dim=-1)
attn = torch.softmax(attn, dim=-1)
if s_aux is not None:
attn = attn[:, :, 1:]
out = torch.einsum("hqk,khd->qhd", attn, v).to(dtype=dtype)
outputs.append(out)
start_idx += query_len
return torch.cat(outputs, dim=0)
@torch.inference_mode()
def varlen_with_paged_kv(
seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],
head_size: int,
sliding_window: int | None,
dtype: torch.dtype,
block_size: int,
soft_cap: float | None,
num_blocks: int,
use_alibi: bool,
use_sink: bool,
isa: str,
) -> None:
current_platform.seed_everything(0)
num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens]
kv_lens = [x[1] for x in seq_lens]
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
max_kv_len = max(kv_lens)
window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1)
scale = head_size**-0.5
token_num = sum(query_lens)
# for n heads the set of slopes is the geometric sequence that starts
# 2^(-8/n)
alibi_slopes = _get_alibi_slopes(num_query_heads) if use_alibi else None
s_aux = (
15 * torch.rand((num_query_heads,), dtype=torch.bfloat16) if use_sink else None
)
query = tensor_cache(
elem_num=token_num * num_query_heads * head_size,
dtype=dtype,
)
query = query.view(
token_num,
num_query_heads,
head_size,
)
key_value = tensor_cache(
elem_num=2 * num_blocks * num_kv_heads * block_size * head_size,
dtype=dtype,
)
key_value = key_value.view(
2,
num_blocks,
block_size,
num_kv_heads,
head_size,
)
key_cache, value_cache = key_value.unbind(0)
# KV cache for CPU attention
packed_key_cache = torch.empty(
num_blocks, num_kv_heads, block_size, head_size, dtype=dtype
)
packed_value_cache = torch.empty_like(packed_key_cache)
cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum(
dim=0, dtype=torch.int32
)
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(
0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
)
# use reshape_and_cache to pack key_cache and value_cache
slot_mapping = torch.arange(0, num_blocks * block_size, dtype=torch.int64)
cpu_attn_reshape_and_cache(
key=key_cache.view(-1, num_kv_heads, head_size),
value=value_cache.view(-1, num_kv_heads, head_size),
key_cache=packed_key_cache,
value_cache=packed_value_cache,
slot_mapping=slot_mapping,
isa=isa,
)
metadata = cpu_attn_get_scheduler_metadata(
num_reqs=num_seqs,
num_heads=num_query_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
seq_lens=kv_lens_tensor,
dtype=dtype,
query_start_loc=cu_query_lens,
causal=True,
sliding_window_size=sliding_window if sliding_window is not None else -1,
isa=isa,
enable_kv_split=False,
)
out_without_split = torch.empty_like(query)
cpu_attention_with_kv_cache(
query=query,
key_cache=packed_key_cache,
value_cache=packed_value_cache,
output=out_without_split,
query_start_loc=cu_query_lens,
seq_lens=kv_lens_tensor,
scale=scale,
causal=True,
alibi_slopes=alibi_slopes,
sliding_window=window_size,
block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0,
scheduler_metadata=metadata,
s_aux=s_aux,
)
metadata = cpu_attn_get_scheduler_metadata(
num_reqs=num_seqs,
num_heads=num_query_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
seq_lens=kv_lens_tensor,
dtype=dtype,
query_start_loc=cu_query_lens,
causal=True,
sliding_window_size=sliding_window if sliding_window is not None else -1,
isa=isa,
enable_kv_split=True,
)
out_with_split = torch.empty_like(query)
cpu_attention_with_kv_cache(
query=query,
key_cache=packed_key_cache,
value_cache=packed_value_cache,
output=out_with_split,
query_start_loc=cu_query_lens,
seq_lens=kv_lens_tensor,
scale=scale,
causal=True,
alibi_slopes=alibi_slopes,
sliding_window=window_size,
block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0,
scheduler_metadata=metadata,
s_aux=s_aux,
)
ref_output = ref_paged_attn(
query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=query_lens,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
sliding_window=sliding_window,
soft_cap=soft_cap,
alibi_slopes=alibi_slopes,
s_aux=s_aux,
)
atol, rtol = 1.5e-2, 1e-2
(
torch.testing.assert_close(out_with_split, ref_output, atol=atol, rtol=rtol),
f"{torch.max(torch.abs(out_with_split - ref_output))}",
)
(
torch.testing.assert_close(out_without_split, ref_output, atol=atol, rtol=rtol),
f"{torch.max(torch.abs(out_without_split - ref_output))}",
)
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", [96, 128])
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
@pytest.mark.parametrize("dtype", QTYPES)
@pytest.mark.parametrize("soft_cap", [None])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("use_alibi", [False])
@pytest.mark.parametrize("use_sink", [False])
@pytest.mark.parametrize("isa", ["vec"])
def test_varlen_with_paged_kv_normal_vec(
seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],
head_size: int,
sliding_window: int | None,
dtype: torch.dtype,
block_size: int,
soft_cap: float | None,
num_blocks: int,
use_alibi: bool,
use_sink: bool,
isa: str,
) -> None:
varlen_with_paged_kv(
seq_lens=seq_lens,
num_heads=num_heads,
head_size=head_size,
sliding_window=sliding_window,
dtype=dtype,
block_size=block_size,
soft_cap=soft_cap,
num_blocks=num_blocks,
use_alibi=use_alibi,
use_sink=use_sink,
isa=isa,
)
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", [96, 128])
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("soft_cap", [None])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("use_alibi", [False])
@pytest.mark.parametrize("use_sink", [False])
@pytest.mark.parametrize("isa", ["amx"])
@pytest.mark.skipif(
not torch._C._cpu._is_amx_tile_supported(), reason="no AMX support."
)
def test_varlen_with_paged_kv_normal_amx(
seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],
head_size: int,
sliding_window: int | None,
dtype: torch.dtype,
block_size: int,
soft_cap: float | None,
num_blocks: int,
use_alibi: bool,
use_sink: bool,
isa: str,
) -> None:
varlen_with_paged_kv(
seq_lens=seq_lens,
num_heads=num_heads,
head_size=head_size,
sliding_window=sliding_window,
dtype=dtype,
block_size=block_size,
soft_cap=soft_cap,
num_blocks=num_blocks,
use_alibi=use_alibi,
use_sink=use_sink,
isa=isa,
)
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", [48])
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("soft_cap", [None])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("use_alibi", [False])
@pytest.mark.parametrize("use_sink", [False])
@pytest.mark.parametrize("isa", ["vec16"])
def test_varlen_with_paged_kv_normal_vec16(
seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],
head_size: int,
sliding_window: int | None,
dtype: torch.dtype,
block_size: int,
soft_cap: float | None,
num_blocks: int,
use_alibi: bool,
use_sink: bool,
isa: str,
) -> None:
varlen_with_paged_kv(
seq_lens=seq_lens,
num_heads=num_heads,
head_size=head_size,
sliding_window=sliding_window,
dtype=dtype,
block_size=block_size,
soft_cap=soft_cap,
num_blocks=num_blocks,
use_alibi=use_alibi,
use_sink=use_sink,
isa=isa,
)
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", [96])
@pytest.mark.parametrize("block_size", [128])
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("soft_cap", [50])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("use_alibi", [False])
@pytest.mark.parametrize("use_sink", [False])
@pytest.mark.parametrize(
"isa", ["amx"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
)
def test_varlen_with_paged_kv_softcap(
seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],
head_size: int,
sliding_window: int | None,
dtype: torch.dtype,
block_size: int,
soft_cap: float | None,
num_blocks: int,
use_alibi: bool,
use_sink: bool,
isa: str,
) -> None:
varlen_with_paged_kv(
seq_lens=seq_lens,
num_heads=num_heads,
head_size=head_size,
sliding_window=sliding_window,
dtype=dtype,
block_size=block_size,
soft_cap=soft_cap,
num_blocks=num_blocks,
use_alibi=use_alibi,
use_sink=use_sink,
isa=isa,
)
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", [96])
@pytest.mark.parametrize("block_size", [128])
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("soft_cap", [None])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("use_alibi", [True])
@pytest.mark.parametrize("use_sink", [False])
@pytest.mark.parametrize(
"isa", ["amx"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
)
def test_varlen_with_paged_kv_alibi(
seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],
head_size: int,
sliding_window: int | None,
dtype: torch.dtype,
block_size: int,
soft_cap: float | None,
num_blocks: int,
use_alibi: bool,
use_sink: bool,
isa: str,
) -> None:
varlen_with_paged_kv(
seq_lens=seq_lens,
num_heads=num_heads,
head_size=head_size,
sliding_window=sliding_window,
dtype=dtype,
block_size=block_size,
soft_cap=soft_cap,
num_blocks=num_blocks,
use_alibi=use_alibi,
use_sink=use_sink,
isa=isa,
)
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", [96])
@pytest.mark.parametrize("block_size", [128])
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("soft_cap", [None])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("use_alibi", [False])
@pytest.mark.parametrize("use_sink", [True])
@pytest.mark.parametrize(
"isa", ["amx"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
)
def test_varlen_with_paged_kv_sink(
seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],
head_size: int,
sliding_window: int | None,
dtype: torch.dtype,
block_size: int,
soft_cap: float | None,
num_blocks: int,
use_alibi: bool,
use_sink: bool,
isa: str,
) -> None:
varlen_with_paged_kv(
seq_lens=seq_lens,
num_heads=num_heads,
head_size=head_size,
sliding_window=sliding_window,
dtype=dtype,
block_size=block_size,
soft_cap=soft_cap,
num_blocks=num_blocks,
use_alibi=use_alibi,
use_sink=use_sink,
isa=isa,
)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Integration tests for FlexAttention backend vs default backend"""
import pytest import pytest
import torch import torch
......
...@@ -38,7 +38,11 @@ AITER_MODEL_LIST = [ ...@@ -38,7 +38,11 @@ AITER_MODEL_LIST = [
[ [
pytest.param( pytest.param(
"bigscience/bloom-560m", # bloom - testing alibi slopes "bigscience/bloom-560m", # bloom - testing alibi slopes
marks=[pytest.mark.core_model, pytest.mark.slow_test], marks=[
pytest.mark.core_model,
pytest.mark.slow_test,
pytest.mark.cpu_model,
],
), ),
pytest.param( pytest.param(
"openai-community/gpt2", # gpt2 "openai-community/gpt2", # gpt2
...@@ -55,6 +59,10 @@ AITER_MODEL_LIST = [ ...@@ -55,6 +59,10 @@ AITER_MODEL_LIST = [
pytest.mark.slow_test, pytest.mark.slow_test,
], ],
), ),
pytest.param(
"google/gemma-2-2b-it", # test hybrid attention
marks=[pytest.mark.cpu_model],
),
pytest.param( pytest.param(
"zai-org/chatglm3-6b", # chatglm (text-only) "zai-org/chatglm3-6b", # chatglm (text-only)
), ),
...@@ -64,7 +72,6 @@ AITER_MODEL_LIST = [ ...@@ -64,7 +72,6 @@ AITER_MODEL_LIST = [
), ),
pytest.param( pytest.param(
"openbmb/MiniCPM3-4B", "openbmb/MiniCPM3-4B",
# fused_moe not supported on CPU
marks=[pytest.mark.core_model, large_gpu_mark(min_gb=32)], marks=[pytest.mark.core_model, large_gpu_mark(min_gb=32)],
), ),
pytest.param( pytest.param(
...@@ -93,11 +100,7 @@ AITER_MODEL_LIST = [ ...@@ -93,11 +100,7 @@ AITER_MODEL_LIST = [
pytest.param("bigcode/starcoder2-3b"), # starcoder2 pytest.param("bigcode/starcoder2-3b"), # starcoder2
pytest.param( pytest.param(
"TitanML/tiny-mixtral", # mixtral "TitanML/tiny-mixtral", # mixtral
marks=[pytest.mark.core_model], marks=[pytest.mark.core_model, pytest.mark.cpu_model],
),
pytest.param(
"allenai/OLMoE-1B-7B-0924-Instruct",
marks=[pytest.mark.cpu_model],
), ),
pytest.param("swiss-ai/Apertus-8B-Instruct-2509"), # apertus pytest.param("swiss-ai/Apertus-8B-Instruct-2509"), # apertus
], ],
......
...@@ -23,8 +23,7 @@ from ...utils import check_embeddings_close ...@@ -23,8 +23,7 @@ from ...utils import check_embeddings_close
), ),
pytest.param( pytest.param(
"intfloat/e5-mistral-7b-instruct", "intfloat/e5-mistral-7b-instruct",
# CPU v1 doesn't support sliding window marks=[pytest.mark.core_model, pytest.mark.cpu_model],
marks=[pytest.mark.core_model],
), ),
pytest.param( pytest.param(
"ssmits/Qwen2-7B-Instruct-embed-base", marks=[pytest.mark.cpu_model] "ssmits/Qwen2-7B-Instruct-embed-base", marks=[pytest.mark.cpu_model]
......
...@@ -243,7 +243,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -243,7 +243,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"FalconH1ForCausalLM": _HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base"), "FalconH1ForCausalLM": _HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base"),
"FlexOlmoForCausalLM": _HfExamplesInfo("allenai/Flex-reddit-2x7B-1T"), "FlexOlmoForCausalLM": _HfExamplesInfo("allenai/Flex-reddit-2x7B-1T"),
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"), "GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"),
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"), "Gemma2ForCausalLM": _HfExamplesInfo(
"google/gemma-2-9b", extras={"tiny": "google/gemma-2-2b-it"}
),
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"), "Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),
"Gemma3nForCausalLM": _HfExamplesInfo("google/gemma-3n-E2B-it"), "Gemma3nForCausalLM": _HfExamplesInfo("google/gemma-3n-E2B-it"),
"GlmForCausalLM": _HfExamplesInfo("zai-org/glm-4-9b-chat-hf"), "GlmForCausalLM": _HfExamplesInfo("zai-org/glm-4-9b-chat-hf"),
......
...@@ -2583,6 +2583,88 @@ def onednn_scaled_mm( ...@@ -2583,6 +2583,88 @@ def onednn_scaled_mm(
return output return output
def cpu_attn_get_scheduler_metadata(
num_reqs: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
seq_lens: torch.Tensor,
dtype: torch.dtype,
query_start_loc: torch.Tensor,
causal: bool,
sliding_window_size: int,
isa: str,
enable_kv_split: bool,
) -> torch.Tensor:
sheduler_metadata = torch.ops._C.get_scheduler_metadata(
num_reqs,
num_heads,
num_kv_heads,
head_dim,
seq_lens,
dtype,
query_start_loc,
causal,
sliding_window_size,
isa,
enable_kv_split,
)
return sheduler_metadata
def cpu_attn_reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
isa: str,
) -> None:
torch.ops._C.cpu_attn_reshape_and_cache(
key,
value,
key_cache,
value_cache,
slot_mapping,
isa,
)
def cpu_attention_with_kv_cache(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
output: torch.Tensor,
query_start_loc: torch.Tensor,
seq_lens: torch.Tensor,
scale: float,
causal: bool,
alibi_slopes: torch.Tensor | None,
sliding_window: tuple[int, int],
block_table: torch.Tensor,
softcap: float,
scheduler_metadata: torch.Tensor,
s_aux: torch.Tensor | None,
) -> None:
torch.ops._C.cpu_attention_with_kv_cache(
query,
key_cache,
value_cache,
output,
query_start_loc,
seq_lens,
scale,
causal,
alibi_slopes,
sliding_window[0],
sliding_window[1],
block_table,
softcap,
scheduler_metadata,
s_aux,
)
if hasattr(torch.ops._qutlass_C, "matmul_mxf4_bf16_tn"): if hasattr(torch.ops._qutlass_C, "matmul_mxf4_bf16_tn"):
@register_fake("_qutlass_C::matmul_mxf4_bf16_tn") @register_fake("_qutlass_C::matmul_mxf4_bf16_tn")
......
...@@ -49,7 +49,7 @@ class AttentionBackendEnum(enum.Enum, metaclass=_AttentionBackendEnumMeta): ...@@ -49,7 +49,7 @@ class AttentionBackendEnum(enum.Enum, metaclass=_AttentionBackendEnumMeta):
ROCM_AITER_FA = ( ROCM_AITER_FA = (
"vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend" "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
) )
TORCH_SDPA = "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend" TORCH_SDPA = "" # this tag is only used for ViT
FLASHINFER = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" FLASHINFER = "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
FLASHINFER_MLA = ( FLASHINFER_MLA = (
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend" "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
...@@ -70,6 +70,7 @@ class AttentionBackendEnum(enum.Enum, metaclass=_AttentionBackendEnumMeta): ...@@ -70,6 +70,7 @@ class AttentionBackendEnum(enum.Enum, metaclass=_AttentionBackendEnumMeta):
"vllm.v1.attention.backends.rocm_aiter_unified_attn." "vllm.v1.attention.backends.rocm_aiter_unified_attn."
"RocmAiterUnifiedAttentionBackend" "RocmAiterUnifiedAttentionBackend"
) )
CPU_ATTN = "vllm.v1.attention.backends.cpu_attn.CPUAttentionBackend"
# Placeholder for third-party/custom backends - must be registered before use # Placeholder for third-party/custom backends - must be registered before use
CUSTOM = "" CUSTOM = ""
......
...@@ -1726,9 +1726,6 @@ class EngineArgs: ...@@ -1726,9 +1726,6 @@ class EngineArgs:
) )
_raise_unsupported_error(feature_name=name) _raise_unsupported_error(feature_name=name)
if current_platform.is_cpu() and model_config.get_sliding_window() is not None:
_raise_unsupported_error(feature_name="sliding window (CPU backend)")
def _set_default_args( def _set_default_args(
self, usage_context: UsageContext, model_config: ModelConfig self, usage_context: UsageContext, model_config: ModelConfig
) -> None: ) -> None:
......
...@@ -8,7 +8,6 @@ import platform ...@@ -8,7 +8,6 @@ import platform
import subprocess import subprocess
import sys import sys
from dataclasses import dataclass from dataclasses import dataclass
from importlib.util import find_spec
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import regex as re import regex as re
...@@ -139,16 +138,15 @@ class CpuPlatform(Platform): ...@@ -139,16 +138,15 @@ class CpuPlatform(Platform):
) -> str: ) -> str:
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
if selected_backend and selected_backend != AttentionBackendEnum.TORCH_SDPA: if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN:
logger.info("Cannot use %s backend on CPU.", selected_backend) logger.info("Cannot use %s backend on CPU.", selected_backend)
if use_mla: if use_mla:
raise NotImplementedError("MLA is not supported on CPU.") raise NotImplementedError("MLA is not supported on CPU.")
if use_sparse: if use_sparse:
raise NotImplementedError("Sparse Attention is not supported on CPU.") raise NotImplementedError("Sparse Attention is not supported on CPU.")
logger.info("Using Torch SDPA backend.")
if not use_v1: if not use_v1:
raise ValueError("CPU backend only supports V1.") raise ValueError("CPU backend only supports V1.")
return AttentionBackendEnum.TORCH_SDPA.get_path() return AttentionBackendEnum.CPU_ATTN.get_path()
@classmethod @classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int: def get_device_total_memory(cls, device_id: int = 0) -> int:
...@@ -186,15 +184,13 @@ class CpuPlatform(Platform): ...@@ -186,15 +184,13 @@ class CpuPlatform(Platform):
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
ipex_available = find_spec("intel_extension_for_pytorch") is not None if cache_config.block_size is None:
cache_config.block_size = 128
if cache_config and cache_config.block_size is None: if cache_config.block_size % 32 != 0:
cache_config.block_size = 128 if ipex_available else 16 logger.warning(
"CPU backend prefers block_size is multiples of 32, "
if not ipex_available and cache_config.block_size != 16: "otherwise the performance is not optimized."
raise RuntimeError(
f"--block-size={cache_config.block_size} requires"
" intel_extension_for_pytorch"
) )
scheduler_config = vllm_config.scheduler_config scheduler_config = vllm_config.scheduler_config
...@@ -207,22 +203,11 @@ class CpuPlatform(Platform): ...@@ -207,22 +203,11 @@ class CpuPlatform(Platform):
"backend is not compatible with FP8 KV cache." "backend is not compatible with FP8 KV cache."
) )
if cache_config.cache_dtype == "fp8_e4m3": if cache_config.cache_dtype != "auto":
cache_config.cache_dtype = "fp8_e5m2"
logger.warning(
"CPU backend doesn't support fp8_e4m3 KV cache type, cast to fp8_e5m2."
)
if (
cache_config.cache_dtype != "auto"
and model_config is not None
and model_config.dtype == torch.half
):
logger.warning( logger.warning(
"FP8 KV cache on the CPU backend only does not" "CPU backend doesn't support KV cache quantization fallback to auto."
" support fp16 for now, cast to bf16."
) )
model_config.dtype = torch.bfloat16 cache_config.cache_dtype = "auto"
cache_config.cpu_kvcache_space_bytes = CpuPlatform.get_device_total_memory() cache_config.cpu_kvcache_space_bytes = CpuPlatform.get_device_total_memory()
......
...@@ -57,7 +57,6 @@ STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND" ...@@ -57,7 +57,6 @@ STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
# Possible string values of STR_BACKEND_ENV_VAR # Possible string values of STR_BACKEND_ENV_VAR
# register, corresponding to possible backends # register, corresponding to possible backends
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER" STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
STR_XFORMERS_ATTN_VAL: str = "XFORMERS" STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN" STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID" STR_INVALID_VAL: str = "INVALID"
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass from dataclasses import dataclass
from typing import ClassVar, Optional from typing import ClassVar
import numpy as np
import torch import torch
from torch.nn.functional import scaled_dot_product_attention
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import ( from vllm.attention.backends.abstract import (
AttentionBackend, AttentionBackend,
AttentionImpl, AttentionImpl,
AttentionLayer, AttentionLayer,
AttentionMetadata,
AttentionType, AttentionType,
is_quantized_kv_cache, is_quantized_kv_cache,
) )
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
...@@ -24,44 +23,38 @@ from vllm.v1.attention.backends.utils import ( ...@@ -24,44 +23,38 @@ from vllm.v1.attention.backends.utils import (
) )
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
try:
import intel_extension_for_pytorch.llm.modules as ipex_modules
_use_ipex = True
# AttributeError is to handle a bug in ipex
# https://github.com/intel/intel-extension-for-pytorch/pull/813
except (ImportError, AttributeError):
_use_ipex = False
from vllm import _custom_ops as ops
logger = init_logger(__name__) logger = init_logger(__name__)
_CPU_ARCH_PREFER_MIXED_BATCH = (CpuArchEnum.X86,)
class TorchSDPABackend(AttentionBackend): class CPUAttentionBackend(AttentionBackend):
accept_output_buffer: bool = False accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [ supported_dtypes: ClassVar[list[torch.dtype]] = [
torch.float16, torch.float16,
torch.bfloat16, torch.bfloat16,
torch.float32, torch.float32,
] ]
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16, torch.float32]
@classmethod @classmethod
def get_supported_head_sizes(cls) -> list[int]: def get_supported_head_sizes(cls) -> list[int]:
attn_impl = _get_paged_attn_impl() return [32, 64, 96, 128, 160, 192, 224, 256]
return attn_impl.get_supported_head_sizes()
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "TORCH_SDPA" return "CPU_ATTN"
@staticmethod @staticmethod
def get_impl_cls() -> type["TorchSDPABackendImpl"]: def get_impl_cls() -> type["CPUAttentionBackendImpl"]:
return TorchSDPABackendImpl return CPUAttentionBackendImpl
@staticmethod @staticmethod
def get_builder_cls() -> type["TorchSDPAMetadataBuilderV1"]: def get_builder_cls() -> type["CPUAttentionMetadataBuilder"]:
return TorchSDPAMetadataBuilderV1 return CPUAttentionMetadataBuilder
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(
...@@ -71,9 +64,7 @@ class TorchSDPABackend(AttentionBackend): ...@@ -71,9 +64,7 @@ class TorchSDPABackend(AttentionBackend):
head_size: int, head_size: int,
cache_dtype_str: str = "auto", cache_dtype_str: str = "auto",
) -> tuple[int, ...]: ) -> tuple[int, ...]:
return _get_paged_attn_impl().get_kv_cache_shape( return 2, num_blocks, num_kv_heads, block_size, head_size
num_blocks, block_size, num_kv_heads, head_size
)
@staticmethod @staticmethod
def use_cascade_attention(*args, **kwargs) -> bool: def use_cascade_attention(*args, **kwargs) -> bool:
...@@ -81,264 +72,26 @@ class TorchSDPABackend(AttentionBackend): ...@@ -81,264 +72,26 @@ class TorchSDPABackend(AttentionBackend):
@dataclass @dataclass
class TorchSDPAMetadata(AttentionMetadata): class CPUAttentionMetadata:
"""Attention metadata for prefill and decode batched together.""" isa: str
num_actual_tokens: int # Number of tokens excluding padding.
# Total number of prefill requests. max_query_len: int
num_prefills: int query_start_loc: torch.Tensor
# Number of prefill tokens. max_seq_len: int
num_prefill_tokens: int seq_lens: torch.Tensor
# Number of decode tokens. Note that it is equivalent to the number of block_table: torch.Tensor
# decode requests.
num_decode_tokens: int
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor slot_mapping: torch.Tensor
"""Metadata for PagedAttention.""" scheduler_metadata: torch.Tensor | None
# (batch_size,). The length of sequences (entire tokens seen so far) per causal: bool = True
# sequence.
decode_seq_lens_tensor: torch.Tensor | None
# Maximum sequence length in the batch. 0 if it is prefill-only batch.
decode_max_seq_len: int
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
decode_block_tables: torch.Tensor | None
"""Metadata for TorchSDPABackend.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
chunked_prefill: bool
seq_lens: list[int] | None = None # For non-chunked prefill
# For chunked prefill only
max_query_len: int | None = None
prefill_max_seq_len: int | None = None
prefill_query_start_loc: torch.Tensor | None = None
prefill_seq_start_loc: torch.Tensor | None = None
prefill_block_tables: torch.Tensor | None = None
# For V1 logits index only
query_start_loc: torch.Tensor | None = None
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
encoder_seq_lens: list[int] | None = None
encoder_seq_lens_tensor: torch.Tensor | None = None
# Maximum sequence length among encoder sequences
max_encoder_seq_len: int | None = None
# Number of tokens input to encoder
num_encoder_tokens: int | None = None
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping: torch.Tensor | None = None
cross_block_tables: torch.Tensor | None = None
def __post_init__(self):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# when alibi slopes is used. It is because of the limitation
# from xformer API.
# will not appear in the __repr__ and __init__
self.attn_bias: list[torch.Tensor] | None = None
self.encoder_attn_bias: list[torch.Tensor] | None = None
self.cross_attn_bias: list[torch.Tensor] | None = None
@property
def is_all_encoder_attn_metadata_set(self):
"""
All attention metadata required for encoder attention is set.
"""
return (
(self.encoder_seq_lens is not None)
and (self.encoder_seq_lens_tensor is not None)
and (self.max_encoder_seq_len is not None)
)
@property
def is_all_cross_attn_metadata_set(self):
"""
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
"""
return (
self.is_all_encoder_attn_metadata_set
and (self.cross_slot_mapping is not None)
and (self.cross_block_tables is not None)
)
@property
def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]:
if self.num_prefill_tokens == 0:
return None
return self
@property
def decode_metadata(self) -> Optional["TorchSDPAMetadata"]:
if self.num_decode_tokens == 0:
return None
return self
def get_seq_lens(
self,
attn_type: str,
):
"""
Extract appropriate sequence lengths from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence lengths tensor for query
* Appropriate sequence lengths tensor for key & value
"""
if (
attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY
):
seq_lens_q = self.seq_lens
seq_lens_kv = self.seq_lens
elif attn_type == AttentionType.ENCODER:
seq_lens_q = self.encoder_seq_lens
seq_lens_kv = self.encoder_seq_lens
elif attn_type == AttentionType.ENCODER_DECODER:
seq_lens_q = self.seq_lens
seq_lens_kv = self.encoder_seq_lens
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
return seq_lens_q, seq_lens_kv
def get_attn_bias(
self,
attn_type: str,
) -> list[torch.Tensor] | None:
"""
Extract appropriate attention bias from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate attention bias value given the attention type
"""
if (
attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY
):
return self.attn_bias
elif attn_type == AttentionType.ENCODER:
return self.encoder_attn_bias
elif attn_type == AttentionType.ENCODER_DECODER:
return self.cross_attn_bias
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
def set_attn_bias(
self,
attn_bias: list[torch.Tensor],
attn_type: str,
) -> None:
"""
Update appropriate attention bias field of attention metadata,
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_bias: The desired attention bias value
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
"""
if (
attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY
):
self.attn_bias = attn_bias
elif attn_type == AttentionType.ENCODER:
self.encoder_attn_bias = attn_bias
elif attn_type == AttentionType.ENCODER_DECODER:
self.cross_attn_bias = attn_bias
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
def get_seq_len_block_table_args(
self,
attn_type: str,
) -> tuple:
"""
The particular choice of sequence-length- and block-table-related
attributes which should be extracted from attn_metadata is dependent
on the type of attention operation.
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths &
cross-attn block-tables fields
Encoder attn -> select encoder sequence lengths fields & no block tables
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* is_prompt: True if prefill, False otherwise
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence-lengths tensor
* Appropriate max sequence-length scalar
* Appropriate block tables (or None)
"""
if (
attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY
):
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
return (
self.decode_seq_lens_tensor,
self.decode_max_seq_len,
self.decode_block_tables,
)
elif attn_type == AttentionType.ENCODER_DECODER:
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
return (
self.encoder_seq_lens_tensor,
self.max_encoder_seq_len,
self.cross_block_tables,
)
elif attn_type == AttentionType.ENCODER:
# No block tables associated with encoder attention
return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len, None)
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
# can be removed after deprecate sdpa
use_sdpa_prefill: bool = False
num_decode_tokens: int = 0
sdpa_attn_masks: list[torch.Tensor | None] | None = None
sdpa_start_loc: torch.Tensor | None = None
class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
reorder_batch_threshold: int = 1
class CPUAttentionMetadataBuilder(AttentionMetadataBuilder[CPUAttentionMetadata]):
def __init__( def __init__(
self, self,
kv_cache_spec: AttentionSpec, kv_cache_spec: AttentionSpec,
...@@ -348,80 +101,104 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]): ...@@ -348,80 +101,104 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
) -> None: ) -> None:
super().__init__(kv_cache_spec, layer_names, vllm_config, device) super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.scheduler_config = vllm_config.scheduler_config self.use_sdpa_prefill = False
self._init_reorder_batch_threshold(1, False) reorder_batch_threshold = None
if current_platform.get_cpu_architecture() not in _CPU_ARCH_PREFER_MIXED_BATCH:
# in this case, decode seqs are reordered to the front of prefill seqs
# to split decode and prefill. Then use SDPA for prefill and
# cpu_attention_with_kv_cache for decode
reorder_batch_threshold = 1
self.use_sdpa_prefill = True
self.seq_start_loc_cpu = torch.zeros( self._init_reorder_batch_threshold(reorder_batch_threshold, False)
vllm_config.scheduler_config.max_num_seqs + 1,
dtype=torch.int32, self.kv_cache_spec = kv_cache_spec
device="cpu", self.vllm_config = vllm_config
parallel_config = vllm_config.parallel_config
self.num_kv_heads = vllm_config.model_config.get_num_kv_heads(parallel_config)
self.num_heads = vllm_config.model_config.get_num_attention_heads(
parallel_config
) )
self.seq_start_loc_np = self.seq_start_loc_cpu.numpy() self.head_dim = kv_cache_spec.head_size
self.dtype = vllm_config.model_config.dtype
self.window_size = getattr(kv_cache_spec, "sliding_window", -1)
if self.window_size is None:
self.window_size = -1
self.block_size = vllm_config.cache_config.block_size
self.isa = _get_attn_isa(self.dtype, self.block_size)
def build( def build(
self, self,
common_prefix_len: int, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False, fast_build: bool = False,
) -> TorchSDPAMetadata: ) -> CPUAttentionMetadata:
num_reqs = common_attn_metadata.num_reqs num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len max_query_len = common_attn_metadata.max_query_len
max_seq_len = common_attn_metadata.max_seq_len
seq_lens_cpu = common_attn_metadata.seq_lens_cpu query_start_loc = common_attn_metadata.query_start_loc
seq_lens_np = seq_lens_cpu.numpy() seq_lens = common_attn_metadata.seq_lens
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
query_start_loc_np = query_start_loc_cpu.numpy()
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold,
require_uniform=True,
)
)
max_prefill_seq_len = (
seq_lens_np[num_decodes:num_reqs].max().item() if num_prefills > 0 else 0
)
max_decode_seq_len = (
seq_lens_np[:num_decodes].max().item() if num_prefills < num_reqs else 0
)
self.seq_start_loc_np[0] = 0
np.cumsum(seq_lens_np, out=self.seq_start_loc_np[1 : num_reqs + 1])
slot_mapping = common_attn_metadata.slot_mapping.long()
block_table_tensor = common_attn_metadata.block_table_tensor block_table_tensor = common_attn_metadata.block_table_tensor
query_start_loc_np = query_start_loc_cpu.numpy() slot_mapping = common_attn_metadata.slot_mapping
query_start_loc_np[num_decodes : num_reqs + 1] -= num_decode_tokens causal = common_attn_metadata.causal
sdpa_start_loc = query_start_loc
num_decode_tokens = 0
if self.use_sdpa_prefill and causal:
# Decoder, need reorder and truncate
assert self.reorder_batch_threshold
(num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = (
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold,
require_uniform=True,
)
)
num_reqs = num_decodes
sdpa_start_loc = sdpa_start_loc[num_decodes:] - num_decode_tokens
seq_lens = seq_lens[:num_decodes]
query_start_loc = query_start_loc[: num_decodes + 1]
block_table_tensor = block_table_tensor[:num_decodes]
sheduler_metadata = None
if causal:
# for decode batch, use the custom kernel
sheduler_metadata = ops.cpu_attn_get_scheduler_metadata(
num_reqs=num_reqs,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
head_dim=self.head_dim,
seq_lens=seq_lens,
dtype=self.dtype,
query_start_loc=query_start_loc,
causal=causal,
sliding_window_size=self.window_size,
isa=self.isa,
enable_kv_split=True,
)
attn_metadata = TorchSDPAMetadata( attn_metadata = CPUAttentionMetadata(
num_prefills=num_prefills, isa=self.isa,
num_prefill_tokens=num_prefill_tokens, num_actual_tokens=num_actual_tokens,
num_decode_tokens=num_decode_tokens,
slot_mapping=slot_mapping,
# to ensure inference when chunked_prefill is disabled
seq_lens=seq_lens_cpu.tolist()[num_decodes:], # prefill
decode_seq_lens_tensor=seq_lens_cpu[:num_decodes], # decode
decode_max_seq_len=max_decode_seq_len, # decode
decode_block_tables=block_table_tensor[:num_decodes], # decode
chunked_prefill=self.scheduler_config.chunked_prefill_enabled,
max_query_len=max_query_len, max_query_len=max_query_len,
prefill_max_seq_len=max_prefill_seq_len, query_start_loc=query_start_loc,
prefill_query_start_loc=query_start_loc_cpu[ max_seq_len=max_seq_len,
num_decodes : num_reqs + 1 seq_lens=seq_lens,
], # prefill block_table=block_table_tensor,
prefill_seq_start_loc=self.seq_start_loc_cpu[ slot_mapping=slot_mapping,
num_decodes : num_reqs + 1 scheduler_metadata=sheduler_metadata,
], # prefill causal=causal,
prefill_block_tables=block_table_tensor[num_decodes:num_reqs], # prefill use_sdpa_prefill=self.use_sdpa_prefill,
query_start_loc=query_start_loc_cpu[: num_reqs + 1], # for logits index num_decode_tokens=num_decode_tokens,
sdpa_start_loc=sdpa_start_loc,
) )
return attn_metadata return attn_metadata
class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): class CPUAttentionBackendImpl(AttentionImpl):
def __init__( def __init__(
self, self,
num_heads: int, num_heads: int,
...@@ -434,37 +211,48 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -434,37 +211,48 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
logits_soft_cap: float | None = None, logits_soft_cap: float | None = None,
attn_type: str = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: str | None = None, kv_sharing_target_layer_name: str | None = None,
sinks: torch.Tensor | None = None,
) -> None: ) -> None:
if kv_sharing_target_layer_name is not None: self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
raise NotImplementedError("KV sharing is not supported in V0.")
if logits_soft_cap is not None:
logger.warning_once(
"Torch SPDA does not support logits soft cap. "
"Outputs may be slightly off."
)
self.paged_attn_impl = _get_paged_attn_impl()
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.scale = float(scale) self.scale = float(scale)
if logits_soft_cap is not None and attn_type in (
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
):
logger.warning_once(
"CPU_ATTN does not support logits softcap for"
" ENCODER and ENCODER_ONLY, outputs may be slightly off"
)
if logits_soft_cap is None:
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap
self.num_kv_heads = num_kv_heads self.num_kv_heads = num_kv_heads
if alibi_slopes is not None: if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes self.alibi_slopes = alibi_slopes
self.sliding_window = sliding_window if sliding_window is None:
self.sliding_window = (-1, -1)
elif attn_type == AttentionType.ENCODER_ONLY:
self.sliding_window = (sliding_window - 1, sliding_window - 1)
else:
self.sliding_window = (sliding_window - 1, 0)
self.kv_cache_dtype = kv_cache_dtype self.kv_cache_dtype = kv_cache_dtype
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.need_mask = (
self.alibi_slopes is not None or self.sliding_window is not None
)
if is_quantized_kv_cache(kv_cache_dtype) and not _use_ipex: if is_quantized_kv_cache(kv_cache_dtype):
raise NotImplementedError( raise NotImplementedError("FP8 KV cache is unsupported in CPU_ATTN")
"Torch SDPA backend FP8 KV cache requires "
"intel_extension_for_pytorch support."
)
self.attn_type = attn_type self.attn_type = attn_type
self.sinks = sinks
if self.sinks is not None:
assert self.sinks.shape[0] == num_heads, (
"Sinks must have the same number of heads as the number of "
"heads in the layer"
)
def forward( def forward(
self, self,
layer: AttentionLayer, layer: AttentionLayer,
...@@ -472,196 +260,130 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -472,196 +260,130 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: TorchSDPAMetadata, # type: ignore attn_metadata: CPUAttentionMetadata | None,
output: torch.Tensor | None = None, output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None, output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention. """Forward pass for CPU attention backend.
Args: Args:
query: shape = [num_tokens, num_heads * head_size] query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads * head_size] key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape = kv_cache: shape =
[2, num_blocks, block_size * num_kv_heads * head_size] [2, num_blocks, num_kv_heads, block_size, head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention. attn_metadata: Metadata for attention.
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None: if output_scale is not None or output_block_scale is not None:
raise NotImplementedError( raise NotImplementedError(
"fused output quantization is not yet supported" "fused output quantization is not yet supported"
" for TorchSDPABackendImpl" " for CPUAttentionBackendImpl"
) )
# For warming-up # For warming-up
if attn_metadata is None: if attn_metadata is None:
return query return output
attn_type = self.attn_type num_actual_tokens = attn_metadata.num_actual_tokens
if attn_type == AttentionType.ENCODER and (
not attn_metadata.is_all_encoder_attn_metadata_set # Handle encoder attention differently - no KV cache needed
): if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
raise AttributeError( # For encoder attention,
"Encoder attention requires setting encoder metadata attributes." return self._run_sdpa_forward(
) query[:num_actual_tokens],
elif attn_type == AttentionType.ENCODER_DECODER and ( key[:num_actual_tokens],
not attn_metadata.is_all_cross_attn_metadata_set value[:num_actual_tokens],
): output[:num_actual_tokens],
raise AttributeError( attn_metadata,
"Encoder/decoder cross-attention " self.attn_type,
"requires setting cross-attention "
"metadata attributes."
) )
# Reshape the query, key, and value tensors. # For decoder and cross-attention, use KV cache, size are
query = query.view(-1, self.num_heads, self.head_size) # [num_blocks, num_kv_heads, block_size, head_size]
if key is not None: key_cache, value_cache = kv_cache.unbind(0)
assert value is not None
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
else:
assert value is None
if attn_type != AttentionType.ENCODER and kv_cache.numel() > 0:
# KV-cache during decoder-self- or
# encoder-decoder-cross-attention, but not
# during encoder attention.
#
# Even if there are no new key/value pairs to cache,
# we still need to break out key_cache and value_cache
# i.e. for later use by paged attention
key_cache, value_cache = self.paged_attn_impl.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size
)
if (key is not None) and (value is not None): # key and value may be None in the case of cross attention. They are
if attn_type == AttentionType.ENCODER_DECODER: # calculated once based on the output from the encoder and then cached
# Update cross-attention KV cache (prefill-only) # in KV cache.
# During cross-attention decode, key & value will be None, if (
# preventing this IF-statement branch from running self.kv_sharing_target_layer_name is None
updated_slot_mapping = attn_metadata.cross_slot_mapping and key is not None
else: and value is not None
# Update self-attention KV cache (prefill/decode) ):
updated_slot_mapping = attn_metadata.slot_mapping ops.cpu_attn_reshape_and_cache(
key,
self.paged_attn_impl.write_to_paged_cache( value,
key, key_cache,
value, value_cache,
key_cache, attn_metadata.slot_mapping,
value_cache, attn_metadata.isa,
updated_slot_mapping, )
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if attn_type != AttentionType.ENCODER: if attn_metadata.use_sdpa_prefill:
# Decoder self-attention supports chunked prefill. assert self.sinks is None, "Attention sink is unsupported in SDPA prefill"
# Encoder/decoder cross-attention requires no chunked
# prefill (100% prefill or 100% decode tokens, no mix)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens num_decode_tokens = attn_metadata.num_decode_tokens
else: self._run_sdpa_forward(
# Encoder attention - chunked prefill is not applicable; query[num_decode_tokens:num_actual_tokens],
# derive token-count from query shape & and treat them key[num_decode_tokens:num_actual_tokens],
# as 100% prefill tokens value[num_decode_tokens:num_actual_tokens],
assert attn_metadata.num_encoder_tokens is not None output[num_decode_tokens:num_actual_tokens],
num_prefill_tokens = attn_metadata.num_encoder_tokens attn_metadata,
num_decode_tokens = 0 self.attn_type,
if attn_type == AttentionType.DECODER:
# Only enforce this shape-constraint for decoder
# self-attention
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
output = torch.empty_like(query)
if prefill_meta := attn_metadata.prefill_metadata:
if not prefill_meta.prefill_metadata.chunked_prefill: # type: ignore
assert attn_metadata.seq_lens is not None
self._run_sdpa_forward(
output, query, key, value, prefill_meta, attn_type=attn_type
)
else:
# prefix-enabled attention
assert not self.need_mask
import intel_extension_for_pytorch.llm.modules as ipex_modules
output = torch.empty_like(query)
ipex_modules.PagedAttention.flash_attn_varlen_func(
output[prefill_meta.num_decode_tokens :, :, :],
query[prefill_meta.num_decode_tokens :, :, :],
key_cache,
value_cache,
prefill_meta.prefill_query_start_loc,
prefill_meta.prefill_seq_start_loc,
prefill_meta.max_query_len,
prefill_meta.prefill_max_seq_len,
self.scale,
True,
prefill_meta.prefill_block_tables,
self.alibi_slopes,
)
if decode_meta := attn_metadata.decode_metadata:
assert attn_type != AttentionType.ENCODER_ONLY, (
"Encoder-only models should not have decode metadata."
) )
# Decoding run. num_actual_tokens = num_decode_tokens
(
seq_lens_arg, if num_actual_tokens > 0:
max_seq_len_arg, ops.cpu_attention_with_kv_cache(
block_tables_arg, query=query[:num_actual_tokens],
) = decode_meta.get_seq_len_block_table_args(attn_type) key_cache=key_cache,
value_cache=value_cache,
self.paged_attn_impl.forward_decode( output=output[:num_actual_tokens], # type: ignore
output[: attn_metadata.num_decode_tokens, :, :], query_start_loc=attn_metadata.query_start_loc,
query[: attn_metadata.num_decode_tokens, :, :], seq_lens=attn_metadata.seq_lens,
key_cache, scale=self.scale,
value_cache, causal=attn_metadata.causal,
block_tables_arg, alibi_slopes=self.alibi_slopes, # type: ignore
seq_lens_arg, sliding_window=self.sliding_window,
max_seq_len_arg, block_table=attn_metadata.block_table,
self.kv_cache_dtype, softcap=self.logits_soft_cap,
self.num_kv_heads, scheduler_metadata=attn_metadata.scheduler_metadata,
self.scale, s_aux=self.sinks,
self.alibi_slopes,
layer._k_scale,
layer._v_scale,
) )
# Reshape the output tensor. return output
return output.view(-1, self.num_heads * self.head_size)
def _run_sdpa_forward( def _run_sdpa_forward(
self, self,
output: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
attn_metadata: TorchSDPAMetadata, output: torch.Tensor,
attn_type: str = AttentionType.DECODER, attn_metadata: CPUAttentionMetadata,
) -> None: attn_type: str,
attn_masks = attn_metadata.get_attn_bias(attn_type) ) -> torch.Tensor:
attn_masks = attn_metadata.sdpa_attn_masks
if attn_masks is None: if attn_masks is None:
if self.alibi_slopes is not None: if self.alibi_slopes is not None:
attn_masks = _make_alibi_bias( attn_masks = _make_alibi_bias(
self.alibi_slopes, self.alibi_slopes,
query.dtype, query.dtype,
attn_metadata.seq_lens, # type: ignore attn_metadata.sdpa_start_loc,
) )
elif self.sliding_window is not None: elif self.sliding_window[0] != -1 or self.sliding_window[1] != -1:
assert attn_metadata.seq_lens is not None assert attn_metadata.seq_lens is not None
attn_masks = _make_sliding_window_bias( attn_masks = _make_sliding_window_bias(
attn_metadata.seq_lens, self.sliding_window, query.dtype attn_metadata.sdpa_start_loc,
self.sliding_window[0],
self.sliding_window[1],
query.dtype,
) )
else: else:
seq_lens, _ = attn_metadata.get_seq_lens(attn_type) attn_masks = [None] * (attn_metadata.sdpa_start_loc.size(0) - 1) # type: ignore
attn_masks = [None] * len(seq_lens) attn_metadata.sdpa_attn_masks = attn_masks
attn_metadata.set_attn_bias(attn_masks, attn_type)
query = query.movedim(0, query.dim() - 2) query = query.movedim(0, query.dim() - 2)
key = key.movedim(0, key.dim() - 2) key = key.movedim(0, key.dim() - 2)
...@@ -673,21 +395,16 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -673,21 +395,16 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
causal_attn = attn_type == AttentionType.DECODER causal_attn = attn_type == AttentionType.DECODER
seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type) sdpa_start_loc = attn_metadata.sdpa_start_loc.numpy() # type: ignore
# Incoming Q and KV contain decoded tokens as well, hence start at an offset for i in range(len(attn_masks)):
# equal to num_decode_tokens since decode requests appear first mask = attn_masks[i]
start_q, start_kv = ( start_q = sdpa_start_loc[i]
attn_metadata.num_decode_tokens, end_q = sdpa_start_loc[i + 1]
attn_metadata.num_decode_tokens,
)
for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv, attn_masks):
end_q = start_q + seq_len_q
end_kv = start_kv + seq_len_kv
sub_out = ( sub_out = (
scaled_dot_product_attention( torch.nn.functional.scaled_dot_product_attention(
query[None, :, start_q:end_q, :], query[None, :, start_q:end_q, :],
key[None, :, start_kv:end_kv, :], key[None, :, start_q:end_q, :],
value[None, :, start_kv:end_kv, :], value[None, :, start_q:end_q, :],
attn_mask=mask, attn_mask=mask,
dropout_p=0.0, dropout_p=0.0,
is_causal=causal_attn and mask is None, is_causal=causal_attn and mask is None,
...@@ -697,17 +414,20 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -697,17 +414,20 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
.movedim(query.dim() - 2, 0) .movedim(query.dim() - 2, 0)
) )
output[start_q:end_q, :, :] = sub_out output[start_q:end_q, :, :] = sub_out
start_q, start_kv = end_q, end_kv return output
def _make_alibi_bias( def _make_alibi_bias(
alibi_slopes: torch.Tensor, alibi_slopes: torch.Tensor,
dtype: torch.dtype, dtype: torch.dtype,
seq_lens: list[int], sdpa_start_loc: torch.Tensor,
) -> list[torch.Tensor]: ) -> list[torch.Tensor]:
attn_biases: list[torch.Tensor] = [] attn_biases: list[torch.Tensor] = []
for seq_len in seq_lens: seq_num = sdpa_start_loc.size(0) - 1
bias = torch.arange(seq_len, dtype=dtype) sdpa_start_loc = sdpa_start_loc.numpy() # type: ignore
for i in range(seq_num):
seq_len = sdpa_start_loc[i + 1] - sdpa_start_loc[i]
bias = torch.arange(seq_len, dtype=dtype) # type: ignore
# NOTE(zhuohan): HF uses # NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)` # `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but # here. We find that both biases give the same results, but
...@@ -719,7 +439,7 @@ def _make_alibi_bias( ...@@ -719,7 +439,7 @@ def _make_alibi_bias(
bias = bias[None, :].repeat((num_heads, 1, 1)) bias = bias[None, :].repeat((num_heads, 1, 1))
bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0) bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0)
inf_mask = ( inf_mask = (
torch.empty((1, seq_len, seq_len), dtype=bias.dtype) torch.empty((1, seq_len, seq_len), dtype=bias.dtype) # type: ignore
.fill_(-torch.inf) .fill_(-torch.inf)
.triu_(diagonal=1) .triu_(diagonal=1)
) )
...@@ -729,210 +449,37 @@ def _make_alibi_bias( ...@@ -729,210 +449,37 @@ def _make_alibi_bias(
def _make_sliding_window_bias( def _make_sliding_window_bias(
seq_lens: list[int], sdpa_start_loc: torch.Tensor,
window_size: int | None, left_window_size: int,
right_window_size: int,
dtype: torch.dtype, dtype: torch.dtype,
) -> list[torch.Tensor]: ) -> list[torch.Tensor]:
attn_biases: list[torch.Tensor] = [] attn_biases: list[torch.Tensor] = []
for seq_len in seq_lens: seq_num = sdpa_start_loc.size(0) - 1
tensor = torch.full( sdpa_start_loc = sdpa_start_loc.numpy() # type: ignore
(1, seq_len, seq_len), for i in range(seq_num):
dtype=dtype, seq_len = sdpa_start_loc[i + 1] - sdpa_start_loc[i]
mask = torch.full( # type: ignore
(1, seq_len, seq_len), # type: ignore
fill_value=1, fill_value=1,
dtype=dtype,
) )
shift = 0
mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore if right_window_size != -1:
if window_size is not None: mask = torch.tril(mask, diagonal=right_window_size)
mask = torch.triu(mask, diagonal=shift - window_size + 1) if left_window_size != -1:
mask = torch.triu(mask, diagonal=-left_window_size)
mask = torch.log(mask) mask = torch.log(mask)
attn_biases.append(mask.to(dtype)) attn_biases.append(mask)
return attn_biases return attn_biases
class _PagedAttention: def _get_attn_isa(dtype: torch.dtype, block_size: int) -> str:
@staticmethod supports_amx = torch._C._cpu._is_amx_tile_supported()
def get_supported_head_sizes() -> list[int]: if supports_amx and dtype in (torch.bfloat16,) and block_size % 32 == 0:
return [32, 64, 80, 96, 112, 128, 192, 256] return "amx"
elif block_size % 32 == 0:
@staticmethod return "vec"
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
*args,
) -> tuple[int, ...]:
return 2, num_blocks, block_size * num_kv_heads * head_size
@staticmethod
def split_kv_cache(
kv_cache: torch.Tensor,
num_kv_heads: int,
head_size: int,
*args,
) -> tuple[torch.Tensor, torch.Tensor]:
x = 16 // kv_cache.element_size()
num_blocks = kv_cache.shape[1]
key_cache = kv_cache[0]
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, -1, x)
value_cache = kv_cache[1]
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
return key_cache, value_cache
@staticmethod
def write_to_paged_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
*args,
) -> None:
ops.reshape_and_cache(
key,
value,
key_cache,
value_cache,
slot_mapping.flatten(),
kv_cache_dtype,
k_scale,
v_scale,
)
@staticmethod
def forward_decode(
output: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
max_context_len: int,
kv_cache_dtype: str,
num_kv_heads: int,
scale: float,
alibi_slopes: torch.Tensor | None,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
*args,
) -> None:
tp_rank: int = 0
blocksparse_local_blocks: int = 0
blocksparse_vert_stride: int = 0
blocksparse_block_size: int = 64
blocksparse_head_sliding_step: int = 0
block_size = value_cache.shape[3]
ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
)
class _IPEXPagedAttention(_PagedAttention):
@staticmethod
def get_supported_head_sizes() -> list[int]:
return []
@staticmethod
def split_kv_cache(
kv_cache: torch.Tensor,
num_kv_heads: int,
head_size: int,
*args,
) -> tuple[torch.Tensor, torch.Tensor]:
num_blocks = kv_cache.shape[1]
key_cache = kv_cache[0]
key_cache = key_cache.view(num_blocks, num_kv_heads, -1, head_size)
value_cache = kv_cache[1]
value_cache = value_cache.view(num_blocks, num_kv_heads, -1, head_size)
return key_cache, value_cache
@staticmethod
def write_to_paged_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
*args,
) -> None:
ipex_modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slot_mapping.flatten().int()
)
@staticmethod
def forward_decode(
output: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
max_context_len: int,
kv_cache_dtype: str,
num_kv_heads: int,
scale: float,
alibi_slopes: torch.Tensor | None,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
*args,
) -> None:
block_size = value_cache.shape[2]
head_mapping = (
torch.arange(
0,
num_kv_heads,
device="cpu",
dtype=torch.int32,
)
.view(num_kv_heads, 1)
.repeat_interleave(query.size(1) // num_kv_heads)
.flatten()
)
ipex_modules.PagedAttention.single_query_cached_kv_attention(
output,
query.contiguous(),
key_cache,
value_cache,
head_mapping,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
)
def _get_paged_attn_impl():
if _use_ipex:
return _IPEXPagedAttention
else: else:
return _PagedAttention return "vec16"
...@@ -265,7 +265,7 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): ...@@ -265,7 +265,7 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
def _init_reorder_batch_threshold( def _init_reorder_batch_threshold(
self, self,
reorder_batch_threshold: int = 1, reorder_batch_threshold: int | None = 1,
supports_spec_as_decode: bool = False, supports_spec_as_decode: bool = False,
supports_dcp_with_varlen: bool = False, supports_dcp_with_varlen: bool = False,
) -> None: ) -> None:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Any from typing import Any
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -12,9 +12,6 @@ from vllm.model_executor.model_loader import get_model ...@@ -12,9 +12,6 @@ from vllm.model_executor.model_loader import get_model
from vllm.v1.utils import CpuGpuBuffer from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.gpu_model_runner import GPUModelRunner
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -31,15 +28,6 @@ class CPUModelRunner(GPUModelRunner): ...@@ -31,15 +28,6 @@ class CPUModelRunner(GPUModelRunner):
self._postprocess_tensors() self._postprocess_tensors()
# Note: Remove the override after new attention backend finished
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
if len(self.kv_cache_config.kv_cache_groups) > 1:
raise ValueError(
"Multiple KVCacheGroups is not"
"currently supported with CPU model runner."
)
super()._may_reorder_batch(scheduler_output)
def _postprocess_tensors(self) -> None: def _postprocess_tensors(self) -> None:
# Note: replace device tensors with cpu tensors # Note: replace device tensors with cpu tensors
def replace_tensor(obj: Any, cpu_attr_name: str, device_attr_name) -> None: def replace_tensor(obj: Any, cpu_attr_name: str, device_attr_name) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment