Unverified Commit 2612ba92 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[1/N][Attention] Restructure attention: move files (#31916)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent 1f8b7c53
...@@ -8,15 +8,15 @@ from typing import TYPE_CHECKING, Optional ...@@ -8,15 +8,15 @@ from typing import TYPE_CHECKING, Optional
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.torch_utils import cuda_device_count_stateless from vllm.utils.torch_utils import cuda_device_count_stateless
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from .interface import DeviceCapability, Platform, PlatformEnum from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.v1.attention.selector import AttentionSelectorConfig
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -7,14 +7,14 @@ from typing import TYPE_CHECKING, Optional ...@@ -7,14 +7,14 @@ from typing import TYPE_CHECKING, Optional
import torch import torch
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from .interface import DeviceCapability, Platform, PlatformEnum from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.v1.attention.selector import AttentionSelectorConfig
else: else:
VllmConfig = None VllmConfig = None
......
...@@ -6,16 +6,16 @@ from typing import ClassVar ...@@ -6,16 +6,16 @@ from typing import ClassVar
import torch import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import ( from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionImpl, AttentionImpl,
AttentionLayer, AttentionLayer,
AttentionType, AttentionType,
is_quantized_kv_cache, is_quantized_kv_cache,
) )
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
......
...@@ -9,24 +9,24 @@ from typing import ClassVar ...@@ -9,24 +9,24 @@ from typing import ClassVar
import numpy as np import numpy as np
import torch import torch
from vllm.attention.backends.abstract import ( from vllm.attention.layer import Attention
from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionImpl, AttentionImpl,
AttentionType, AttentionType,
MultipleOf, MultipleOf,
is_quantized_kv_cache, is_quantized_kv_cache,
) )
from vllm.attention.layer import Attention from vllm.v1.attention.backends.fa_utils import (
from vllm.attention.ops.common import cp_lse_ag_out_rs
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import (
flash_attn_supports_fp8, flash_attn_supports_fp8,
get_flash_attn_version, get_flash_attn_version,
is_flash_attn_varlen_func_available, is_flash_attn_varlen_func_available,
) )
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
if is_flash_attn_varlen_func_available(): if is_flash_attn_varlen_func_available():
from vllm.attention.utils.fa_utils import ( from vllm.v1.attention.backends.fa_utils import (
flash_attn_supports_sinks, flash_attn_supports_sinks,
flash_attn_varlen_func, flash_attn_varlen_func,
get_scheduler_metadata, get_scheduler_metadata,
......
...@@ -4,14 +4,14 @@ ...@@ -4,14 +4,14 @@
import torch import torch
from vllm.attention.backends.abstract import AttentionType from vllm.v1.attention.backend import AttentionType
from vllm.attention.ops.triton_reshape_and_cache_flash import ( from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash_diffkv, triton_reshape_and_cache_flash_diffkv,
) )
from vllm.attention.utils.fa_utils import is_flash_attn_varlen_func_available
if is_flash_attn_varlen_func_available(): if is_flash_attn_varlen_func_available():
from vllm.attention.utils.fa_utils import flash_attn_varlen_func from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_func
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.attention.backends.utils import get_kv_cache_layout
......
...@@ -19,14 +19,6 @@ from flashinfer.utils import FP4Tensor ...@@ -19,14 +19,6 @@ from flashinfer.utils import FP4Tensor
from typing_extensions import override from typing_extensions import override
from vllm import envs from vllm import envs
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
AttentionType,
MultipleOf,
)
from vllm.attention.ops.common import cp_lse_ag_out_rs
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config
from vllm.config.cache import CacheDType from vllm.config.cache import CacheDType
from vllm.distributed.parallel_state import get_dcp_group from vllm.distributed.parallel_state import get_dcp_group
...@@ -48,6 +40,12 @@ from vllm.utils.flashinfer import ( ...@@ -48,6 +40,12 @@ from vllm.utils.flashinfer import (
) )
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import is_pin_memory_available from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionImpl,
AttentionType,
MultipleOf,
)
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionCGSupport,
AttentionMetadataBuilder, AttentionMetadataBuilder,
...@@ -59,6 +57,8 @@ from vllm.v1.attention.backends.utils import ( ...@@ -59,6 +57,8 @@ from vllm.v1.attention.backends.utils import (
infer_global_hyperparameters, infer_global_hyperparameters,
split_decodes_and_prefills, split_decodes_and_prefills,
) )
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.utils import CpuGpuBuffer from vllm.v1.utils import CpuGpuBuffer
......
...@@ -20,12 +20,6 @@ from torch.nn.attention.flex_attention import ( ...@@ -20,12 +20,6 @@ from torch.nn.attention.flex_attention import (
or_masks, or_masks,
) )
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
AttentionType,
is_quantized_kv_cache,
)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.cache import CacheDType from vllm.config.cache import CacheDType
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -35,6 +29,12 @@ from vllm.model_executor.layers.batch_invariant import ( ...@@ -35,6 +29,12 @@ from vllm.model_executor.layers.batch_invariant import (
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionImpl,
AttentionType,
is_quantized_kv_cache,
)
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
......
...@@ -6,8 +6,8 @@ from dataclasses import dataclass ...@@ -6,8 +6,8 @@ from dataclasses import dataclass
import torch import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.v1.attention.backend import AttentionBackend
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
PAD_SLOT_ID, PAD_SLOT_ID,
AttentionCGSupport, AttentionCGSupport,
......
...@@ -4,8 +4,8 @@ from dataclasses import dataclass ...@@ -4,8 +4,8 @@ from dataclasses import dataclass
import torch import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.v1.attention.backend import AttentionBackend
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionCGSupport,
AttentionMetadataBuilder, AttentionMetadataBuilder,
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from vllm.attention.backends.abstract import AttentionBackend from vllm.v1.attention.backend import AttentionBackend
from vllm.v1.attention.backends.mamba_attn import ( from vllm.v1.attention.backends.mamba_attn import (
BaseMambaAttentionMetadata, BaseMambaAttentionMetadata,
BaseMambaAttentionMetadataBuilder, BaseMambaAttentionMetadataBuilder,
......
...@@ -5,9 +5,9 @@ from dataclasses import dataclass, replace ...@@ -5,9 +5,9 @@ from dataclasses import dataclass, replace
import torch import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backend import AttentionBackend
from vllm.v1.attention.backends.mamba_attn import ( from vllm.v1.attention.backends.mamba_attn import (
BaseMambaAttentionMetadata, BaseMambaAttentionMetadata,
BaseMambaAttentionMetadataBuilder, BaseMambaAttentionMetadataBuilder,
......
...@@ -199,15 +199,6 @@ from tqdm import tqdm ...@@ -199,15 +199,6 @@ from tqdm import tqdm
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm import envs from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionLayer,
AttentionMetadata,
MLAAttentionImpl,
)
from vllm.attention.ops.common import cp_lse_ag_out_rs
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.config import ModelConfig, VllmConfig, get_current_vllm_config from vllm.config import ModelConfig, VllmConfig, get_current_vllm_config
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -222,6 +213,13 @@ from vllm.model_executor.layers.linear import ( ...@@ -222,6 +213,13 @@ from vllm.model_executor.layers.linear import (
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_nvidia_artifactory from vllm.utils.flashinfer import has_nvidia_artifactory
from vllm.utils.math_utils import cdiv, round_down from vllm.utils.math_utils import cdiv, round_down
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionLayer,
AttentionMetadata,
MLAAttentionImpl,
)
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
...@@ -230,6 +228,8 @@ from vllm.v1.attention.backends.utils import ( ...@@ -230,6 +228,8 @@ from vllm.v1.attention.backends.utils import (
infer_global_hyperparameters, infer_global_hyperparameters,
split_decodes_and_prefills, split_decodes_and_prefills,
) )
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
......
...@@ -7,15 +7,15 @@ from typing import ClassVar ...@@ -7,15 +7,15 @@ from typing import ClassVar
import torch import torch
import vllm._custom_ops as ops import vllm._custom_ops as ops
from vllm.attention.backends.abstract import ( from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backend import (
AttentionLayer, AttentionLayer,
AttentionType, AttentionType,
MultipleOf, MultipleOf,
is_quantized_kv_cache, is_quantized_kv_cache,
) )
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.mla.common import ( from vllm.v1.attention.backends.mla.common import (
MLACommonBackend, MLACommonBackend,
MLACommonImpl, MLACommonImpl,
......
...@@ -6,23 +6,23 @@ from typing import ClassVar ...@@ -6,23 +6,23 @@ from typing import ClassVar
import torch import torch
from vllm.attention.backends.abstract import ( from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backend import (
AttentionLayer, AttentionLayer,
AttentionType, AttentionType,
MultipleOf, MultipleOf,
is_quantized_kv_cache, is_quantized_kv_cache,
) )
from vllm.attention.utils.fa_utils import ( from vllm.v1.attention.backends.fa_utils import (
flash_attn_supports_mla, flash_attn_supports_mla,
get_flash_attn_version, get_flash_attn_version,
) )
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.mla.common import ( from vllm.v1.attention.backends.mla.common import (
MLACommonBackend, MLACommonBackend,
MLACommonDecodeMetadata, MLACommonDecodeMetadata,
......
...@@ -6,14 +6,14 @@ from typing import ClassVar ...@@ -6,14 +6,14 @@ from typing import ClassVar
import torch import torch
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
from vllm.attention.backends.abstract import ( from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backend import (
AttentionLayer, AttentionLayer,
AttentionType, AttentionType,
MultipleOf, MultipleOf,
) )
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.mla.common import ( from vllm.v1.attention.backends.mla.common import (
MLACommonBackend, MLACommonBackend,
MLACommonImpl, MLACommonImpl,
......
...@@ -6,12 +6,6 @@ from typing import ClassVar ...@@ -6,12 +6,6 @@ from typing import ClassVar
import torch import torch
from vllm.attention.backends.abstract import AttentionLayer, AttentionType, MultipleOf
from vllm.attention.ops.flashmla import (
flash_mla_with_kvcache,
get_mla_metadata,
is_flashmla_dense_supported,
)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.cache import CacheDType from vllm.config.cache import CacheDType
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -19,6 +13,7 @@ from vllm.model_executor.layers.batch_invariant import ( ...@@ -19,6 +13,7 @@ from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant, vllm_is_batch_invariant,
) )
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backend import AttentionLayer, AttentionType, MultipleOf
from vllm.v1.attention.backends.mla.common import ( from vllm.v1.attention.backends.mla.common import (
MLACommonBackend, MLACommonBackend,
MLACommonDecodeMetadata, MLACommonDecodeMetadata,
...@@ -32,6 +27,11 @@ from vllm.v1.attention.backends.utils import ( ...@@ -32,6 +27,11 @@ from vllm.v1.attention.backends.utils import (
reshape_attn_output_for_spec_decode, reshape_attn_output_for_spec_decode,
reshape_query_for_spec_decode, reshape_query_for_spec_decode,
) )
from vllm.v1.attention.ops.flashmla import (
flash_mla_with_kvcache,
get_mla_metadata,
is_flashmla_dense_supported,
)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -78,11 +78,11 @@ class FlashMLABackend(MLACommonBackend): ...@@ -78,11 +78,11 @@ class FlashMLABackend(MLACommonBackend):
device_capability: DeviceCapability, device_capability: DeviceCapability,
) -> str | None: ) -> str | None:
if use_sparse: if use_sparse:
from vllm.attention.ops.flashmla import is_flashmla_sparse_supported from vllm.v1.attention.ops.flashmla import is_flashmla_sparse_supported
return is_flashmla_sparse_supported()[1] return is_flashmla_sparse_supported()[1]
else: else:
from vllm.attention.ops.flashmla import is_flashmla_dense_supported from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported
return is_flashmla_dense_supported()[1] return is_flashmla_dense_supported()[1]
......
...@@ -7,17 +7,6 @@ import numpy as np ...@@ -7,17 +7,6 @@ import numpy as np
import torch import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionLayer,
AttentionMetadata,
MultipleOf,
)
from vllm.attention.ops.flashmla import (
flash_mla_sparse_prefill,
flash_mla_with_kvcache,
get_mla_metadata,
)
from vllm.config import VllmConfig, get_current_vllm_config from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config.cache import CacheDType from vllm.config.cache import CacheDType
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -25,6 +14,12 @@ from vllm.platforms import current_platform ...@@ -25,6 +14,12 @@ from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionLayer,
AttentionMetadata,
MultipleOf,
)
from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl, get_mla_dims from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl, get_mla_dims
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionCGSupport,
...@@ -35,6 +30,11 @@ from vllm.v1.attention.backends.utils import ( ...@@ -35,6 +30,11 @@ from vllm.v1.attention.backends.utils import (
split_decodes_and_prefills, split_decodes_and_prefills,
split_prefill_chunks, split_prefill_chunks,
) )
from vllm.v1.attention.ops.flashmla import (
flash_mla_sparse_prefill,
flash_mla_with_kvcache,
get_mla_metadata,
)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.workspace import current_workspace_manager from vllm.v1.worker.workspace import current_workspace_manager
......
...@@ -5,14 +5,14 @@ from typing import ClassVar ...@@ -5,14 +5,14 @@ from typing import ClassVar
import torch import torch
from vllm.attention.backends.abstract import (
AttentionBackend,
MultipleOf,
)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, is_deep_gemm_supported from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, is_deep_gemm_supported
from vllm.v1.attention.backend import (
AttentionBackend,
MultipleOf,
)
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionCGSupport,
AttentionMetadataBuilder, AttentionMetadataBuilder,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment