Unverified Commit 2612ba92 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[1/N][Attention] Restructure attention: move files (#31916)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent 1f8b7c53
...@@ -15,7 +15,6 @@ import numpy as np ...@@ -15,7 +15,6 @@ import numpy as np
import torch import torch
import zmq import zmq
from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorBase_V1,
...@@ -56,11 +55,12 @@ from vllm.utils.network_utils import ( ...@@ -56,11 +55,12 @@ from vllm.utils.network_utils import (
make_zmq_path, make_zmq_path,
make_zmq_socket, make_zmq_socket,
) )
from vllm.v1.attention.selector import get_attn_backend
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import RequestStatus from vllm.v1.request import RequestStatus
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request from vllm.v1.request import Request
......
...@@ -7,7 +7,6 @@ from typing import TYPE_CHECKING, Any ...@@ -7,7 +7,6 @@ from typing import TYPE_CHECKING, Any
import torch import torch
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.kv_transfer import KVTransferConfig from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
...@@ -24,6 +23,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( ...@@ -24,6 +23,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
PromMetricT, PromMetricT,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput from vllm.v1.outputs import KVConnectorOutput
......
...@@ -20,8 +20,6 @@ import torch ...@@ -20,8 +20,6 @@ import torch
import zmq import zmq
from vllm import envs from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import ( from vllm.distributed.kv_transfer.kv_connector.utils import (
EngineId, EngineId,
...@@ -50,7 +48,9 @@ from vllm.forward_context import ForwardContext ...@@ -50,7 +48,9 @@ from vllm.forward_context import ForwardContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.network_utils import make_zmq_path, make_zmq_socket from vllm.utils.network_utils import make_zmq_path, make_zmq_socket
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.attention.selector import get_attn_backend
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
......
...@@ -8,7 +8,6 @@ from typing import Any ...@@ -8,7 +8,6 @@ from typing import Any
import torch import torch
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent
...@@ -20,6 +19,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import ( ...@@ -20,6 +19,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import (
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.kv_cache_utils import BlockHash from vllm.v1.core.kv_cache_utils import BlockHash
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
......
...@@ -7,7 +7,6 @@ from typing import TYPE_CHECKING, Any, Optional ...@@ -7,7 +7,6 @@ from typing import TYPE_CHECKING, Any, Optional
import regex as re import regex as re
import torch import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorBase_V1,
...@@ -19,6 +18,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine import ( ...@@ -19,6 +18,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine import (
) )
from vllm.distributed.parallel_state import get_world_group from vllm.distributed.parallel_state import get_world_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.mla.common import MLACommonMetadata from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
......
...@@ -32,7 +32,6 @@ from pydantic.fields import FieldInfo ...@@ -32,7 +32,6 @@ from pydantic.fields import FieldInfo
from typing_extensions import TypeIs from typing_extensions import TypeIs
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import ( from vllm.config import (
AttentionConfig, AttentionConfig,
CacheConfig, CacheConfig,
...@@ -94,6 +93,7 @@ from vllm.utils.argparse_utils import FlexibleArgumentParser ...@@ -94,6 +93,7 @@ from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.mem_constants import GiB_bytes from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.network_utils import get_ip from vllm.utils.network_utils import get_ip
from vllm.utils.torch_utils import resolve_kv_cache_dtype_string from vllm.utils.torch_utils import resolve_kv_cache_dtype_string
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.sample.logits_processor import LogitsProcessor from vllm.v1.sample.logits_processor import LogitsProcessor
if TYPE_CHECKING: if TYPE_CHECKING:
......
...@@ -684,7 +684,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -684,7 +684,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
None, None,
lambda: list( lambda: list(
__import__( __import__(
"vllm.attention.backends.registry", fromlist=["AttentionBackendEnum"] "vllm.v1.attention.backends.registry", fromlist=["AttentionBackendEnum"]
).AttentionBackendEnum.__members__.keys() ).AttentionBackendEnum.__members__.keys()
), ),
), ),
......
...@@ -10,10 +10,10 @@ from typing import Any, NamedTuple ...@@ -10,10 +10,10 @@ from typing import Any, NamedTuple
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.ubatch_utils import UBatchSlices from vllm.v1.worker.ubatch_utils import UBatchSlices
......
...@@ -4,12 +4,11 @@ import functools ...@@ -4,12 +4,11 @@ import functools
import torch import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.config.vllm import VllmConfig from vllm.config.vllm import VllmConfig
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.v1.attention.backend import AttentionBackend
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionCGSupport,
AttentionMetadataBuilder, AttentionMetadataBuilder,
...@@ -17,6 +16,7 @@ from vllm.v1.attention.backends.utils import ( ...@@ -17,6 +16,7 @@ from vllm.v1.attention.backends.utils import (
make_local_attention_virtual_batches, make_local_attention_virtual_batches,
subclass_attention_backend, subclass_attention_backend,
) )
from vllm.v1.attention.selector import get_attn_backend
from vllm.v1.kv_cache_interface import ( from vllm.v1.kv_cache_interface import (
AttentionSpec, AttentionSpec,
ChunkedLocalAttentionSpec, ChunkedLocalAttentionSpec,
......
...@@ -6,20 +6,20 @@ from copy import copy ...@@ -6,20 +6,20 @@ from copy import copy
import numpy as np import numpy as np
import torch import torch
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionMetadata,
AttentionType,
)
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionMetadata,
AttentionType,
)
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata, CommonAttentionMetadata,
subclass_attention_backend, subclass_attention_backend,
) )
from vllm.v1.attention.selector import get_attn_backend
from vllm.v1.kv_cache_interface import CrossAttentionSpec, KVCacheSpec from vllm.v1.kv_cache_interface import CrossAttentionSpec, KVCacheSpec
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -5,19 +5,19 @@ from copy import copy ...@@ -5,19 +5,19 @@ from copy import copy
import torch import torch
from vllm.attention.backends.abstract import ( from vllm.attention.layer import Attention
from vllm.config import CacheConfig
from vllm.config.vllm import VllmConfig
from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionMetadata, AttentionMetadata,
AttentionType, AttentionType,
) )
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig
from vllm.config.vllm import VllmConfig
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata, CommonAttentionMetadata,
subclass_attention_backend, subclass_attention_backend,
) )
from vllm.v1.attention.selector import get_attn_backend
from vllm.v1.kv_cache_interface import KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheSpec
......
...@@ -4,16 +4,16 @@ ...@@ -4,16 +4,16 @@
import torch import torch
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper,
vit_torch_sdpa_wrapper,
)
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.config import MultiModalConfig from vllm.config import MultiModalConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper,
vit_torch_sdpa_wrapper,
)
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -4,26 +4,26 @@ import functools ...@@ -4,26 +4,26 @@ import functools
import torch import torch
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionMetadata,
AttentionType,
)
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash_diffkv,
)
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionMetadata,
AttentionType,
)
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata, CommonAttentionMetadata,
subclass_attention_backend, subclass_attention_backend,
) )
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash_diffkv,
)
from vllm.v1.attention.selector import get_attn_backend
from vllm.v1.kv_cache_interface import ( from vllm.v1.kv_cache_interface import (
AttentionSpec, AttentionSpec,
KVCacheSpec, KVCacheSpec,
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from vllm.attention.backends.abstract import AttentionBackend, AttentionImpl
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.v1.attention.backend import AttentionBackend, AttentionImpl
from vllm.v1.kv_cache_interface import KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheSpec
......
...@@ -6,11 +6,11 @@ from typing import Any ...@@ -6,11 +6,11 @@ from typing import Any
import torch import torch
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
from vllm.v1.attention.backends.registry import AttentionBackendEnum
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -5,7 +5,6 @@ import torch ...@@ -5,7 +5,6 @@ import torch
from einops import rearrange from einops import rearrange
from torch import nn from torch import nn
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed import ( from vllm.distributed import (
divide, divide,
...@@ -17,6 +16,7 @@ from vllm.logger import init_logger ...@@ -17,6 +16,7 @@ from vllm.logger import init_logger
from vllm.model_executor.model_loader.weight_utils import sharded_weight_loader from vllm.model_executor.model_loader.weight_utils import sharded_weight_loader
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
from .fla.ops.kda import ( from .fla.ops.kda import (
......
...@@ -5,10 +5,10 @@ from collections.abc import Iterable ...@@ -5,10 +5,10 @@ from collections.abc import Iterable
import torch import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.selector import get_mamba_attn_backend
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.v1.attention.backend import AttentionBackend
from vllm.v1.attention.selector import get_mamba_attn_backend
from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec
......
...@@ -8,7 +8,6 @@ import torch.nn.functional as F ...@@ -8,7 +8,6 @@ import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from torch import nn from torch import nn
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
...@@ -29,6 +28,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import ( ...@@ -29,6 +28,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
) )
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
import torch import torch
from torch import nn from torch import nn
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed import ( from vllm.distributed import (
divide, divide,
...@@ -43,6 +42,7 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -43,6 +42,7 @@ from vllm.model_executor.model_loader.weight_utils import (
) )
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
# Added by the IBM Team, 2024 # Added by the IBM Team, 2024
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment