"vllm/vscode:/vscode.git/clone" did not exist on "b53d79983c273b2775456d99c0e0890aea073512"
test_deepseekr1_reasoning_parser.py 6.01 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
import pytest
from transformers import AutoTokenizer

6
7
from tests.reasoning.utils import run_reasoning_extraction
from vllm.reasoning import ReasoningParser, ReasoningParserManager
8
9
10
11
12

parser_name = "deepseek_r1"
start_token = "<think>"
end_token = "</think>"

13
14
15
16
17
18
19
20
REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"


@pytest.fixture(scope="module")
def deepseek_r1_qwen_tokenizer():
    return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)


21
SIMPLE_REASONING = {
22
    "output": "This is a reasoning section</think>This is the rest",
23
24
    "reasoning_content": "This is a reasoning section",
    "content": "This is the rest",
25
    "is_reasoning_end": True,
26
27
}
COMPLETE_REASONING = {
28
    "output": "This is a reasoning section</think>",
29
30
    "reasoning_content": "This is a reasoning section",
    "content": None,
31
    "is_reasoning_end": True,
32
}
33
NO_CONTENT = {
34
    "output": "This is content",
35
36
    "reasoning_content": "This is content",
    "content": None,
37
    "is_reasoning_end": False,
38
39
40
41
42
}
NO_REASONING_STREAMING = {
    "output": "This is a reasoning section",
    "reasoning_content": "This is a reasoning section",
    "content": None,
43
    "is_reasoning_end": False,
44
45
}
MULTIPLE_LINES = {
46
    "output": "This\nThat</think>This is the rest\nThat",
47
48
    "reasoning_content": "This\nThat",
    "content": "This is the rest\nThat",
49
    "is_reasoning_end": True,
50
51
}
SHORTEST_REASONING_NO_STREAMING = {
52
    "output": "</think>This is the rest",
53
54
    "reasoning_content": "",
    "content": "This is the rest",
55
    "is_reasoning_end": True,
56
57
}
SHORTEST_REASONING = {
58
59
60
    "output": "</think>This is the rest",
    "reasoning_content": None,
    "content": "This is the rest",
61
    "is_reasoning_end": True,
62
63
64
65
66
}
REASONING_WITH_THINK = {
    "output": "<think>This is a reasoning section</think>This is the rest",
    "reasoning_content": "This is a reasoning section",
    "content": "This is the rest",
67
    "is_reasoning_end": True,
68
69
70
71
72
}
COMPLETE_REASONING_WITH_THINK = {
    "output": "<think>This is a reasoning section</think>",
    "reasoning_content": "This is a reasoning section",
    "content": None,
73
    "is_reasoning_end": True,
74
75
76
77
78
}
MULTIPLE_LINES_WITH_THINK = {
    "output": "<think>This\nThat</think>This is the rest\nThat",
    "reasoning_content": "This\nThat",
    "content": "This is the rest\nThat",
79
    "is_reasoning_end": True,
80
81
82
83
84
}
SHORTEST_REASONING_NO_STREAMING_WITH_THINK = {
    "output": "</think>This is the rest",
    "reasoning_content": "",
    "content": "This is the rest",
85
    "is_reasoning_end": True,
86
87
88
}
SHORTEST_REASONING_WITH_THINK = {
    "output": "</think>This is the rest",
89
90
    "reasoning_content": None,
    "content": "This is the rest",
91
    "is_reasoning_end": True,
92
93
94
95
96
97
}

TEST_CASES = [
    pytest.param(
        False,
        SIMPLE_REASONING,
98
        id="simple_reasoning",
99
100
101
102
    ),
    pytest.param(
        True,
        SIMPLE_REASONING,
103
        id="simple_reasoning_streaming",
104
105
106
107
    ),
    pytest.param(
        False,
        COMPLETE_REASONING,
108
        id="complete_reasoning",
109
110
111
112
    ),
    pytest.param(
        True,
        COMPLETE_REASONING,
113
        id="complete_reasoning_streaming",
114
115
116
    ),
    pytest.param(
        False,
117
118
        NO_CONTENT,
        id="no_content_token",
119
120
121
    ),
    pytest.param(
        True,
122
123
        NO_REASONING_STREAMING,
        id="no_reasoning_token_streaming",
124
125
126
127
    ),
    pytest.param(
        False,
        MULTIPLE_LINES,
128
        id="multiple_lines",
129
130
131
132
133
134
135
136
137
    ),
    pytest.param(
        True,
        MULTIPLE_LINES,
        id="multiple_lines_streaming",
    ),
    pytest.param(
        True,
        SHORTEST_REASONING,
138
        id="shortest",
139
140
141
142
143
144
    ),
    pytest.param(
        False,
        SHORTEST_REASONING_NO_STREAMING,
        id="shortest_streaming",
    ),
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
    pytest.param(
        False,
        REASONING_WITH_THINK,
        id="reasoning_with_think",
    ),
    pytest.param(
        True,
        REASONING_WITH_THINK,
        id="reasoning_with_think_streaming",
    ),
    pytest.param(
        False,
        COMPLETE_REASONING_WITH_THINK,
        id="complete_reasoning_with_think",
    ),
    pytest.param(
        True,
        COMPLETE_REASONING_WITH_THINK,
        id="complete_reasoning_with_think_streaming",
    ),
    pytest.param(
        False,
        MULTIPLE_LINES_WITH_THINK,
        id="multiple_lines_with_think",
    ),
    pytest.param(
        True,
        MULTIPLE_LINES_WITH_THINK,
        id="multiple_lines_with_think_streaming",
    ),
    pytest.param(
        False,
        SHORTEST_REASONING_NO_STREAMING_WITH_THINK,
        id="shortest_with_think",
    ),
    pytest.param(
        True,
        SHORTEST_REASONING_WITH_THINK,
        id="shortest_with_think_streaming",
    ),
185
186
187
188
189
190
191
]


@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
def test_reasoning(
    streaming: bool,
    param_dict: dict,
192
    deepseek_r1_qwen_tokenizer,
193
):
194
    output = deepseek_r1_qwen_tokenizer.tokenize(param_dict["output"])
195
    # decode everything to tokens
196
    output_tokens: list[str] = [
197
198
        deepseek_r1_qwen_tokenizer.convert_tokens_to_string([token])
        for token in output
199
200
    ]
    parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(
201
        parser_name)(deepseek_r1_qwen_tokenizer)
202
203
204
205
206
207
208

    reasoning, content = run_reasoning_extraction(parser,
                                                  output_tokens,
                                                  streaming=streaming)

    assert reasoning == param_dict["reasoning_content"]
    assert content == param_dict["content"]
209
210
211
212
213
214
215
216
217
218
219
220
221
222

    # Test is_reasoning_end
    output_ids = deepseek_r1_qwen_tokenizer.convert_tokens_to_ids(output)
    is_reasoning_end = parser.is_reasoning_end(output_ids)
    assert is_reasoning_end == param_dict["is_reasoning_end"]

    # Test extract_content
    if param_dict["content"] is not None:
        content = parser.extract_content_ids(output_ids)
        assert content == deepseek_r1_qwen_tokenizer.convert_tokens_to_ids(
            deepseek_r1_qwen_tokenizer.tokenize(param_dict["content"]))
    else:
        content = parser.extract_content_ids(output)
        assert content == []