Unverified Commit 2612ba92 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[1/N][Attention] Restructure attention: move files (#31916)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent 1f8b7c53
...@@ -270,7 +270,7 @@ def test_reshape_and_cache_flash( ...@@ -270,7 +270,7 @@ def test_reshape_and_cache_flash(
v_scale, v_scale,
) )
elif implementation == "triton": elif implementation == "triton":
from vllm.attention.ops.triton_reshape_and_cache_flash import ( from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash, triton_reshape_and_cache_flash,
) )
......
...@@ -7,12 +7,12 @@ import random ...@@ -7,12 +7,12 @@ import random
import pytest import pytest
import torch import torch
from vllm.attention.ops.flashmla import ( from vllm.triton_utils import triton
from vllm.v1.attention.ops.flashmla import (
flash_mla_with_kvcache, flash_mla_with_kvcache,
get_mla_metadata, get_mla_metadata,
is_flashmla_dense_supported, is_flashmla_dense_supported,
) )
from vllm.triton_utils import triton
def cal_diff( def cal_diff(
......
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
def test_sparse_flashmla_metadata_smoke(): def test_sparse_flashmla_metadata_smoke():
import vllm.attention.ops.flashmla as fm import vllm.v1.attention.ops.flashmla as fm
ok, reason = fm.is_flashmla_sparse_supported() ok, reason = fm.is_flashmla_sparse_supported()
if not ok: if not ok:
...@@ -34,7 +34,7 @@ def test_sparse_flashmla_metadata_smoke(): ...@@ -34,7 +34,7 @@ def test_sparse_flashmla_metadata_smoke():
def test_sparse_flashmla_decode_smoke(): def test_sparse_flashmla_decode_smoke():
import vllm.attention.ops.flashmla as fm import vllm.v1.attention.ops.flashmla as fm
ok, reason = fm.is_flashmla_sparse_supported() ok, reason = fm.is_flashmla_sparse_supported()
if not ok: if not ok:
...@@ -97,7 +97,7 @@ def test_sparse_flashmla_decode_smoke(): ...@@ -97,7 +97,7 @@ def test_sparse_flashmla_decode_smoke():
def test_sparse_flashmla_prefill_smoke(): def test_sparse_flashmla_prefill_smoke():
import vllm.attention.ops.flashmla as fm import vllm.v1.attention.ops.flashmla as fm
ok, reason = fm.is_flashmla_sparse_supported() ok, reason = fm.is_flashmla_sparse_supported()
if not ok: if not ok:
......
...@@ -5,10 +5,10 @@ import pytest ...@@ -5,10 +5,10 @@ import pytest
import torch import torch
from vllm._custom_ops import merge_attn_states as merge_attn_states_cuda from vllm._custom_ops import merge_attn_states as merge_attn_states_cuda
from vllm.attention.ops.triton_merge_attn_states import ( from vllm.platforms import current_platform
from vllm.v1.attention.ops.triton_merge_attn_states import (
merge_attn_states as merge_attn_states_triton, merge_attn_states as merge_attn_states_triton,
) )
from vllm.platforms import current_platform
# Naive PyTorch Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 # Naive PyTorch Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
......
...@@ -12,14 +12,14 @@ from unittest.mock import patch ...@@ -12,14 +12,14 @@ from unittest.mock import patch
import pytest import pytest
import torch import torch
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.attention.selector import _cached_get_attn_backend
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.cpu import CpuPlatform from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.rocm import RocmPlatform from vllm.platforms.rocm import RocmPlatform
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import set_random_seed
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.selector import _cached_get_attn_backend
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import torch import torch
from torch.testing import assert_close from torch.testing import assert_close
from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton
def test_pack_seq_basic_fp8(): def test_pack_seq_basic_fp8():
......
...@@ -10,10 +10,12 @@ import pytest ...@@ -10,10 +10,12 @@ import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode
from vllm.attention.ops.prefix_prefill import context_attention_fwd
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, set_random_seed from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, set_random_seed
from vllm.v1.attention.ops.chunked_prefill_paged_decode import (
chunked_prefill_paged_decode,
)
from vllm.v1.attention.ops.prefix_prefill import context_attention_fwd
NUM_HEADS = [64] NUM_HEADS = [64]
NUM_QUERIES_PER_KV = [1, 64] NUM_QUERIES_PER_KV = [1, 64]
......
...@@ -4,10 +4,10 @@ ...@@ -4,10 +4,10 @@
import pytest import pytest
import torch import torch
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
from vllm.config import AttentionConfig, VllmConfig, set_current_vllm_config from vllm.config import AttentionConfig, VllmConfig, set_current_vllm_config
from vllm.platforms.rocm import RocmPlatform from vllm.platforms.rocm import RocmPlatform
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.selector import _cached_get_attn_backend, get_attn_backend
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
...@@ -19,7 +19,7 @@ def clear_cache(): ...@@ -19,7 +19,7 @@ def clear_cache():
@pytest.mark.skip(reason="Skipped for now. Should be revisited.") @pytest.mark.skip(reason="Skipped for now. Should be revisited.")
def test_selector(monkeypatch: pytest.MonkeyPatch): def test_selector(monkeypatch: pytest.MonkeyPatch):
# Set the current platform to ROCm using monkeypatch # Set the current platform to ROCm using monkeypatch
monkeypatch.setattr("vllm.attention.selector.current_platform", RocmPlatform()) monkeypatch.setattr("vllm.v1.attention.selector.current_platform", RocmPlatform())
# Test standard ROCm attention # Test standard ROCm attention
attention_config = AttentionConfig(backend=AttentionBackendEnum.ROCM_ATTN) attention_config = AttentionConfig(backend=AttentionBackendEnum.ROCM_ATTN)
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
import pytest import pytest
import torch import torch
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.attention.ops.triton_decode_attention import decode_attention_fwd
@pytest.mark.parametrize("B", [3, 5]) @pytest.mark.parametrize("B", [3, 5])
......
...@@ -5,7 +5,7 @@ import pytest ...@@ -5,7 +5,7 @@ import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from vllm.attention.ops.triton_prefill_attention import context_attention_fwd from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd
def ref_masked_attention( def ref_masked_attention(
......
...@@ -5,10 +5,10 @@ ...@@ -5,10 +5,10 @@
import pytest import pytest
import torch import torch
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.math_utils import next_power_of_2 from vllm.utils.math_utils import next_power_of_2
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import set_random_seed
from vllm.v1.attention.ops.triton_unified_attention import unified_attention
NUM_HEADS = [(4, 4), (8, 2)] NUM_HEADS = [(4, 4), (8, 2)]
HEAD_SIZES = [128, 256] HEAD_SIZES = [128, 256]
......
...@@ -13,11 +13,11 @@ import torch ...@@ -13,11 +13,11 @@ import torch
from torch._prims_common import TensorLikeType from torch._prims_common import TensorLikeType
from tests.kernels.quant_utils import native_w8a8_block_matmul from tests.kernels.quant_utils import native_w8a8_block_matmul
from vllm.attention.backends.abstract import AttentionType
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.utils.torch_utils import make_tensor_with_pad from vllm.utils.torch_utils import make_tensor_with_pad
from vllm.v1.attention.backend import AttentionType
# For now, disable "test_aot_dispatch_dynamic" since there are some # For now, disable "test_aot_dispatch_dynamic" since there are some
# bugs related to this test in PyTorch 2.4. # bugs related to this test in PyTorch 2.4.
......
...@@ -14,10 +14,10 @@ import pytest ...@@ -14,10 +14,10 @@ import pytest
from transformers import AutoProcessor from transformers import AutoProcessor
from vllm import LLM, EngineArgs, SamplingParams from vllm import LLM, EngineArgs, SamplingParams
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.multimodal.utils import encode_image_url from vllm.multimodal.utils import encode_image_url
from vllm.multimodal.video import sample_frames_from_video from vllm.multimodal.video import sample_frames_from_video
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from ....utils import create_new_process_for_each_test from ....utils import create_new_process_for_each_test
from ...utils import dummy_hf_overrides from ...utils import dummy_hf_overrides
......
...@@ -9,7 +9,7 @@ Note: these tests will only pass on L4 GPU. ...@@ -9,7 +9,7 @@ Note: these tests will only pass on L4 GPU.
import pytest import pytest
from tests.quantization.utils import is_quant_method_supported from tests.quantization.utils import is_quant_method_supported
from vllm.attention.utils.fa_utils import flash_attn_supports_fp8 from vllm.v1.attention.backends.fa_utils import flash_attn_supports_fp8
from vllm.platforms import current_platform from vllm.platforms import current_platform
from ..utils import check_logprobs_close from ..utils import check_logprobs_close
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.attention.backends.abstract import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionImpl, AttentionImpl,
) )
from vllm.attention.backends.registry import ( from vllm.v1.attention.backends.registry import (
AttentionBackendEnum, AttentionBackendEnum,
MambaAttentionBackendEnum, MambaAttentionBackendEnum,
register_backend, register_backend,
......
...@@ -15,8 +15,6 @@ from tests.v1.attention.utils import ( ...@@ -15,8 +15,6 @@ from tests.v1.attention.utils import (
create_vllm_config, create_vllm_config,
try_get_attention_backend, try_get_attention_backend,
) )
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
...@@ -25,6 +23,8 @@ from vllm.utils.torch_utils import ( ...@@ -25,6 +23,8 @@ from vllm.utils.torch_utils import (
is_torch_equal_or_newer, is_torch_equal_or_newer,
set_random_seed, set_random_seed,
) )
from vllm.v1.attention.backend import AttentionType
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata, CommonAttentionMetadata,
set_kv_cache_layout, set_kv_cache_layout,
......
...@@ -18,15 +18,15 @@ from tests.v1.attention.utils import ( ...@@ -18,15 +18,15 @@ from tests.v1.attention.utils import (
try_get_attention_backend, try_get_attention_backend,
) )
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
from vllm.attention.utils.fa_utils import flash_attn_supports_mla
from vllm.config.vllm import set_current_vllm_config from vllm.config.vllm import set_current_vllm_config
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.attention.backends.fa_utils import flash_attn_supports_mla
from vllm.v1.attention.backends.mla.common import QueryLenSupport from vllm.v1.attention.backends.mla.common import QueryLenSupport
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported
from vllm.v1.kv_cache_interface import FullAttentionSpec from vllm.v1.kv_cache_interface import FullAttentionSpec
BACKENDS_TO_TEST = [ BACKENDS_TO_TEST = [
......
...@@ -7,9 +7,9 @@ from unittest.mock import MagicMock, patch ...@@ -7,9 +7,9 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
import torch import torch
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import AttentionSelectorConfig
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.selector import AttentionSelectorConfig
# ROCm-specific attention backend selection tests # ROCm-specific attention backend selection tests
pytestmark = pytest.mark.skipif( pytestmark = pytest.mark.skipif(
......
...@@ -21,7 +21,6 @@ from tests.v1.attention.utils import ( ...@@ -21,7 +21,6 @@ from tests.v1.attention.utils import (
create_vllm_config, create_vllm_config,
) )
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.ops import flashmla
from vllm.config import set_current_vllm_config from vllm.config import set_current_vllm_config
from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -31,6 +30,7 @@ from vllm.v1.attention.backends.mla.flashmla_sparse import ( ...@@ -31,6 +30,7 @@ from vllm.v1.attention.backends.mla.flashmla_sparse import (
triton_convert_req_index_to_global_index, triton_convert_req_index_to_global_index,
) )
from vllm.v1.attention.backends.utils import split_prefill_chunks from vllm.v1.attention.backends.utils import split_prefill_chunks
from vllm.v1.attention.ops import flashmla
SPARSE_BACKEND_BATCH_SPECS = { SPARSE_BACKEND_BATCH_SPECS = {
name: BATCH_SPECS[name] name: BATCH_SPECS[name]
......
...@@ -7,8 +7,6 @@ from dataclasses import dataclass ...@@ -7,8 +7,6 @@ from dataclasses import dataclass
import pytest import pytest
import torch import torch
from vllm.attention.backends.abstract import AttentionImpl
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import ( from vllm.config import (
CacheConfig, CacheConfig,
CompilationConfig, CompilationConfig,
...@@ -20,6 +18,8 @@ from vllm.config import ( ...@@ -20,6 +18,8 @@ from vllm.config import (
VllmConfig, VllmConfig,
) )
from vllm.config.model import ModelDType from vllm.config.model import ModelDType
from vllm.v1.attention.backend import AttentionImpl
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment