Unverified Commit 2aaa4238 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[Attention] Move Backend enum into registry (#25893)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent ad2d7880
......@@ -11,8 +11,8 @@ import pytest
import torch
from tests.quantization.utils import is_quant_method_supported
from tests.v1.attention.utils import _Backend
from vllm import LLM, SamplingParams
from vllm.attention.backends.registry import _Backend
from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
PassConfig)
......
......@@ -8,11 +8,11 @@ import torch._dynamo
from tests.compile.backend import LazyInitPass, TestBackend
from tests.models.utils import check_outputs_equal
from tests.v1.attention.utils import (BatchSpec, _Backend,
create_common_attn_metadata)
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
from vllm import LLM, SamplingParams
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.attention import Attention, AttentionMetadata
from vllm.attention.backends.registry import _Backend
from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.compilation.fusion import QUANT_OPS
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
......
......@@ -10,8 +10,9 @@ from unittest.mock import patch
import pytest
import torch
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.selector import _Backend, _cached_get_attn_backend
from vllm.attention.selector import _cached_get_attn_backend
from vllm.platforms import current_platform
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
......
......@@ -15,10 +15,10 @@ from torch._prims_common import TensorLikeType
from tests.kernels.quant_utils import native_w8a8_block_matmul
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
from vllm.attention.backends.registry import _Backend
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
from vllm.platforms.interface import _Backend
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL,
STR_XFORMERS_ATTN_VAL, make_tensor_with_pad)
......
......@@ -8,11 +8,11 @@ import pytest
import torch
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
from tests.v1.attention.utils import (BatchSpec, _Backend,
create_common_attn_metadata,
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
create_standard_kv_cache_spec,
create_vllm_config,
get_attention_backend)
from vllm.attention.backends.registry import _Backend
from vllm.config import ModelConfig
from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer
......
......@@ -6,12 +6,12 @@ from typing import Optional, Union
import pytest
import torch
from tests.v1.attention.utils import (BatchSpec, _Backend,
create_common_attn_metadata,
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
create_standard_kv_cache_spec,
create_vllm_config,
get_attention_backend)
from vllm import _custom_ops as ops
from vllm.attention.backends.registry import _Backend
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec
......
......@@ -8,10 +8,11 @@ from typing import Optional, Union
import pytest
import torch
from vllm.attention.backends.registry import _Backend
from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig,
LoadConfig, ModelConfig, ModelDType, ParallelConfig,
SchedulerConfig, VllmConfig)
from vllm.platforms import _Backend, current_platform
from vllm.platforms import current_platform
from vllm.utils import resolve_obj_by_qualname
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec
......
......@@ -8,10 +8,10 @@ import pytest
import torch
from tests.utils import get_attn_backend_list_based_on_platform
from tests.v1.attention.utils import (BatchSpec, _Backend,
create_common_attn_metadata,
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
create_standard_kv_cache_spec,
get_attention_backend)
from vllm.attention.backends.registry import _Backend
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VllmConfig)
......
......@@ -6,10 +6,10 @@ from unittest import mock
import pytest
import torch
from tests.v1.attention.utils import (BatchSpec, _Backend,
create_common_attn_metadata,
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
create_standard_kv_cache_spec,
get_attention_backend)
from vllm.attention.backends.registry import _Backend
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VllmConfig)
......
......@@ -6,9 +6,10 @@ from typing import Optional
import torch
from tests.v1.attention.utils import (_Backend, create_standard_kv_cache_spec,
from tests.v1.attention.utils import (create_standard_kv_cache_spec,
create_vllm_config,
get_attention_backend)
from vllm.attention.backends.registry import _Backend
from vllm.config import ParallelConfig, SpeculativeConfig
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention backend registry"""
import enum
class _Backend(enum.Enum):
FLASH_ATTN = enum.auto()
TRITON_ATTN = enum.auto()
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
ROCM_AITER_MLA = enum.auto()
ROCM_AITER_FA = enum.auto() # used for ViT attn backend
TORCH_SDPA = enum.auto()
FLASHINFER = enum.auto()
FLASHINFER_MLA = enum.auto()
TRITON_MLA = enum.auto()
CUTLASS_MLA = enum.auto()
FLASHMLA = enum.auto()
FLASH_ATTN_MLA = enum.auto()
PALLAS = enum.auto()
IPEX = enum.auto()
NO_ATTENTION = enum.auto()
FLEX_ATTENTION = enum.auto()
TREE_ATTN = enum.auto()
ROCM_ATTN = enum.auto()
......@@ -10,6 +10,7 @@ import torch.nn.functional as F
import vllm.envs as envs
from vllm.attention import AttentionType
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.registry import _Backend
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
from vllm.config import CacheConfig, get_current_vllm_config
......@@ -26,7 +27,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.platforms import _Backend, current_platform
from vllm.platforms import current_platform
from vllm.utils import GiB_bytes, direct_register_custom_op
logger = init_logger(__name__)
......
......@@ -11,8 +11,9 @@ import torch
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.registry import _Backend
from vllm.logger import init_logger
from vllm.platforms import _Backend, current_platform
from vllm.platforms import current_platform
from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname
logger = init_logger(__name__)
......
......@@ -20,6 +20,7 @@ import torch
import zmq
from vllm import envs
from vllm.attention.backends.registry import _Backend
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
......@@ -32,7 +33,7 @@ from vllm.distributed.parallel_state import (
from vllm.distributed.utils import divide
from vllm.forward_context import ForwardContext
from vllm.logger import init_logger
from vllm.platforms import _Backend, current_platform
from vllm.platforms import current_platform
from vllm.utils import make_zmq_path, make_zmq_socket
from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.core.sched.output import SchedulerOutput
......
......@@ -619,7 +619,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
# All possible options loaded dynamically from _Backend enum
"VLLM_ATTENTION_BACKEND":
env_with_choices("VLLM_ATTENTION_BACKEND", None,
lambda: list(__import__('vllm.platforms.interface', \
lambda: list(__import__(
'vllm.attention.backends.registry',
fromlist=['_Backend'])._Backend.__members__.keys())),
# If set, vllm will use flashinfer sampler
......
......@@ -9,6 +9,7 @@ import torch.nn.functional as F
from torch.nn import LayerNorm
from transformers.models.qwen2_vl import Qwen2VLProcessor
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability
from vllm.config import VllmConfig
from vllm.distributed import utils as dist_utils
......@@ -38,7 +39,6 @@ from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalDataDict
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.dotsocr import (DotsOCRConfig,
DotsVisionConfig)
......
......@@ -34,6 +34,7 @@ import torch.nn.functional as F
from einops import rearrange, repeat
from transformers import BatchFeature
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability
from vllm.config import VllmConfig
from vllm.distributed import parallel_state
......@@ -54,7 +55,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend, current_platform
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
......
......@@ -46,6 +46,7 @@ from transformers.models.glm4v.video_processing_glm4v import (
Glm4vVideoProcessor)
from transformers.video_utils import VideoMetadata
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability
from vllm.config import VllmConfig
from vllm.distributed import (get_tensor_model_parallel_world_size,
......@@ -69,7 +70,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
......
......@@ -17,6 +17,7 @@ from transformers.modeling_outputs import (BaseModelOutput,
BaseModelOutputWithPooling)
from transformers.utils import torch_int
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
......@@ -39,7 +40,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from vllm.utils.tensor_schema import TensorSchema, TensorShape
......
......@@ -38,6 +38,7 @@ from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability
from vllm.config import VllmConfig
from vllm.distributed import parallel_state
......@@ -62,7 +63,6 @@ from vllm.multimodal.evs import (compute_mrope_for_media,
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.utils import is_pin_memory_available
from vllm.utils.tensor_schema import TensorSchema, TensorShape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment