Unverified Commit 7f829be7 authored by Li, Jiang's avatar Li, Jiang Committed by GitHub
Browse files

[CPU] Refactor CPU attention backend (#27954)


Signed-off-by: default avatarjiang1.li <jiang1.li@intel.com>
parent e1710393
...@@ -35,7 +35,7 @@ DEVICE_MLA_BACKENDS = { ...@@ -35,7 +35,7 @@ DEVICE_MLA_BACKENDS = {
DEVICE_REGULAR_ATTN_BACKENDS = { DEVICE_REGULAR_ATTN_BACKENDS = {
"cuda": ["XFORMERS", "FLASHINFER", "FLASH_ATTN"], "cuda": ["XFORMERS", "FLASHINFER", "FLASH_ATTN"],
"hip": ["ROCM_ATTN"], "hip": ["ROCM_ATTN"],
"cpu": ["TORCH_SDPA"], "cpu": ["CPU_ATTN"],
} }
DEVICE_MLA_BLOCK_SIZES = { DEVICE_MLA_BLOCK_SIZES = {
...@@ -86,7 +86,7 @@ def test_env( ...@@ -86,7 +86,7 @@ def test_env(
if device == "cpu": if device == "cpu":
with patch("vllm.platforms.current_platform", CpuPlatform()): with patch("vllm.platforms.current_platform", CpuPlatform()):
backend = get_attn_backend(16, torch.float16, None, block_size) backend = get_attn_backend(16, torch.float16, None, block_size)
assert backend.get_name() == "TORCH_SDPA" assert backend.get_name() == "CPU_ATTN"
elif device == "hip": elif device == "hip":
with patch("vllm.platforms.current_platform", RocmPlatform()): with patch("vllm.platforms.current_platform", RocmPlatform()):
...@@ -224,7 +224,7 @@ def test_fp32_fallback(device: str): ...@@ -224,7 +224,7 @@ def test_fp32_fallback(device: str):
if device == "cpu": if device == "cpu":
with patch("vllm.platforms.current_platform", CpuPlatform()): with patch("vllm.platforms.current_platform", CpuPlatform()):
backend = get_attn_backend(16, torch.float32, None, 16) backend = get_attn_backend(16, torch.float32, None, 16)
assert backend.get_name() == "TORCH_SDPA" assert backend.get_name() == "CPU_ATTN"
elif device == "cuda": elif device == "cuda":
with patch("vllm.platforms.current_platform", CudaPlatform()): with patch("vllm.platforms.current_platform", CudaPlatform()):
......
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Integration tests for FlexAttention backend vs default backend"""
import pytest import pytest
import torch import torch
......
...@@ -38,7 +38,11 @@ AITER_MODEL_LIST = [ ...@@ -38,7 +38,11 @@ AITER_MODEL_LIST = [
[ [
pytest.param( pytest.param(
"bigscience/bloom-560m", # bloom - testing alibi slopes "bigscience/bloom-560m", # bloom - testing alibi slopes
marks=[pytest.mark.core_model, pytest.mark.slow_test], marks=[
pytest.mark.core_model,
pytest.mark.slow_test,
pytest.mark.cpu_model,
],
), ),
pytest.param( pytest.param(
"openai-community/gpt2", # gpt2 "openai-community/gpt2", # gpt2
...@@ -55,6 +59,10 @@ AITER_MODEL_LIST = [ ...@@ -55,6 +59,10 @@ AITER_MODEL_LIST = [
pytest.mark.slow_test, pytest.mark.slow_test,
], ],
), ),
pytest.param(
"google/gemma-2-2b-it", # test hybrid attention
marks=[pytest.mark.cpu_model],
),
pytest.param( pytest.param(
"zai-org/chatglm3-6b", # chatglm (text-only) "zai-org/chatglm3-6b", # chatglm (text-only)
), ),
...@@ -64,7 +72,6 @@ AITER_MODEL_LIST = [ ...@@ -64,7 +72,6 @@ AITER_MODEL_LIST = [
), ),
pytest.param( pytest.param(
"openbmb/MiniCPM3-4B", "openbmb/MiniCPM3-4B",
# fused_moe not supported on CPU
marks=[pytest.mark.core_model, large_gpu_mark(min_gb=32)], marks=[pytest.mark.core_model, large_gpu_mark(min_gb=32)],
), ),
pytest.param( pytest.param(
...@@ -93,11 +100,7 @@ AITER_MODEL_LIST = [ ...@@ -93,11 +100,7 @@ AITER_MODEL_LIST = [
pytest.param("bigcode/starcoder2-3b"), # starcoder2 pytest.param("bigcode/starcoder2-3b"), # starcoder2
pytest.param( pytest.param(
"TitanML/tiny-mixtral", # mixtral "TitanML/tiny-mixtral", # mixtral
marks=[pytest.mark.core_model], marks=[pytest.mark.core_model, pytest.mark.cpu_model],
),
pytest.param(
"allenai/OLMoE-1B-7B-0924-Instruct",
marks=[pytest.mark.cpu_model],
), ),
pytest.param("swiss-ai/Apertus-8B-Instruct-2509"), # apertus pytest.param("swiss-ai/Apertus-8B-Instruct-2509"), # apertus
], ],
......
...@@ -23,8 +23,7 @@ from ...utils import check_embeddings_close ...@@ -23,8 +23,7 @@ from ...utils import check_embeddings_close
), ),
pytest.param( pytest.param(
"intfloat/e5-mistral-7b-instruct", "intfloat/e5-mistral-7b-instruct",
# CPU v1 doesn't support sliding window marks=[pytest.mark.core_model, pytest.mark.cpu_model],
marks=[pytest.mark.core_model],
), ),
pytest.param( pytest.param(
"ssmits/Qwen2-7B-Instruct-embed-base", marks=[pytest.mark.cpu_model] "ssmits/Qwen2-7B-Instruct-embed-base", marks=[pytest.mark.cpu_model]
......
...@@ -243,7 +243,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -243,7 +243,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"FalconH1ForCausalLM": _HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base"), "FalconH1ForCausalLM": _HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base"),
"FlexOlmoForCausalLM": _HfExamplesInfo("allenai/Flex-reddit-2x7B-1T"), "FlexOlmoForCausalLM": _HfExamplesInfo("allenai/Flex-reddit-2x7B-1T"),
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"), "GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"),
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"), "Gemma2ForCausalLM": _HfExamplesInfo(
"google/gemma-2-9b", extras={"tiny": "google/gemma-2-2b-it"}
),
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"), "Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),
"Gemma3nForCausalLM": _HfExamplesInfo("google/gemma-3n-E2B-it"), "Gemma3nForCausalLM": _HfExamplesInfo("google/gemma-3n-E2B-it"),
"GlmForCausalLM": _HfExamplesInfo("zai-org/glm-4-9b-chat-hf"), "GlmForCausalLM": _HfExamplesInfo("zai-org/glm-4-9b-chat-hf"),
......
...@@ -2583,6 +2583,88 @@ def onednn_scaled_mm( ...@@ -2583,6 +2583,88 @@ def onednn_scaled_mm(
return output return output
def cpu_attn_get_scheduler_metadata(
num_reqs: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
seq_lens: torch.Tensor,
dtype: torch.dtype,
query_start_loc: torch.Tensor,
causal: bool,
sliding_window_size: int,
isa: str,
enable_kv_split: bool,
) -> torch.Tensor:
sheduler_metadata = torch.ops._C.get_scheduler_metadata(
num_reqs,
num_heads,
num_kv_heads,
head_dim,
seq_lens,
dtype,
query_start_loc,
causal,
sliding_window_size,
isa,
enable_kv_split,
)
return sheduler_metadata
def cpu_attn_reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
isa: str,
) -> None:
torch.ops._C.cpu_attn_reshape_and_cache(
key,
value,
key_cache,
value_cache,
slot_mapping,
isa,
)
def cpu_attention_with_kv_cache(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
output: torch.Tensor,
query_start_loc: torch.Tensor,
seq_lens: torch.Tensor,
scale: float,
causal: bool,
alibi_slopes: torch.Tensor | None,
sliding_window: tuple[int, int],
block_table: torch.Tensor,
softcap: float,
scheduler_metadata: torch.Tensor,
s_aux: torch.Tensor | None,
) -> None:
torch.ops._C.cpu_attention_with_kv_cache(
query,
key_cache,
value_cache,
output,
query_start_loc,
seq_lens,
scale,
causal,
alibi_slopes,
sliding_window[0],
sliding_window[1],
block_table,
softcap,
scheduler_metadata,
s_aux,
)
if hasattr(torch.ops._qutlass_C, "matmul_mxf4_bf16_tn"): if hasattr(torch.ops._qutlass_C, "matmul_mxf4_bf16_tn"):
@register_fake("_qutlass_C::matmul_mxf4_bf16_tn") @register_fake("_qutlass_C::matmul_mxf4_bf16_tn")
......
...@@ -49,7 +49,7 @@ class AttentionBackendEnum(enum.Enum, metaclass=_AttentionBackendEnumMeta): ...@@ -49,7 +49,7 @@ class AttentionBackendEnum(enum.Enum, metaclass=_AttentionBackendEnumMeta):
ROCM_AITER_FA = ( ROCM_AITER_FA = (
"vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend" "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
) )
TORCH_SDPA = "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend" TORCH_SDPA = "" # this tag is only used for ViT
FLASHINFER = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" FLASHINFER = "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
FLASHINFER_MLA = ( FLASHINFER_MLA = (
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend" "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
...@@ -70,6 +70,7 @@ class AttentionBackendEnum(enum.Enum, metaclass=_AttentionBackendEnumMeta): ...@@ -70,6 +70,7 @@ class AttentionBackendEnum(enum.Enum, metaclass=_AttentionBackendEnumMeta):
"vllm.v1.attention.backends.rocm_aiter_unified_attn." "vllm.v1.attention.backends.rocm_aiter_unified_attn."
"RocmAiterUnifiedAttentionBackend" "RocmAiterUnifiedAttentionBackend"
) )
CPU_ATTN = "vllm.v1.attention.backends.cpu_attn.CPUAttentionBackend"
# Placeholder for third-party/custom backends - must be registered before use # Placeholder for third-party/custom backends - must be registered before use
CUSTOM = "" CUSTOM = ""
......
...@@ -1726,9 +1726,6 @@ class EngineArgs: ...@@ -1726,9 +1726,6 @@ class EngineArgs:
) )
_raise_unsupported_error(feature_name=name) _raise_unsupported_error(feature_name=name)
if current_platform.is_cpu() and model_config.get_sliding_window() is not None:
_raise_unsupported_error(feature_name="sliding window (CPU backend)")
def _set_default_args( def _set_default_args(
self, usage_context: UsageContext, model_config: ModelConfig self, usage_context: UsageContext, model_config: ModelConfig
) -> None: ) -> None:
......
...@@ -8,7 +8,6 @@ import platform ...@@ -8,7 +8,6 @@ import platform
import subprocess import subprocess
import sys import sys
from dataclasses import dataclass from dataclasses import dataclass
from importlib.util import find_spec
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import regex as re import regex as re
...@@ -139,16 +138,15 @@ class CpuPlatform(Platform): ...@@ -139,16 +138,15 @@ class CpuPlatform(Platform):
) -> str: ) -> str:
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
if selected_backend and selected_backend != AttentionBackendEnum.TORCH_SDPA: if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN:
logger.info("Cannot use %s backend on CPU.", selected_backend) logger.info("Cannot use %s backend on CPU.", selected_backend)
if use_mla: if use_mla:
raise NotImplementedError("MLA is not supported on CPU.") raise NotImplementedError("MLA is not supported on CPU.")
if use_sparse: if use_sparse:
raise NotImplementedError("Sparse Attention is not supported on CPU.") raise NotImplementedError("Sparse Attention is not supported on CPU.")
logger.info("Using Torch SDPA backend.")
if not use_v1: if not use_v1:
raise ValueError("CPU backend only supports V1.") raise ValueError("CPU backend only supports V1.")
return AttentionBackendEnum.TORCH_SDPA.get_path() return AttentionBackendEnum.CPU_ATTN.get_path()
@classmethod @classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int: def get_device_total_memory(cls, device_id: int = 0) -> int:
...@@ -186,15 +184,13 @@ class CpuPlatform(Platform): ...@@ -186,15 +184,13 @@ class CpuPlatform(Platform):
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
ipex_available = find_spec("intel_extension_for_pytorch") is not None if cache_config.block_size is None:
cache_config.block_size = 128
if cache_config and cache_config.block_size is None: if cache_config.block_size % 32 != 0:
cache_config.block_size = 128 if ipex_available else 16 logger.warning(
"CPU backend prefers block_size is multiples of 32, "
if not ipex_available and cache_config.block_size != 16: "otherwise the performance is not optimized."
raise RuntimeError(
f"--block-size={cache_config.block_size} requires"
" intel_extension_for_pytorch"
) )
scheduler_config = vllm_config.scheduler_config scheduler_config = vllm_config.scheduler_config
...@@ -207,22 +203,11 @@ class CpuPlatform(Platform): ...@@ -207,22 +203,11 @@ class CpuPlatform(Platform):
"backend is not compatible with FP8 KV cache." "backend is not compatible with FP8 KV cache."
) )
if cache_config.cache_dtype == "fp8_e4m3": if cache_config.cache_dtype != "auto":
cache_config.cache_dtype = "fp8_e5m2"
logger.warning(
"CPU backend doesn't support fp8_e4m3 KV cache type, cast to fp8_e5m2."
)
if (
cache_config.cache_dtype != "auto"
and model_config is not None
and model_config.dtype == torch.half
):
logger.warning( logger.warning(
"FP8 KV cache on the CPU backend only does not" "CPU backend doesn't support KV cache quantization fallback to auto."
" support fp16 for now, cast to bf16."
) )
model_config.dtype = torch.bfloat16 cache_config.cache_dtype = "auto"
cache_config.cpu_kvcache_space_bytes = CpuPlatform.get_device_total_memory() cache_config.cpu_kvcache_space_bytes = CpuPlatform.get_device_total_memory()
......
...@@ -57,7 +57,6 @@ STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND" ...@@ -57,7 +57,6 @@ STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
# Possible string values of STR_BACKEND_ENV_VAR # Possible string values of STR_BACKEND_ENV_VAR
# register, corresponding to possible backends # register, corresponding to possible backends
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER" STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
STR_XFORMERS_ATTN_VAL: str = "XFORMERS" STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN" STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID" STR_INVALID_VAL: str = "INVALID"
......
This diff is collapsed.
...@@ -265,7 +265,7 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): ...@@ -265,7 +265,7 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
def _init_reorder_batch_threshold( def _init_reorder_batch_threshold(
self, self,
reorder_batch_threshold: int = 1, reorder_batch_threshold: int | None = 1,
supports_spec_as_decode: bool = False, supports_spec_as_decode: bool = False,
supports_dcp_with_varlen: bool = False, supports_dcp_with_varlen: bool = False,
) -> None: ) -> None:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Any from typing import Any
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -12,9 +12,6 @@ from vllm.model_executor.model_loader import get_model ...@@ -12,9 +12,6 @@ from vllm.model_executor.model_loader import get_model
from vllm.v1.utils import CpuGpuBuffer from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.gpu_model_runner import GPUModelRunner
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -31,15 +28,6 @@ class CPUModelRunner(GPUModelRunner): ...@@ -31,15 +28,6 @@ class CPUModelRunner(GPUModelRunner):
self._postprocess_tensors() self._postprocess_tensors()
# Note: Remove the override after new attention backend finished
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
if len(self.kv_cache_config.kv_cache_groups) > 1:
raise ValueError(
"Multiple KVCacheGroups is not"
"currently supported with CPU model runner."
)
super()._may_reorder_batch(scheduler_output)
def _postprocess_tensors(self) -> None: def _postprocess_tensors(self) -> None:
# Note: replace device tensors with cpu tensors # Note: replace device tensors with cpu tensors
def replace_tensor(obj: Any, cpu_attr_name: str, device_attr_name) -> None: def replace_tensor(obj: Any, cpu_attr_name: str, device_attr_name) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment