Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b665bbc2
Unverified
Commit
b665bbc2
authored
Jan 07, 2026
by
Cyrus Leung
Committed by
GitHub
Jan 07, 2026
Browse files
[Chore] Migrate V0 attention utils (#31891)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
97413875
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
30 additions
and
47 deletions
+30
-47
tests/kernels/mamba/test_causal_conv1d.py
tests/kernels/mamba/test_causal_conv1d.py
+1
-1
tests/kernels/mamba/test_mamba_ssm.py
tests/kernels/mamba/test_mamba_ssm.py
+1
-1
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+0
-33
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
+1
-1
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
+1
-1
vllm/v1/attention/backends/gdn_attn.py
vllm/v1/attention/backends/gdn_attn.py
+1
-1
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+22
-2
vllm/v1/attention/backends/mla/flashmla_sparse.py
vllm/v1/attention/backends/mla/flashmla_sparse.py
+1
-2
vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
+1
-4
vllm/v1/worker/gpu/block_table.py
vllm/v1/worker/gpu/block_table.py
+1
-1
No files found.
tests/kernels/mamba/test_causal_conv1d.py
View file @
b665bbc2
...
@@ -7,12 +7,12 @@ import torch
...
@@ -7,12 +7,12 @@ import torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
einops
import
rearrange
from
vllm.attention.backends.utils
import
PAD_SLOT_ID
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
causal_conv1d_fn
,
causal_conv1d_fn
,
causal_conv1d_update
,
causal_conv1d_update
,
)
)
from
vllm.utils.torch_utils
import
set_random_seed
from
vllm.utils.torch_utils
import
set_random_seed
from
vllm.v1.attention.backends.utils
import
PAD_SLOT_ID
def
causal_conv1d_ref
(
def
causal_conv1d_ref
(
...
...
tests/kernels/mamba/test_mamba_ssm.py
View file @
b665bbc2
...
@@ -8,12 +8,12 @@ from einops import rearrange, repeat
...
@@ -8,12 +8,12 @@ from einops import rearrange, repeat
from
tests.kernels.utils
import
opcheck
from
tests.kernels.utils
import
opcheck
from
vllm
import
_custom_ops
as
ops
# noqa: F401
from
vllm
import
_custom_ops
as
ops
# noqa: F401
from
vllm.attention.backends.utils
import
PAD_SLOT_ID
from
vllm.model_executor.layers.mamba.ops.mamba_ssm
import
(
from
vllm.model_executor.layers.mamba.ops.mamba_ssm
import
(
selective_scan_fn
,
selective_scan_fn
,
selective_state_update
,
selective_state_update
,
)
)
from
vllm.utils.torch_utils
import
set_random_seed
from
vllm.utils.torch_utils
import
set_random_seed
from
vllm.v1.attention.backends.utils
import
PAD_SLOT_ID
def
selective_state_update_ref
(
def
selective_state_update_ref
(
...
...
vllm/attention/backends/utils.py
deleted
100644 → 0
View file @
97413875
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention backend utils"""
from
dataclasses
import
dataclass
from
vllm.config
import
ModelConfig
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
PAD_SLOT_ID
=
-
1
@
dataclass
class
MLADims
:
q_lora_rank
:
int
|
None
kv_lora_rank
:
int
qk_nope_head_dim
:
int
qk_rope_head_dim
:
int
v_head_dim
:
int
def
get_mla_dims
(
model_config
:
ModelConfig
)
->
MLADims
:
hf_text_config
=
model_config
.
hf_text_config
return
MLADims
(
q_lora_rank
=
getattr
(
hf_text_config
,
"q_lora_rank"
,
None
),
kv_lora_rank
=
hf_text_config
.
kv_lora_rank
,
qk_nope_head_dim
=
hf_text_config
.
qk_nope_head_dim
,
qk_rope_head_dim
=
hf_text_config
.
qk_rope_head_dim
,
v_head_dim
=
hf_text_config
.
v_head_dim
,
)
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
View file @
b665bbc2
...
@@ -8,8 +8,8 @@
...
@@ -8,8 +8,8 @@
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
vllm.attention.backends.utils
import
PAD_SLOT_ID
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.attention.backends.utils
import
PAD_SLOT_ID
@
triton
.
jit
()
@
triton
.
jit
()
...
...
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
View file @
b665bbc2
...
@@ -8,8 +8,8 @@ import torch
...
@@ -8,8 +8,8 @@ import torch
from
packaging
import
version
from
packaging
import
version
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.utils
import
PAD_SLOT_ID
from
vllm.triton_utils
import
HAS_TRITON
,
tl
,
triton
from
vllm.triton_utils
import
HAS_TRITON
,
tl
,
triton
from
vllm.v1.attention.backends.utils
import
PAD_SLOT_ID
TRITON3
=
HAS_TRITON
and
(
version
.
parse
(
triton
.
__version__
)
>=
version
.
parse
(
"3.0.0"
))
TRITON3
=
HAS_TRITON
and
(
version
.
parse
(
triton
.
__version__
)
>=
version
.
parse
(
"3.0.0"
))
...
...
vllm/v1/attention/backends/gdn_attn.py
View file @
b665bbc2
...
@@ -7,9 +7,9 @@ from dataclasses import dataclass
...
@@ -7,9 +7,9 @@ from dataclasses import dataclass
import
torch
import
torch
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.backends.utils
import
PAD_SLOT_ID
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.v1.attention.backends.utils
import
(
from
vllm.v1.attention.backends.utils
import
(
PAD_SLOT_ID
,
AttentionCGSupport
,
AttentionCGSupport
,
AttentionMetadataBuilder
,
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
CommonAttentionMetadata
,
...
...
vllm/v1/attention/backends/mla/common.py
View file @
b665bbc2
...
@@ -205,11 +205,10 @@ from vllm.attention.backends.abstract import (
...
@@ -205,11 +205,10 @@ from vllm.attention.backends.abstract import (
AttentionMetadata
,
AttentionMetadata
,
MLAAttentionImpl
,
MLAAttentionImpl
,
)
)
from
vllm.attention.backends.utils
import
get_mla_dims
from
vllm.attention.ops.common
import
cp_lse_ag_out_rs
from
vllm.attention.ops.common
import
cp_lse_ag_out_rs
from
vllm.attention.ops.merge_attn_states
import
merge_attn_states
from
vllm.attention.ops.merge_attn_states
import
merge_attn_states
from
vllm.attention.utils.fa_utils
import
get_flash_attn_version
from
vllm.attention.utils.fa_utils
import
get_flash_attn_version
from
vllm.config
import
VllmConfig
,
get_current_vllm_config
from
vllm.config
import
ModelConfig
,
VllmConfig
,
get_current_vllm_config
from
vllm.distributed.parallel_state
import
get_dcp_group
,
is_global_first_rank
from
vllm.distributed.parallel_state
import
get_dcp_group
,
is_global_first_rank
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.batch_invariant
import
(
from
vllm.model_executor.layers.batch_invariant
import
(
...
@@ -479,6 +478,27 @@ def use_trtllm_ragged_deepseek_prefill() -> bool:
...
@@ -479,6 +478,27 @@ def use_trtllm_ragged_deepseek_prefill() -> bool:
)
)
@
dataclass
class
MLADims
:
q_lora_rank
:
int
|
None
kv_lora_rank
:
int
qk_nope_head_dim
:
int
qk_rope_head_dim
:
int
v_head_dim
:
int
def
get_mla_dims
(
model_config
:
ModelConfig
)
->
MLADims
:
hf_text_config
=
model_config
.
hf_text_config
return
MLADims
(
q_lora_rank
=
getattr
(
hf_text_config
,
"q_lora_rank"
,
None
),
kv_lora_rank
=
hf_text_config
.
kv_lora_rank
,
qk_nope_head_dim
=
hf_text_config
.
qk_nope_head_dim
,
qk_rope_head_dim
=
hf_text_config
.
qk_rope_head_dim
,
v_head_dim
=
hf_text_config
.
v_head_dim
,
)
class
MLACommonMetadataBuilder
(
AttentionMetadataBuilder
[
M
]):
class
MLACommonMetadataBuilder
(
AttentionMetadataBuilder
[
M
]):
"""
"""
NOTE: Please read the comment at the top of the file before trying to
NOTE: Please read the comment at the top of the file before trying to
...
...
vllm/v1/attention/backends/mla/flashmla_sparse.py
View file @
b665bbc2
...
@@ -13,7 +13,6 @@ from vllm.attention.backends.abstract import (
...
@@ -13,7 +13,6 @@ from vllm.attention.backends.abstract import (
AttentionMetadata
,
AttentionMetadata
,
MultipleOf
,
MultipleOf
,
)
)
from
vllm.attention.backends.utils
import
get_mla_dims
from
vllm.attention.ops.flashmla
import
(
from
vllm.attention.ops.flashmla
import
(
flash_mla_sparse_prefill
,
flash_mla_sparse_prefill
,
flash_mla_with_kvcache
,
flash_mla_with_kvcache
,
...
@@ -26,7 +25,7 @@ from vllm.platforms import current_platform
...
@@ -26,7 +25,7 @@ from vllm.platforms import current_platform
from
vllm.platforms.interface
import
DeviceCapability
from
vllm.platforms.interface
import
DeviceCapability
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.math_utils
import
cdiv
from
vllm.v1.attention.backends.mla.common
import
MLACommonBaseImpl
from
vllm.v1.attention.backends.mla.common
import
MLACommonBaseImpl
,
get_mla_dims
from
vllm.v1.attention.backends.utils
import
(
from
vllm.v1.attention.backends.utils
import
(
AttentionCGSupport
,
AttentionCGSupport
,
AttentionMetadataBuilder
,
AttentionMetadataBuilder
,
...
...
vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
View file @
b665bbc2
...
@@ -14,12 +14,9 @@ from vllm.attention.backends.abstract import (
...
@@ -14,12 +14,9 @@ from vllm.attention.backends.abstract import (
AttentionLayer
,
AttentionLayer
,
AttentionMetadata
,
AttentionMetadata
,
)
)
from
vllm.attention.backends.utils
import
get_mla_dims
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.v1.attention.backends.mla.common
import
(
from
vllm.v1.attention.backends.mla.common
import
MLACommonBaseImpl
,
get_mla_dims
MLACommonBaseImpl
,
)
from
vllm.v1.attention.backends.mla.flashmla_sparse
import
(
from
vllm.v1.attention.backends.mla.flashmla_sparse
import
(
triton_convert_req_index_to_global_index
,
triton_convert_req_index_to_global_index
,
)
)
...
...
vllm/v1/worker/gpu/block_table.py
View file @
b665bbc2
...
@@ -4,9 +4,9 @@ from collections.abc import Iterable
...
@@ -4,9 +4,9 @@ from collections.abc import Iterable
import
torch
import
torch
from
vllm.attention.backends.utils
import
PAD_SLOT_ID
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.math_utils
import
cdiv
from
vllm.v1.attention.backends.utils
import
PAD_SLOT_ID
from
vllm.v1.utils
import
CpuGpuBuffer
from
vllm.v1.utils
import
CpuGpuBuffer
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment