# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ Contains replacement functions to fallback Triton usages in CPU backend """ from collections.abc import Callable import torch class _FuncWrapper: def __init__(self, func: Callable) -> None: self.func = func def __getitem__(self, *args, **kwargs) -> Callable: return self.func # For _compute_slot_mapping_kernel in vllm/v1/worker/block_table.py def _compute_slot_mapping_kernel_impl( num_tokens: int, max_num_tokens: int, query_start_loc: torch.Tensor, # [num_reqs + 1], int32 positions: torch.Tensor, # [num_tokens], int64 block_table: torch.Tensor, # [max_num_reqs, max_num_blocks_per_req], int32 block_table_stride: int, # max_num_blocks_per_req block_size: int, slot_mapping: torch.Tensor, # [max_num_tokens], int64 TOTAL_CP_WORLD_SIZE: int, TOTAL_CP_RANK: int, CP_KV_CACHE_INTERLEAVE_SIZE: int, PAD_ID: int, BLOCK_SIZE: int, ) -> None: assert TOTAL_CP_WORLD_SIZE == 1, "Context Parallelism is not supported on CPU." torch.ops._C.compute_slot_mapping_kernel_impl( query_start_loc, positions, block_table, slot_mapping, block_size, ) compute_slot_mapping_kernel = _FuncWrapper(_compute_slot_mapping_kernel_impl)