Unverified Commit bf3ffb61 authored by Benjamin Chislett's avatar Benjamin Chislett Committed by GitHub
Browse files

[Bugfix] Fix ChunkedLocalAttention CUDA Graph setting (#28739)


Signed-off-by: default avatarBenjamin Chislett <bchislett@nvidia.com>
parent e5c78956
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from typing import ClassVar
import torch
......@@ -12,11 +11,16 @@ from vllm.config.vllm import VllmConfig
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
make_local_attention_virtual_batches,
subclass_attention_backend,
)
from vllm.v1.kv_cache_interface import ChunkedLocalAttentionSpec, KVCacheSpec
from vllm.v1.kv_cache_interface import (
AttentionSpec,
ChunkedLocalAttentionSpec,
KVCacheSpec,
)
from ..layer import Attention
......@@ -30,9 +34,18 @@ def create_chunked_local_attention_backend(
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_"
underlying_builder = underlying_attn_backend.get_builder_cls()
assert issubclass(underlying_builder, AttentionMetadataBuilder)
class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
@classmethod
def get_cudagraph_support(
cls: type["AttentionMetadataBuilder"],
vllm_config: VllmConfig,
kv_cache_spec: AttentionSpec,
) -> AttentionCGSupport:
# Explicit override in case the underlying builder specialized this getter.
# @override omitted only because of mypy limitation due to type variable.
return AttentionCGSupport.NEVER
def build(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment