Unverified Commit 63f1fde2 authored by Li, Jiang's avatar Li, Jiang Committed by GitHub
Browse files

[Hardware][CPU] Support chunked-prefill and prefix-caching on CPU (#10355)


Signed-off-by: default avatarjiang1.li <jiang1.li@intel.com>
parent d5b28447
...@@ -25,6 +25,7 @@ docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/hugg ...@@ -25,6 +25,7 @@ docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/hugg
function cpu_tests() { function cpu_tests() {
set -e set -e
export NUMA_NODE=$2
# offline inference # offline inference
docker exec cpu-test-avx2-"$NUMA_NODE" bash -c " docker exec cpu-test-avx2-"$NUMA_NODE" bash -c "
...@@ -57,6 +58,12 @@ function cpu_tests() { ...@@ -57,6 +58,12 @@ function cpu_tests() {
pytest -s -v \ pytest -s -v \
tests/quantization/test_ipex_quant.py" tests/quantization/test_ipex_quant.py"
# Run chunked-prefill and prefix-cache test
docker exec cpu-test-"$NUMA_NODE" bash -c "
set -e
pytest -s -v -k cpu_model \
tests/basic_correctness/test_chunked_prefill.py"
# online inference # online inference
docker exec cpu-test-"$NUMA_NODE" bash -c " docker exec cpu-test-"$NUMA_NODE" bash -c "
set -e set -e
...@@ -75,4 +82,4 @@ function cpu_tests() { ...@@ -75,4 +82,4 @@ function cpu_tests() {
# All of CPU tests are expected to be finished less than 25 mins. # All of CPU tests are expected to be finished less than 25 mins.
export -f cpu_tests export -f cpu_tests
timeout 25m bash -c "cpu_tests $CORE_RANGE" timeout 30m bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
...@@ -5,11 +5,11 @@ Installation with CPU ...@@ -5,11 +5,11 @@ Installation with CPU
vLLM initially supports basic model inferencing and serving on x86 CPU platform, with data types FP32, FP16 and BF16. vLLM CPU backend supports the following vLLM features: vLLM initially supports basic model inferencing and serving on x86 CPU platform, with data types FP32, FP16 and BF16. vLLM CPU backend supports the following vLLM features:
- Tensor Parallel (``-tp = N``) - Tensor Parallel
- Quantization (``INT8 W8A8, AWQ``) - Model Quantization (``INT8 W8A8, AWQ``)
- Chunked-prefill
.. note:: - Prefix-caching
More advanced features on `chunked-prefill`, `prefix-caching` and `FP8 KV cache` are under development and will be available soon. - FP8-E5M2 KV-Caching (TODO)
Table of contents: Table of contents:
......
...@@ -344,7 +344,7 @@ Feature x Hardware ...@@ -344,7 +344,7 @@ Feature x Hardware
- ✅ - ✅
- ✅ - ✅
- ✅ - ✅
- -
- ✅ - ✅
* - :ref:`APC <apc>` * - :ref:`APC <apc>`
- `✗ <https://github.com/vllm-project/vllm/issues/3687>`__ - `✗ <https://github.com/vllm-project/vllm/issues/3687>`__
...@@ -352,7 +352,7 @@ Feature x Hardware ...@@ -352,7 +352,7 @@ Feature x Hardware
- ✅ - ✅
- ✅ - ✅
- ✅ - ✅
- -
- ✅ - ✅
* - :ref:`LoRA <lora>` * - :ref:`LoRA <lora>`
- ✅ - ✅
......
...@@ -12,6 +12,7 @@ from contextlib import nullcontext ...@@ -12,6 +12,7 @@ from contextlib import nullcontext
import pytest import pytest
from tests.kernels.utils import override_backend_env_variable from tests.kernels.utils import override_backend_env_variable
from vllm.platforms import current_platform
from ..models.utils import check_logprobs_close, check_outputs_equal from ..models.utils import check_logprobs_close, check_outputs_equal
from ..utils import multi_gpu_test from ..utils import multi_gpu_test
...@@ -206,12 +207,14 @@ def test_models_with_fp8_kv_cache( ...@@ -206,12 +207,14 @@ def test_models_with_fp8_kv_cache(
# NOTE: Increasing this in this suite will fail CI because we currently cannot # NOTE: Increasing this in this suite will fail CI because we currently cannot
# reset distributed env properly. Use a value > 1 just when you test. # reset distributed env properly. Use a value > 1 just when you test.
@pytest.mark.parametrize("tensor_parallel_size", [1]) @pytest.mark.parametrize("tensor_parallel_size", [1])
@pytest.mark.parametrize("dtype", ["half"])
def test_with_prefix_caching( def test_with_prefix_caching(
vllm_runner, vllm_runner,
max_tokens: int, max_tokens: int,
enforce_eager: bool, enforce_eager: bool,
chunk_size: int, chunk_size: int,
tensor_parallel_size: int, tensor_parallel_size: int,
dtype: str,
) -> None: ) -> None:
""" """
Checks exact match decode with and without prefix caching Checks exact match decode with and without prefix caching
...@@ -233,7 +236,7 @@ def test_with_prefix_caching( ...@@ -233,7 +236,7 @@ def test_with_prefix_caching(
for enable in (True, False): for enable in (True, False):
with vllm_runner( with vllm_runner(
model, model,
dtype="half", dtype=dtype,
max_num_batched_tokens=max_num_batched_tokens, max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=True, enable_chunked_prefill=True,
enable_prefix_caching=enable, enable_prefix_caching=enable,
...@@ -260,3 +263,61 @@ def test_with_prefix_caching( ...@@ -260,3 +263,61 @@ def test_with_prefix_caching(
name_0="w/o prefix caching", name_0="w/o prefix caching",
name_1="with prefix caching", name_1="with prefix caching",
) )
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
@pytest.mark.parametrize("enforce_eager", [False])
@pytest.mark.parametrize("attention_backend", ["TORCH_SDPA"])
@pytest.mark.cpu_model
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
def test_models_cpu(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
chunked_prefill_token_size: int,
enforce_eager: bool,
attention_backend: str,
monkeypatch,
) -> None:
test_models(
hf_runner,
vllm_runner,
example_prompts,
model,
dtype,
max_tokens,
chunked_prefill_token_size,
enforce_eager,
1,
attention_backend,
monkeypatch,
)
@pytest.mark.parametrize("max_tokens", [16])
@pytest.mark.parametrize("enforce_eager", [False])
@pytest.mark.parametrize("chunk_size", [30, 32])
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.cpu_model
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
def test_with_prefix_caching_cpu(
vllm_runner,
max_tokens: int,
enforce_eager: bool,
chunk_size: int,
dtype: str,
) -> None:
test_with_prefix_caching(
vllm_runner,
max_tokens,
enforce_eager,
chunk_size,
1,
dtype,
)
...@@ -7,18 +7,14 @@ import torch ...@@ -7,18 +7,14 @@ import torch
from torch.nn.functional import scaled_dot_product_attention from torch.nn.functional import scaled_dot_product_attention
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType) AttentionMetadata,
AttentionMetadataBuilder,
AttentionType)
from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.ipex_attn import PagedAttention
from vllm.attention.ops.paged_attn import PagedAttentionMetadata from vllm.attention.ops.paged_attn import PagedAttentionMetadata
from vllm.platforms import current_platform from vllm.utils import make_tensor_with_pad
from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder
if current_platform.is_cpu():
try:
from vllm.attention.ops.ipex_attn import PagedAttention
except ImportError:
from vllm.attention.ops.paged_attn import PagedAttention
else:
from vllm.attention.ops.paged_attn import PagedAttention
class TorchSDPABackend(AttentionBackend): class TorchSDPABackend(AttentionBackend):
...@@ -39,6 +35,10 @@ class TorchSDPABackend(AttentionBackend): ...@@ -39,6 +35,10 @@ class TorchSDPABackend(AttentionBackend):
def get_state_cls() -> Type["CommonAttentionState"]: def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState return CommonAttentionState
@staticmethod
def get_builder_cls() -> Type["TorchSDPAMetadataBuilder"]:
return TorchSDPAMetadataBuilder
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(
num_blocks: int, num_blocks: int,
...@@ -71,9 +71,15 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -71,9 +71,15 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
""" """
# Currently, input sequences can only contain all prompts # Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts. # or all decoding. True if all sequences are prompts.
is_prompt: bool chunked_prefill: bool
slot_mapping: torch.Tensor seq_lens: Optional[List[int]] = None # For non-chunked prefill
seq_lens: Optional[List[int]]
# For chunked prefill only
max_query_len: Optional[int] = None
max_kv_len: Optional[int] = None
query_start_loc: Optional[torch.Tensor] = None
kv_start_loc: Optional[torch.Tensor] = None
prefill_block_tables: Optional[torch.Tensor] = None
# Begin encoder attn & enc/dec cross-attn fields... # Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation # Encoder sequence lengths representation
...@@ -123,20 +129,14 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -123,20 +129,14 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
@property @property
def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]: def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]:
# Currently chunked prefill is not supported if self.num_prefill_tokens == 0:
if self.num_decode_tokens == 0: return None
assert self.num_prefills > 0 return self
return self
return None
@property @property
def decode_metadata(self) -> Optional["TorchSDPAMetadata"]: def decode_metadata(self) -> Optional["TorchSDPAMetadata"]:
# Currently chunked prefill is not supported if self.num_decode_tokens == 0:
if self.num_prefills > 0:
assert self.num_decode_tokens == 0
return None return None
return self return self
def get_seq_lens( def get_seq_lens(
...@@ -274,6 +274,105 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -274,6 +274,105 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
raise AttributeError(f"Invalid attention type {str(attn_type)}") raise AttributeError(f"Invalid attention type {str(attn_type)}")
class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]):
def __init__(self, input_builder: ModelInputForCPUBuilder) -> None:
self.chunked_prefill = input_builder.chunked_prefill
self.input_data = input_builder.input_data
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int) -> TorchSDPAMetadata:
input_data = self.input_data
prefill_seq_lens = seq_lens[0:input_data.num_prefills]
prefill_query_lens = query_lens[0:input_data.num_prefills]
slot_mapping = torch.tensor(input_data.slot_mapping,
dtype=torch.long,
device="cpu")
# For chunked-prefill
if self.chunked_prefill and input_data.num_prefill_tokens != 0:
prefill_block_tables = make_tensor_with_pad(
self.input_data.prefill_block_tables,
pad=0,
dtype=torch.int32,
device="cpu",
)
query_lens_tensor = torch.tensor(prefill_query_lens,
dtype=torch.int32,
device="cpu")
kv_lens_tensor = torch.tensor(prefill_seq_lens,
dtype=torch.int32,
device="cpu")
query_start_loc = torch.zeros(input_data.num_prefills + 1,
dtype=torch.int32,
device="cpu")
kv_start_loc = torch.zeros(input_data.num_prefills + 1,
dtype=torch.int32,
device="cpu")
torch.cumsum(query_lens_tensor,
dim=0,
dtype=torch.int32,
out=query_start_loc[1:])
torch.cumsum(kv_lens_tensor,
dim=0,
dtype=torch.int32,
out=kv_start_loc[1:])
max_query_len = max(prefill_query_lens)
max_kv_len = max(prefill_seq_lens)
else:
prefill_block_tables = None
query_start_loc = None
kv_start_loc = None
max_query_len = None
max_kv_len = None
# For paged attention
if input_data.num_decode_tokens != 0:
seq_lens_tensor = torch.tensor(
input_data.seq_lens[input_data.num_prefills:],
dtype=torch.int32,
device="cpu",
)
block_tables = make_tensor_with_pad(
self.input_data.decode_block_tables,
pad=0,
dtype=torch.int32,
device="cpu",
)
else:
block_tables = torch.tensor([])
seq_lens_tensor = torch.tensor([])
# For multi-modal models
placeholder_index_maps = None
if len(input_data.multi_modal_inputs_list) != 0:
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
input_data.multi_modal_placeholder_maps.items()
}
attn_metadata = TorchSDPAMetadata(
chunked_prefill=self.chunked_prefill,
seq_lens=prefill_seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_kv_len=max_kv_len,
query_start_loc=query_start_loc,
kv_start_loc=kv_start_loc,
max_decode_seq_len=input_data.max_decode_seq_len,
num_prefills=input_data.num_prefills,
num_prefill_tokens=input_data.num_prefill_tokens,
num_decode_tokens=input_data.num_decode_tokens,
block_tables=block_tables,
prefill_block_tables=prefill_block_tables,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=placeholder_index_maps,
)
return attn_metadata
class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
def __init__( def __init__(
...@@ -409,19 +508,35 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -409,19 +508,35 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
assert key.shape[0] == num_prefill_tokens + num_decode_tokens assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens assert value.shape[0] == num_prefill_tokens + num_decode_tokens
output = torch.empty_like(query)
if prefill_meta := attn_metadata.prefill_metadata: if prefill_meta := attn_metadata.prefill_metadata:
assert attn_metadata.seq_lens is not None assert attn_metadata.seq_lens is not None
if (kv_cache.numel() == 0 if not prefill_meta.prefill_metadata.chunked_prefill: # type: ignore
or prefill_meta.block_tables.numel() == 0): self._run_sdpa_forward(output,
output = self._run_sdpa_forward(query, query,
key, key,
value, value,
prefill_meta, prefill_meta,
attn_type=attn_type) attn_type=attn_type)
else: else:
# prefix-enabled attention # prefix-enabled attention
raise RuntimeError( assert not self.need_mask
"Torch SDPA backend doesn't support prefix decoding.") import intel_extension_for_pytorch.llm.modules as ipex_modules
output = torch.empty_like(query)
ipex_modules.PagedAttention.flash_attn_varlen_func(
output[:prefill_meta.num_prefill_tokens, :, :],
query[:prefill_meta.num_prefill_tokens, :, :],
key_cache,
value_cache,
prefill_meta.query_start_loc,
prefill_meta.kv_start_loc,
prefill_meta.max_query_len,
prefill_meta.max_kv_len,
self.scale,
True,
prefill_meta.prefill_block_tables,
self.alibi_slopes,
)
if decode_meta := attn_metadata.decode_metadata: if decode_meta := attn_metadata.decode_metadata:
assert attn_type != AttentionType.ENCODER_ONLY, ( assert attn_type != AttentionType.ENCODER_ONLY, (
...@@ -433,8 +548,9 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -433,8 +548,9 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
block_tables_arg, block_tables_arg,
) = decode_meta.get_seq_len_block_table_args(attn_type) ) = decode_meta.get_seq_len_block_table_args(attn_type)
output = PagedAttention.forward_decode( PagedAttention.forward_decode(
query, output[attn_metadata.num_prefill_tokens:, :, :],
query[attn_metadata.num_prefill_tokens:, :, :],
key_cache, key_cache,
value_cache, value_cache,
block_tables_arg, block_tables_arg,
...@@ -453,12 +569,13 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -453,12 +569,13 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
def _run_sdpa_forward( def _run_sdpa_forward(
self, self,
output: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
attn_metadata: TorchSDPAMetadata, attn_metadata: TorchSDPAMetadata,
attn_type: AttentionType = AttentionType.DECODER, attn_type: AttentionType = AttentionType.DECODER,
): ) -> None:
if self.num_kv_heads != self.num_heads: if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1) key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
value = value.repeat_interleave(self.num_queries_per_kv, dim=1) value = value.repeat_interleave(self.num_queries_per_kv, dim=1)
...@@ -479,7 +596,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -479,7 +596,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
attn_masks = [None] * len(seq_lens) attn_masks = [None] * len(seq_lens)
attn_metadata.set_attn_bias(attn_masks, attn_type) attn_metadata.set_attn_bias(attn_masks, attn_type)
output = torch.empty_like(query)
query = query.movedim(0, query.dim() - 2) query = query.movedim(0, query.dim() - 2)
key = key.movedim(0, key.dim() - 2) key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2) value = value.movedim(0, value.dim() - 2)
...@@ -502,7 +618,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -502,7 +618,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0) scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0)
output[start_q:end_q, :, :] = sub_out output[start_q:end_q, :, :] = sub_out
start_q, start_kv = end_q, end_kv start_q, start_kv = end_q, end_kv
return output
def _make_alibi_bias( def _make_alibi_bias(
......
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import intel_extension_for_pytorch.llm.modules as ipex_modules try:
import intel_extension_for_pytorch.llm.modules as ipex_modules
_use_ipex = True
except ImportError:
_use_ipex = False
import torch import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
class PagedAttention: class _PagedAttention:
@staticmethod @staticmethod
def get_supported_head_sizes() -> List[int]: def get_supported_head_sizes() -> List[int]:
...@@ -22,6 +27,105 @@ class PagedAttention: ...@@ -22,6 +27,105 @@ class PagedAttention:
) -> Tuple[int, ...]: ) -> Tuple[int, ...]:
return (2, num_blocks, block_size * num_kv_heads * head_size) return (2, num_blocks, block_size * num_kv_heads * head_size)
@staticmethod
def split_kv_cache(
kv_cache: torch.Tensor,
num_kv_heads: int,
head_size: int,
*args,
) -> Tuple[torch.Tensor, torch.Tensor]:
x = 16 // kv_cache.element_size()
num_blocks = kv_cache.shape[1]
key_cache = kv_cache[0]
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
-1, x)
value_cache = kv_cache[1]
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
return key_cache, value_cache
@staticmethod
def write_to_paged_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
*args,
) -> None:
ops.reshape_and_cache(
key,
value,
key_cache,
value_cache,
slot_mapping.flatten(),
kv_cache_dtype,
k_scale,
v_scale,
)
@staticmethod
def forward_decode(
output: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
max_context_len: int,
kv_cache_dtype: str,
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
k_scale: float,
v_scale: float,
*args,
) -> None:
tp_rank: int = 0
blocksparse_local_blocks: int = 0
blocksparse_vert_stride: int = 0
blocksparse_block_size: int = 64
blocksparse_head_sliding_step: int = 0
block_size = value_cache.shape[3]
ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
*args,
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
ops.copy_blocks(key_caches, value_caches, src_to_dists)
class _IPEXPagedAttention(_PagedAttention):
@staticmethod @staticmethod
def split_kv_cache( def split_kv_cache(
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
...@@ -55,6 +159,7 @@ class PagedAttention: ...@@ -55,6 +159,7 @@ class PagedAttention:
@staticmethod @staticmethod
def forward_decode( def forward_decode(
output: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key_cache: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, value_cache: torch.Tensor,
...@@ -68,8 +173,7 @@ class PagedAttention: ...@@ -68,8 +173,7 @@ class PagedAttention:
k_scale: float, k_scale: float,
v_scale: float, v_scale: float,
*args, *args,
) -> torch.Tensor: ) -> None:
output = torch.empty_like(query)
block_size = value_cache.shape[2] block_size = value_cache.shape[2]
head_mapping = torch.arange( head_mapping = torch.arange(
0, 0,
...@@ -83,41 +187,5 @@ class PagedAttention: ...@@ -83,41 +187,5 @@ class PagedAttention:
scale, block_tables, context_lens, block_size, max_context_len, scale, block_tables, context_lens, block_size, max_context_len,
alibi_slopes) alibi_slopes)
return output
@staticmethod
def forward_prefix(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache_dtype: str,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
subquery_start_loc: torch.Tensor,
prompt_lens_tensor: torch.Tensor,
context_lens: torch.Tensor,
max_subquery_len: int,
alibi_slopes: Optional[torch.Tensor],
*args,
) -> torch.Tensor:
raise NotImplementedError
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
*args,
) -> None:
raise NotImplementedError
@staticmethod PagedAttention = _IPEXPagedAttention if _use_ipex else _PagedAttention
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
*args,
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
ops.copy_blocks(key_caches, value_caches, src_to_dists)
...@@ -53,11 +53,6 @@ class CpuPlatform(Platform): ...@@ -53,11 +53,6 @@ class CpuPlatform(Platform):
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
if cache_config.enable_prefix_caching:
logger.warning(
"Prefix caching is not supported on CPU, disable it.")
cache_config.enable_prefix_caching = False
kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE
if kv_cache_space >= 0: if kv_cache_space >= 0:
...@@ -74,10 +69,12 @@ class CpuPlatform(Platform): ...@@ -74,10 +69,12 @@ class CpuPlatform(Platform):
f" {kv_cache_space}, expect a positive integer value.") f" {kv_cache_space}, expect a positive integer value.")
scheduler_config = vllm_config.scheduler_config scheduler_config = vllm_config.scheduler_config
if scheduler_config.chunked_prefill_enabled: if ((scheduler_config.chunked_prefill_enabled
logger.warning( or cache_config.enable_prefix_caching)
"Chunked prefill is not supported on CPU, disable it.") and model_config.dtype == torch.half):
scheduler_config.chunked_prefill_enabled = False logger.warning("Chunked-prefill on the CPU backend only does not"
" support fp16 for now, cast to bf16.")
model_config.dtype = torch.bfloat16
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
if (parallel_config.distributed_executor_backend is not None if (parallel_config.distributed_executor_backend is not None
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment