Unverified Commit 352b90c4 authored by Li, Jiang's avatar Li, Jiang Committed by GitHub
Browse files

[Bugfix] Add replacement of _compute_slot_mapping_kernel on CPU (#37987)


Signed-off-by: default avatarjiang1.li <jiang1.li@intel.com>
parent 1c0aabde
...@@ -3,7 +3,6 @@ depends_on: [] ...@@ -3,7 +3,6 @@ depends_on: []
steps: steps:
- label: CPU-Kernel Tests - label: CPU-Kernel Tests
depends_on: [] depends_on: []
soft_fail: true
device: intel_cpu device: intel_cpu
no_plugin: true no_plugin: true
source_file_dependencies: source_file_dependencies:
...@@ -23,7 +22,6 @@ steps: ...@@ -23,7 +22,6 @@ steps:
- label: CPU-Compatibility Tests - label: CPU-Compatibility Tests
depends_on: [] depends_on: []
soft_fail: true
device: intel_cpu device: intel_cpu
no_plugin: true no_plugin: true
source_file_dependencies: source_file_dependencies:
...@@ -37,7 +35,6 @@ steps: ...@@ -37,7 +35,6 @@ steps:
- label: CPU-Language Generation and Pooling Model Tests - label: CPU-Language Generation and Pooling Model Tests
depends_on: [] depends_on: []
soft_fail: true
device: intel_cpu device: intel_cpu
no_plugin: true no_plugin: true
source_file_dependencies: source_file_dependencies:
...@@ -53,7 +50,6 @@ steps: ...@@ -53,7 +50,6 @@ steps:
- label: CPU-Quantization Model Tests - label: CPU-Quantization Model Tests
depends_on: [] depends_on: []
soft_fail: true
device: intel_cpu device: intel_cpu
no_plugin: true no_plugin: true
source_file_dependencies: source_file_dependencies:
...@@ -73,7 +69,6 @@ steps: ...@@ -73,7 +69,6 @@ steps:
- label: CPU-Distributed Tests - label: CPU-Distributed Tests
depends_on: [] depends_on: []
soft_fail: true
device: intel_cpu device: intel_cpu
no_plugin: true no_plugin: true
source_file_dependencies: source_file_dependencies:
...@@ -92,7 +87,6 @@ steps: ...@@ -92,7 +87,6 @@ steps:
- label: CPU-Multi-Modal Model Tests %N - label: CPU-Multi-Modal Model Tests %N
depends_on: [] depends_on: []
soft_fail: true
device: intel_cpu device: intel_cpu
no_plugin: true no_plugin: true
source_file_dependencies: source_file_dependencies:
...@@ -107,7 +101,6 @@ steps: ...@@ -107,7 +101,6 @@ steps:
- label: "Arm CPU Test" - label: "Arm CPU Test"
depends_on: [] depends_on: []
soft_fail: true
device: arm_cpu device: arm_cpu
no_plugin: true no_plugin: true
commands: commands:
......
...@@ -126,6 +126,12 @@ void cpu_fused_moe(torch::Tensor& output, const torch::Tensor& input, ...@@ -126,6 +126,12 @@ void cpu_fused_moe(torch::Tensor& output, const torch::Tensor& input,
const torch::Tensor& topk_id, const bool skip_weighted, const torch::Tensor& topk_id, const bool skip_weighted,
const std::string& act, const std::string& isa); const std::string& act, const std::string& isa);
void compute_slot_mapping_kernel_impl(const torch::Tensor query_start_loc,
const torch::Tensor positions,
const torch::Tensor block_table,
torch::Tensor slot_mapping,
const int64_t block_size);
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops // vLLM custom ops
...@@ -334,6 +340,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -334,6 +340,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor! out, Tensor query, Tensor kv_cache," " Tensor! out, Tensor query, Tensor kv_cache,"
" float scale, Tensor block_tables, Tensor seq_lens) -> ()"); " float scale, Tensor block_tables, Tensor seq_lens) -> ()");
ops.impl("mla_decode_kvcache", torch::kCPU, &mla_decode_kvcache); ops.impl("mla_decode_kvcache", torch::kCPU, &mla_decode_kvcache);
ops.def(
"compute_slot_mapping_kernel_impl(Tensor query_start_loc, Tensor "
"positions, Tensor block_table, Tensor(a3!) slot_mapping, SymInt "
"block_size) -> ()",
&compute_slot_mapping_kernel_impl);
} }
REGISTER_EXTENSION(TORCH_EXTENSION_NAME) REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
...@@ -189,3 +189,38 @@ ScratchPadManager* ScratchPadManager::get_scratchpad_manager() { ...@@ -189,3 +189,38 @@ ScratchPadManager* ScratchPadManager::get_scratchpad_manager() {
return &manager; return &manager;
} }
} // namespace cpu_utils } // namespace cpu_utils
void compute_slot_mapping_kernel_impl(const torch::Tensor query_start_loc,
const torch::Tensor positions,
const torch::Tensor block_table,
torch::Tensor slot_mapping,
const int64_t block_size) {
const int32_t req_num = query_start_loc.size(0) - 1;
const int64_t block_table_stride = block_table.stride(0);
const int32_t* __restrict__ query_start_loc_ptr =
query_start_loc.data_ptr<int32_t>();
const int64_t* __restrict__ positions_ptr = positions.data_ptr<int64_t>();
const int32_t* __restrict__ blocktable_ptr = block_table.data_ptr<int32_t>();
int64_t* __restrict__ slot_mapping_ptr = slot_mapping.data_ptr<int64_t>();
#pragma omp parallel for
for (int32_t req_idx = 0; req_idx < req_num; ++req_idx) {
int32_t token_start_idx = query_start_loc_ptr[req_idx];
int32_t token_end_idx = query_start_loc_ptr[req_idx + 1];
int32_t token_num = token_end_idx - token_start_idx;
const int64_t* __restrict__ curr_position_ptr =
positions_ptr + token_start_idx;
int64_t* __restrict__ curr_slot_mapping_ptr =
slot_mapping_ptr + token_start_idx;
const int32_t* __restrict__ curr_block_table_ptr =
blocktable_ptr + req_idx * block_table_stride;
for (int32_t token_idx = 0; token_idx < token_num; ++token_idx) {
int64_t token_position = curr_position_ptr[token_idx];
int64_t block_id = curr_block_table_ptr[token_position / block_size];
curr_slot_mapping_ptr[token_idx] =
block_id * block_size + token_position % block_size;
}
}
}
...@@ -161,7 +161,7 @@ RUN ln -s /usr/bin/clangd-14 /usr/bin/clangd ...@@ -161,7 +161,7 @@ RUN ln -s /usr/bin/clangd-14 /usr/bin/clangd
# install development dependencies (for testing) # install development dependencies (for testing)
RUN --mount=type=cache,target=/root/.cache/uv \ RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install -e tests/vllm_test_utils uv pip install --no-build-isolation -e tests/vllm_test_utils
RUN --mount=type=cache,target=/root/.cache/uv \ RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=cache,target=/root/.cache/ccache \ --mount=type=cache,target=/root/.cache/ccache \
......
...@@ -309,7 +309,7 @@ class AWQMarlinConfig(QuantizationConfig): ...@@ -309,7 +309,7 @@ class AWQMarlinConfig(QuantizationConfig):
group_size = quant_config.get("group_size") group_size = quant_config.get("group_size")
zero_point = quant_config.get("zero_point") zero_point = quant_config.get("zero_point")
if not (current_platform.is_cuda_alike() or current_platform.is_cpu()): if not current_platform.is_cuda_alike():
return False return False
if quant_method != "awq": if quant_method != "awq":
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Contains replacement functions to fallback Triton usages in CPU backend
"""
from collections.abc import Callable
import torch
class _FuncWrapper:
def __init__(self, func: Callable) -> None:
self.func = func
def __getitem__(self, *args, **kwargs) -> Callable:
return self.func
# For _compute_slot_mapping_kernel in vllm/v1/worker/block_table.py
def _compute_slot_mapping_kernel_impl(
num_tokens: int,
max_num_tokens: int,
query_start_loc: torch.Tensor, # [num_reqs + 1], int32
positions: torch.Tensor, # [num_tokens], int64
block_table: torch.Tensor, # [max_num_reqs, max_num_blocks_per_req], int32
block_table_stride: int, # max_num_blocks_per_req
block_size: int,
slot_mapping: torch.Tensor, # [max_num_tokens], int64
TOTAL_CP_WORLD_SIZE: int,
TOTAL_CP_RANK: int,
CP_KV_CACHE_INTERLEAVE_SIZE: int,
PAD_ID: int,
BLOCK_SIZE: int,
) -> None:
assert TOTAL_CP_WORLD_SIZE == 1, "Context Parallelism is not supported on CPU."
torch.ops._C.compute_slot_mapping_kernel_impl(
query_start_loc,
positions,
block_table,
slot_mapping,
block_size,
)
compute_slot_mapping_kernel = _FuncWrapper(_compute_slot_mapping_kernel_impl)
...@@ -6,6 +6,7 @@ from typing import Any ...@@ -6,6 +6,7 @@ from typing import Any
import torch import torch
import torch.nn as nn import torch.nn as nn
import vllm.utils.cpu_triton_utils as cpu_tl
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
...@@ -28,6 +29,7 @@ class CPUModelRunner(GPUModelRunner): ...@@ -28,6 +29,7 @@ class CPUModelRunner(GPUModelRunner):
self.cascade_attn_enabled = False self.cascade_attn_enabled = False
self._postprocess_tensors() self._postprocess_tensors()
self._postprocess_triton()
def _postprocess_tensors(self) -> None: def _postprocess_tensors(self) -> None:
# Note: replace device tensors with cpu tensors # Note: replace device tensors with cpu tensors
...@@ -52,6 +54,13 @@ class CPUModelRunner(GPUModelRunner): ...@@ -52,6 +54,13 @@ class CPUModelRunner(GPUModelRunner):
if isinstance(v, CpuGpuBuffer): if isinstance(v, CpuGpuBuffer):
v.gpu = v.cpu v.gpu = v.cpu
def _postprocess_triton(self) -> None:
import vllm.v1.worker.block_table
vllm.v1.worker.block_table._compute_slot_mapping_kernel = (
cpu_tl.compute_slot_mapping_kernel
)
@instrument(span_name="Loading (CPU)") @instrument(span_name="Loading (CPU)")
def load_model(self, load_dummy_weights: bool = False) -> None: def load_model(self, load_dummy_weights: bool = False) -> None:
if load_dummy_weights: if load_dummy_weights:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment