test_deepseekr1_reasoning_parser.py 7.42 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
import pytest
from transformers import AutoTokenizer

7
8
from tests.reasoning.utils import run_reasoning_extraction
from vllm.reasoning import ReasoningParser, ReasoningParserManager
9
10
11
12
13

parser_name = "deepseek_r1"
start_token = "<think>"
end_token = "</think>"

14
15
16
17
18
19
20
21
REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"


@pytest.fixture(scope="module")
def deepseek_r1_qwen_tokenizer():
    return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)


22
SIMPLE_REASONING = {
23
    "output": "This is a reasoning section</think>This is the rest",
24
    "reasoning": "This is a reasoning section",
25
    "content": "This is the rest",
26
    "is_reasoning_end": True,
27
28
}
COMPLETE_REASONING = {
29
    "output": "This is a reasoning section</think>",
30
    "reasoning": "This is a reasoning section",
31
    "content": None,
32
    "is_reasoning_end": True,
33
}
34
NO_CONTENT = {
35
    "output": "This is content",
36
    "reasoning": "This is content",
37
    "content": None,
38
    "is_reasoning_end": False,
39
40
41
}
NO_REASONING_STREAMING = {
    "output": "This is a reasoning section",
42
    "reasoning": "This is a reasoning section",
43
    "content": None,
44
    "is_reasoning_end": False,
45
46
}
MULTIPLE_LINES = {
47
    "output": "This\nThat</think>This is the rest\nThat",
48
    "reasoning": "This\nThat",
49
    "content": "This is the rest\nThat",
50
    "is_reasoning_end": True,
51
52
}
SHORTEST_REASONING_NO_STREAMING = {
53
    "output": "</think>This is the rest",
54
    "reasoning": "",
55
    "content": "This is the rest",
56
    "is_reasoning_end": True,
57
58
}
SHORTEST_REASONING = {
59
    "output": "</think>This is the rest",
60
    "reasoning": None,
61
    "content": "This is the rest",
62
    "is_reasoning_end": True,
63
64
65
}
REASONING_WITH_THINK = {
    "output": "<think>This is a reasoning section</think>This is the rest",
66
    "reasoning": "This is a reasoning section",
67
    "content": "This is the rest",
68
    "is_reasoning_end": True,
69
70
71
}
COMPLETE_REASONING_WITH_THINK = {
    "output": "<think>This is a reasoning section</think>",
72
    "reasoning": "This is a reasoning section",
73
    "content": None,
74
    "is_reasoning_end": True,
75
76
77
}
MULTIPLE_LINES_WITH_THINK = {
    "output": "<think>This\nThat</think>This is the rest\nThat",
78
    "reasoning": "This\nThat",
79
    "content": "This is the rest\nThat",
80
    "is_reasoning_end": True,
81
82
83
}
SHORTEST_REASONING_NO_STREAMING_WITH_THINK = {
    "output": "</think>This is the rest",
84
    "reasoning": "",
85
    "content": "This is the rest",
86
    "is_reasoning_end": True,
87
88
89
}
SHORTEST_REASONING_WITH_THINK = {
    "output": "</think>This is the rest",
90
    "reasoning": None,
91
    "content": "This is the rest",
92
    "is_reasoning_end": True,
93
}
94
95
THINK_NO_END = {
    "output": "<think>This is a reasoning section",
96
    "reasoning": "This is a reasoning section",
97
98
99
100
101
    "content": None,
    "is_reasoning_end": False,
}
EMPTY = {
    "output": "",
102
    "reasoning": "",
103
104
105
106
107
    "content": None,
    "is_reasoning_end": False,
}
EMPTY_STREAMING = {
    "output": "",
108
    "reasoning": None,
109
110
111
112
113
    "content": None,
    "is_reasoning_end": False,
}
NEW_LINE = {
    "output": "\n<think>This is a reasoning section</think>\nThis is the rest",
114
    "reasoning": "This is a reasoning section",
115
116
117
118
119
120
121
122
123
    "content": "\nThis is the rest",
    "is_reasoning_end": True,
}
# Streaming cannot handle new lines at the beginning of the output
# because we need to support <think>...</think> and </think>...
# We cannot know if the text before <think> is reasoning content
# or not.
NEW_LINE_STREAMING = {
    "output": "\n<think>This is a reasoning section</think>\nThis is the rest",
124
    "reasoning": "\nThis is a reasoning section",
125
126
127
    "content": "\nThis is the rest",
    "is_reasoning_end": True,
}
128
129
130
131
132

TEST_CASES = [
    pytest.param(
        False,
        SIMPLE_REASONING,
133
        id="simple_reasoning",
134
135
136
137
    ),
    pytest.param(
        True,
        SIMPLE_REASONING,
138
        id="simple_reasoning_streaming",
139
140
141
142
    ),
    pytest.param(
        False,
        COMPLETE_REASONING,
143
        id="complete_reasoning",
144
145
146
147
    ),
    pytest.param(
        True,
        COMPLETE_REASONING,
148
        id="complete_reasoning_streaming",
149
150
151
    ),
    pytest.param(
        False,
152
153
        NO_CONTENT,
        id="no_content_token",
154
155
156
    ),
    pytest.param(
        True,
157
158
        NO_REASONING_STREAMING,
        id="no_reasoning_token_streaming",
159
160
161
162
    ),
    pytest.param(
        False,
        MULTIPLE_LINES,
163
        id="multiple_lines",
164
165
166
167
168
169
170
171
172
    ),
    pytest.param(
        True,
        MULTIPLE_LINES,
        id="multiple_lines_streaming",
    ),
    pytest.param(
        True,
        SHORTEST_REASONING,
173
        id="shortest",
174
175
176
177
178
179
    ),
    pytest.param(
        False,
        SHORTEST_REASONING_NO_STREAMING,
        id="shortest_streaming",
    ),
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
    pytest.param(
        False,
        REASONING_WITH_THINK,
        id="reasoning_with_think",
    ),
    pytest.param(
        True,
        REASONING_WITH_THINK,
        id="reasoning_with_think_streaming",
    ),
    pytest.param(
        False,
        COMPLETE_REASONING_WITH_THINK,
        id="complete_reasoning_with_think",
    ),
    pytest.param(
        True,
        COMPLETE_REASONING_WITH_THINK,
        id="complete_reasoning_with_think_streaming",
    ),
    pytest.param(
        False,
        MULTIPLE_LINES_WITH_THINK,
        id="multiple_lines_with_think",
    ),
    pytest.param(
        True,
        MULTIPLE_LINES_WITH_THINK,
        id="multiple_lines_with_think_streaming",
    ),
    pytest.param(
        False,
        SHORTEST_REASONING_NO_STREAMING_WITH_THINK,
        id="shortest_with_think",
    ),
    pytest.param(
        True,
        SHORTEST_REASONING_WITH_THINK,
        id="shortest_with_think_streaming",
    ),
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
    pytest.param(
        False,
        THINK_NO_END,
        id="think_no_end",
    ),
    pytest.param(
        True,
        THINK_NO_END,
        id="think_no_end_streaming",
    ),
    pytest.param(
        False,
        EMPTY,
        id="empty",
    ),
    pytest.param(
        True,
        EMPTY_STREAMING,
        id="empty_streaming",
    ),
    pytest.param(
        False,
        NEW_LINE,
        id="new_line",
    ),
    pytest.param(
        True,
        NEW_LINE_STREAMING,
        id="new_line_streaming",
    ),
250
251
252
253
254
255
256
]


@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
def test_reasoning(
    streaming: bool,
    param_dict: dict,
257
    deepseek_r1_qwen_tokenizer,
258
):
259
    output = deepseek_r1_qwen_tokenizer.tokenize(param_dict["output"])
260
    # decode everything to tokens
261
    output_tokens: list[str] = [
262
        deepseek_r1_qwen_tokenizer.convert_tokens_to_string([token]) for token in output
263
    ]
264
265
266
    parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)(
        deepseek_r1_qwen_tokenizer
    )
267

268
269
270
    reasoning, content = run_reasoning_extraction(
        parser, output_tokens, streaming=streaming
    )
271

272
    assert reasoning == param_dict["reasoning"]
273
    assert content == param_dict["content"]
274
275
276
277
278
279
280
281
282
283

    # Test is_reasoning_end
    output_ids = deepseek_r1_qwen_tokenizer.convert_tokens_to_ids(output)
    is_reasoning_end = parser.is_reasoning_end(output_ids)
    assert is_reasoning_end == param_dict["is_reasoning_end"]

    # Test extract_content
    if param_dict["content"] is not None:
        content = parser.extract_content_ids(output_ids)
        assert content == deepseek_r1_qwen_tokenizer.convert_tokens_to_ids(
284
285
            deepseek_r1_qwen_tokenizer.tokenize(param_dict["content"])
        )
286
287
288
    else:
        content = parser.extract_content_ids(output)
        assert content == []