"examples/pooling/token_classify/ner_offline.py" did not exist on "5f696c33b1fbf33fe91ecdd958874b9dd52f79b4"
Unverified Commit f08919b7 authored by Elaine Zhao's avatar Elaine Zhao Committed by GitHub
Browse files

[Bugfix] Respect min_tokens in scheduler stop check (#26317)


Signed-off-by: default avatarElaine Zhao <elaineyz@amazon.com>
parent 93f2c0aa
......@@ -497,6 +497,96 @@ def test_stop_via_update_from_output():
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11]
def test_check_stop_min_tokens():
"""Test that requests don't stop when min_tokens requirement isn't met."""
from vllm.v1.core.sched.utils import check_stop
# Test case 1: num_output_tokens < min_tokens
# Should return False (don't stop)
sampling_params = SamplingParams(
ignore_eos=False,
max_tokens=20,
min_tokens=5,
)
request = Request(
request_id="0",
prompt_token_ids=[0, 1, 2],
sampling_params=sampling_params,
pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
)
# Simulate having generated 3 output tokens (less than min_tokens=5)
request.append_output_token_ids([10, 11, EOS_TOKEN_ID]) # EOS token present
result = check_stop(request, max_model_len=100)
assert result is False, "Should not stop when num_output_tokens<min_tokens"
# Test case 2: num_output_tokens >= min_tokens
# Should follow normal stopping logic (stop on EOS)
request.append_output_token_ids(
[
10,
11,
12,
13,
14,
EOS_TOKEN_ID,
]
) # 6 tokens > min_tokens
result = check_stop(request, max_model_len=100)
assert result is True, "Should stop on EOS when min_tokens met"
assert request.status == RequestStatus.FINISHED_STOPPED
# Test case 3: min_tokens = 0, should follow normal stopping logic
sampling_params_no_min = SamplingParams(
ignore_eos=False,
max_tokens=20,
min_tokens=0,
)
request_no_min = Request(
request_id="1",
prompt_token_ids=[0, 1, 2],
sampling_params=sampling_params_no_min,
pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
)
request_no_min.append_output_token_ids([10, EOS_TOKEN_ID])
result = check_stop(request_no_min, max_model_len=100)
assert result is True, "Should stop on EOS when min_tokens=0"
assert request_no_min.status == RequestStatus.FINISHED_STOPPED
# Test case 4: min_tokens > 0 with stop token (not EOS)
sampling_params_stop = SamplingParams(
ignore_eos=False,
max_tokens=20,
min_tokens=5,
stop_token_ids=[42],
)
request_stop = Request(
request_id="2",
prompt_token_ids=[0, 1, 2],
sampling_params=sampling_params_stop,
pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
)
# Only 3 output tokens, less than min_tokens=5, but has stop token
request_stop.append_output_token_ids([10, 11, 42])
result = check_stop(request_stop, max_model_len=100)
assert result is False, "Should not stop when num_output_tokens<min_tokens"
# Test case 5: min_tokens met, should stop on stop token
request_stop.append_output_token_ids(
[10, 11, 12, 13, 14, 42]
) # 6 tokens >= min_tokens=5
result = check_stop(request_stop, max_model_len=100)
assert result is True, "Should stop on stop token when min_tokens met"
assert request_stop.status == RequestStatus.FINISHED_STOPPED
assert request_stop.stop_reason == 42
@pytest.mark.parametrize(
"enable_prefix_caching, prompt_logprobs",
[
......
......@@ -58,6 +58,11 @@ def check_stop(
sampling_params = request.sampling_params
assert sampling_params is not None
min_tokens = sampling_params.min_tokens
if request.num_output_tokens < min_tokens:
return False
last_token_id = request.output_token_ids[-1]
if not sampling_params.ignore_eos and last_token_id == request.eos_token_id:
request.status = RequestStatus.FINISHED_STOPPED
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment