Unverified Commit 2612ba92 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[1/N][Attention] Restructure attention: move files (#31916)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent 1f8b7c53
...@@ -7,8 +7,8 @@ from typing import ClassVar ...@@ -7,8 +7,8 @@ from typing import ClassVar
import torch import torch
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.abstract import AttentionLayer, MultipleOf
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.v1.attention.backend import AttentionLayer, MultipleOf
from vllm.v1.attention.backends.mla.common import ( from vllm.v1.attention.backends.mla.common import (
MLACommonBackend, MLACommonBackend,
MLACommonDecodeMetadata, MLACommonDecodeMetadata,
......
...@@ -9,13 +9,13 @@ import torch ...@@ -9,13 +9,13 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.abstract import ( from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionLayer, AttentionLayer,
AttentionMetadata, AttentionMetadata,
) )
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl, get_mla_dims from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl, get_mla_dims
from vllm.v1.attention.backends.mla.flashmla_sparse import ( from vllm.v1.attention.backends.mla.flashmla_sparse import (
triton_convert_req_index_to_global_index, triton_convert_req_index_to_global_index,
......
...@@ -5,23 +5,23 @@ from typing import ClassVar ...@@ -5,23 +5,23 @@ from typing import ClassVar
import torch import torch
from vllm.attention.backends.abstract import (
AttentionLayer,
AttentionType,
is_quantized_kv_cache,
)
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
from vllm.config.cache import CacheDType from vllm.config.cache import CacheDType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant, vllm_is_batch_invariant,
) )
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backend import (
AttentionLayer,
AttentionType,
is_quantized_kv_cache,
)
from vllm.v1.attention.backends.mla.common import ( from vllm.v1.attention.backends.mla.common import (
MLACommonBackend, MLACommonBackend,
MLACommonImpl, MLACommonImpl,
MLACommonMetadata, MLACommonMetadata,
) )
from vllm.v1.attention.ops.triton_decode_attention import decode_attention_fwd
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -10,7 +10,7 @@ from vllm.logger import init_logger ...@@ -10,7 +10,7 @@ from vllm.logger import init_logger
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend from vllm.v1.attention.backend import AttentionBackend
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -7,25 +7,25 @@ from typing import ClassVar ...@@ -7,25 +7,25 @@ from typing import ClassVar
import torch import torch
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
AttentionType,
MultipleOf,
)
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import get_cu_count from vllm.utils.platform_utils import get_cu_count
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionImpl,
AttentionType,
MultipleOf,
)
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionCGSupport,
AttentionMetadataBuilder, AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
split_decodes_prefills_and_extends, split_decodes_prefills_and_extends,
) )
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
_PARTITION_SIZE_ROCM = 256 _PARTITION_SIZE_ROCM = 256
......
...@@ -5,12 +5,12 @@ ...@@ -5,12 +5,12 @@
import torch import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import AttentionType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, QuantKey,
kFp8StaticTensorSym, kFp8StaticTensorSym,
) )
from vllm.v1.attention.backend import AttentionType
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.rocm_attn import ( from vllm.v1.attention.backends.rocm_attn import (
RocmAttentionBackend, RocmAttentionBackend,
......
...@@ -7,17 +7,6 @@ from typing import ClassVar ...@@ -7,17 +7,6 @@ from typing import ClassVar
import torch import torch
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
AttentionType,
MultipleOf,
)
from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash,
)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
...@@ -25,12 +14,25 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -25,12 +14,25 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym, kFp8StaticTensorSym,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionImpl,
AttentionType,
MultipleOf,
)
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionCGSupport,
AttentionMetadataBuilder, AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
) )
from vllm.v1.attention.ops.chunked_prefill_paged_decode import (
chunked_prefill_paged_decode,
)
from vllm.v1.attention.ops.paged_attn import PagedAttention
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash,
)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass from dataclasses import dataclass
from vllm.attention.backends.abstract import AttentionBackend from vllm.v1.attention.backend import AttentionBackend
from vllm.v1.attention.backends.mamba_attn import ( from vllm.v1.attention.backends.mamba_attn import (
BaseMambaAttentionMetadata, BaseMambaAttentionMetadata,
BaseMambaAttentionMetadataBuilder, BaseMambaAttentionMetadataBuilder,
......
...@@ -9,20 +9,20 @@ from typing import ClassVar, Optional ...@@ -9,20 +9,20 @@ from typing import ClassVar, Optional
import torch import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import ( from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionImpl, AttentionImpl,
AttentionType, AttentionType,
MultipleOf, MultipleOf,
) )
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
split_decodes_and_prefills, split_decodes_and_prefills,
) )
from vllm.v1.attention.ops.triton_unified_attention import unified_attention
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -7,17 +7,6 @@ from typing import ClassVar ...@@ -7,17 +7,6 @@ from typing import ClassVar
import torch import torch
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
AttentionType,
MultipleOf,
)
from vllm.attention.ops.triton_prefill_attention import context_attention_fwd
from vllm.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash,
)
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.config import CUDAGraphMode, VllmConfig from vllm.config import CUDAGraphMode, VllmConfig
from vllm.config.cache import CacheDType from vllm.config.cache import CacheDType
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -28,11 +17,22 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -28,11 +17,22 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import next_power_of_2 from vllm.utils.math_utils import next_power_of_2
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionImpl,
AttentionType,
MultipleOf,
)
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionCGSupport,
AttentionMetadataBuilder, AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
) )
from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash,
)
from vllm.v1.attention.ops.triton_unified_attention import unified_attention
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -29,16 +29,16 @@ if TYPE_CHECKING: ...@@ -29,16 +29,16 @@ if TYPE_CHECKING:
from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_input_batch import InputBatch
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
AttentionMetadata,
)
from vllm.distributed.kv_transfer.kv_connector.utils import ( from vllm.distributed.kv_transfer.kv_connector.utils import (
get_kv_connector_cache_layout, get_kv_connector_cache_layout,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionImpl,
AttentionMetadata,
)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.ubatch_utils import UBatchSlice from vllm.v1.worker.ubatch_utils import UBatchSlice
......
...@@ -40,7 +40,7 @@ def merge_attn_states( ...@@ -40,7 +40,7 @@ def merge_attn_states(
output, prefix_output, prefix_lse, suffix_output, suffix_lse, output_lse output, prefix_output, prefix_lse, suffix_output, suffix_lse, output_lse
) )
else: else:
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states from vllm.v1.attention.ops.triton_merge_attn_states import merge_attn_states
return merge_attn_states( return merge_attn_states(
output, prefix_output, prefix_lse, suffix_output, suffix_lse, output_lse output, prefix_output, prefix_lse, suffix_output, suffix_lse, output_lse
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment