Unverified Commit bf3ffb61 authored by Benjamin Chislett's avatar Benjamin Chislett Committed by GitHub
Browse files

[Bugfix] Fix ChunkedLocalAttention CUDA Graph setting (#28739)


Signed-off-by: default avatarBenjamin Chislett <bchislett@nvidia.com>
parent e5c78956
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools import functools
from typing import ClassVar
import torch import torch
...@@ -12,11 +11,16 @@ from vllm.config.vllm import VllmConfig ...@@ -12,11 +11,16 @@ from vllm.config.vllm import VllmConfig
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
make_local_attention_virtual_batches, make_local_attention_virtual_batches,
subclass_attention_backend, subclass_attention_backend,
) )
from vllm.v1.kv_cache_interface import ChunkedLocalAttentionSpec, KVCacheSpec from vllm.v1.kv_cache_interface import (
AttentionSpec,
ChunkedLocalAttentionSpec,
KVCacheSpec,
)
from ..layer import Attention from ..layer import Attention
...@@ -30,9 +34,18 @@ def create_chunked_local_attention_backend( ...@@ -30,9 +34,18 @@ def create_chunked_local_attention_backend(
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_" prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_"
underlying_builder = underlying_attn_backend.get_builder_cls() underlying_builder = underlying_attn_backend.get_builder_cls()
assert issubclass(underlying_builder, AttentionMetadataBuilder)
class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER @classmethod
def get_cudagraph_support(
cls: type["AttentionMetadataBuilder"],
vllm_config: VllmConfig,
kv_cache_spec: AttentionSpec,
) -> AttentionCGSupport:
# Explicit override in case the underlying builder specialized this getter.
# @override omitted only because of mypy limitation due to type variable.
return AttentionCGSupport.NEVER
def build( def build(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment