Unverified Commit 2612ba92 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[1/N][Attention] Restructure attention: move files (#31916)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent 1f8b7c53
......@@ -6,9 +6,9 @@ import random
import pytest
import torch
from vllm.attention.utils.fa_utils import flash_attn_supports_mla
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer
from vllm.v1.attention.backends.fa_utils import flash_attn_supports_mla
skip_unsupported = pytest.mark.skipif(
not (current_platform.is_cuda() and current_platform.has_device_capability(80)),
......
......@@ -14,12 +14,12 @@ from unittest.mock import patch
import pytest
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorBase_V1,
KVConnectorRole,
)
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from .utils import create_scheduler, create_vllm_config
......
......@@ -13,7 +13,6 @@ from tests.v1.attention.utils import (
create_standard_kv_cache_spec,
try_get_attention_backend,
)
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import (
AttentionConfig,
CacheConfig,
......@@ -27,6 +26,7 @@ from vllm.config import (
from vllm.config.load import LoadConfig
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.platforms import current_platform
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
......
......@@ -12,7 +12,6 @@ from tests.v1.attention.utils import (
create_standard_kv_cache_spec,
try_get_attention_backend,
)
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import (
CacheConfig,
DeviceConfig,
......@@ -25,6 +24,7 @@ from vllm.config import (
from vllm.config.load import LoadConfig
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.platforms import current_platform
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.spec_decode.eagle import EagleProposer
mimo_7b_dir = "XiaomiMiMo/MiMo-7B-Base"
......
......@@ -11,9 +11,9 @@ from tests.v1.attention.utils import (
create_vllm_config,
try_get_attention_backend,
)
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.utils.fa_utils import is_flash_attn_varlen_func_available
from vllm.config import ParallelConfig, SpeculativeConfig
from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
if not is_flash_attn_varlen_func_available():
......
......@@ -5,8 +5,6 @@ import numpy as np
import pytest
import torch
from vllm.attention.backends.abstract import MultipleOf
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import Attention
from vllm.config import (
AttentionConfig,
......@@ -27,6 +25,8 @@ from vllm.sampling_params import SamplingParams
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.system_utils import update_environment_variables
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.attention.backend import MultipleOf
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.core.kv_cache_utils import estimate_max_model_len, get_kv_cache_configs
from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput
from vllm.v1.kv_cache_interface import (
......
......@@ -73,7 +73,9 @@ EXCLUDE = [
"vllm/model_executor/models",
"vllm/model_executor/layers/fla/ops",
# Ignore triton kernels in ops.
"vllm/attention/ops",
"vllm/v1/attention/ops",
# TODO(matt): remove.
"vllm/v1/attention/backends/fa_utils.py",
]
......
......@@ -8,13 +8,6 @@ import torch
import torch.nn as nn
import vllm.envs as envs
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionType,
MLAAttentionImpl,
)
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import get_attn_backend
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
from vllm.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer
from vllm.config import CacheConfig, get_current_vllm_config
......@@ -37,6 +30,13 @@ from vllm.utils.torch_utils import (
direct_register_custom_op,
kv_cache_dtype_str_to_dtype,
)
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionType,
MLAAttentionImpl,
)
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.selector import get_attn_backend
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheSpec,
......
......@@ -6,9 +6,9 @@ from typing import Any, Literal
from pydantic import field_validator
from pydantic.dataclasses import dataclass
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.v1.attention.backends.registry import AttentionBackendEnum
logger = init_logger(__name__)
......
......@@ -12,7 +12,6 @@ from pydantic import ConfigDict, Field, field_validator, model_validator
from pydantic.dataclasses import dataclass
import vllm.envs as envs
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.model_arch import (
ModelArchitectureConfig,
)
......@@ -50,6 +49,7 @@ from vllm.transformers_utils.model_arch_config_convertor import (
from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri
from vllm.transformers_utils.utils import maybe_model_redirect
from vllm.utils.import_utils import LazyLoader
from vllm.v1.attention.backends.registry import AttentionBackendEnum
if TYPE_CHECKING:
from transformers import PretrainedConfig
......
......@@ -7,9 +7,9 @@ from typing import Any, Literal, TypeAlias
from pydantic import ConfigDict, Field, field_validator, model_validator
from pydantic.dataclasses import dataclass
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.utils import config
from vllm.utils.hashing import safe_hash
from vllm.v1.attention.backends.registry import AttentionBackendEnum
@dataclass
......@@ -124,7 +124,7 @@ class MultiModalConfig:
mm_encoder_attn_backend: AttentionBackendEnum | None = None
"""Optional override for the multi-modal encoder attention backend when
using vision transformers. Accepts any value from
`vllm.attention.backends.registry.AttentionBackendEnum` (e.g. `FLASH_ATTN`)."""
`vllm.v1.attention.backends.registry.AttentionBackendEnum` (e.g. `FLASH_ATTN`)."""
interleave_mm_strings: bool = False
"""Enable fully interleaved support for multimodal prompts, while using
--chat-template-content-format=string."""
......
......@@ -10,10 +10,10 @@ from typing import TYPE_CHECKING, Literal
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import get_current_vllm_config
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.logger import init_logger
from vllm.v1.attention.backend import AttentionBackend
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
if TYPE_CHECKING:
......
......@@ -42,8 +42,8 @@ from typing import TYPE_CHECKING, Any, Literal, Optional
import torch
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.logger import init_logger
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput
......
......@@ -36,7 +36,6 @@ from typing import TYPE_CHECKING, Any, Optional
import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorBase_V1,
KVConnectorRole,
......@@ -44,6 +43,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import (
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
from vllm.logger import init_logger
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backend import AttentionMetadata
if TYPE_CHECKING:
from vllm.config import VllmConfig
......
......@@ -7,7 +7,6 @@ from typing import TYPE_CHECKING, Any, Optional
import safetensors
import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
......@@ -16,6 +15,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
)
from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput
......
......@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Any
import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_events import (
BlockStored,
......@@ -19,6 +18,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorRole,
)
from vllm.logger import init_logger
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput
......
......@@ -36,7 +36,6 @@ except ImportError:
PluginLauncher as RuntimePluginLauncher,
)
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
......@@ -54,6 +53,7 @@ from vllm.distributed.parallel_state import get_tensor_model_parallel_rank, get_
from vllm.sampling_params import SamplingParams
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import get_kv_cache_torch_dtype
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.version import __version__ as VLLM_VERSION
......
......@@ -10,13 +10,13 @@ import zmq
from lmcache.integration.vllm.utils import mla_enabled
from lmcache.utils import init_logger as lmcache_init_logger
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
)
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput
from vllm.v1.request import RequestStatus
......
......@@ -16,8 +16,6 @@ import zmq
import zmq.asyncio
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import TpKVTopology
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
......@@ -33,7 +31,9 @@ from vllm.distributed.parallel_state import (
from vllm.forward_context import ForwardContext
from vllm.logger import init_logger
from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.attention.selector import get_attn_backend
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import RequestStatus
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment