Unverified Commit 430dd4d9 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[Attention] Remove imports from `vllm/attention/__init__.py` (#29342)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent c4c0354e
...@@ -29,7 +29,7 @@ The initialization code should look like this: ...@@ -29,7 +29,7 @@ The initialization code should look like this:
```python ```python
from torch import nn from torch import nn
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.attention import Attention from vllm.attention.layer import Attention
class MyAttention(nn.Module): class MyAttention(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str): def __init__(self, vllm_config: VllmConfig, prefix: str):
......
...@@ -9,8 +9,9 @@ from tests.compile.backend import LazyInitPass, TestBackend ...@@ -9,8 +9,9 @@ from tests.compile.backend import LazyInitPass, TestBackend
from tests.utils import flat_product from tests.utils import flat_product
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import Attention
from vllm.attention.selector import global_force_attn_backend_context_manager from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.fx_utils import find_op_nodes
......
...@@ -5,7 +5,8 @@ import pytest ...@@ -5,7 +5,8 @@ import pytest
import torch import torch
from tests.compile.backend import TestBackend from tests.compile.backend import TestBackend
from vllm.attention import Attention, AttentionType from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.compilation.matcher_utils import FLASHINFER_ROTARY_OP, RMS_OP, ROTARY_OP from vllm.compilation.matcher_utils import FLASHINFER_ROTARY_OP, RMS_OP, ROTARY_OP
from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass from vllm.compilation.post_cleanup import PostCleanupPass
......
...@@ -14,7 +14,7 @@ import torch ...@@ -14,7 +14,7 @@ import torch
from torch._prims_common import TensorLikeType from torch._prims_common import TensorLikeType
from tests.kernels.quant_utils import native_w8a8_block_matmul from tests.kernels.quant_utils import native_w8a8_block_matmul
from vllm.attention import AttentionType from vllm.attention.backends.abstract import AttentionType
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.utils import ( from vllm.utils import (
......
...@@ -5,8 +5,8 @@ import numpy as np ...@@ -5,8 +5,8 @@ import numpy as np
import pytest import pytest
import torch import torch
from vllm.attention import Attention
from vllm.attention.backends.abstract import MultipleOf from vllm.attention.backends.abstract import MultipleOf
from vllm.attention.layer import Attention
from vllm.config import ( from vllm.config import (
CacheConfig, CacheConfig,
ModelConfig, ModelConfig,
......
...@@ -7,7 +7,7 @@ from vllm.v1.worker.utils import bind_kv_cache ...@@ -7,7 +7,7 @@ from vllm.v1.worker.utils import bind_kv_cache
def test_bind_kv_cache(): def test_bind_kv_cache():
from vllm.attention import Attention from vllm.attention.layer import Attention
ctx = { ctx = {
"layers.0.self_attn": Attention(32, 128, 0.1), "layers.0.self_attn": Attention(32, 128, 0.1),
...@@ -35,7 +35,7 @@ def test_bind_kv_cache(): ...@@ -35,7 +35,7 @@ def test_bind_kv_cache():
def test_bind_kv_cache_non_attention(): def test_bind_kv_cache_non_attention():
from vllm.attention import Attention from vllm.attention.layer import Attention
# example from Jamba PP=2 # example from Jamba PP=2
ctx = { ctx = {
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionMetadata,
AttentionType,
)
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend, get_mamba_attn_backend
__all__ = [
"Attention",
"AttentionBackend",
"AttentionMetadata",
"AttentionType",
"get_attn_backend",
"get_mamba_attn_backend",
]
...@@ -178,7 +178,7 @@ class AttentionBackend(ABC): ...@@ -178,7 +178,7 @@ class AttentionBackend(ABC):
By default, only supports decoder attention. By default, only supports decoder attention.
Backends should override this to support other attention types. Backends should override this to support other attention types.
""" """
from vllm.attention import AttentionType from vllm.attention.backends.abstract import AttentionType
return attn_type == AttentionType.DECODER return attn_type == AttentionType.DECODER
......
...@@ -10,8 +10,11 @@ import torch.nn as nn ...@@ -10,8 +10,11 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import vllm.envs as envs import vllm.envs as envs
from vllm.attention import AttentionType from vllm.attention.backends.abstract import (
from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl AttentionBackend,
AttentionType,
MLAAttentionImpl,
)
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
......
...@@ -10,7 +10,7 @@ from torch import fx ...@@ -10,7 +10,7 @@ from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.attention import Attention from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
......
...@@ -9,7 +9,7 @@ from torch import fx ...@@ -9,7 +9,7 @@ from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.attention import Attention from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
......
...@@ -20,7 +20,7 @@ import torch ...@@ -20,7 +20,7 @@ import torch
import zmq import zmq
from vllm import envs from vllm import envs
from vllm.attention import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig from vllm.config import VllmConfig
......
...@@ -8,7 +8,8 @@ from typing import Any, ClassVar ...@@ -8,7 +8,8 @@ from typing import Any, ClassVar
import torch import torch
from vllm.attention import Attention, AttentionBackend, AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.v1 import ( from vllm.distributed.kv_transfer.kv_connector.v1 import (
......
...@@ -8,7 +8,7 @@ import torch.nn.functional as F ...@@ -8,7 +8,7 @@ import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from torch import nn from torch import nn
from vllm.attention import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
......
...@@ -11,8 +11,7 @@ import torch ...@@ -11,8 +11,7 @@ import torch
from torch import nn from torch import nn
from typing_extensions import assert_never from typing_extensions import assert_never
from vllm.attention import Attention from vllm.attention.layer import Attention, MLAAttention
from vllm.attention.layer import MLAAttention
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
......
...@@ -9,7 +9,8 @@ from itertools import islice ...@@ -9,7 +9,8 @@ from itertools import islice
import torch import torch
from torch import nn from torch import nn
from vllm.attention import Attention, AttentionType from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import ( from vllm.distributed import (
......
...@@ -32,7 +32,8 @@ import torch ...@@ -32,7 +32,8 @@ import torch
from torch import nn from torch import nn
from transformers import ApertusConfig from transformers import ApertusConfig
from vllm.attention import Attention, AttentionType from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
......
...@@ -8,7 +8,7 @@ from itertools import islice ...@@ -8,7 +8,7 @@ from itertools import islice
import torch import torch
from torch import nn from torch import nn
from vllm.attention import Attention from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import ( from vllm.distributed import (
......
...@@ -29,7 +29,7 @@ import torch ...@@ -29,7 +29,7 @@ import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import ( from vllm.distributed import (
......
...@@ -32,7 +32,7 @@ import torch.nn.functional as F ...@@ -32,7 +32,7 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from vllm.attention import Attention from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import ( from vllm.distributed import (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment