Unverified Commit 2aaa4238 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[Attention] Move Backend enum into registry (#25893)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent ad2d7880
...@@ -11,8 +11,8 @@ import pytest ...@@ -11,8 +11,8 @@ import pytest
import torch import torch
from tests.quantization.utils import is_quant_method_supported from tests.quantization.utils import is_quant_method_supported
from tests.v1.attention.utils import _Backend
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.attention.backends.registry import _Backend
from vllm.attention.selector import global_force_attn_backend_context_manager from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
PassConfig) PassConfig)
......
...@@ -8,11 +8,11 @@ import torch._dynamo ...@@ -8,11 +8,11 @@ import torch._dynamo
from tests.compile.backend import LazyInitPass, TestBackend from tests.compile.backend import LazyInitPass, TestBackend
from tests.models.utils import check_outputs_equal from tests.models.utils import check_outputs_equal
from tests.v1.attention.utils import (BatchSpec, _Backend, from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
create_common_attn_metadata)
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.attention.backends.registry import _Backend
from vllm.attention.selector import global_force_attn_backend_context_manager from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.compilation.fusion import QUANT_OPS from vllm.compilation.fusion import QUANT_OPS
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
......
...@@ -10,8 +10,9 @@ from unittest.mock import patch ...@@ -10,8 +10,9 @@ from unittest.mock import patch
import pytest import pytest
import torch import torch
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import MultiHeadAttention from vllm.attention.layer import MultiHeadAttention
from vllm.attention.selector import _Backend, _cached_get_attn_backend from vllm.attention.selector import _cached_get_attn_backend
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.cpu import CpuPlatform from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform from vllm.platforms.cuda import CudaPlatform
......
...@@ -15,10 +15,10 @@ from torch._prims_common import TensorLikeType ...@@ -15,10 +15,10 @@ from torch._prims_common import TensorLikeType
from tests.kernels.quant_utils import native_w8a8_block_matmul from tests.kernels.quant_utils import native_w8a8_block_matmul
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
from vllm.attention.backends.registry import _Backend
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.utils import ( from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input) moe_kernel_quantize_input)
from vllm.platforms.interface import _Backend
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL,
STR_XFORMERS_ATTN_VAL, make_tensor_with_pad) STR_XFORMERS_ATTN_VAL, make_tensor_with_pad)
......
...@@ -8,11 +8,11 @@ import pytest ...@@ -8,11 +8,11 @@ import pytest
import torch import torch
from torch.nn.attention.flex_attention import create_block_mask, flex_attention from torch.nn.attention.flex_attention import create_block_mask, flex_attention
from tests.v1.attention.utils import (BatchSpec, _Backend, from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
create_common_attn_metadata,
create_standard_kv_cache_spec, create_standard_kv_cache_spec,
create_vllm_config, create_vllm_config,
get_attention_backend) get_attention_backend)
from vllm.attention.backends.registry import _Backend
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer
......
...@@ -6,12 +6,12 @@ from typing import Optional, Union ...@@ -6,12 +6,12 @@ from typing import Optional, Union
import pytest import pytest
import torch import torch
from tests.v1.attention.utils import (BatchSpec, _Backend, from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
create_common_attn_metadata,
create_standard_kv_cache_spec, create_standard_kv_cache_spec,
create_vllm_config, create_vllm_config,
get_attention_backend) get_attention_backend)
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.registry import _Backend
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec from vllm.v1.kv_cache_interface import FullAttentionSpec
......
...@@ -8,10 +8,11 @@ from typing import Optional, Union ...@@ -8,10 +8,11 @@ from typing import Optional, Union
import pytest import pytest
import torch import torch
from vllm.attention.backends.registry import _Backend
from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig, from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig,
LoadConfig, ModelConfig, ModelDType, ParallelConfig, LoadConfig, ModelConfig, ModelDType, ParallelConfig,
SchedulerConfig, VllmConfig) SchedulerConfig, VllmConfig)
from vllm.platforms import _Backend, current_platform from vllm.platforms import current_platform
from vllm.utils import resolve_obj_by_qualname from vllm.utils import resolve_obj_by_qualname
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec from vllm.v1.kv_cache_interface import FullAttentionSpec
......
...@@ -8,10 +8,10 @@ import pytest ...@@ -8,10 +8,10 @@ import pytest
import torch import torch
from tests.utils import get_attn_backend_list_based_on_platform from tests.utils import get_attn_backend_list_based_on_platform
from tests.v1.attention.utils import (BatchSpec, _Backend, from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
create_common_attn_metadata,
create_standard_kv_cache_spec, create_standard_kv_cache_spec,
get_attention_backend) get_attention_backend)
from vllm.attention.backends.registry import _Backend
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig,
VllmConfig) VllmConfig)
......
...@@ -6,10 +6,10 @@ from unittest import mock ...@@ -6,10 +6,10 @@ from unittest import mock
import pytest import pytest
import torch import torch
from tests.v1.attention.utils import (BatchSpec, _Backend, from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
create_common_attn_metadata,
create_standard_kv_cache_spec, create_standard_kv_cache_spec,
get_attention_backend) get_attention_backend)
from vllm.attention.backends.registry import _Backend
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig,
VllmConfig) VllmConfig)
......
...@@ -6,9 +6,10 @@ from typing import Optional ...@@ -6,9 +6,10 @@ from typing import Optional
import torch import torch
from tests.v1.attention.utils import (_Backend, create_standard_kv_cache_spec, from tests.v1.attention.utils import (create_standard_kv_cache_spec,
create_vllm_config, create_vllm_config,
get_attention_backend) get_attention_backend)
from vllm.attention.backends.registry import _Backend
from vllm.config import ParallelConfig, SpeculativeConfig from vllm.config import ParallelConfig, SpeculativeConfig
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention backend registry"""
import enum
class _Backend(enum.Enum):
FLASH_ATTN = enum.auto()
TRITON_ATTN = enum.auto()
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
ROCM_AITER_MLA = enum.auto()
ROCM_AITER_FA = enum.auto() # used for ViT attn backend
TORCH_SDPA = enum.auto()
FLASHINFER = enum.auto()
FLASHINFER_MLA = enum.auto()
TRITON_MLA = enum.auto()
CUTLASS_MLA = enum.auto()
FLASHMLA = enum.auto()
FLASH_ATTN_MLA = enum.auto()
PALLAS = enum.auto()
IPEX = enum.auto()
NO_ATTENTION = enum.auto()
FLEX_ATTENTION = enum.auto()
TREE_ATTN = enum.auto()
ROCM_ATTN = enum.auto()
...@@ -10,6 +10,7 @@ import torch.nn.functional as F ...@@ -10,6 +10,7 @@ import torch.nn.functional as F
import vllm.envs as envs import vllm.envs as envs
from vllm.attention import AttentionType from vllm.attention import AttentionType
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.registry import _Backend
from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
from vllm.config import CacheConfig, get_current_vllm_config from vllm.config import CacheConfig, get_current_vllm_config
...@@ -26,7 +27,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod ...@@ -26,7 +27,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape) GroupShape)
from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.platforms import _Backend, current_platform from vllm.platforms import current_platform
from vllm.utils import GiB_bytes, direct_register_custom_op from vllm.utils import GiB_bytes, direct_register_custom_op
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -11,8 +11,9 @@ import torch ...@@ -11,8 +11,9 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.registry import _Backend
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import _Backend, current_platform from vllm.platforms import current_platform
from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -20,6 +20,7 @@ import torch ...@@ -20,6 +20,7 @@ import torch
import zmq import zmq
from vllm import envs from vllm import envs
from vllm.attention.backends.registry import _Backend
from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
...@@ -32,7 +33,7 @@ from vllm.distributed.parallel_state import ( ...@@ -32,7 +33,7 @@ from vllm.distributed.parallel_state import (
from vllm.distributed.utils import divide from vllm.distributed.utils import divide
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import _Backend, current_platform from vllm.platforms import current_platform
from vllm.utils import make_zmq_path, make_zmq_socket from vllm.utils import make_zmq_path, make_zmq_socket
from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
......
...@@ -619,7 +619,8 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -619,7 +619,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
# All possible options loaded dynamically from _Backend enum # All possible options loaded dynamically from _Backend enum
"VLLM_ATTENTION_BACKEND": "VLLM_ATTENTION_BACKEND":
env_with_choices("VLLM_ATTENTION_BACKEND", None, env_with_choices("VLLM_ATTENTION_BACKEND", None,
lambda: list(__import__('vllm.platforms.interface', \ lambda: list(__import__(
'vllm.attention.backends.registry',
fromlist=['_Backend'])._Backend.__members__.keys())), fromlist=['_Backend'])._Backend.__members__.keys())),
# If set, vllm will use flashinfer sampler # If set, vllm will use flashinfer sampler
......
...@@ -9,6 +9,7 @@ import torch.nn.functional as F ...@@ -9,6 +9,7 @@ import torch.nn.functional as F
from torch.nn import LayerNorm from torch.nn import LayerNorm
from transformers.models.qwen2_vl import Qwen2VLProcessor from transformers.models.qwen2_vl import Qwen2VLProcessor
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability from vllm.attention.layer import check_upstream_fa_availability
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
...@@ -38,7 +39,6 @@ from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, ...@@ -38,7 +39,6 @@ from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalDataDict from vllm.multimodal.inputs import MultiModalDataDict
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.dotsocr import (DotsOCRConfig, from vllm.transformers_utils.configs.dotsocr import (DotsOCRConfig,
DotsVisionConfig) DotsVisionConfig)
......
...@@ -34,6 +34,7 @@ import torch.nn.functional as F ...@@ -34,6 +34,7 @@ import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from transformers import BatchFeature from transformers import BatchFeature
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability from vllm.attention.layer import check_upstream_fa_availability
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import parallel_state from vllm.distributed import parallel_state
...@@ -54,7 +55,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -54,7 +55,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
PromptUpdate) PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend, current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
......
...@@ -46,6 +46,7 @@ from transformers.models.glm4v.video_processing_glm4v import ( ...@@ -46,6 +46,7 @@ from transformers.models.glm4v.video_processing_glm4v import (
Glm4vVideoProcessor) Glm4vVideoProcessor)
from transformers.video_utils import VideoMetadata from transformers.video_utils import VideoMetadata
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability from vllm.attention.layer import check_upstream_fa_availability
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import (get_tensor_model_parallel_world_size, from vllm.distributed import (get_tensor_model_parallel_world_size,
...@@ -69,7 +70,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -69,7 +70,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails) PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
......
...@@ -17,6 +17,7 @@ from transformers.modeling_outputs import (BaseModelOutput, ...@@ -17,6 +17,7 @@ from transformers.modeling_outputs import (BaseModelOutput,
BaseModelOutputWithPooling) BaseModelOutputWithPooling)
from transformers.utils import torch_int from transformers.utils import torch_int
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability from vllm.attention.layer import check_upstream_fa_availability
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
...@@ -39,7 +40,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -39,7 +40,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
PromptUpdate) PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of from vllm.utils import is_list_of
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
......
...@@ -38,6 +38,7 @@ from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor ...@@ -38,6 +38,7 @@ from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig) Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability from vllm.attention.layer import check_upstream_fa_availability
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import parallel_state from vllm.distributed import parallel_state
...@@ -62,7 +63,6 @@ from vllm.multimodal.evs import (compute_mrope_for_media, ...@@ -62,7 +63,6 @@ from vllm.multimodal.evs import (compute_mrope_for_media,
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import PromptReplacement, PromptUpdate from vllm.multimodal.processing import PromptReplacement, PromptUpdate
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment