Unverified Commit 2612ba92 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[1/N][Attention] Restructure attention: move files (#31916)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent 1f8b7c53
...@@ -6,9 +6,9 @@ import random ...@@ -6,9 +6,9 @@ import random
import pytest import pytest
import torch import torch
from vllm.attention.utils.fa_utils import flash_attn_supports_mla
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer from vllm.utils.flashinfer import has_flashinfer
from vllm.v1.attention.backends.fa_utils import flash_attn_supports_mla
skip_unsupported = pytest.mark.skipif( skip_unsupported = pytest.mark.skipif(
not (current_platform.is_cuda() and current_platform.has_device_capability(80)), not (current_platform.is_cuda() and current_platform.has_device_capability(80)),
......
...@@ -14,12 +14,12 @@ from unittest.mock import patch ...@@ -14,12 +14,12 @@ from unittest.mock import patch
import pytest import pytest
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1 import ( from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorBase_V1, KVConnectorBase_V1,
KVConnectorRole, KVConnectorRole,
) )
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from .utils import create_scheduler, create_vllm_config from .utils import create_scheduler, create_vllm_config
......
...@@ -13,7 +13,6 @@ from tests.v1.attention.utils import ( ...@@ -13,7 +13,6 @@ from tests.v1.attention.utils import (
create_standard_kv_cache_spec, create_standard_kv_cache_spec,
try_get_attention_backend, try_get_attention_backend,
) )
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import ( from vllm.config import (
AttentionConfig, AttentionConfig,
CacheConfig, CacheConfig,
...@@ -27,6 +26,7 @@ from vllm.config import ( ...@@ -27,6 +26,7 @@ from vllm.config import (
from vllm.config.load import LoadConfig from vllm.config.load import LoadConfig
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
......
...@@ -12,7 +12,6 @@ from tests.v1.attention.utils import ( ...@@ -12,7 +12,6 @@ from tests.v1.attention.utils import (
create_standard_kv_cache_spec, create_standard_kv_cache_spec,
try_get_attention_backend, try_get_attention_backend,
) )
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import ( from vllm.config import (
CacheConfig, CacheConfig,
DeviceConfig, DeviceConfig,
...@@ -25,6 +24,7 @@ from vllm.config import ( ...@@ -25,6 +24,7 @@ from vllm.config import (
from vllm.config.load import LoadConfig from vllm.config.load import LoadConfig
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.eagle import EagleProposer
mimo_7b_dir = "XiaomiMiMo/MiMo-7B-Base" mimo_7b_dir = "XiaomiMiMo/MiMo-7B-Base"
......
...@@ -11,9 +11,9 @@ from tests.v1.attention.utils import ( ...@@ -11,9 +11,9 @@ from tests.v1.attention.utils import (
create_vllm_config, create_vllm_config,
try_get_attention_backend, try_get_attention_backend,
) )
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.utils.fa_utils import is_flash_attn_varlen_func_available
from vllm.config import ParallelConfig, SpeculativeConfig from vllm.config import ParallelConfig, SpeculativeConfig
from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata
if not is_flash_attn_varlen_func_available(): if not is_flash_attn_varlen_func_available():
......
...@@ -5,8 +5,6 @@ import numpy as np ...@@ -5,8 +5,6 @@ import numpy as np
import pytest import pytest
import torch import torch
from vllm.attention.backends.abstract import MultipleOf
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import ( from vllm.config import (
AttentionConfig, AttentionConfig,
...@@ -27,6 +25,8 @@ from vllm.sampling_params import SamplingParams ...@@ -27,6 +25,8 @@ from vllm.sampling_params import SamplingParams
from vllm.utils.mem_constants import GiB_bytes from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.system_utils import update_environment_variables from vllm.utils.system_utils import update_environment_variables
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import set_random_seed
from vllm.v1.attention.backend import MultipleOf
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.core.kv_cache_utils import estimate_max_model_len, get_kv_cache_configs from vllm.v1.core.kv_cache_utils import estimate_max_model_len, get_kv_cache_configs
from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput
from vllm.v1.kv_cache_interface import ( from vllm.v1.kv_cache_interface import (
......
...@@ -73,7 +73,9 @@ EXCLUDE = [ ...@@ -73,7 +73,9 @@ EXCLUDE = [
"vllm/model_executor/models", "vllm/model_executor/models",
"vllm/model_executor/layers/fla/ops", "vllm/model_executor/layers/fla/ops",
# Ignore triton kernels in ops. # Ignore triton kernels in ops.
"vllm/attention/ops", "vllm/v1/attention/ops",
# TODO(matt): remove.
"vllm/v1/attention/backends/fa_utils.py",
] ]
......
...@@ -8,13 +8,6 @@ import torch ...@@ -8,13 +8,6 @@ import torch
import torch.nn as nn import torch.nn as nn
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionType,
MLAAttentionImpl,
)
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import get_attn_backend
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
from vllm.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer from vllm.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer
from vllm.config import CacheConfig, get_current_vllm_config from vllm.config import CacheConfig, get_current_vllm_config
...@@ -37,6 +30,13 @@ from vllm.utils.torch_utils import ( ...@@ -37,6 +30,13 @@ from vllm.utils.torch_utils import (
direct_register_custom_op, direct_register_custom_op,
kv_cache_dtype_str_to_dtype, kv_cache_dtype_str_to_dtype,
) )
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionType,
MLAAttentionImpl,
)
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.selector import get_attn_backend
from vllm.v1.kv_cache_interface import ( from vllm.v1.kv_cache_interface import (
FullAttentionSpec, FullAttentionSpec,
KVCacheSpec, KVCacheSpec,
......
...@@ -6,9 +6,9 @@ from typing import Any, Literal ...@@ -6,9 +6,9 @@ from typing import Any, Literal
from pydantic import field_validator from pydantic import field_validator
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.utils import config from vllm.config.utils import config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backends.registry import AttentionBackendEnum
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -12,7 +12,6 @@ from pydantic import ConfigDict, Field, field_validator, model_validator ...@@ -12,7 +12,6 @@ from pydantic import ConfigDict, Field, field_validator, model_validator
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.model_arch import ( from vllm.config.model_arch import (
ModelArchitectureConfig, ModelArchitectureConfig,
) )
...@@ -50,6 +49,7 @@ from vllm.transformers_utils.model_arch_config_convertor import ( ...@@ -50,6 +49,7 @@ from vllm.transformers_utils.model_arch_config_convertor import (
from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri
from vllm.transformers_utils.utils import maybe_model_redirect from vllm.transformers_utils.utils import maybe_model_redirect
from vllm.utils.import_utils import LazyLoader from vllm.utils.import_utils import LazyLoader
from vllm.v1.attention.backends.registry import AttentionBackendEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PretrainedConfig from transformers import PretrainedConfig
......
...@@ -7,9 +7,9 @@ from typing import Any, Literal, TypeAlias ...@@ -7,9 +7,9 @@ from typing import Any, Literal, TypeAlias
from pydantic import ConfigDict, Field, field_validator, model_validator from pydantic import ConfigDict, Field, field_validator, model_validator
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.utils import config from vllm.config.utils import config
from vllm.utils.hashing import safe_hash from vllm.utils.hashing import safe_hash
from vllm.v1.attention.backends.registry import AttentionBackendEnum
@dataclass @dataclass
...@@ -124,7 +124,7 @@ class MultiModalConfig: ...@@ -124,7 +124,7 @@ class MultiModalConfig:
mm_encoder_attn_backend: AttentionBackendEnum | None = None mm_encoder_attn_backend: AttentionBackendEnum | None = None
"""Optional override for the multi-modal encoder attention backend when """Optional override for the multi-modal encoder attention backend when
using vision transformers. Accepts any value from using vision transformers. Accepts any value from
`vllm.attention.backends.registry.AttentionBackendEnum` (e.g. `FLASH_ATTN`).""" `vllm.v1.attention.backends.registry.AttentionBackendEnum` (e.g. `FLASH_ATTN`)."""
interleave_mm_strings: bool = False interleave_mm_strings: bool = False
"""Enable fully interleaved support for multimodal prompts, while using """Enable fully interleaved support for multimodal prompts, while using
--chat-template-content-format=string.""" --chat-template-content-format=string."""
......
...@@ -10,10 +10,10 @@ from typing import TYPE_CHECKING, Literal ...@@ -10,10 +10,10 @@ from typing import TYPE_CHECKING, Literal
import torch import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backend import AttentionBackend
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
if TYPE_CHECKING: if TYPE_CHECKING:
......
...@@ -42,8 +42,8 @@ from typing import TYPE_CHECKING, Any, Literal, Optional ...@@ -42,8 +42,8 @@ from typing import TYPE_CHECKING, Any, Literal, Optional
import torch import torch
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput from vllm.v1.outputs import KVConnectorOutput
......
...@@ -36,7 +36,6 @@ from typing import TYPE_CHECKING, Any, Optional ...@@ -36,7 +36,6 @@ from typing import TYPE_CHECKING, Any, Optional
import torch import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.distributed.kv_transfer.kv_connector.v1 import ( from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorBase_V1, KVConnectorBase_V1,
KVConnectorRole, KVConnectorRole,
...@@ -44,6 +43,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import ( ...@@ -44,6 +43,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import (
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backend import AttentionMetadata
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import VllmConfig from vllm.config import VllmConfig
......
...@@ -7,7 +7,6 @@ from typing import TYPE_CHECKING, Any, Optional ...@@ -7,7 +7,6 @@ from typing import TYPE_CHECKING, Any, Optional
import safetensors import safetensors
import torch import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorBase_V1,
...@@ -16,6 +15,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import ( ...@@ -16,6 +15,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash from vllm.utils.hashing import safe_hash
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.mla.common import MLACommonMetadata from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
......
...@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Any ...@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Any
import torch import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_events import ( from vllm.distributed.kv_events import (
BlockStored, BlockStored,
...@@ -19,6 +18,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import ( ...@@ -19,6 +18,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorRole, KVConnectorRole,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput from vllm.v1.outputs import KVConnectorOutput
......
...@@ -36,7 +36,6 @@ except ImportError: ...@@ -36,7 +36,6 @@ except ImportError:
PluginLauncher as RuntimePluginLauncher, PluginLauncher as RuntimePluginLauncher,
) )
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorBase_V1,
...@@ -54,6 +53,7 @@ from vllm.distributed.parallel_state import get_tensor_model_parallel_rank, get_ ...@@ -54,6 +53,7 @@ from vllm.distributed.parallel_state import get_tensor_model_parallel_rank, get_
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import get_kv_cache_torch_dtype from vllm.utils.torch_utils import get_kv_cache_torch_dtype
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
......
...@@ -10,13 +10,13 @@ import zmq ...@@ -10,13 +10,13 @@ import zmq
from lmcache.integration.vllm.utils import mla_enabled from lmcache.integration.vllm.utils import mla_enabled
from lmcache.utils import init_logger as lmcache_init_logger from lmcache.utils import init_logger as lmcache_init_logger
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorBase_V1,
KVConnectorMetadata, KVConnectorMetadata,
KVConnectorRole, KVConnectorRole,
) )
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput from vllm.v1.outputs import KVConnectorOutput
from vllm.v1.request import RequestStatus from vllm.v1.request import RequestStatus
......
...@@ -16,8 +16,6 @@ import zmq ...@@ -16,8 +16,6 @@ import zmq
import zmq.asyncio import zmq.asyncio
from vllm import envs from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import TpKVTopology from vllm.distributed.kv_transfer.kv_connector.utils import TpKVTopology
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
...@@ -33,7 +31,9 @@ from vllm.distributed.parallel_state import ( ...@@ -33,7 +31,9 @@ from vllm.distributed.parallel_state import (
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.attention.selector import get_attn_backend
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import RequestStatus from vllm.v1.request import RequestStatus
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment