Unverified Commit 81786c87 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[BugFix] Fix async scheduling + reasoning with struct output (#31332)


Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
parent f1531d9f
......@@ -608,7 +608,7 @@ Make the response as short as possible.
@pytest.mark.parametrize(
"model_name, backend, tokenizer_mode, reasoning_parser, speculative_config", # noqa: E501
"model_name, backend, tokenizer_mode, reasoning_parser, speculative_config, async_scheduling", # noqa: E501
[
(
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
......@@ -616,8 +616,10 @@ Make the response as short as possible.
"auto",
"deepseek_r1",
NGRAM_SPEC_CONFIG,
False,
),
("Qwen/Qwen3-1.7B", "xgrammar", "auto", "deepseek_r1", None),
("Qwen/Qwen3-1.7B", "xgrammar", "auto", "deepseek_r1", None, False),
("Qwen/Qwen3-1.7B", "xgrammar", "auto", "deepseek_r1", None, True),
],
)
def test_structured_output_with_reasoning_matrices(
......@@ -626,6 +628,7 @@ def test_structured_output_with_reasoning_matrices(
reasoning_parser: str,
model_name: str,
speculative_config: dict[str, Any] | None,
async_scheduling: bool,
):
if current_platform.is_tpu() and speculative_config:
pytest.skip("TPU does not support speculative decoding")
......@@ -646,6 +649,7 @@ def test_structured_output_with_reasoning_matrices(
),
tokenizer_mode=tokenizer_mode,
speculative_config=speculative_config,
async_scheduling=async_scheduling,
)
tokenizer = llm.get_tokenizer()
reasoner = ReasoningParserManager.get_reasoning_parser(reasoning_parser)(
......
......@@ -71,6 +71,7 @@ class TestReasoningStructuredOutput:
request.prompt_token_ids = [1, 2, 3, 4, 5]
request.all_token_ids = [1, 2, 3, 4, 5, 6, 7, 8]
request.num_computed_tokens = 5
request.num_output_placeholders = 0
return request
def test_should_fill_bitmask_with_enable_in_reasoning(
......
......@@ -339,8 +339,9 @@ class StructuredOutputManager:
return True
# Check if reasoning ends in *this* step
delta_from = request.num_computed_tokens - request.num_output_placeholders
if self.reasoner.is_reasoning_end_streaming(
request.all_token_ids, request.all_token_ids[request.num_computed_tokens :]
request.all_token_ids, request.all_token_ids[delta_from:]
):
# Reasoning just ended, so we shouldn't advance til
# next pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment