Unverified Commit 6f0f570c authored by Chen Zhang's avatar Chen Zhang Committed by GitHub
Browse files

[deepseek] kernel block size for UniformTypeKVCacheSpecs (#26559)


Signed-off-by: default avatarChen Zhang <zhangch99@outlook.com>
parent b545a0b2
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass from dataclasses import dataclass
from typing import ClassVar, Optional from typing import ClassVar, Optional, Union
import torch import torch
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionMetadata,
MultipleOf,
)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata
...@@ -47,6 +51,10 @@ class DeepseekV32IndexerBackend(AttentionBackend): ...@@ -47,6 +51,10 @@ class DeepseekV32IndexerBackend(AttentionBackend):
def get_kv_cache_stride_order() -> tuple[int, ...]: def get_kv_cache_stride_order() -> tuple[int, ...]:
return (0, 1, 2) return (0, 1, 2)
@classmethod
def get_supported_kernel_block_size(cls) -> list[Union[int, MultipleOf]]:
return [64]
@dataclass @dataclass
class DeepseekV32IndexerPrefillChunkMetadata: class DeepseekV32IndexerPrefillChunkMetadata:
......
...@@ -4242,9 +4242,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -4242,9 +4242,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for kv_cache_group_id, kv_cache_group in enumerate( for kv_cache_group_id, kv_cache_group in enumerate(
kv_cache_config.kv_cache_groups kv_cache_config.kv_cache_groups
): ):
if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec): kv_cache_spec = kv_cache_group.kv_cache_spec
if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs):
# All layers in the UniformTypeKVCacheSpecs have the same type,
# Pick an arbitrary one to dispatch.
kv_cache_spec = next(iter(kv_cache_spec.kv_cache_specs.values()))
if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec):
continue continue
elif isinstance(kv_cache_group.kv_cache_spec, AttentionSpec): elif isinstance(kv_cache_spec, AttentionSpec):
# This is an attention backend that supports virtual # This is an attention backend that supports virtual
# block splitting. Get the supported block sizes from # block splitting. Get the supported block sizes from
# all backends in the group. # all backends in the group.
...@@ -4254,10 +4259,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -4254,10 +4259,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
kv_manager_block_size, attn_groups kv_manager_block_size, attn_groups
) )
kernel_block_sizes.append(selected_kernel_size) kernel_block_sizes.append(selected_kernel_size)
elif isinstance(kv_cache_group.kv_cache_spec, MambaSpec): elif isinstance(kv_cache_spec, MambaSpec):
# This is likely Mamba or other non-attention cache, # This is likely Mamba or other non-attention cache,
# no splitting. # no splitting.
kernel_block_sizes.append(kv_cache_group.kv_cache_spec.block_size) kernel_block_sizes.append(kv_cache_spec.block_size)
else: else:
raise NotImplementedError( raise NotImplementedError(
f"unknown kv cache spec {kv_cache_group.kv_cache_spec}" f"unknown kv cache spec {kv_cache_group.kv_cache_spec}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment