Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
aec18492
Unverified
Commit
aec18492
authored
Apr 08, 2026
by
Wentao Ye
Committed by
GitHub
Apr 09, 2026
Browse files
[CI] Fix mypy for `vllm/v1/ops` (#39219)
Signed-off-by:
yewentao256
<
zhyanwentao@126.com
>
parent
2a49284c
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
28 additions
and
25 deletions
+28
-25
tools/pre_commit/mypy.py
tools/pre_commit/mypy.py
+0
-2
vllm/v1/attention/ops/prefix_prefill.py
vllm/v1/attention/ops/prefix_prefill.py
+3
-1
vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
+22
-22
vllm/v1/attention/ops/vit_attn_wrappers.py
vllm/v1/attention/ops/vit_attn_wrappers.py
+3
-0
No files found.
tools/pre_commit/mypy.py
View file @
aec18492
...
...
@@ -36,8 +36,6 @@ SEPARATE_GROUPS = [
EXCLUDE
=
[
"vllm/model_executor/models"
,
"vllm/model_executor/layers/fla/ops"
,
# Ignore triton kernels in ops.
"vllm/v1/attention/ops"
,
# TODO: Remove these entries after fixing mypy errors.
"vllm/benchmarks"
,
]
...
...
vllm/v1/attention/ops/prefix_prefill.py
View file @
aec18492
...
...
@@ -4,6 +4,8 @@
# The kernels in this file are adapted from LightLLM's context_attention_fwd:
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
from
typing
import
Any
import
torch
from
vllm.platforms
import
current_platform
...
...
@@ -780,7 +782,7 @@ def context_attention_fwd(
return
max_seq_len
=
0
if
max_seq_len
is
None
else
max_seq_len
extra_kargs
=
{}
extra_kargs
:
dict
[
str
,
Any
]
=
{}
if
current_platform
.
is_rocm
():
extra_kargs
=
{}
...
...
vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
View file @
aec18492
...
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
functools
import
importlib
from
importlib.util
import
find_spec
import
torch
...
...
@@ -276,11 +277,9 @@ def fp8_paged_mqa_logits_torch(
@
functools
.
lru_cache
def
paged_mqa_logits_module
():
paged_mqa_logits_module_path
=
None
if
importlib
.
util
.
find_spec
(
"aiter.ops.triton.pa_mqa_logits"
)
is
not
None
:
if
find_spec
(
"aiter.ops.triton.pa_mqa_logits"
)
is
not
None
:
paged_mqa_logits_module_path
=
"aiter.ops.triton.pa_mqa_logits"
elif
(
importlib
.
util
.
find_spec
(
"aiter.ops.triton.attention.pa_mqa_logits"
)
is
not
None
):
elif
find_spec
(
"aiter.ops.triton.attention.pa_mqa_logits"
)
is
not
None
:
paged_mqa_logits_module_path
=
"aiter.ops.triton.attention.pa_mqa_logits"
if
paged_mqa_logits_module_path
is
not
None
:
...
...
@@ -380,9 +379,9 @@ def fp8_mqa_logits_torch(
Returns:
Logits tensor of shape [M, N], dtype `torch.float32`.
"""
k
v
,
scale
=
kv
seq_len_kv
=
k
v
.
shape
[
0
]
k
=
k
v
.
to
(
torch
.
bfloat16
)
k
_fp8
,
scale
=
kv
seq_len_kv
=
k
_fp8
.
shape
[
0
]
k
=
k
_fp8
.
to
(
torch
.
bfloat16
)
q
=
q
.
to
(
torch
.
bfloat16
)
mask_lo
=
(
...
...
@@ -403,12 +402,9 @@ def fp8_mqa_logits_torch(
@
functools
.
lru_cache
def
mqa_logits_module
():
mqa_logits_module_path
=
None
if
importlib
.
util
.
find_spec
(
"aiter.ops.triton.fp8_mqa_logits"
)
is
not
None
:
if
find_spec
(
"aiter.ops.triton.fp8_mqa_logits"
)
is
not
None
:
mqa_logits_module_path
=
"aiter.ops.triton.fp8_mqa_logits"
elif
(
importlib
.
util
.
find_spec
(
"aiter.ops.triton.attention.fp8_mqa_logits"
)
is
not
None
):
elif
find_spec
(
"aiter.ops.triton.attention.fp8_mqa_logits"
)
is
not
None
:
mqa_logits_module_path
=
"aiter.ops.triton.attention.fp8_mqa_logits"
if
mqa_logits_module_path
is
not
None
:
...
...
@@ -455,8 +451,8 @@ def rocm_fp8_mqa_logits(
if
aiter_mqa_logits_module
is
not
None
:
fp8_mqa_logits
=
aiter_mqa_logits_module
.
fp8_mqa_logits
k
v
,
scale
=
kv
return
fp8_mqa_logits
(
q
,
k
v
,
scale
,
weights
,
cu_seqlen_ks
,
cu_seqlen_ke
)
k
_fp8
,
scale
=
kv
return
fp8_mqa_logits
(
q
,
k
_fp8
,
scale
,
weights
,
cu_seqlen_ks
,
cu_seqlen_ke
)
else
:
return
fp8_mqa_logits_torch
(
q
,
kv
,
weights
,
cu_seqlen_ks
,
cu_seqlen_ke
)
...
...
@@ -523,12 +519,14 @@ def rocm_aiter_sparse_attn_indexer(
total_seq_lens
,
topk_indices_buffer
,
)
attn_metadata
=
attn_metadata
[
k_cache_prefix
]
assert
isinstance
(
attn_metadata
,
DeepseekV32IndexerMetadata
)
slot_mapping
=
attn_metadata
.
slot_mapping
has_decode
=
attn_metadata
.
num_decodes
>
0
has_prefill
=
attn_metadata
.
num_prefills
>
0
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
layer_attn_metadata
=
attn_metadata
[
k_cache_prefix
]
assert
isinstance
(
layer_attn_metadata
,
DeepseekV32IndexerMetadata
)
assert
topk_indices_buffer
is
not
None
assert
scale_fmt
is
not
None
slot_mapping
=
layer_attn_metadata
.
slot_mapping
has_decode
=
layer_attn_metadata
.
num_decodes
>
0
has_prefill
=
layer_attn_metadata
.
num_prefills
>
0
num_decode_tokens
=
layer_attn_metadata
.
num_decode_tokens
ops
.
indexer_k_quant_and_cache
(
k
,
...
...
@@ -540,7 +538,8 @@ def rocm_aiter_sparse_attn_indexer(
topk_indices_buffer
[:
hidden_states
.
shape
[
0
]]
=
-
1
if
has_prefill
:
prefill_metadata
=
attn_metadata
.
prefill
prefill_metadata
=
layer_attn_metadata
.
prefill
assert
prefill_metadata
is
not
None
for
chunk
in
prefill_metadata
.
chunks
:
k_fp8
=
torch
.
empty
(
[
chunk
.
total_seq_lens
,
head_dim
],
...
...
@@ -585,7 +584,8 @@ def rocm_aiter_sparse_attn_indexer(
)
if
has_decode
:
decode_metadata
=
attn_metadata
.
decode
decode_metadata
=
layer_attn_metadata
.
decode
assert
decode_metadata
is
not
None
# kv_cache size requirement [num_block, block_size, n_head, head_dim],
# we only have [num_block, block_size, head_dim],
kv_cache
=
kv_cache
.
unsqueeze
(
-
2
)
...
...
vllm/v1/attention/ops/vit_attn_wrappers.py
View file @
aec18492
...
...
@@ -292,6 +292,9 @@ def flashinfer_wrapper(
# RoPE has already made q and k contiguous.
q
,
k
=
q
.
contiguous
(),
k
.
contiguous
()
assert
cu_seqlens
is
not
None
assert
max_seqlen
is
not
None
assert
sequence_lengths
is
not
None
assert
len
(
cu_seqlens
)
%
2
==
0
,
"cu_seqlens must be divisible by 2"
cu_seqlength
=
len
(
cu_seqlens
)
//
2
batch_offsets_qko
=
cu_seqlens
[:
cu_seqlength
].
view
(
-
1
,
1
,
1
,
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment