Unverified Commit 2612ba92 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[1/N][Attention] Restructure attention: move files (#31916)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent 1f8b7c53
......@@ -270,7 +270,7 @@ def test_reshape_and_cache_flash(
v_scale,
)
elif implementation == "triton":
from vllm.attention.ops.triton_reshape_and_cache_flash import (
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash,
)
......
......@@ -7,12 +7,12 @@ import random
import pytest
import torch
from vllm.attention.ops.flashmla import (
from vllm.triton_utils import triton
from vllm.v1.attention.ops.flashmla import (
flash_mla_with_kvcache,
get_mla_metadata,
is_flashmla_dense_supported,
)
from vllm.triton_utils import triton
def cal_diff(
......
......@@ -5,7 +5,7 @@ import torch
def test_sparse_flashmla_metadata_smoke():
import vllm.attention.ops.flashmla as fm
import vllm.v1.attention.ops.flashmla as fm
ok, reason = fm.is_flashmla_sparse_supported()
if not ok:
......@@ -34,7 +34,7 @@ def test_sparse_flashmla_metadata_smoke():
def test_sparse_flashmla_decode_smoke():
import vllm.attention.ops.flashmla as fm
import vllm.v1.attention.ops.flashmla as fm
ok, reason = fm.is_flashmla_sparse_supported()
if not ok:
......@@ -97,7 +97,7 @@ def test_sparse_flashmla_decode_smoke():
def test_sparse_flashmla_prefill_smoke():
import vllm.attention.ops.flashmla as fm
import vllm.v1.attention.ops.flashmla as fm
ok, reason = fm.is_flashmla_sparse_supported()
if not ok:
......
......@@ -5,10 +5,10 @@ import pytest
import torch
from vllm._custom_ops import merge_attn_states as merge_attn_states_cuda
from vllm.attention.ops.triton_merge_attn_states import (
from vllm.platforms import current_platform
from vllm.v1.attention.ops.triton_merge_attn_states import (
merge_attn_states as merge_attn_states_triton,
)
from vllm.platforms import current_platform
# Naive PyTorch Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
......
......@@ -12,14 +12,14 @@ from unittest.mock import patch
import pytest
import torch
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.attention.selector import _cached_get_attn_backend
from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention
from vllm.platforms import current_platform
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.rocm import RocmPlatform
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.selector import _cached_get_attn_backend
@pytest.fixture(autouse=True)
......
......@@ -4,7 +4,7 @@
import torch
from torch.testing import assert_close
from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton
from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton
def test_pack_seq_basic_fp8():
......
......@@ -10,10 +10,12 @@ import pytest
import torch
import torch.nn.functional as F
from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode
from vllm.attention.ops.prefix_prefill import context_attention_fwd
from vllm.platforms import current_platform
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, set_random_seed
from vllm.v1.attention.ops.chunked_prefill_paged_decode import (
chunked_prefill_paged_decode,
)
from vllm.v1.attention.ops.prefix_prefill import context_attention_fwd
NUM_HEADS = [64]
NUM_QUERIES_PER_KV = [1, 64]
......
......@@ -4,10 +4,10 @@
import pytest
import torch
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
from vllm.config import AttentionConfig, VllmConfig, set_current_vllm_config
from vllm.platforms.rocm import RocmPlatform
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.selector import _cached_get_attn_backend, get_attn_backend
@pytest.fixture(autouse=True)
......@@ -19,7 +19,7 @@ def clear_cache():
@pytest.mark.skip(reason="Skipped for now. Should be revisited.")
def test_selector(monkeypatch: pytest.MonkeyPatch):
# Set the current platform to ROCm using monkeypatch
monkeypatch.setattr("vllm.attention.selector.current_platform", RocmPlatform())
monkeypatch.setattr("vllm.v1.attention.selector.current_platform", RocmPlatform())
# Test standard ROCm attention
attention_config = AttentionConfig(backend=AttentionBackendEnum.ROCM_ATTN)
......
......@@ -4,8 +4,8 @@
import pytest
import torch
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.ops.triton_decode_attention import decode_attention_fwd
@pytest.mark.parametrize("B", [3, 5])
......
......@@ -5,7 +5,7 @@ import pytest
import torch
import torch.nn.functional as F
from vllm.attention.ops.triton_prefill_attention import context_attention_fwd
from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd
def ref_masked_attention(
......
......@@ -5,10 +5,10 @@
import pytest
import torch
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.platforms import current_platform
from vllm.utils.math_utils import next_power_of_2
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.attention.ops.triton_unified_attention import unified_attention
NUM_HEADS = [(4, 4), (8, 2)]
HEAD_SIZES = [128, 256]
......
......@@ -13,11 +13,11 @@ import torch
from torch._prims_common import TensorLikeType
from tests.kernels.quant_utils import native_w8a8_block_matmul
from vllm.attention.backends.abstract import AttentionType
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.utils.torch_utils import make_tensor_with_pad
from vllm.v1.attention.backend import AttentionType
# For now, disable "test_aot_dispatch_dynamic" since there are some
# bugs related to this test in PyTorch 2.4.
......
......@@ -14,10 +14,10 @@ import pytest
from transformers import AutoProcessor
from vllm import LLM, EngineArgs, SamplingParams
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.multimodal.utils import encode_image_url
from vllm.multimodal.video import sample_frames_from_video
from vllm.platforms import current_platform
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from ....utils import create_new_process_for_each_test
from ...utils import dummy_hf_overrides
......
......@@ -9,7 +9,7 @@ Note: these tests will only pass on L4 GPU.
import pytest
from tests.quantization.utils import is_quant_method_supported
from vllm.attention.utils.fa_utils import flash_attn_supports_fp8
from vllm.v1.attention.backends.fa_utils import flash_attn_supports_fp8
from vllm.platforms import current_platform
from ..utils import check_logprobs_close
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.attention.backends.abstract import (
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionImpl,
)
from vllm.attention.backends.registry import (
from vllm.v1.attention.backends.registry import (
AttentionBackendEnum,
MambaAttentionBackendEnum,
register_backend,
......
......@@ -15,8 +15,6 @@ from tests.v1.attention.utils import (
create_vllm_config,
try_get_attention_backend,
)
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import ModelConfig
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
......@@ -25,6 +23,8 @@ from vllm.utils.torch_utils import (
is_torch_equal_or_newer,
set_random_seed,
)
from vllm.v1.attention.backend import AttentionType
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
set_kv_cache_layout,
......
......@@ -18,15 +18,15 @@ from tests.v1.attention.utils import (
try_get_attention_backend,
)
from vllm import _custom_ops as ops
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
from vllm.attention.utils.fa_utils import flash_attn_supports_mla
from vllm.config.vllm import set_current_vllm_config
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.attention.backends.fa_utils import flash_attn_supports_mla
from vllm.v1.attention.backends.mla.common import QueryLenSupport
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported
from vllm.v1.kv_cache_interface import FullAttentionSpec
BACKENDS_TO_TEST = [
......
......@@ -7,9 +7,9 @@ from unittest.mock import MagicMock, patch
import pytest
import torch
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import AttentionSelectorConfig
from vllm.platforms import current_platform
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.selector import AttentionSelectorConfig
# ROCm-specific attention backend selection tests
pytestmark = pytest.mark.skipif(
......
......@@ -21,7 +21,6 @@ from tests.v1.attention.utils import (
create_vllm_config,
)
from vllm import _custom_ops as ops
from vllm.attention.ops import flashmla
from vllm.config import set_current_vllm_config
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.platforms import current_platform
......@@ -31,6 +30,7 @@ from vllm.v1.attention.backends.mla.flashmla_sparse import (
triton_convert_req_index_to_global_index,
)
from vllm.v1.attention.backends.utils import split_prefill_chunks
from vllm.v1.attention.ops import flashmla
SPARSE_BACKEND_BATCH_SPECS = {
name: BATCH_SPECS[name]
......
......@@ -7,8 +7,6 @@ from dataclasses import dataclass
import pytest
import torch
from vllm.attention.backends.abstract import AttentionImpl
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import (
CacheConfig,
CompilationConfig,
......@@ -20,6 +18,8 @@ from vllm.config import (
VllmConfig,
)
from vllm.config.model import ModelDType
from vllm.v1.attention.backend import AttentionImpl
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment