Unverified Commit d983769c authored by who who who's avatar who who who Committed by GitHub
Browse files

fix cuda graph (#22721)


Signed-off-by: default avatarfsx950223 <fsx950223@outlook.com>
parent 8fd92092
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with AiterFlashAttention."""
from dataclasses import dataclass
from typing import ClassVar, Optional
from typing import Optional
import torch
......@@ -11,7 +11,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec
......@@ -231,7 +232,7 @@ class AiterFlashAttentionMetadata:
class AiterFlashAttentionMetadataBuilder(
AttentionMetadataBuilder[AiterFlashAttentionMetadata]):
full_cudagraph_supported: ClassVar[bool] = True
cudagraph_support = AttentionCGSupport.ALWAYS
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment