Unverified Commit 5510cf0e authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Misc] Add `get_name` method to attention backends (#4685)

parent 0f9a6e3d
...@@ -9,6 +9,11 @@ import torch ...@@ -9,6 +9,11 @@ import torch
class AttentionBackend(ABC): class AttentionBackend(ABC):
"""Abstract class for attention backends.""" """Abstract class for attention backends."""
@staticmethod
@abstractmethod
def get_name() -> str:
raise NotImplementedError
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def get_impl_cls() -> Type["AttentionImpl"]: def get_impl_cls() -> Type["AttentionImpl"]:
......
...@@ -19,6 +19,10 @@ from vllm.attention.ops.paged_attn import (PagedAttention, ...@@ -19,6 +19,10 @@ from vllm.attention.ops.paged_attn import (PagedAttention,
class FlashAttentionBackend(AttentionBackend): class FlashAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "flash-attn"
@staticmethod @staticmethod
def get_impl_cls() -> Type["FlashAttentionImpl"]: def get_impl_cls() -> Type["FlashAttentionImpl"]:
return FlashAttentionImpl return FlashAttentionImpl
......
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple, Type from typing import Any, Dict, List, Optional, Set, Tuple, Type
try: import flashinfer
import flashinfer
from flash_attn import flash_attn_varlen_func
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
except ImportError:
flashinfer = None
flash_attn_varlen_func = None
BatchDecodeWithPagedKVCacheWrapper = None
import torch import torch
from flash_attn import flash_attn_varlen_func
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
...@@ -20,6 +14,10 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, ...@@ -20,6 +14,10 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
class FlashInferBackend(AttentionBackend): class FlashInferBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "flashinfer"
@staticmethod @staticmethod
def get_impl_cls() -> Type["FlashInferImpl"]: def get_impl_cls() -> Type["FlashInferImpl"]:
return FlashInferImpl return FlashInferImpl
......
...@@ -17,6 +17,10 @@ logger = init_logger(__name__) ...@@ -17,6 +17,10 @@ logger = init_logger(__name__)
class ROCmFlashAttentionBackend(AttentionBackend): class ROCmFlashAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "rocm-flash-attn"
@staticmethod @staticmethod
def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]: def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
return ROCmFlashAttentionImpl return ROCmFlashAttentionImpl
......
...@@ -15,6 +15,10 @@ from vllm.attention.ops.paged_attn import (PagedAttention, ...@@ -15,6 +15,10 @@ from vllm.attention.ops.paged_attn import (PagedAttention,
class TorchSDPABackend(AttentionBackend): class TorchSDPABackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "torch-sdpa"
@staticmethod @staticmethod
def get_impl_cls() -> Type["TorchSDPABackendImpl"]: def get_impl_cls() -> Type["TorchSDPABackendImpl"]:
return TorchSDPABackendImpl return TorchSDPABackendImpl
......
...@@ -20,6 +20,10 @@ logger = init_logger(__name__) ...@@ -20,6 +20,10 @@ logger = init_logger(__name__)
class XFormersBackend(AttentionBackend): class XFormersBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "xformers"
@staticmethod @staticmethod
def get_impl_cls() -> Type["XFormersImpl"]: def get_impl_cls() -> Type["XFormersImpl"]:
return XFormersImpl return XFormersImpl
......
...@@ -9,7 +9,6 @@ import torch.nn as nn ...@@ -9,7 +9,6 @@ import torch.nn as nn
from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage, from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
get_attn_backend) get_attn_backend)
from vllm.attention.backends.flashinfer import FlashInferBackend
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig) ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
...@@ -395,7 +394,7 @@ class ModelRunner: ...@@ -395,7 +394,7 @@ class ModelRunner:
dtype=seq_start_loc.dtype, dtype=seq_start_loc.dtype,
out=seq_start_loc[1:]) out=seq_start_loc[1:])
if self.attn_backend is FlashInferBackend: if self.attn_backend.get_name() == "flashinfer":
attn_metadata = self.attn_backend.make_metadata( attn_metadata = self.attn_backend.make_metadata(
is_prompt=True, is_prompt=True,
use_cuda_graph=False, use_cuda_graph=False,
...@@ -556,7 +555,7 @@ class ModelRunner: ...@@ -556,7 +555,7 @@ class ModelRunner:
device=self.device, device=self.device,
) )
if self.attn_backend is FlashInferBackend: if self.attn_backend.get_name() == "flashinfer":
if not hasattr(self, "flashinfer_workspace_buffer"): if not hasattr(self, "flashinfer_workspace_buffer"):
# Allocate 16MB workspace buffer # Allocate 16MB workspace buffer
# Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment