test_ernie45_reasoning_parser.py 2.72 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest
from transformers import AutoTokenizer

from tests.reasoning.utils import run_reasoning_extraction
from vllm.reasoning import ReasoningParser, ReasoningParserManager

parser_name = "ernie45"

REASONING_MODEL_NAME = "baidu/ERNIE-4.5-21B-A3B-Thinking"


@pytest.fixture(scope="module")
def ernie45_tokenizer():
    return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)


# 带 </think>,非stream
WITH_THINK = {
    "output": "abc</think>def",
23
    "reasoning": "abc",
24
25
26
27
28
    "content": "def",
}
# 带 </think>,stream
WITH_THINK_STREAM = {
    "output": "abc</think>def",
29
    "reasoning": "abc",
30
31
    "content": "def",
}
32
# without </think>, all is reasoning
33
34
WITHOUT_THINK = {
    "output": "abc",
35
    "reasoning": "abc",
36
37
    "content": None,
}
38
# without </think>, all is reasoning
39
40
WITHOUT_THINK_STREAM = {
    "output": "abc",
41
    "reasoning": "abc",
42
43
44
45
46
    "content": None,
}

COMPLETE_REASONING = {
    "output": "abc</think>",
47
    "reasoning": "abc",
48
49
50
51
    "content": None,
}
MULTILINE_REASONING = {
    "output": "abc\nABC</think>def\nDEF",
52
    "reasoning": "abc\nABC",
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    "content": "def\nDEF",
}

TEST_CASES = [
    pytest.param(
        False,
        WITH_THINK,
        id="with_think",
    ),
    pytest.param(
        True,
        WITH_THINK_STREAM,
        id="with_think_stream",
    ),
    pytest.param(
        False,
        WITHOUT_THINK,
        id="without_think",
    ),
    pytest.param(
        True,
        WITHOUT_THINK_STREAM,
        id="without_think_stream",
    ),
    pytest.param(
        False,
        COMPLETE_REASONING,
        id="complete_reasoning",
    ),
    pytest.param(
        True,
        COMPLETE_REASONING,
        id="complete_reasoning_stream",
    ),
    pytest.param(
        False,
        MULTILINE_REASONING,
        id="multiline_reasoning",
    ),
    pytest.param(
        True,
        MULTILINE_REASONING,
        id="multiline_reasoning_stream",
    ),
]


@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
def test_reasoning(
    streaming: bool,
    param_dict: dict,
    ernie45_tokenizer,
):
    output = ernie45_tokenizer.tokenize(param_dict["output"])
    output_tokens: list[str] = []
    for token in output:
        one_token = ernie45_tokenizer.convert_tokens_to_string([token])
        if one_token:
            output_tokens.append(one_token)

    parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)(
        ernie45_tokenizer
    )

    reasoning, content = run_reasoning_extraction(
        parser, output_tokens, streaming=streaming
    )

    print()

123
    assert reasoning == param_dict["reasoning"]
124
    assert content == param_dict["content"]