Unverified Commit 63f1fde2 authored by Li, Jiang's avatar Li, Jiang Committed by GitHub
Browse files

[Hardware][CPU] Support chunked-prefill and prefix-caching on CPU (#10355)


Signed-off-by: default avatarjiang1.li <jiang1.li@intel.com>
parent d5b28447
......@@ -25,6 +25,7 @@ docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/hugg
function cpu_tests() {
set -e
export NUMA_NODE=$2
# offline inference
docker exec cpu-test-avx2-"$NUMA_NODE" bash -c "
......@@ -57,6 +58,12 @@ function cpu_tests() {
pytest -s -v \
tests/quantization/test_ipex_quant.py"
# Run chunked-prefill and prefix-cache test
docker exec cpu-test-"$NUMA_NODE" bash -c "
set -e
pytest -s -v -k cpu_model \
tests/basic_correctness/test_chunked_prefill.py"
# online inference
docker exec cpu-test-"$NUMA_NODE" bash -c "
set -e
......@@ -75,4 +82,4 @@ function cpu_tests() {
# All of CPU tests are expected to be finished less than 25 mins.
export -f cpu_tests
timeout 25m bash -c "cpu_tests $CORE_RANGE"
timeout 30m bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
......@@ -5,11 +5,11 @@ Installation with CPU
vLLM initially supports basic model inferencing and serving on x86 CPU platform, with data types FP32, FP16 and BF16. vLLM CPU backend supports the following vLLM features:
- Tensor Parallel (``-tp = N``)
- Quantization (``INT8 W8A8, AWQ``)
.. note::
More advanced features on `chunked-prefill`, `prefix-caching` and `FP8 KV cache` are under development and will be available soon.
- Tensor Parallel
- Model Quantization (``INT8 W8A8, AWQ``)
- Chunked-prefill
- Prefix-caching
- FP8-E5M2 KV-Caching (TODO)
Table of contents:
......
......@@ -344,7 +344,7 @@ Feature x Hardware
- ✅
- ✅
- ✅
-
-
- ✅
* - :ref:`APC <apc>`
- `✗ <https://github.com/vllm-project/vllm/issues/3687>`__
......@@ -352,7 +352,7 @@ Feature x Hardware
- ✅
- ✅
- ✅
-
-
- ✅
* - :ref:`LoRA <lora>`
- ✅
......
......@@ -12,6 +12,7 @@ from contextlib import nullcontext
import pytest
from tests.kernels.utils import override_backend_env_variable
from vllm.platforms import current_platform
from ..models.utils import check_logprobs_close, check_outputs_equal
from ..utils import multi_gpu_test
......@@ -206,12 +207,14 @@ def test_models_with_fp8_kv_cache(
# NOTE: Increasing this in this suite will fail CI because we currently cannot
# reset distributed env properly. Use a value > 1 just when you test.
@pytest.mark.parametrize("tensor_parallel_size", [1])
@pytest.mark.parametrize("dtype", ["half"])
def test_with_prefix_caching(
vllm_runner,
max_tokens: int,
enforce_eager: bool,
chunk_size: int,
tensor_parallel_size: int,
dtype: str,
) -> None:
"""
Checks exact match decode with and without prefix caching
......@@ -233,7 +236,7 @@ def test_with_prefix_caching(
for enable in (True, False):
with vllm_runner(
model,
dtype="half",
dtype=dtype,
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=True,
enable_prefix_caching=enable,
......@@ -260,3 +263,61 @@ def test_with_prefix_caching(
name_0="w/o prefix caching",
name_1="with prefix caching",
)
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
@pytest.mark.parametrize("enforce_eager", [False])
@pytest.mark.parametrize("attention_backend", ["TORCH_SDPA"])
@pytest.mark.cpu_model
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
def test_models_cpu(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
chunked_prefill_token_size: int,
enforce_eager: bool,
attention_backend: str,
monkeypatch,
) -> None:
test_models(
hf_runner,
vllm_runner,
example_prompts,
model,
dtype,
max_tokens,
chunked_prefill_token_size,
enforce_eager,
1,
attention_backend,
monkeypatch,
)
@pytest.mark.parametrize("max_tokens", [16])
@pytest.mark.parametrize("enforce_eager", [False])
@pytest.mark.parametrize("chunk_size", [30, 32])
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.cpu_model
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
def test_with_prefix_caching_cpu(
vllm_runner,
max_tokens: int,
enforce_eager: bool,
chunk_size: int,
dtype: str,
) -> None:
test_with_prefix_caching(
vllm_runner,
max_tokens,
enforce_eager,
chunk_size,
1,
dtype,
)
......@@ -7,18 +7,14 @@ import torch
from torch.nn.functional import scaled_dot_product_attention
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
AttentionMetadata,
AttentionMetadataBuilder,
AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.ipex_attn import PagedAttention
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
from vllm.platforms import current_platform
if current_platform.is_cpu():
try:
from vllm.attention.ops.ipex_attn import PagedAttention
except ImportError:
from vllm.attention.ops.paged_attn import PagedAttention
else:
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.utils import make_tensor_with_pad
from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder
class TorchSDPABackend(AttentionBackend):
......@@ -39,6 +35,10 @@ class TorchSDPABackend(AttentionBackend):
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_builder_cls() -> Type["TorchSDPAMetadataBuilder"]:
return TorchSDPAMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
......@@ -71,9 +71,15 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
slot_mapping: torch.Tensor
seq_lens: Optional[List[int]]
chunked_prefill: bool
seq_lens: Optional[List[int]] = None # For non-chunked prefill
# For chunked prefill only
max_query_len: Optional[int] = None
max_kv_len: Optional[int] = None
query_start_loc: Optional[torch.Tensor] = None
kv_start_loc: Optional[torch.Tensor] = None
prefill_block_tables: Optional[torch.Tensor] = None
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
......@@ -123,20 +129,14 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
@property
def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]:
# Currently chunked prefill is not supported
if self.num_decode_tokens == 0:
assert self.num_prefills > 0
return self
if self.num_prefill_tokens == 0:
return None
return self
@property
def decode_metadata(self) -> Optional["TorchSDPAMetadata"]:
# Currently chunked prefill is not supported
if self.num_prefills > 0:
assert self.num_decode_tokens == 0
if self.num_decode_tokens == 0:
return None
return self
def get_seq_lens(
......@@ -274,6 +274,105 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
raise AttributeError(f"Invalid attention type {str(attn_type)}")
class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]):
def __init__(self, input_builder: ModelInputForCPUBuilder) -> None:
self.chunked_prefill = input_builder.chunked_prefill
self.input_data = input_builder.input_data
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int) -> TorchSDPAMetadata:
input_data = self.input_data
prefill_seq_lens = seq_lens[0:input_data.num_prefills]
prefill_query_lens = query_lens[0:input_data.num_prefills]
slot_mapping = torch.tensor(input_data.slot_mapping,
dtype=torch.long,
device="cpu")
# For chunked-prefill
if self.chunked_prefill and input_data.num_prefill_tokens != 0:
prefill_block_tables = make_tensor_with_pad(
self.input_data.prefill_block_tables,
pad=0,
dtype=torch.int32,
device="cpu",
)
query_lens_tensor = torch.tensor(prefill_query_lens,
dtype=torch.int32,
device="cpu")
kv_lens_tensor = torch.tensor(prefill_seq_lens,
dtype=torch.int32,
device="cpu")
query_start_loc = torch.zeros(input_data.num_prefills + 1,
dtype=torch.int32,
device="cpu")
kv_start_loc = torch.zeros(input_data.num_prefills + 1,
dtype=torch.int32,
device="cpu")
torch.cumsum(query_lens_tensor,
dim=0,
dtype=torch.int32,
out=query_start_loc[1:])
torch.cumsum(kv_lens_tensor,
dim=0,
dtype=torch.int32,
out=kv_start_loc[1:])
max_query_len = max(prefill_query_lens)
max_kv_len = max(prefill_seq_lens)
else:
prefill_block_tables = None
query_start_loc = None
kv_start_loc = None
max_query_len = None
max_kv_len = None
# For paged attention
if input_data.num_decode_tokens != 0:
seq_lens_tensor = torch.tensor(
input_data.seq_lens[input_data.num_prefills:],
dtype=torch.int32,
device="cpu",
)
block_tables = make_tensor_with_pad(
self.input_data.decode_block_tables,
pad=0,
dtype=torch.int32,
device="cpu",
)
else:
block_tables = torch.tensor([])
seq_lens_tensor = torch.tensor([])
# For multi-modal models
placeholder_index_maps = None
if len(input_data.multi_modal_inputs_list) != 0:
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
input_data.multi_modal_placeholder_maps.items()
}
attn_metadata = TorchSDPAMetadata(
chunked_prefill=self.chunked_prefill,
seq_lens=prefill_seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_kv_len=max_kv_len,
query_start_loc=query_start_loc,
kv_start_loc=kv_start_loc,
max_decode_seq_len=input_data.max_decode_seq_len,
num_prefills=input_data.num_prefills,
num_prefill_tokens=input_data.num_prefill_tokens,
num_decode_tokens=input_data.num_decode_tokens,
block_tables=block_tables,
prefill_block_tables=prefill_block_tables,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=placeholder_index_maps,
)
return attn_metadata
class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
def __init__(
......@@ -409,19 +508,35 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
output = torch.empty_like(query)
if prefill_meta := attn_metadata.prefill_metadata:
assert attn_metadata.seq_lens is not None
if (kv_cache.numel() == 0
or prefill_meta.block_tables.numel() == 0):
output = self._run_sdpa_forward(query,
if not prefill_meta.prefill_metadata.chunked_prefill: # type: ignore
self._run_sdpa_forward(output,
query,
key,
value,
prefill_meta,
attn_type=attn_type)
else:
# prefix-enabled attention
raise RuntimeError(
"Torch SDPA backend doesn't support prefix decoding.")
assert not self.need_mask
import intel_extension_for_pytorch.llm.modules as ipex_modules
output = torch.empty_like(query)
ipex_modules.PagedAttention.flash_attn_varlen_func(
output[:prefill_meta.num_prefill_tokens, :, :],
query[:prefill_meta.num_prefill_tokens, :, :],
key_cache,
value_cache,
prefill_meta.query_start_loc,
prefill_meta.kv_start_loc,
prefill_meta.max_query_len,
prefill_meta.max_kv_len,
self.scale,
True,
prefill_meta.prefill_block_tables,
self.alibi_slopes,
)
if decode_meta := attn_metadata.decode_metadata:
assert attn_type != AttentionType.ENCODER_ONLY, (
......@@ -433,8 +548,9 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
block_tables_arg,
) = decode_meta.get_seq_len_block_table_args(attn_type)
output = PagedAttention.forward_decode(
query,
PagedAttention.forward_decode(
output[attn_metadata.num_prefill_tokens:, :, :],
query[attn_metadata.num_prefill_tokens:, :, :],
key_cache,
value_cache,
block_tables_arg,
......@@ -453,12 +569,13 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
def _run_sdpa_forward(
self,
output: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: TorchSDPAMetadata,
attn_type: AttentionType = AttentionType.DECODER,
):
) -> None:
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
value = value.repeat_interleave(self.num_queries_per_kv, dim=1)
......@@ -479,7 +596,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
attn_masks = [None] * len(seq_lens)
attn_metadata.set_attn_bias(attn_masks, attn_type)
output = torch.empty_like(query)
query = query.movedim(0, query.dim() - 2)
key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2)
......@@ -502,7 +618,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0)
output[start_q:end_q, :, :] = sub_out
start_q, start_kv = end_q, end_kv
return output
def _make_alibi_bias(
......
from typing import Dict, List, Optional, Tuple
import intel_extension_for_pytorch.llm.modules as ipex_modules
try:
import intel_extension_for_pytorch.llm.modules as ipex_modules
_use_ipex = True
except ImportError:
_use_ipex = False
import torch
from vllm import _custom_ops as ops
class PagedAttention:
class _PagedAttention:
@staticmethod
def get_supported_head_sizes() -> List[int]:
......@@ -22,6 +27,105 @@ class PagedAttention:
) -> Tuple[int, ...]:
return (2, num_blocks, block_size * num_kv_heads * head_size)
@staticmethod
def split_kv_cache(
kv_cache: torch.Tensor,
num_kv_heads: int,
head_size: int,
*args,
) -> Tuple[torch.Tensor, torch.Tensor]:
x = 16 // kv_cache.element_size()
num_blocks = kv_cache.shape[1]
key_cache = kv_cache[0]
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
-1, x)
value_cache = kv_cache[1]
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
return key_cache, value_cache
@staticmethod
def write_to_paged_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
*args,
) -> None:
ops.reshape_and_cache(
key,
value,
key_cache,
value_cache,
slot_mapping.flatten(),
kv_cache_dtype,
k_scale,
v_scale,
)
@staticmethod
def forward_decode(
output: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
max_context_len: int,
kv_cache_dtype: str,
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
k_scale: float,
v_scale: float,
*args,
) -> None:
tp_rank: int = 0
blocksparse_local_blocks: int = 0
blocksparse_vert_stride: int = 0
blocksparse_block_size: int = 64
blocksparse_head_sliding_step: int = 0
block_size = value_cache.shape[3]
ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
*args,
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
ops.copy_blocks(key_caches, value_caches, src_to_dists)
class _IPEXPagedAttention(_PagedAttention):
@staticmethod
def split_kv_cache(
kv_cache: torch.Tensor,
......@@ -55,6 +159,7 @@ class PagedAttention:
@staticmethod
def forward_decode(
output: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
......@@ -68,8 +173,7 @@ class PagedAttention:
k_scale: float,
v_scale: float,
*args,
) -> torch.Tensor:
output = torch.empty_like(query)
) -> None:
block_size = value_cache.shape[2]
head_mapping = torch.arange(
0,
......@@ -83,41 +187,5 @@ class PagedAttention:
scale, block_tables, context_lens, block_size, max_context_len,
alibi_slopes)
return output
@staticmethod
def forward_prefix(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache_dtype: str,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
subquery_start_loc: torch.Tensor,
prompt_lens_tensor: torch.Tensor,
context_lens: torch.Tensor,
max_subquery_len: int,
alibi_slopes: Optional[torch.Tensor],
*args,
) -> torch.Tensor:
raise NotImplementedError
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
*args,
) -> None:
raise NotImplementedError
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
*args,
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
ops.copy_blocks(key_caches, value_caches, src_to_dists)
PagedAttention = _IPEXPagedAttention if _use_ipex else _PagedAttention
......@@ -53,11 +53,6 @@ class CpuPlatform(Platform):
cache_config = vllm_config.cache_config
if cache_config.enable_prefix_caching:
logger.warning(
"Prefix caching is not supported on CPU, disable it.")
cache_config.enable_prefix_caching = False
kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE
if kv_cache_space >= 0:
......@@ -74,10 +69,12 @@ class CpuPlatform(Platform):
f" {kv_cache_space}, expect a positive integer value.")
scheduler_config = vllm_config.scheduler_config
if scheduler_config.chunked_prefill_enabled:
logger.warning(
"Chunked prefill is not supported on CPU, disable it.")
scheduler_config.chunked_prefill_enabled = False
if ((scheduler_config.chunked_prefill_enabled
or cache_config.enable_prefix_caching)
and model_config.dtype == torch.half):
logger.warning("Chunked-prefill on the CPU backend only does not"
" support fp16 for now, cast to bf16.")
model_config.dtype = torch.bfloat16
parallel_config = vllm_config.parallel_config
if (parallel_config.distributed_executor_backend is not None
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment