Unverified Commit 2612ba92 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[1/N][Attention] Restructure attention: move files (#31916)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent 1f8b7c53
......@@ -8,15 +8,15 @@ from typing import TYPE_CHECKING, Optional
import torch
import vllm.envs as envs
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger
from vllm.utils.torch_utils import cuda_device_count_stateless
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig
from vllm.v1.attention.selector import AttentionSelectorConfig
logger = init_logger(__name__)
......
......@@ -7,14 +7,14 @@ from typing import TYPE_CHECKING, Optional
import torch
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig
from vllm.v1.attention.selector import AttentionSelectorConfig
else:
VllmConfig = None
......
......@@ -6,16 +6,16 @@ from typing import ClassVar
import torch
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionImpl,
AttentionLayer,
AttentionType,
is_quantized_kv_cache,
)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
......
......@@ -9,24 +9,24 @@ from typing import ClassVar
import numpy as np
import torch
from vllm.attention.backends.abstract import (
from vllm.attention.layer import Attention
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionImpl,
AttentionType,
MultipleOf,
is_quantized_kv_cache,
)
from vllm.attention.layer import Attention
from vllm.attention.ops.common import cp_lse_ag_out_rs
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import (
from vllm.v1.attention.backends.fa_utils import (
flash_attn_supports_fp8,
get_flash_attn_version,
is_flash_attn_varlen_func_available,
)
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
if is_flash_attn_varlen_func_available():
from vllm.attention.utils.fa_utils import (
from vllm.v1.attention.backends.fa_utils import (
flash_attn_supports_sinks,
flash_attn_varlen_func,
get_scheduler_metadata,
......
......@@ -4,14 +4,14 @@
import torch
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.ops.triton_reshape_and_cache_flash import (
from vllm.v1.attention.backend import AttentionType
from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash_diffkv,
)
from vllm.attention.utils.fa_utils import is_flash_attn_varlen_func_available
if is_flash_attn_varlen_func_available():
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_func
from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import get_kv_cache_layout
......
......@@ -19,14 +19,6 @@ from flashinfer.utils import FP4Tensor
from typing_extensions import override
from vllm import envs
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
AttentionType,
MultipleOf,
)
from vllm.attention.ops.common import cp_lse_ag_out_rs
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config
from vllm.config.cache import CacheDType
from vllm.distributed.parallel_state import get_dcp_group
......@@ -48,6 +40,12 @@ from vllm.utils.flashinfer import (
)
from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionImpl,
AttentionType,
MultipleOf,
)
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
......@@ -59,6 +57,8 @@ from vllm.v1.attention.backends.utils import (
infer_global_hyperparameters,
split_decodes_and_prefills,
)
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.utils import CpuGpuBuffer
......
......@@ -20,12 +20,6 @@ from torch.nn.attention.flex_attention import (
or_masks,
)
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
AttentionType,
is_quantized_kv_cache,
)
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
......@@ -35,6 +29,12 @@ from vllm.model_executor.layers.batch_invariant import (
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import is_torch_equal_or_newer
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionImpl,
AttentionType,
is_quantized_kv_cache,
)
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
......
......@@ -6,8 +6,8 @@ from dataclasses import dataclass
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig
from vllm.v1.attention.backend import AttentionBackend
from vllm.v1.attention.backends.utils import (
PAD_SLOT_ID,
AttentionCGSupport,
......
......@@ -4,8 +4,8 @@ from dataclasses import dataclass
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig
from vllm.v1.attention.backend import AttentionBackend
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
......
......@@ -3,7 +3,7 @@
from dataclasses import dataclass
from vllm.attention.backends.abstract import AttentionBackend
from vllm.v1.attention.backend import AttentionBackend
from vllm.v1.attention.backends.mamba_attn import (
BaseMambaAttentionMetadata,
BaseMambaAttentionMetadataBuilder,
......
......@@ -5,9 +5,9 @@ from dataclasses import dataclass, replace
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backend import AttentionBackend
from vllm.v1.attention.backends.mamba_attn import (
BaseMambaAttentionMetadata,
BaseMambaAttentionMetadataBuilder,
......
......@@ -199,15 +199,6 @@ from tqdm import tqdm
from vllm import _custom_ops as ops
from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionLayer,
AttentionMetadata,
MLAAttentionImpl,
)
from vllm.attention.ops.common import cp_lse_ag_out_rs
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.config import ModelConfig, VllmConfig, get_current_vllm_config
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
from vllm.logger import init_logger
......@@ -222,6 +213,13 @@ from vllm.model_executor.layers.linear import (
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_nvidia_artifactory
from vllm.utils.math_utils import cdiv, round_down
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionLayer,
AttentionMetadata,
MLAAttentionImpl,
)
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
......@@ -230,6 +228,8 @@ from vllm.v1.attention.backends.utils import (
infer_global_hyperparameters,
split_decodes_and_prefills,
)
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
from vllm.v1.kv_cache_interface import AttentionSpec
......
......@@ -7,15 +7,15 @@ from typing import ClassVar
import torch
import vllm._custom_ops as ops
from vllm.attention.backends.abstract import (
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backend import (
AttentionLayer,
AttentionType,
MultipleOf,
is_quantized_kv_cache,
)
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
MLACommonImpl,
......
......@@ -6,23 +6,23 @@ from typing import ClassVar
import torch
from vllm.attention.backends.abstract import (
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backend import (
AttentionLayer,
AttentionType,
MultipleOf,
is_quantized_kv_cache,
)
from vllm.attention.utils.fa_utils import (
from vllm.v1.attention.backends.fa_utils import (
flash_attn_supports_mla,
get_flash_attn_version,
)
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
MLACommonDecodeMetadata,
......
......@@ -6,14 +6,14 @@ from typing import ClassVar
import torch
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
from vllm.attention.backends.abstract import (
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backend import (
AttentionLayer,
AttentionType,
MultipleOf,
)
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
MLACommonImpl,
......
......@@ -6,12 +6,6 @@ from typing import ClassVar
import torch
from vllm.attention.backends.abstract import AttentionLayer, AttentionType, MultipleOf
from vllm.attention.ops.flashmla import (
flash_mla_with_kvcache,
get_mla_metadata,
is_flashmla_dense_supported,
)
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
......@@ -19,6 +13,7 @@ from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backend import AttentionLayer, AttentionType, MultipleOf
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
MLACommonDecodeMetadata,
......@@ -32,6 +27,11 @@ from vllm.v1.attention.backends.utils import (
reshape_attn_output_for_spec_decode,
reshape_query_for_spec_decode,
)
from vllm.v1.attention.ops.flashmla import (
flash_mla_with_kvcache,
get_mla_metadata,
is_flashmla_dense_supported,
)
from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__)
......@@ -78,11 +78,11 @@ class FlashMLABackend(MLACommonBackend):
device_capability: DeviceCapability,
) -> str | None:
if use_sparse:
from vllm.attention.ops.flashmla import is_flashmla_sparse_supported
from vllm.v1.attention.ops.flashmla import is_flashmla_sparse_supported
return is_flashmla_sparse_supported()[1]
else:
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported
return is_flashmla_dense_supported()[1]
......
......@@ -7,17 +7,6 @@ import numpy as np
import torch
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionLayer,
AttentionMetadata,
MultipleOf,
)
from vllm.attention.ops.flashmla import (
flash_mla_sparse_prefill,
flash_mla_with_kvcache,
get_mla_metadata,
)
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
......@@ -25,6 +14,12 @@ from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionLayer,
AttentionMetadata,
MultipleOf,
)
from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl, get_mla_dims
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
......@@ -35,6 +30,11 @@ from vllm.v1.attention.backends.utils import (
split_decodes_and_prefills,
split_prefill_chunks,
)
from vllm.v1.attention.ops.flashmla import (
flash_mla_sparse_prefill,
flash_mla_with_kvcache,
get_mla_metadata,
)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.workspace import current_workspace_manager
......
......@@ -5,14 +5,14 @@ from typing import ClassVar
import torch
from vllm.attention.backends.abstract import (
AttentionBackend,
MultipleOf,
)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, is_deep_gemm_supported
from vllm.v1.attention.backend import (
AttentionBackend,
MultipleOf,
)
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment