Unverified Commit 3951d3ea authored by Martin Hickey's avatar Martin Hickey Committed by GitHub
Browse files

[MyPy] Enable mypy for `vllm/model_executor/layers/` (#40159)


Signed-off-by: default avatarMartin Hickey <martin.hickey@ie.ibm.com>
parent 6f2c71be
......@@ -40,6 +40,7 @@ from vllm.utils.torch_utils import (
_resolve_layer_name,
direct_register_custom_op,
)
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
......@@ -258,15 +259,16 @@ class MambaMixer(MambaBase, PluggableLayer):
"""
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
attn_metadata_raw = forward_context.attn_metadata
assert self.cache_config is not None
mamba_block_size = self.cache_config.mamba_block_size
is_mamba_cache_all = self.cache_config.mamba_cache_mode == "all"
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
attn_metadata: AttentionMetadata | None = None
if attn_metadata_raw is not None:
assert isinstance(attn_metadata_raw, dict)
attn_metadata = attn_metadata_raw[self.prefix]
assert isinstance(attn_metadata, Mamba1AttentionMetadata)
query_start_loc_p = attn_metadata.query_start_loc_p
state_indices_tensor_p = attn_metadata.state_indices_tensor_p
......@@ -391,6 +393,9 @@ class MambaMixer(MambaBase, PluggableLayer):
ssm_outputs.append(scan_out_p)
if has_decode:
# state_indices_tensor_d is assigned when attn_metadata is not None,
# and has_decode is only True when attn_metadata is not None
assert state_indices_tensor_d is not None
if is_mamba_cache_all:
state_indices_tensor_d_input = state_indices_tensor_d.gather(
1, block_idx_last_computed_token_d.unsqueeze(1)
......
......@@ -572,14 +572,16 @@ class MambaMixer2(MambaBase, PluggableLayer):
# kernels to operate in continuous batching and in chunked prefill
# modes; they are computed at top-level model forward since they
# stay the same and reused for all mamba layers in the same iteration
attn_metadata: AttentionMetadata = forward_context.attn_metadata
attn_metadata_raw = forward_context.attn_metadata
assert self.cache_config is not None
mamba_block_size = self.cache_config.mamba_block_size
is_mamba_cache_all = self.cache_config.mamba_cache_mode == "all"
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
attn_metadata: AttentionMetadata | None = None
if attn_metadata_raw is not None:
assert isinstance(attn_metadata_raw, dict)
attn_metadata = attn_metadata_raw[self.prefix]
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
# conv_state must be (..., dim, width-1) for the conv kernels.
# DS layout stores it that way directly; SD layout needs a
......@@ -708,6 +710,7 @@ class MambaMixer2(MambaBase, PluggableLayer):
# 3. State Space Model sequence transformation
initial_states = None
if has_initial_states_p is not None and prep_initial_states:
assert state_indices_tensor_p is not None
kernel_ssm_indices = state_indices_tensor_p
if is_mamba_cache_all:
kernel_ssm_indices = state_indices_tensor_p.gather(
......@@ -746,6 +749,13 @@ class MambaMixer2(MambaBase, PluggableLayer):
)
if is_mamba_cache_all:
assert mamba_block_size is not None
assert state_indices_tensor_p is not None
assert block_idx_first_scheduled_token_p is not None
assert block_idx_last_scheduled_token_p is not None
assert last_chunk_indices_p is not None
assert num_computed_tokens_p is not None
# The chunk_stride is the number of chunks per mamba block
# e.g., if mamba_block_size = 512 and chunk_size = 256,
# then chunk_stride = 2
......@@ -810,6 +820,7 @@ class MambaMixer2(MambaBase, PluggableLayer):
ssm_state[cache_blocks_to_fill] = from_where
# For all seqs, store the last state (note: might be partial):
assert state_indices_tensor_p is not None
ssm_state[
state_indices_tensor_p.gather(
1, block_idx_last_scheduled_token_p.unsqueeze(1)
......@@ -820,10 +831,12 @@ class MambaMixer2(MambaBase, PluggableLayer):
# update ssm states
# - varlen state is a (num_prefills, nheads, headdim, dstate)
# tensor
assert state_indices_tensor_p is not None
ssm_state[state_indices_tensor_p] = varlen_states
# Process decode requests
if has_decode:
assert state_indices_tensor_d is not None
if is_mamba_cache_all:
state_indices_tensor_d_input = state_indices_tensor_d.gather(
1, block_idx_last_computed_token_d.unsqueeze(1)
......
......@@ -113,10 +113,11 @@ class ShortConv(MambaBase, CustomOp):
# chunked prefill modes; they are computed at top-level model forward
# since they stay the same and reused for all mamba layers in the same
# iteration.
attn_metadata: AttentionMetadata = forward_context.attn_metadata
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
attn_metadata_raw = forward_context.attn_metadata
attn_metadata: AttentionMetadata | None = None
if attn_metadata_raw is not None:
assert isinstance(attn_metadata_raw, dict)
attn_metadata = attn_metadata_raw[self.prefix]
assert isinstance(attn_metadata, ShortConvAttentionMetadata)
conv_state = (
self.kv_cache[0]
......
......@@ -115,6 +115,7 @@ def pooler_for_classify(
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
assert model_config.pooler_config is not None
head = ClassifierPoolerHead(
head_dtype=model_config.head_dtype,
classifier=classifier,
......
......@@ -124,6 +124,7 @@ def pooler_for_token_classify(
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
assert model_config.pooler_config is not None
head = TokenClassifierPoolerHead(
head_dtype=model_config.head_dtype,
classifier=classifier,
......
......@@ -3,7 +3,7 @@
# Supports FP-Quant compression, see https://arxiv.org/abs/2509.23202
from typing import Any
from typing import Any, Literal, cast
import torch
from torch.nn.parameter import Parameter
......@@ -251,7 +251,11 @@ class FPQuantLinearMethod(LinearMethodBase):
def fused_quantize_mx(
x_flat: torch.Tensor, hadamard_matrix: torch.Tensor, forward_method: str
) -> tuple[torch.Tensor, torch.Tensor]:
return fusedQuantizeMx(x_flat, hadamard_matrix, method=forward_method)
return fusedQuantizeMx(
x_flat,
hadamard_matrix,
method=cast(Literal["quest", "abs_max"], forward_method),
)
def fused_quantize_mx_fake(x_flat, hadamard_matrix, forward_method):
......
......@@ -114,7 +114,7 @@ class QuarkConfig(QuantizationConfig):
:param hf_to_vllm_mapper: maps from hf model structure (the assumed
structure of the qconfig) to vllm model structure
"""
quant_config_with_hf_to_vllm_mapper = {}
quant_config_with_hf_to_vllm_mapper: dict[str, Any] = {}
for k, v in self.quant_config.items():
if isinstance(v, list):
......
......@@ -26,7 +26,7 @@ from vllm.v1.worker.workspace import current_workspace_manager
if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops
elif current_platform.is_xpu():
from vllm._xpu_ops import xpu_ops as ops
from vllm._xpu_ops import xpu_ops
logger = init_logger(__name__)
......@@ -84,12 +84,12 @@ def sparse_attn_indexer(
total_seq_lens,
topk_indices_buffer,
)
attn_metadata = attn_metadata[k_cache_prefix]
assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
slot_mapping = attn_metadata.slot_mapping
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
attn_metadata_narrowed = attn_metadata[k_cache_prefix]
assert isinstance(attn_metadata_narrowed, DeepseekV32IndexerMetadata)
slot_mapping = attn_metadata_narrowed.slot_mapping
has_decode = attn_metadata_narrowed.num_decodes > 0
has_prefill = attn_metadata_narrowed.num_prefills > 0
num_decode_tokens = attn_metadata_narrowed.num_decode_tokens
# During speculative decoding, k may be padded to the CUDA graph batch
# size while slot_mapping only covers actual tokens. Truncate k to avoid
......@@ -97,6 +97,8 @@ def sparse_attn_indexer(
num_tokens = slot_mapping.shape[0]
k = k[:num_tokens]
# scale_fmt can be None, but the function expects str
assert scale_fmt is not None
ops.indexer_k_quant_and_cache(
k,
kv_cache,
......@@ -107,7 +109,7 @@ def sparse_attn_indexer(
topk_indices_buffer[: hidden_states.shape[0]] = -1
if has_prefill:
prefill_metadata = attn_metadata.prefill
prefill_metadata = attn_metadata_narrowed.prefill
assert prefill_metadata is not None
# Get the full shared workspace buffers once (will allocate on first use)
......@@ -144,7 +146,7 @@ def sparse_attn_indexer(
]
if current_platform.is_xpu():
ops.top_k_per_row_prefill(
xpu_ops.top_k_per_row_prefill( # type: ignore[attr-defined]
logits,
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
......@@ -167,7 +169,7 @@ def sparse_attn_indexer(
)
if has_decode:
decode_metadata = attn_metadata.decode
decode_metadata = attn_metadata_narrowed.decode
assert decode_metadata is not None
# kv_cache shape [
# kv_cache size requirement [num_block, block_size, n_head, head_dim],
......@@ -217,11 +219,11 @@ def sparse_attn_indexer(
topk_indices,
topk_workspace,
topk_tokens,
attn_metadata.max_seq_len,
attn_metadata_narrowed.max_seq_len,
)
else:
if current_platform.is_xpu():
ops.top_k_per_row_decode(
xpu_ops.top_k_per_row_decode( # type: ignore[attr-defined]
logits,
next_n,
seq_lens,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment