Unverified Commit 7ef58737 authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[CI] Fix mypy for `vllm/v1/structured_output` (#32722)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent 5e4e0e51
...@@ -49,6 +49,7 @@ FILES = [ ...@@ -49,6 +49,7 @@ FILES = [
"vllm/v1/metrics", "vllm/v1/metrics",
"vllm/v1/pool", "vllm/v1/pool",
"vllm/v1/sample", "vllm/v1/sample",
"vllm/v1/structured_output",
"vllm/v1/worker", "vllm/v1/worker",
] ]
...@@ -64,7 +65,6 @@ SEPARATE_GROUPS = [ ...@@ -64,7 +65,6 @@ SEPARATE_GROUPS = [
# v1 related # v1 related
"vllm/v1/kv_offload", "vllm/v1/kv_offload",
"vllm/v1/spec_decode", "vllm/v1/spec_decode",
"vllm/v1/structured_output",
] ]
# TODO(woosuk): Include the code from Megatron and HuggingFace. # TODO(woosuk): Include the code from Megatron and HuggingFace.
......
...@@ -51,7 +51,7 @@ class ReasoningParser: ...@@ -51,7 +51,7 @@ class ReasoningParser:
return self.model_tokenizer.get_vocab() return self.model_tokenizer.get_vocab()
@abstractmethod @abstractmethod
def is_reasoning_end(self, input_ids: list[int]) -> bool: def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
""" """
Check if the reasoning content ends in the input_ids. Check if the reasoning content ends in the input_ids.
...@@ -68,7 +68,7 @@ class ReasoningParser: ...@@ -68,7 +68,7 @@ class ReasoningParser:
""" """
def is_reasoning_end_streaming( def is_reasoning_end_streaming(
self, input_ids: list[int], delta_ids: list[int] self, input_ids: Sequence[int], delta_ids: Sequence[int]
) -> bool: ) -> bool:
""" """
Check if the reasoning content ends in the input_ids on a Check if the reasoning content ends in the input_ids on a
......
...@@ -65,7 +65,7 @@ class BaseThinkingReasoningParser(ReasoningParser): ...@@ -65,7 +65,7 @@ class BaseThinkingReasoningParser(ReasoningParser):
"think start/end tokens in the tokenizer!" "think start/end tokens in the tokenizer!"
) )
def is_reasoning_end(self, input_ids: list[int]) -> bool: def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
start_token_id = self.start_token_id start_token_id = self.start_token_id
end_token_id = self.end_token_id end_token_id = self.end_token_id
...@@ -77,7 +77,7 @@ class BaseThinkingReasoningParser(ReasoningParser): ...@@ -77,7 +77,7 @@ class BaseThinkingReasoningParser(ReasoningParser):
return False return False
def is_reasoning_end_streaming( def is_reasoning_end_streaming(
self, input_ids: list[int], delta_ids: list[int] self, input_ids: Sequence[int], delta_ids: Sequence[int]
) -> bool: ) -> bool:
end_token_id = self.end_token_id end_token_id = self.end_token_id
return end_token_id in delta_ids return end_token_id in delta_ids
......
...@@ -41,7 +41,7 @@ class DeepSeekV3ReasoningParser(ReasoningParser): ...@@ -41,7 +41,7 @@ class DeepSeekV3ReasoningParser(ReasoningParser):
return self._parser.is_reasoning_end(input_ids) return self._parser.is_reasoning_end(input_ids)
def is_reasoning_end_streaming( def is_reasoning_end_streaming(
self, input_ids: list[int], delta_ids: list[int] self, input_ids: Sequence[int], delta_ids: Sequence[int]
) -> bool: ) -> bool:
return self._parser.is_reasoning_end_streaming(input_ids, delta_ids) return self._parser.is_reasoning_end_streaming(input_ids, delta_ids)
......
...@@ -78,7 +78,7 @@ class GptOssReasoningParser(ReasoningParser): ...@@ -78,7 +78,7 @@ class GptOssReasoningParser(ReasoningParser):
self.reasoning_end_token_ids_suffix = self.model_tokenizer.encode("<|message|>") self.reasoning_end_token_ids_suffix = self.model_tokenizer.encode("<|message|>")
self.reasoning_max_num_between_tokens = 20 self.reasoning_max_num_between_tokens = 20
def is_reasoning_end(self, input_ids: list[int]) -> bool: def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
end_token_ids_prefix = self.reasoning_end_token_ids_prefix end_token_ids_prefix = self.reasoning_end_token_ids_prefix
end_token_ids_suffix = self.reasoning_end_token_ids_suffix end_token_ids_suffix = self.reasoning_end_token_ids_suffix
assert len(end_token_ids_prefix) > 0, "reasoning_end_token_ids_prefix is empty" assert len(end_token_ids_prefix) > 0, "reasoning_end_token_ids_prefix is empty"
......
...@@ -61,7 +61,7 @@ class Holo2ReasoningParser(ReasoningParser): ...@@ -61,7 +61,7 @@ class Holo2ReasoningParser(ReasoningParser):
return self._parser.is_reasoning_end(input_ids) return self._parser.is_reasoning_end(input_ids)
def is_reasoning_end_streaming( def is_reasoning_end_streaming(
self, input_ids: list[int], delta_ids: list[int] self, input_ids: Sequence[int], delta_ids: Sequence[int]
) -> bool: ) -> bool:
return self._parser.is_reasoning_end_streaming(input_ids, delta_ids) return self._parser.is_reasoning_end_streaming(input_ids, delta_ids)
......
...@@ -79,7 +79,7 @@ class HunyuanA13BReasoningParser(ReasoningParser): ...@@ -79,7 +79,7 @@ class HunyuanA13BReasoningParser(ReasoningParser):
self.token_buffer = [] self.token_buffer = []
self.text_buffer = "" self.text_buffer = ""
def is_reasoning_end(self, input_ids: list[int]) -> bool: def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
return self.current_state == "response" return self.current_state == "response"
def extract_content_ids(self, input_ids: list[int]) -> list[int]: def extract_content_ids(self, input_ids: list[int]) -> list[int]:
......
...@@ -31,12 +31,12 @@ class IdentityReasoningParser(ReasoningParser): ...@@ -31,12 +31,12 @@ class IdentityReasoningParser(ReasoningParser):
"constructor during construction." "constructor during construction."
) )
def is_reasoning_end(self, input_ids: list[int]) -> bool: def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
# Always return True, since we never treat reasoning specially # Always return True, since we never treat reasoning specially
return True return True
def is_reasoning_end_streaming( def is_reasoning_end_streaming(
self, input_ids: list[int], delta_ids: list[int] self, input_ids: Sequence[int], delta_ids: Sequence[int]
) -> bool: ) -> bool:
return True return True
......
...@@ -88,7 +88,7 @@ class MiniMaxM2AppendThinkReasoningParser(ReasoningParser): ...@@ -88,7 +88,7 @@ class MiniMaxM2AppendThinkReasoningParser(ReasoningParser):
super().__init__(tokenizer, *args, **kwargs) super().__init__(tokenizer, *args, **kwargs)
self.end_token_id = self.vocab.get("</think>") self.end_token_id = self.vocab.get("</think>")
def is_reasoning_end(self, input_ids: list[int]) -> bool: def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
end_token_id = self.end_token_id end_token_id = self.end_token_id
return any(input_id == end_token_id for input_id in reversed(input_ids)) return any(input_id == end_token_id for input_id in reversed(input_ids))
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from functools import cached_property from functools import cached_property
from vllm.entrypoints.openai.chat_completion.protocol import ( from vllm.entrypoints.openai.chat_completion.protocol import (
...@@ -65,7 +66,7 @@ class MistralReasoningParser(BaseThinkingReasoningParser): ...@@ -65,7 +66,7 @@ class MistralReasoningParser(BaseThinkingReasoningParser):
return SpecialTokens.end_think return SpecialTokens.end_think
def is_reasoning_end(self, input_ids: list[int]) -> bool: def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
has_eot_token = False has_eot_token = False
for id in input_ids[::-1]: for id in input_ids[::-1]:
......
...@@ -242,7 +242,7 @@ class Olmo3ReasoningParser(ReasoningParser): ...@@ -242,7 +242,7 @@ class Olmo3ReasoningParser(ReasoningParser):
think_start=self.think_start, think_end=self.think_end think_start=self.think_start, think_end=self.think_end
) )
def is_reasoning_end(self, input_ids: list[int]) -> bool: def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
text = self.model_tokenizer.decode(input_ids) text = self.model_tokenizer.decode(input_ids)
return self.think_end in text return self.think_end in text
......
...@@ -100,11 +100,11 @@ class Step3ReasoningParser(ReasoningParser): ...@@ -100,11 +100,11 @@ class Step3ReasoningParser(ReasoningParser):
return reasoning, content return reasoning, content
def is_reasoning_end(self, input_ids: list[int]) -> bool: def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
return self.think_end_token_id in input_ids return self.think_end_token_id in input_ids
def is_reasoning_end_streaming( def is_reasoning_end_streaming(
self, input_ids: list[int], delta_ids: list[int] self, input_ids: Sequence[int], delta_ids: Sequence[int]
) -> bool: ) -> bool:
end_token_id = self.think_end_token_id end_token_id = self.think_end_token_id
return end_token_id in delta_ids return end_token_id in delta_ids
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -17,15 +19,15 @@ elif current_platform.is_xpu(): ...@@ -17,15 +19,15 @@ elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops from vllm._ipex_ops import ipex_ops
reshape_and_cache_flash = ipex_ops.reshape_and_cache_flash reshape_and_cache_flash = ipex_ops.reshape_and_cache_flash
flash_attn_varlen_func = ipex_ops.flash_attn_varlen_func flash_attn_varlen_func = ipex_ops.flash_attn_varlen_func # type: ignore[assignment]
get_scheduler_metadata = ipex_ops.get_scheduler_metadata get_scheduler_metadata = ipex_ops.get_scheduler_metadata # type: ignore[assignment]
elif current_platform.is_rocm(): elif current_platform.is_rocm():
try: try:
from flash_attn import flash_attn_varlen_func # noqa: F401 from flash_attn import flash_attn_varlen_func # type: ignore[no-redef]
except ImportError: except ImportError:
def flash_attn_varlen_func(*args, **kwargs): def flash_attn_varlen_func(*args: Any, **kwargs: Any) -> Any: # type: ignore[no-redef,misc]
raise ImportError( raise ImportError(
"ROCm platform requires upstream flash-attn " "ROCm platform requires upstream flash-attn "
"to be installed. Please install flash-attn first." "to be installed. Please install flash-attn first."
......
...@@ -49,7 +49,7 @@ class AiterTritonMLAImpl(AiterMLAImpl): ...@@ -49,7 +49,7 @@ class AiterTritonMLAImpl(AiterMLAImpl):
def _flash_attn_varlen_diff_headdims( def _flash_attn_varlen_diff_headdims(
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
): ):
result = self.flash_attn_varlen_func( result = self.flash_attn_varlen_func( # type: ignore[call-arg]
q, q,
k, k,
v, v,
......
...@@ -230,7 +230,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): ...@@ -230,7 +230,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
def _flash_attn_varlen_diff_headdims( def _flash_attn_varlen_diff_headdims(
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
): ):
output = self.flash_attn_varlen_func( output = self.flash_attn_varlen_func( # type: ignore[call-arg]
q=q, q=q,
k=k, k=k,
v=v, v=v,
......
...@@ -294,7 +294,7 @@ class StructuredOutputManager: ...@@ -294,7 +294,7 @@ class StructuredOutputManager:
assert request.structured_output_request is not None assert request.structured_output_request is not None
if request.structured_output_request.reasoning_ended is None: if request.structured_output_request.reasoning_ended is None:
request.structured_output_request.reasoning_ended = ( request.structured_output_request.reasoning_ended = (
self.reasoner.is_reasoning_end(request.prompt_token_ids) self.reasoner.is_reasoning_end(request.prompt_token_ids or [])
) )
return request.structured_output_request.reasoning_ended return request.structured_output_request.reasoning_ended
return True return True
...@@ -323,8 +323,9 @@ class StructuredOutputManager: ...@@ -323,8 +323,9 @@ class StructuredOutputManager:
# Check if reasoning ends in *this* step # Check if reasoning ends in *this* step
delta_from = request.num_computed_tokens - request.num_output_placeholders delta_from = request.num_computed_tokens - request.num_output_placeholders
all_token_ids = request.all_token_ids
if self.reasoner.is_reasoning_end_streaming( if self.reasoner.is_reasoning_end_streaming(
request.all_token_ids, request.all_token_ids[delta_from:] all_token_ids, all_token_ids[delta_from:]
): ):
# Reasoning just ended, so we shouldn't advance til # Reasoning just ended, so we shouldn't advance til
# next pass # next pass
......
...@@ -284,6 +284,9 @@ def serialize_guidance_grammar( ...@@ -284,6 +284,9 @@ def serialize_guidance_grammar(
def validate_guidance_grammar( def validate_guidance_grammar(
sampling_params: SamplingParams, tokenizer: llguidance.LLTokenizer | None = None sampling_params: SamplingParams, tokenizer: llguidance.LLTokenizer | None = None
) -> None: ) -> None:
# if structured output is not enabled, there is nothing to validate
if sampling_params.structured_outputs is None:
return
tp, grm = get_structured_output_key(sampling_params.structured_outputs) tp, grm = get_structured_output_key(sampling_params.structured_outputs)
guidance_grm = serialize_guidance_grammar(tp, grm) guidance_grm = serialize_guidance_grammar(tp, grm)
err = llguidance.LLMatcher.validate_grammar(guidance_grm, tokenizer) err = llguidance.LLMatcher.validate_grammar(guidance_grm, tokenizer)
......
...@@ -69,7 +69,7 @@ class XgrammarBackend(StructuredOutputBackend): ...@@ -69,7 +69,7 @@ class XgrammarBackend(StructuredOutputBackend):
if idx < vocab_size: if idx < vocab_size:
encoded_vocab[idx] = token encoded_vocab[idx] = token
stop_token_ids = [self.tokenizer.eos_token_id] stop_token_ids = [self.tokenizer.eos_token_id]
backend_str = self.tokenizer.tokenizer.backend_tokenizer.to_str() backend_str = self.tokenizer.tokenizer.backend_tokenizer.to_str() # type: ignore[attr-defined]
metadata = xgr.TokenizerInfo._detect_metadata_from_hf(backend_str) metadata = xgr.TokenizerInfo._detect_metadata_from_hf(backend_str)
tokenizer_info = xgr.TokenizerInfo( tokenizer_info = xgr.TokenizerInfo(
encoded_vocab=encoded_vocab, encoded_vocab=encoded_vocab,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment