Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3951d3ea
Unverified
Commit
3951d3ea
authored
Apr 22, 2026
by
Martin Hickey
Committed by
GitHub
Apr 21, 2026
Browse files
[MyPy] Enable mypy for `vllm/model_executor/layers/` (#40159)
Signed-off-by:
Martin Hickey
<
martin.hickey@ie.ibm.com
>
parent
6f2c71be
Changes
28
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
54 additions
and
27 deletions
+54
-27
vllm/model_executor/layers/mamba/mamba_mixer.py
vllm/model_executor/layers/mamba/mamba_mixer.py
+9
-4
vllm/model_executor/layers/mamba/mamba_mixer2.py
vllm/model_executor/layers/mamba/mamba_mixer2.py
+17
-4
vllm/model_executor/layers/mamba/short_conv.py
vllm/model_executor/layers/mamba/short_conv.py
+5
-4
vllm/model_executor/layers/pooler/seqwise/poolers.py
vllm/model_executor/layers/pooler/seqwise/poolers.py
+1
-0
vllm/model_executor/layers/pooler/tokwise/poolers.py
vllm/model_executor/layers/pooler/tokwise/poolers.py
+1
-0
vllm/model_executor/layers/quantization/fp_quant.py
vllm/model_executor/layers/quantization/fp_quant.py
+6
-2
vllm/model_executor/layers/quantization/quark/quark.py
vllm/model_executor/layers/quantization/quark/quark.py
+1
-1
vllm/model_executor/layers/sparse_attn_indexer.py
vllm/model_executor/layers/sparse_attn_indexer.py
+14
-12
No files found.
vllm/model_executor/layers/mamba/mamba_mixer.py
View file @
3951d3ea
...
...
@@ -40,6 +40,7 @@ from vllm.utils.torch_utils import (
_resolve_layer_name
,
direct_register_custom_op
,
)
from
vllm.v1.attention.backend
import
AttentionMetadata
from
vllm.v1.attention.backends.mamba1_attn
import
Mamba1AttentionMetadata
...
...
@@ -258,15 +259,16 @@ class MambaMixer(MambaBase, PluggableLayer):
"""
forward_context
:
ForwardContext
=
get_forward_context
()
attn_metadata
=
forward_context
.
attn_metadata
attn_metadata
_raw
=
forward_context
.
attn_metadata
assert
self
.
cache_config
is
not
None
mamba_block_size
=
self
.
cache_config
.
mamba_block_size
is_mamba_cache_all
=
self
.
cache_config
.
mamba_cache_mode
==
"all"
if
attn_metadata
is
not
None
:
assert
isinstance
(
attn_metadata
,
dict
)
attn_metadata
=
attn_metadata
[
self
.
prefix
]
attn_metadata
:
AttentionMetadata
|
None
=
None
if
attn_metadata_raw
is
not
None
:
assert
isinstance
(
attn_metadata_raw
,
dict
)
attn_metadata
=
attn_metadata_raw
[
self
.
prefix
]
assert
isinstance
(
attn_metadata
,
Mamba1AttentionMetadata
)
query_start_loc_p
=
attn_metadata
.
query_start_loc_p
state_indices_tensor_p
=
attn_metadata
.
state_indices_tensor_p
...
...
@@ -391,6 +393,9 @@ class MambaMixer(MambaBase, PluggableLayer):
ssm_outputs
.
append
(
scan_out_p
)
if
has_decode
:
# state_indices_tensor_d is assigned when attn_metadata is not None,
# and has_decode is only True when attn_metadata is not None
assert
state_indices_tensor_d
is
not
None
if
is_mamba_cache_all
:
state_indices_tensor_d_input
=
state_indices_tensor_d
.
gather
(
1
,
block_idx_last_computed_token_d
.
unsqueeze
(
1
)
...
...
vllm/model_executor/layers/mamba/mamba_mixer2.py
View file @
3951d3ea
...
...
@@ -572,14 +572,16 @@ class MambaMixer2(MambaBase, PluggableLayer):
# kernels to operate in continuous batching and in chunked prefill
# modes; they are computed at top-level model forward since they
# stay the same and reused for all mamba layers in the same iteration
attn_metadata
:
AttentionMetadata
=
forward_context
.
attn_metadata
attn_metadata
_raw
=
forward_context
.
attn_metadata
assert
self
.
cache_config
is
not
None
mamba_block_size
=
self
.
cache_config
.
mamba_block_size
is_mamba_cache_all
=
self
.
cache_config
.
mamba_cache_mode
==
"all"
if
attn_metadata
is
not
None
:
assert
isinstance
(
attn_metadata
,
dict
)
attn_metadata
=
attn_metadata
[
self
.
prefix
]
attn_metadata
:
AttentionMetadata
|
None
=
None
if
attn_metadata_raw
is
not
None
:
assert
isinstance
(
attn_metadata_raw
,
dict
)
attn_metadata
=
attn_metadata_raw
[
self
.
prefix
]
assert
isinstance
(
attn_metadata
,
Mamba2AttentionMetadata
)
# conv_state must be (..., dim, width-1) for the conv kernels.
# DS layout stores it that way directly; SD layout needs a
...
...
@@ -708,6 +710,7 @@ class MambaMixer2(MambaBase, PluggableLayer):
# 3. State Space Model sequence transformation
initial_states
=
None
if
has_initial_states_p
is
not
None
and
prep_initial_states
:
assert
state_indices_tensor_p
is
not
None
kernel_ssm_indices
=
state_indices_tensor_p
if
is_mamba_cache_all
:
kernel_ssm_indices
=
state_indices_tensor_p
.
gather
(
...
...
@@ -746,6 +749,13 @@ class MambaMixer2(MambaBase, PluggableLayer):
)
if
is_mamba_cache_all
:
assert
mamba_block_size
is
not
None
assert
state_indices_tensor_p
is
not
None
assert
block_idx_first_scheduled_token_p
is
not
None
assert
block_idx_last_scheduled_token_p
is
not
None
assert
last_chunk_indices_p
is
not
None
assert
num_computed_tokens_p
is
not
None
# The chunk_stride is the number of chunks per mamba block
# e.g., if mamba_block_size = 512 and chunk_size = 256,
# then chunk_stride = 2
...
...
@@ -810,6 +820,7 @@ class MambaMixer2(MambaBase, PluggableLayer):
ssm_state
[
cache_blocks_to_fill
]
=
from_where
# For all seqs, store the last state (note: might be partial):
assert
state_indices_tensor_p
is
not
None
ssm_state
[
state_indices_tensor_p
.
gather
(
1
,
block_idx_last_scheduled_token_p
.
unsqueeze
(
1
)
...
...
@@ -820,10 +831,12 @@ class MambaMixer2(MambaBase, PluggableLayer):
# update ssm states
# - varlen state is a (num_prefills, nheads, headdim, dstate)
# tensor
assert
state_indices_tensor_p
is
not
None
ssm_state
[
state_indices_tensor_p
]
=
varlen_states
# Process decode requests
if
has_decode
:
assert
state_indices_tensor_d
is
not
None
if
is_mamba_cache_all
:
state_indices_tensor_d_input
=
state_indices_tensor_d
.
gather
(
1
,
block_idx_last_computed_token_d
.
unsqueeze
(
1
)
...
...
vllm/model_executor/layers/mamba/short_conv.py
View file @
3951d3ea
...
...
@@ -113,10 +113,11 @@ class ShortConv(MambaBase, CustomOp):
# chunked prefill modes; they are computed at top-level model forward
# since they stay the same and reused for all mamba layers in the same
# iteration.
attn_metadata
:
AttentionMetadata
=
forward_context
.
attn_metadata
if
attn_metadata
is
not
None
:
assert
isinstance
(
attn_metadata
,
dict
)
attn_metadata
=
attn_metadata
[
self
.
prefix
]
attn_metadata_raw
=
forward_context
.
attn_metadata
attn_metadata
:
AttentionMetadata
|
None
=
None
if
attn_metadata_raw
is
not
None
:
assert
isinstance
(
attn_metadata_raw
,
dict
)
attn_metadata
=
attn_metadata_raw
[
self
.
prefix
]
assert
isinstance
(
attn_metadata
,
ShortConvAttentionMetadata
)
conv_state
=
(
self
.
kv_cache
[
0
]
...
...
vllm/model_executor/layers/pooler/seqwise/poolers.py
View file @
3951d3ea
...
...
@@ -115,6 +115,7 @@ def pooler_for_classify(
vllm_config
=
get_current_vllm_config
()
model_config
=
vllm_config
.
model_config
assert
model_config
.
pooler_config
is
not
None
head
=
ClassifierPoolerHead
(
head_dtype
=
model_config
.
head_dtype
,
classifier
=
classifier
,
...
...
vllm/model_executor/layers/pooler/tokwise/poolers.py
View file @
3951d3ea
...
...
@@ -124,6 +124,7 @@ def pooler_for_token_classify(
vllm_config
=
get_current_vllm_config
()
model_config
=
vllm_config
.
model_config
assert
model_config
.
pooler_config
is
not
None
head
=
TokenClassifierPoolerHead
(
head_dtype
=
model_config
.
head_dtype
,
classifier
=
classifier
,
...
...
vllm/model_executor/layers/quantization/fp_quant.py
View file @
3951d3ea
...
...
@@ -3,7 +3,7 @@
# Supports FP-Quant compression, see https://arxiv.org/abs/2509.23202
from
typing
import
Any
from
typing
import
Any
,
Literal
,
cast
import
torch
from
torch.nn.parameter
import
Parameter
...
...
@@ -251,7 +251,11 @@ class FPQuantLinearMethod(LinearMethodBase):
def
fused_quantize_mx
(
x_flat
:
torch
.
Tensor
,
hadamard_matrix
:
torch
.
Tensor
,
forward_method
:
str
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
fusedQuantizeMx
(
x_flat
,
hadamard_matrix
,
method
=
forward_method
)
return
fusedQuantizeMx
(
x_flat
,
hadamard_matrix
,
method
=
cast
(
Literal
[
"quest"
,
"abs_max"
],
forward_method
),
)
def
fused_quantize_mx_fake
(
x_flat
,
hadamard_matrix
,
forward_method
):
...
...
vllm/model_executor/layers/quantization/quark/quark.py
View file @
3951d3ea
...
...
@@ -114,7 +114,7 @@ class QuarkConfig(QuantizationConfig):
:param hf_to_vllm_mapper: maps from hf model structure (the assumed
structure of the qconfig) to vllm model structure
"""
quant_config_with_hf_to_vllm_mapper
=
{}
quant_config_with_hf_to_vllm_mapper
:
dict
[
str
,
Any
]
=
{}
for
k
,
v
in
self
.
quant_config
.
items
():
if
isinstance
(
v
,
list
):
...
...
vllm/model_executor/layers/sparse_attn_indexer.py
View file @
3951d3ea
...
...
@@ -26,7 +26,7 @@ from vllm.v1.worker.workspace import current_workspace_manager
if
current_platform
.
is_cuda_alike
():
from
vllm
import
_custom_ops
as
ops
elif
current_platform
.
is_xpu
():
from
vllm._xpu_ops
import
xpu_ops
as
ops
from
vllm._xpu_ops
import
xpu_ops
logger
=
init_logger
(
__name__
)
...
...
@@ -84,12 +84,12 @@ def sparse_attn_indexer(
total_seq_lens
,
topk_indices_buffer
,
)
attn_metadata
=
attn_metadata
[
k_cache_prefix
]
assert
isinstance
(
attn_metadata
,
DeepseekV32IndexerMetadata
)
slot_mapping
=
attn_metadata
.
slot_mapping
has_decode
=
attn_metadata
.
num_decodes
>
0
has_prefill
=
attn_metadata
.
num_prefills
>
0
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
attn_metadata
_narrowed
=
attn_metadata
[
k_cache_prefix
]
assert
isinstance
(
attn_metadata
_narrowed
,
DeepseekV32IndexerMetadata
)
slot_mapping
=
attn_metadata
_narrowed
.
slot_mapping
has_decode
=
attn_metadata
_narrowed
.
num_decodes
>
0
has_prefill
=
attn_metadata
_narrowed
.
num_prefills
>
0
num_decode_tokens
=
attn_metadata
_narrowed
.
num_decode_tokens
# During speculative decoding, k may be padded to the CUDA graph batch
# size while slot_mapping only covers actual tokens. Truncate k to avoid
...
...
@@ -97,6 +97,8 @@ def sparse_attn_indexer(
num_tokens
=
slot_mapping
.
shape
[
0
]
k
=
k
[:
num_tokens
]
# scale_fmt can be None, but the function expects str
assert
scale_fmt
is
not
None
ops
.
indexer_k_quant_and_cache
(
k
,
kv_cache
,
...
...
@@ -107,7 +109,7 @@ def sparse_attn_indexer(
topk_indices_buffer
[:
hidden_states
.
shape
[
0
]]
=
-
1
if
has_prefill
:
prefill_metadata
=
attn_metadata
.
prefill
prefill_metadata
=
attn_metadata
_narrowed
.
prefill
assert
prefill_metadata
is
not
None
# Get the full shared workspace buffers once (will allocate on first use)
...
...
@@ -144,7 +146,7 @@ def sparse_attn_indexer(
]
if
current_platform
.
is_xpu
():
ops
.
top_k_per_row_prefill
(
xpu_
ops
.
top_k_per_row_prefill
(
# type: ignore[attr-defined]
logits
,
chunk
.
cu_seqlen_ks
,
chunk
.
cu_seqlen_ke
,
...
...
@@ -167,7 +169,7 @@ def sparse_attn_indexer(
)
if
has_decode
:
decode_metadata
=
attn_metadata
.
decode
decode_metadata
=
attn_metadata
_narrowed
.
decode
assert
decode_metadata
is
not
None
# kv_cache shape [
# kv_cache size requirement [num_block, block_size, n_head, head_dim],
...
...
@@ -217,11 +219,11 @@ def sparse_attn_indexer(
topk_indices
,
topk_workspace
,
topk_tokens
,
attn_metadata
.
max_seq_len
,
attn_metadata
_narrowed
.
max_seq_len
,
)
else
:
if
current_platform
.
is_xpu
():
ops
.
top_k_per_row_decode
(
xpu_
ops
.
top_k_per_row_decode
(
# type: ignore[attr-defined]
logits
,
next_n
,
seq_lens
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment