Unverified Commit 6f0f570c authored by Chen Zhang's avatar Chen Zhang Committed by GitHub
Browse files

[deepseek] kernel block size for UniformTypeKVCacheSpecs (#26559)


Signed-off-by: default avatarChen Zhang <zhangch99@outlook.com>
parent b545a0b2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar, Optional
from typing import ClassVar, Optional, Union
import torch
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionMetadata,
MultipleOf,
)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata
......@@ -47,6 +51,10 @@ class DeepseekV32IndexerBackend(AttentionBackend):
def get_kv_cache_stride_order() -> tuple[int, ...]:
return (0, 1, 2)
@classmethod
def get_supported_kernel_block_size(cls) -> list[Union[int, MultipleOf]]:
return [64]
@dataclass
class DeepseekV32IndexerPrefillChunkMetadata:
......
......@@ -4242,9 +4242,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for kv_cache_group_id, kv_cache_group in enumerate(
kv_cache_config.kv_cache_groups
):
if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec):
kv_cache_spec = kv_cache_group.kv_cache_spec
if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs):
# All layers in the UniformTypeKVCacheSpecs have the same type,
# Pick an arbitrary one to dispatch.
kv_cache_spec = next(iter(kv_cache_spec.kv_cache_specs.values()))
if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec):
continue
elif isinstance(kv_cache_group.kv_cache_spec, AttentionSpec):
elif isinstance(kv_cache_spec, AttentionSpec):
# This is an attention backend that supports virtual
# block splitting. Get the supported block sizes from
# all backends in the group.
......@@ -4254,10 +4259,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
kv_manager_block_size, attn_groups
)
kernel_block_sizes.append(selected_kernel_size)
elif isinstance(kv_cache_group.kv_cache_spec, MambaSpec):
elif isinstance(kv_cache_spec, MambaSpec):
# This is likely Mamba or other non-attention cache,
# no splitting.
kernel_block_sizes.append(kv_cache_group.kv_cache_spec.block_size)
kernel_block_sizes.append(kv_cache_spec.block_size)
else:
raise NotImplementedError(
f"unknown kv cache spec {kv_cache_group.kv_cache_spec}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment