Unverified Commit 430dd4d9 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[Attention] Remove imports from `vllm/attention/__init__.py` (#29342)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent c4c0354e
......@@ -29,7 +29,7 @@ The initialization code should look like this:
```python
from torch import nn
from vllm.config import VllmConfig
from vllm.attention import Attention
from vllm.attention.layer import Attention
class MyAttention(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str):
......
......@@ -9,8 +9,9 @@ from tests.compile.backend import LazyInitPass, TestBackend
from tests.utils import flat_product
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.attention import Attention, AttentionMetadata
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import Attention
from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
from vllm.compilation.fx_utils import find_op_nodes
......
......@@ -5,7 +5,8 @@ import pytest
import torch
from tests.compile.backend import TestBackend
from vllm.attention import Attention, AttentionType
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.compilation.matcher_utils import FLASHINFER_ROTARY_OP, RMS_OP, ROTARY_OP
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
......
......@@ -14,7 +14,7 @@ import torch
from torch._prims_common import TensorLikeType
from tests.kernels.quant_utils import native_w8a8_block_matmul
from vllm.attention import AttentionType
from vllm.attention.backends.abstract import AttentionType
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.utils import (
......
......@@ -5,8 +5,8 @@ import numpy as np
import pytest
import torch
from vllm.attention import Attention
from vllm.attention.backends.abstract import MultipleOf
from vllm.attention.layer import Attention
from vllm.config import (
CacheConfig,
ModelConfig,
......
......@@ -7,7 +7,7 @@ from vllm.v1.worker.utils import bind_kv_cache
def test_bind_kv_cache():
from vllm.attention import Attention
from vllm.attention.layer import Attention
ctx = {
"layers.0.self_attn": Attention(32, 128, 0.1),
......@@ -35,7 +35,7 @@ def test_bind_kv_cache():
def test_bind_kv_cache_non_attention():
from vllm.attention import Attention
from vllm.attention.layer import Attention
# example from Jamba PP=2
ctx = {
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionMetadata,
AttentionType,
)
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend, get_mamba_attn_backend
__all__ = [
"Attention",
"AttentionBackend",
"AttentionMetadata",
"AttentionType",
"get_attn_backend",
"get_mamba_attn_backend",
]
......@@ -178,7 +178,7 @@ class AttentionBackend(ABC):
By default, only supports decoder attention.
Backends should override this to support other attention types.
"""
from vllm.attention import AttentionType
from vllm.attention.backends.abstract import AttentionType
return attn_type == AttentionType.DECODER
......
......@@ -10,8 +10,11 @@ import torch.nn as nn
import torch.nn.functional as F
import vllm.envs as envs
from vllm.attention import AttentionType
from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionType,
MLAAttentionImpl,
)
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import get_attn_backend
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
......
......@@ -10,7 +10,7 @@ from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.attention import Attention
from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
......
......@@ -9,7 +9,7 @@ from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.attention import Attention
from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
......
......@@ -20,7 +20,7 @@ import torch
import zmq
from vllm import envs
from vllm.attention import AttentionBackend
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig
......
......@@ -8,7 +8,8 @@ from typing import Any, ClassVar
import torch
from vllm.attention import Attention, AttentionBackend, AttentionMetadata
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.v1 import (
......
......@@ -8,7 +8,7 @@ import torch.nn.functional as F
from einops import rearrange
from torch import nn
from vllm.attention import AttentionMetadata
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
......
......@@ -11,8 +11,7 @@ import torch
from torch import nn
from typing_extensions import assert_never
from vllm.attention import Attention
from vllm.attention.layer import MLAAttention
from vllm.attention.layer import Attention, MLAAttention
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
......
......@@ -9,7 +9,8 @@ from itertools import islice
import torch
from torch import nn
from vllm.attention import Attention, AttentionType
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import (
......
......@@ -32,7 +32,8 @@ import torch
from torch import nn
from transformers import ApertusConfig
from vllm.attention import Attention, AttentionType
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
......
......@@ -8,7 +8,7 @@ from itertools import islice
import torch
from torch import nn
from vllm.attention import Attention
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (
......
......@@ -29,7 +29,7 @@ import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (
......
......@@ -32,7 +32,7 @@ import torch.nn.functional as F
from torch import nn
from transformers.configuration_utils import PretrainedConfig
from vllm.attention import Attention
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment