Unverified Commit b665bbc2 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Chore] Migrate V0 attention utils (#31891)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 97413875
...@@ -7,12 +7,12 @@ import torch ...@@ -7,12 +7,12 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_fn,
causal_conv1d_update, causal_conv1d_update,
) )
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import set_random_seed
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
def causal_conv1d_ref( def causal_conv1d_ref(
......
...@@ -8,12 +8,12 @@ from einops import rearrange, repeat ...@@ -8,12 +8,12 @@ from einops import rearrange, repeat
from tests.kernels.utils import opcheck from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401 from vllm import _custom_ops as ops # noqa: F401
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_scan_fn,
selective_state_update, selective_state_update,
) )
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import set_random_seed
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
def selective_state_update_ref( def selective_state_update_ref(
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention backend utils"""
from dataclasses import dataclass
from vllm.config import ModelConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
PAD_SLOT_ID = -1
@dataclass
class MLADims:
q_lora_rank: int | None
kv_lora_rank: int
qk_nope_head_dim: int
qk_rope_head_dim: int
v_head_dim: int
def get_mla_dims(model_config: ModelConfig) -> MLADims:
hf_text_config = model_config.hf_text_config
return MLADims(
q_lora_rank=getattr(hf_text_config, "q_lora_rank", None),
kv_lora_rank=hf_text_config.kv_lora_rank,
qk_nope_head_dim=hf_text_config.qk_nope_head_dim,
qk_rope_head_dim=hf_text_config.qk_rope_head_dim,
v_head_dim=hf_text_config.v_head_dim,
)
...@@ -8,8 +8,8 @@ ...@@ -8,8 +8,8 @@
import numpy as np import numpy as np
import torch import torch
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
@triton.jit() @triton.jit()
......
...@@ -8,8 +8,8 @@ import torch ...@@ -8,8 +8,8 @@ import torch
from packaging import version from packaging import version
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.triton_utils import HAS_TRITON, tl, triton from vllm.triton_utils import HAS_TRITON, tl, triton
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
TRITON3 = HAS_TRITON and (version.parse(triton.__version__) >= version.parse("3.0.0")) TRITON3 = HAS_TRITON and (version.parse(triton.__version__) >= version.parse("3.0.0"))
......
...@@ -7,9 +7,9 @@ from dataclasses import dataclass ...@@ -7,9 +7,9 @@ from dataclasses import dataclass
import torch import torch
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
PAD_SLOT_ID,
AttentionCGSupport, AttentionCGSupport,
AttentionMetadataBuilder, AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
......
...@@ -205,11 +205,10 @@ from vllm.attention.backends.abstract import ( ...@@ -205,11 +205,10 @@ from vllm.attention.backends.abstract import (
AttentionMetadata, AttentionMetadata,
MLAAttentionImpl, MLAAttentionImpl,
) )
from vllm.attention.backends.utils import get_mla_dims
from vllm.attention.ops.common import cp_lse_ag_out_rs from vllm.attention.ops.common import cp_lse_ag_out_rs
from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.config import VllmConfig, get_current_vllm_config from vllm.config import ModelConfig, VllmConfig, get_current_vllm_config
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.batch_invariant import (
...@@ -479,6 +478,27 @@ def use_trtllm_ragged_deepseek_prefill() -> bool: ...@@ -479,6 +478,27 @@ def use_trtllm_ragged_deepseek_prefill() -> bool:
) )
@dataclass
class MLADims:
q_lora_rank: int | None
kv_lora_rank: int
qk_nope_head_dim: int
qk_rope_head_dim: int
v_head_dim: int
def get_mla_dims(model_config: ModelConfig) -> MLADims:
hf_text_config = model_config.hf_text_config
return MLADims(
q_lora_rank=getattr(hf_text_config, "q_lora_rank", None),
kv_lora_rank=hf_text_config.kv_lora_rank,
qk_nope_head_dim=hf_text_config.qk_nope_head_dim,
qk_rope_head_dim=hf_text_config.qk_rope_head_dim,
v_head_dim=hf_text_config.v_head_dim,
)
class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
""" """
NOTE: Please read the comment at the top of the file before trying to NOTE: Please read the comment at the top of the file before trying to
......
...@@ -13,7 +13,6 @@ from vllm.attention.backends.abstract import ( ...@@ -13,7 +13,6 @@ from vllm.attention.backends.abstract import (
AttentionMetadata, AttentionMetadata,
MultipleOf, MultipleOf,
) )
from vllm.attention.backends.utils import get_mla_dims
from vllm.attention.ops.flashmla import ( from vllm.attention.ops.flashmla import (
flash_mla_sparse_prefill, flash_mla_sparse_prefill,
flash_mla_with_kvcache, flash_mla_with_kvcache,
...@@ -26,7 +25,7 @@ from vllm.platforms import current_platform ...@@ -26,7 +25,7 @@ from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl, get_mla_dims
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionCGSupport,
AttentionMetadataBuilder, AttentionMetadataBuilder,
......
...@@ -14,12 +14,9 @@ from vllm.attention.backends.abstract import ( ...@@ -14,12 +14,9 @@ from vllm.attention.backends.abstract import (
AttentionLayer, AttentionLayer,
AttentionMetadata, AttentionMetadata,
) )
from vllm.attention.backends.utils import get_mla_dims
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import ( from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl, get_mla_dims
MLACommonBaseImpl,
)
from vllm.v1.attention.backends.mla.flashmla_sparse import ( from vllm.v1.attention.backends.mla.flashmla_sparse import (
triton_convert_req_index_to_global_index, triton_convert_req_index_to_global_index,
) )
......
...@@ -4,9 +4,9 @@ from collections.abc import Iterable ...@@ -4,9 +4,9 @@ from collections.abc import Iterable
import torch import torch
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
from vllm.v1.utils import CpuGpuBuffer from vllm.v1.utils import CpuGpuBuffer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment