test_stop_checker.py 7.83 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest
from transformers import AutoTokenizer

from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.reasoning import ReasoningParser
from vllm.sampling_params import SamplingParams
from vllm.sequence import Sequence, SequenceStatus

REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"


class MockReasoningParser(ReasoningParser):
    """Mock reasoning parser for testing purposes."""

    def __init__(self,
                 tokenizer: AutoTokenizer,
                 reasoning_active: bool = False):
        super().__init__(tokenizer)
        self.reasoning_active = reasoning_active

    def is_reasoning_end(self, input_ids: list[int]) -> bool:
        return not self.reasoning_active

    def extract_content_ids(self, input_ids: list[int]) -> list[int]:
        return input_ids


class MockSequence(Sequence):
    """Mock sequence for testing purposes."""

    def __init__(self, token_ids, output_text="test_output", eos_token_id=0):
        self.token_ids = token_ids
        self.output_text = output_text
        self.eos_token_id = eos_token_id
        self.status = SequenceStatus.RUNNING
        self.stop_reason = None

    def get_token_ids(self):
        return self.token_ids

    def get_last_token_id(self):
        return self.token_ids[-1] if self.token_ids else None

    def get_len(self):
        return len(self.token_ids)

    def get_output_len(self):
        return len(self.token_ids) - 1  # Simulating prompt + outputs


@pytest.fixture
def deepseek_r1_qwen_tokenizer():
    return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)


@pytest.fixture
def stop_checker():
61
    return StopChecker(max_model_len=10)
62
63
64
65
66


@pytest.fixture
def stop_checker_with_reasoner():
    reasoner = MockReasoningParser(deepseek_r1_qwen_tokenizer)
67
    return StopChecker(max_model_len=10, reasoner=reasoner)
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225


def test_eos_token_stopping(stop_checker):
    """Test sequence stopping when EOS token is encountered."""
    seq = MockSequence(token_ids=[1, 2, 0], eos_token_id=0)
    sampling_params = SamplingParams()

    stop_checker.maybe_stop_sequence(seq,
                                     new_char_count=1,
                                     sampling_params=sampling_params)

    assert seq.status == SequenceStatus.FINISHED_STOPPED


def test_ignore_eos(stop_checker):
    """Test sequence continuing when EOS token is ignored."""
    seq = MockSequence(token_ids=[1, 2, 0], eos_token_id=0)
    sampling_params = SamplingParams(ignore_eos=True)

    stop_checker.maybe_stop_sequence(seq,
                                     new_char_count=1,
                                     sampling_params=sampling_params)

    assert seq.status == SequenceStatus.RUNNING


def test_min_tokens(stop_checker):
    """Test min_tokens prevents early stopping."""
    seq = MockSequence(token_ids=[1, 2, 0], eos_token_id=0)
    sampling_params = SamplingParams(min_tokens=3)

    stop_checker.maybe_stop_sequence(seq,
                                     new_char_count=1,
                                     sampling_params=sampling_params)

    assert seq.status == SequenceStatus.RUNNING


def test_stop_token_ids(stop_checker):
    """Test sequence stopping with custom stop token IDs."""
    seq = MockSequence(token_ids=[1, 2, 3], eos_token_id=0)
    sampling_params = SamplingParams(stop_token_ids=[3])

    stop_checker.maybe_stop_sequence(seq,
                                     new_char_count=1,
                                     sampling_params=sampling_params)

    assert seq.status == SequenceStatus.FINISHED_STOPPED
    assert seq.stop_reason == 3


def test_stop_strings(stop_checker):
    """Test sequence stopping with stop strings."""
    seq = MockSequence(token_ids=[1, 2, 3],
                       output_text="test output with STOP",
                       eos_token_id=0)
    sampling_params = SamplingParams(stop=["STOP"])

    stop_checker.maybe_stop_sequence(seq,
                                     new_char_count=1,
                                     sampling_params=sampling_params)

    assert seq.status == SequenceStatus.FINISHED_STOPPED
    assert seq.stop_reason == "STOP"
    assert "STOP" not in seq.output_text  # Default behavior removes stop string


def test_include_stop_str_in_output(stop_checker):
    """Test keeping stop strings in output."""
    seq = MockSequence(token_ids=[1, 2, 3],
                       output_text="test output with STOP",
                       eos_token_id=0)
    sampling_params = SamplingParams(stop=["STOP"],
                                     include_stop_str_in_output=True)

    stop_checker.maybe_stop_sequence(seq,
                                     new_char_count=5,
                                     sampling_params=sampling_params)

    assert seq.status == SequenceStatus.FINISHED_STOPPED
    assert "STOP" in seq.output_text


def test_max_tokens(stop_checker):
    """Test sequence stopping at max_tokens."""
    seq = MockSequence(token_ids=[1, 2, 3], eos_token_id=0)
    sampling_params = SamplingParams(max_tokens=2)

    stop_checker.maybe_stop_sequence(seq,
                                     new_char_count=1,
                                     sampling_params=sampling_params)

    assert seq.status == SequenceStatus.FINISHED_LENGTH_CAPPED


def test_max_model_len(stop_checker):
    """Test sequence stopping at max_model_len."""
    seq = MockSequence(token_ids=list(range(11)),
                       eos_token_id=0)  # 11 tokens, max is 10
    sampling_params = SamplingParams()

    stop_checker.maybe_stop_sequence(seq,
                                     new_char_count=1,
                                     sampling_params=sampling_params)

    assert seq.status == SequenceStatus.FINISHED_LENGTH_CAPPED


def test_reasoning_skip_stops(stop_checker_with_reasoner):
    """Test that stop tokens and strings are ignored during reasoning."""
    # Set reasoning_active to True to simulate being in reasoning mode
    stop_checker_with_reasoner.reasoner.reasoning_active = True

    # Test with stop token
    seq = MockSequence(token_ids=[1, 2, 3], eos_token_id=0)
    sampling_params = SamplingParams(stop_token_ids=[3])

    stop_checker_with_reasoner.maybe_stop_sequence(
        seq, new_char_count=1, sampling_params=sampling_params)
    assert seq.status == SequenceStatus.RUNNING

    # Test with stop string
    seq = MockSequence(token_ids=[1, 2, 3], output_text="test STOP")
    sampling_params = SamplingParams(stop=["STOP"])

    stop_checker_with_reasoner.maybe_stop_sequence(
        seq, new_char_count=4, sampling_params=sampling_params)
    assert seq.status == SequenceStatus.RUNNING

    # But EOS token still stops the sequence
    seq = MockSequence(token_ids=[1, 2, 0], eos_token_id=0)
    sampling_params = SamplingParams()

    stop_checker_with_reasoner.maybe_stop_sequence(
        seq, new_char_count=1, sampling_params=sampling_params)
    assert seq.status == SequenceStatus.FINISHED_STOPPED


def test_reasoning_end_enables_stops(stop_checker_with_reasoner):
    """Test that stop tokens work after reasoning ends."""
    # Set reasoning_active to False to simulate being out of reasoning mode
    stop_checker_with_reasoner.reasoner.reasoning_active = False

    # Test with stop token
    seq = MockSequence(token_ids=[1, 2, 3], eos_token_id=0)
    sampling_params = SamplingParams(stop_token_ids=[3])

    stop_checker_with_reasoner.maybe_stop_sequence(
        seq, new_char_count=1, sampling_params=sampling_params)
    assert seq.status == SequenceStatus.FINISHED_STOPPED

    # Test with stop string
    seq = MockSequence(token_ids=[1, 2, 3], output_text="test STOP")
    sampling_params = SamplingParams(stop=["STOP"])

    stop_checker_with_reasoner.maybe_stop_sequence(
        seq, new_char_count=4, sampling_params=sampling_params)
    assert seq.status == SequenceStatus.FINISHED_STOPPED