Unverified Commit 2263d44b authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[4/N][Attention] Move MLA common to model_executor (#32060)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
Co-authored-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
parent 4f3676e7
......@@ -19,12 +19,12 @@ from tests.v1.attention.utils import (
)
from vllm import _custom_ops as ops
from vllm.config.vllm import set_current_vllm_config
from vllm.model_executor.layers.attention.mla_attention import QueryLenSupport
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.attention.backend import CommonAttentionMetadata
from vllm.v1.attention.backends.fa_utils import flash_attn_supports_mla
from vllm.v1.attention.backends.mla.common import QueryLenSupport
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported
from vllm.v1.kv_cache_interface import FullAttentionSpec
......
......@@ -14,9 +14,9 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorRole,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadata
from vllm.utils.hashing import safe_hash
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
......
......@@ -18,8 +18,8 @@ from vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine import (
)
from vllm.distributed.parallel_state import get_world_group
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadata
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
......
......@@ -33,8 +33,6 @@ class NewLineFormatter(logging.Formatter):
model_executor/.../quantization/utils/fp8_utils.py
vllm/model_executor/layers/quantization/awq.py ->
model_executor/layers/quantization/awq.py
vllm/v1/attention/backends/mla/common.py ->
v1/attention/backends/mla/common.py
Args:
relpath (Path): The relative path to be shortened.
......
......@@ -9,6 +9,12 @@ import torch
import vllm._custom_ops as ops
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import (
MLACommonBackend,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder,
)
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backend import (
AttentionCGSupport,
......@@ -17,12 +23,6 @@ from vllm.v1.attention.backend import (
MultipleOf,
is_quantized_kv_cache,
)
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder,
)
logger = init_logger(__name__)
......
......@@ -9,6 +9,14 @@ import torch
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import (
MLACommonBackend,
MLACommonDecodeMetadata,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder,
QueryLenSupport,
)
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
......@@ -24,14 +32,6 @@ from vllm.v1.attention.backends.fa_utils import (
flash_attn_supports_mla,
get_flash_attn_version,
)
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
MLACommonDecodeMetadata,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder,
QueryLenSupport,
)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.vllm_flash_attn import ( # type: ignore[attr-defined]
flash_attn_varlen_func,
......
......@@ -8,6 +8,13 @@ from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import (
MLACommonBackend,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder,
QueryLenSupport,
)
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backend import (
AttentionCGSupport,
......@@ -15,13 +22,6 @@ from vllm.v1.attention.backend import (
AttentionType,
MultipleOf,
)
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder,
QueryLenSupport,
)
from vllm.v1.attention.backends.utils import KVCacheLayoutType
logger = init_logger(__name__)
......
......@@ -9,6 +9,14 @@ import torch
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import (
MLACommonBackend,
MLACommonDecodeMetadata,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder,
QueryLenSupport,
)
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
......@@ -19,14 +27,6 @@ from vllm.v1.attention.backend import (
AttentionType,
MultipleOf,
)
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
MLACommonDecodeMetadata,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder,
QueryLenSupport,
)
from vllm.v1.attention.backends.utils import (
reshape_attn_output_for_spec_decode,
reshape_query_for_spec_decode,
......
......@@ -10,6 +10,10 @@ from vllm import _custom_ops as ops
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import (
MLACommonBaseImpl,
get_mla_dims,
)
from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.triton_utils import tl, triton
......@@ -23,7 +27,6 @@ from vllm.v1.attention.backend import (
CommonAttentionMetadata,
MultipleOf,
)
from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl, get_mla_dims
from vllm.v1.attention.backends.utils import (
reshape_attn_output_for_spec_decode,
reshape_query_for_spec_decode,
......
......@@ -8,8 +8,7 @@ import torch
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig
from vllm.v1.attention.backend import AttentionCGSupport, AttentionLayer, MultipleOf
from vllm.v1.attention.backends.mla.common import (
from vllm.model_executor.layers.attention.mla_attention import (
MLACommonBackend,
MLACommonDecodeMetadata,
MLACommonImpl,
......@@ -17,6 +16,7 @@ from vllm.v1.attention.backends.mla.common import (
MLACommonMetadataBuilder,
QueryLenSupport,
)
from vllm.v1.attention.backend import AttentionCGSupport, AttentionLayer, MultipleOf
from vllm.v1.kv_cache_interface import AttentionSpec
......
......@@ -11,6 +11,10 @@ from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import (
MLACommonBaseImpl,
get_mla_dims,
)
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
......@@ -19,7 +23,6 @@ from vllm.v1.attention.backend import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl, get_mla_dims
from vllm.v1.attention.backends.mla.flashmla_sparse import (
triton_convert_req_index_to_global_index,
)
......
......@@ -7,6 +7,11 @@ import torch
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import (
MLACommonBackend,
MLACommonImpl,
MLACommonMetadata,
)
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
......@@ -16,11 +21,6 @@ from vllm.v1.attention.backend import (
AttentionType,
is_quantized_kv_cache,
)
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
MLACommonImpl,
MLACommonMetadata,
)
from vllm.v1.attention.ops.triton_decode_attention import decode_attention_fwd
logger = init_logger(__name__)
......
......@@ -183,7 +183,9 @@ class EagleProposer:
rocm_types.append(AiterFlashAttentionMetadata)
# TRITON_MLA backend support for MLA models (e.g., DeepSeek)
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.model_executor.layers.attention.mla_attention import (
MLACommonMetadata,
)
rocm_types.append(MLACommonMetadata)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment