Unverified Commit 5510cf0e authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Misc] Add `get_name` method to attention backends (#4685)

parent 0f9a6e3d
......@@ -9,6 +9,11 @@ import torch
class AttentionBackend(ABC):
"""Abstract class for attention backends."""
@staticmethod
@abstractmethod
def get_name() -> str:
raise NotImplementedError
@staticmethod
@abstractmethod
def get_impl_cls() -> Type["AttentionImpl"]:
......
......@@ -19,6 +19,10 @@ from vllm.attention.ops.paged_attn import (PagedAttention,
class FlashAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "flash-attn"
@staticmethod
def get_impl_cls() -> Type["FlashAttentionImpl"]:
return FlashAttentionImpl
......
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple, Type
try:
import flashinfer
from flash_attn import flash_attn_varlen_func
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
except ImportError:
flashinfer = None
flash_attn_varlen_func = None
BatchDecodeWithPagedKVCacheWrapper = None
import flashinfer
import torch
from flash_attn import flash_attn_varlen_func
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
......@@ -20,6 +14,10 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
class FlashInferBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "flashinfer"
@staticmethod
def get_impl_cls() -> Type["FlashInferImpl"]:
return FlashInferImpl
......
......@@ -17,6 +17,10 @@ logger = init_logger(__name__)
class ROCmFlashAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "rocm-flash-attn"
@staticmethod
def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
return ROCmFlashAttentionImpl
......
......@@ -15,6 +15,10 @@ from vllm.attention.ops.paged_attn import (PagedAttention,
class TorchSDPABackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "torch-sdpa"
@staticmethod
def get_impl_cls() -> Type["TorchSDPABackendImpl"]:
return TorchSDPABackendImpl
......
......@@ -20,6 +20,10 @@ logger = init_logger(__name__)
class XFormersBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "xformers"
@staticmethod
def get_impl_cls() -> Type["XFormersImpl"]:
return XFormersImpl
......
......@@ -9,7 +9,6 @@ import torch.nn as nn
from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
get_attn_backend)
from vllm.attention.backends.flashinfer import FlashInferBackend
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
......@@ -395,7 +394,7 @@ class ModelRunner:
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
if self.attn_backend is FlashInferBackend:
if self.attn_backend.get_name() == "flashinfer":
attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
use_cuda_graph=False,
......@@ -556,7 +555,7 @@ class ModelRunner:
device=self.device,
)
if self.attn_backend is FlashInferBackend:
if self.attn_backend.get_name() == "flashinfer":
if not hasattr(self, "flashinfer_workspace_buffer"):
# Allocate 16MB workspace buffer
# Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment