test_template.py 15.1 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
chenych's avatar
chenych committed
16
from typing import TYPE_CHECKING
chenych's avatar
chenych committed
17
18
19
20
21

import pytest
from transformers import AutoTokenizer

from llamafactory.data import get_template_and_fix_tokenizer
chenych's avatar
chenych committed
22
from llamafactory.data.template import parse_template
luopl's avatar
luopl committed
23
from llamafactory.hparams import DataArguments
chenych's avatar
chenych committed
24
25
26
27
28
29


if TYPE_CHECKING:
    from transformers import PreTrainedTokenizer


luopl's avatar
luopl committed
30
HF_TOKEN = os.getenv("HF_TOKEN")
chenych's avatar
chenych committed
31

chenych's avatar
chenych committed
32
33
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
TINY_LLAMA4 = os.getenv("TINY_LLAMA4", "llamafactory/tiny-random-Llama-4")
chenych's avatar
chenych committed
34
35
36
37
38
39
40
41

MESSAGES = [
    {"role": "user", "content": "How are you"},
    {"role": "assistant", "content": "I am fine!"},
    {"role": "user", "content": "你好"},
    {"role": "assistant", "content": "很高兴认识你!"},
]

chenych's avatar
chenych committed
42
43
44
45
46
47
48
MESSAGES_WITH_THOUGHT = [
    {"role": "user", "content": "How are you"},
    {"role": "assistant", "content": "<think>\nModel thought here\n</think>\n\nI am fine!"},
    {"role": "user", "content": "你好"},
    {"role": "assistant", "content": "<think>\n模型思考内容\n</think>\n\n很高兴认识你!"},
]

chenych's avatar
chenych committed
49
50

def _check_tokenization(
chenych's avatar
chenych committed
51
    tokenizer: "PreTrainedTokenizer", batch_input_ids: list[list[int]], batch_text: list[str]
chenych's avatar
chenych committed
52
) -> None:
chenych's avatar
chenych committed
53
    r"""Check token ids and texts.
luopl's avatar
luopl committed
54
55
56
57

    encode(text) == token_ids
    decode(token_ids) == text
    """
chenych's avatar
chenych committed
58
    for input_ids, text in zip(batch_input_ids, batch_text):
luopl's avatar
luopl committed
59
        assert tokenizer.encode(text, add_special_tokens=False) == input_ids
chenych's avatar
chenych committed
60
61
62
        assert tokenizer.decode(input_ids) == text


chenych's avatar
chenych committed
63
64
65
66
67
68
69
70
def _check_template(
    model_id: str,
    template_name: str,
    prompt_str: str,
    answer_str: str,
    use_fast: bool,
    messages: list[dict[str, str]] = MESSAGES,
) -> None:
chenych's avatar
chenych committed
71
    r"""Check template.
chenych's avatar
chenych committed
72
73
74
75
76
77

    Args:
        model_id: the model id on hugging face hub.
        template_name: the template name.
        prompt_str: the string corresponding to the prompt part.
        answer_str: the string corresponding to the answer part.
luopl's avatar
luopl committed
78
        use_fast: whether to use fast tokenizer.
chenych's avatar
chenych committed
79
        messages: the list of messages.
chenych's avatar
chenych committed
80

chenych's avatar
chenych committed
81
    """
luopl's avatar
luopl committed
82
    tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=HF_TOKEN)
chenych's avatar
chenych committed
83
84
    content_str = tokenizer.apply_chat_template(messages, tokenize=False)
    content_ids = tokenizer.apply_chat_template(messages, tokenize=True)
luopl's avatar
luopl committed
85
    template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template=template_name))
chenych's avatar
chenych committed
86
    prompt_ids, answer_ids = template.encode_oneturn(tokenizer, messages)
luopl's avatar
luopl committed
87
88
89
    assert content_str == prompt_str + answer_str
    assert content_ids == prompt_ids + answer_ids
    _check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
chenych's avatar
chenych committed
90
91
92
93


@pytest.mark.parametrize("use_fast", [True, False])
def test_encode_oneturn(use_fast: bool):
chenych's avatar
chenych committed
94
    tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
luopl's avatar
luopl committed
95
    template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
chenych's avatar
chenych committed
96
97
98
99
100
101
102
103
104
105
106
107
108
    prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
    prompt_str = (
        "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
        "<|start_header_id|>assistant<|end_header_id|>\n\nI am fine!<|eot_id|>"
        "<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|>"
        "<|start_header_id|>assistant<|end_header_id|>\n\n"
    )
    answer_str = "很高兴认识你!<|eot_id|>"
    _check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))


@pytest.mark.parametrize("use_fast", [True, False])
def test_encode_multiturn(use_fast: bool):
chenych's avatar
chenych committed
109
    tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
luopl's avatar
luopl committed
110
    template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
chenych's avatar
chenych committed
111
112
113
114
115
116
117
    encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES)
    prompt_str_1 = (
        "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
        "<|start_header_id|>assistant<|end_header_id|>\n\n"
    )
    answer_str_1 = "I am fine!<|eot_id|>"
    prompt_str_2 = (
luopl's avatar
luopl committed
118
        "<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
chenych's avatar
chenych committed
119
120
121
122
123
124
125
126
127
    )
    answer_str_2 = "很高兴认识你!<|eot_id|>"
    _check_tokenization(
        tokenizer,
        (encoded_pairs[0][0], encoded_pairs[0][1], encoded_pairs[1][0], encoded_pairs[1][1]),
        (prompt_str_1, answer_str_1, prompt_str_2, answer_str_2),
    )


mashun1's avatar
mashun1 committed
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
@pytest.mark.parametrize("use_fast", [True, False])
@pytest.mark.parametrize("cot_messages", [True, False])
@pytest.mark.parametrize("enable_thinking", [True, False, None])
def test_reasoning_encode_oneturn(use_fast: bool, cot_messages: bool, enable_thinking: bool):
    input_messages = MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast)
    data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking)
    template = get_template_and_fix_tokenizer(tokenizer, data_args)
    prompt_ids, answer_ids = template.encode_oneturn(tokenizer, input_messages)
    output_messages = MESSAGES if enable_thinking is False else input_messages
    prompt_str = (
        f"<|im_start|>user\n{output_messages[0]['content']}<|im_end|>\n<|im_start|>assistant\n"
        f"{MESSAGES[1]['content']}<|im_end|>\n"
        f"<|im_start|>user\n{output_messages[2]['content']}<|im_end|>\n<|im_start|>assistant\n"
    )
    answer_str = f"{output_messages[3]['content']}<|im_end|>\n"
    if not cot_messages or enable_thinking is False:
        if enable_thinking:
            answer_str = "<think>\n\n</think>\n\n" + answer_str
        else:
            prompt_str = prompt_str + "<think>\n\n</think>\n\n"

    _check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))


@pytest.mark.parametrize("use_fast", [True, False])
@pytest.mark.parametrize("cot_messages", [True, False])
@pytest.mark.parametrize("enable_thinking", [True, False, None])
def test_reasoning_encode_multiturn(use_fast: bool, cot_messages: bool, enable_thinking: bool):
    input_messages = MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast)
    data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking)
    template = get_template_and_fix_tokenizer(tokenizer, data_args)
    encoded_pairs = template.encode_multiturn(tokenizer, input_messages)
    output_messages = MESSAGES if enable_thinking is False else input_messages
    prompt_str_1 = f"<|im_start|>user\n{output_messages[0]['content']}<|im_end|>\n<|im_start|>assistant\n"
    answer_str_1 = f"{output_messages[1]['content']}<|im_end|>\n"
    prompt_str_2 = f"<|im_start|>user\n{output_messages[2]['content']}<|im_end|>\n<|im_start|>assistant\n"
    answer_str_2 = f"{output_messages[3]['content']}<|im_end|>\n"
    if not cot_messages or enable_thinking is False:
        if enable_thinking:
            answer_str_1 = "<think>\n\n</think>\n\n" + answer_str_1
            answer_str_2 = "<think>\n\n</think>\n\n" + answer_str_2
        else:
            prompt_str_1 = prompt_str_1 + "<think>\n\n</think>\n\n"
            prompt_str_2 = prompt_str_2 + "<think>\n\n</think>\n\n"

    _check_tokenization(
        tokenizer,
        (encoded_pairs[0][0], encoded_pairs[0][1], encoded_pairs[1][0], encoded_pairs[1][1]),
        (prompt_str_1, answer_str_1, prompt_str_2, answer_str_2),
    )


chenych's avatar
chenych committed
182
183
@pytest.mark.parametrize("use_fast", [True, False])
def test_jinja_template(use_fast: bool):
chenych's avatar
chenych committed
184
185
    tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
    ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
luopl's avatar
luopl committed
186
    template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
chenych's avatar
chenych committed
187
    tokenizer.chat_template = template._get_jinja_template(tokenizer)  # llama3 template no replace
chenych's avatar
chenych committed
188
189
190
191
    assert tokenizer.chat_template != ref_tokenizer.chat_template
    assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES)


chenych's avatar
chenych committed
192
def test_ollama_modelfile():
chenych's avatar
chenych committed
193
    tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
chenych's avatar
chenych committed
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
    template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
    assert template.get_ollama_modelfile(tokenizer) == (
        "# ollama modelfile auto-generated by llamafactory\n\n"
        "FROM .\n\n"
        'TEMPLATE """<|begin_of_text|>'
        "{{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}"
        '{{ range .Messages }}{{ if eq .Role "user" }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Content }}'
        "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
        '{{ else if eq .Role "assistant" }}{{ .Content }}<|eot_id|>{{ end }}{{ end }}"""\n\n'
        'PARAMETER stop "<|eom_id|>"\n'
        'PARAMETER stop "<|eot_id|>"\n'
        "PARAMETER num_ctx 4096\n"
    )


luopl's avatar
luopl committed
209
def test_get_stop_token_ids():
chenych's avatar
chenych committed
210
    tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
luopl's avatar
luopl committed
211
212
213
214
    template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
    assert set(template.get_stop_token_ids(tokenizer)) == {128008, 128009}


chenych's avatar
chenych committed
215
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
luopl's avatar
luopl committed
216
217
@pytest.mark.parametrize("use_fast", [True, False])
def test_gemma_template(use_fast: bool):
chenych's avatar
chenych committed
218
    prompt_str = (
mashun1's avatar
mashun1 committed
219
220
221
        f"<bos><start_of_turn>user\n{MESSAGES[0]['content']}<end_of_turn>\n"
        f"<start_of_turn>model\n{MESSAGES[1]['content']}<end_of_turn>\n"
        f"<start_of_turn>user\n{MESSAGES[2]['content']}<end_of_turn>\n"
chenych's avatar
chenych committed
222
223
        "<start_of_turn>model\n"
    )
mashun1's avatar
mashun1 committed
224
    answer_str = f"{MESSAGES[3]['content']}<end_of_turn>\n"
chenych's avatar
chenych committed
225
    _check_template("google/gemma-3-4b-it", "gemma", prompt_str, answer_str, use_fast)
chenych's avatar
chenych committed
226
227
228


@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
luopl's avatar
luopl committed
229
230
@pytest.mark.parametrize("use_fast", [True, False])
def test_llama3_template(use_fast: bool):
chenych's avatar
chenych committed
231
    prompt_str = (
mashun1's avatar
mashun1 committed
232
233
234
        f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{MESSAGES[0]['content']}<|eot_id|>"
        f"<|start_header_id|>assistant<|end_header_id|>\n\n{MESSAGES[1]['content']}<|eot_id|>"
        f"<|start_header_id|>user<|end_header_id|>\n\n{MESSAGES[2]['content']}<|eot_id|>"
chenych's avatar
chenych committed
235
236
        "<|start_header_id|>assistant<|end_header_id|>\n\n"
    )
mashun1's avatar
mashun1 committed
237
    answer_str = f"{MESSAGES[3]['content']}<|eot_id|>"
luopl's avatar
luopl committed
238
239
240
    _check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str, use_fast)


chenych's avatar
chenych committed
241
242
243
244
245
@pytest.mark.parametrize(
    "use_fast", [True, pytest.param(False, marks=pytest.mark.xfail(reason="Llama 4 has no slow tokenizer."))]
)
def test_llama4_template(use_fast: bool):
    prompt_str = (
mashun1's avatar
mashun1 committed
246
247
248
        f"<|begin_of_text|><|header_start|>user<|header_end|>\n\n{MESSAGES[0]['content']}<|eot|>"
        f"<|header_start|>assistant<|header_end|>\n\n{MESSAGES[1]['content']}<|eot|>"
        f"<|header_start|>user<|header_end|>\n\n{MESSAGES[2]['content']}<|eot|>"
chenych's avatar
chenych committed
249
250
        "<|header_start|>assistant<|header_end|>\n\n"
    )
mashun1's avatar
mashun1 committed
251
    answer_str = f"{MESSAGES[3]['content']}<|eot|>"
chenych's avatar
chenych committed
252
253
254
    _check_template(TINY_LLAMA4, "llama4", prompt_str, answer_str, use_fast)


luopl's avatar
luopl committed
255
256
257
258
259
@pytest.mark.parametrize(
    "use_fast", [True, pytest.param(False, marks=pytest.mark.xfail(reason="Phi-4 slow tokenizer is broken."))]
)
def test_phi4_template(use_fast: bool):
    prompt_str = (
mashun1's avatar
mashun1 committed
260
261
262
        f"<|im_start|>user<|im_sep|>{MESSAGES[0]['content']}<|im_end|>"
        f"<|im_start|>assistant<|im_sep|>{MESSAGES[1]['content']}<|im_end|>"
        f"<|im_start|>user<|im_sep|>{MESSAGES[2]['content']}<|im_end|>"
luopl's avatar
luopl committed
263
264
        "<|im_start|>assistant<|im_sep|>"
    )
mashun1's avatar
mashun1 committed
265
    answer_str = f"{MESSAGES[3]['content']}<|im_end|>"
luopl's avatar
luopl committed
266
    _check_template("microsoft/phi-4", "phi4", prompt_str, answer_str, use_fast)
chenych's avatar
chenych committed
267
268


luopl's avatar
luopl committed
269
@pytest.mark.parametrize("use_fast", [True, False])
chenych's avatar
chenych committed
270
def test_qwen2_5_template(use_fast: bool):
chenych's avatar
chenych committed
271
    prompt_str = (
chenych's avatar
chenych committed
272
        "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n"
mashun1's avatar
mashun1 committed
273
274
275
        f"<|im_start|>user\n{MESSAGES[0]['content']}<|im_end|>\n"
        f"<|im_start|>assistant\n{MESSAGES[1]['content']}<|im_end|>\n"
        f"<|im_start|>user\n{MESSAGES[2]['content']}<|im_end|>\n"
chenych's avatar
chenych committed
276
277
        "<|im_start|>assistant\n"
    )
mashun1's avatar
mashun1 committed
278
    answer_str = f"{MESSAGES[3]['content']}<|im_end|>\n"
chenych's avatar
chenych committed
279
    _check_template("Qwen/Qwen2.5-7B-Instruct", "qwen", prompt_str, answer_str, use_fast)
chenych's avatar
chenych committed
280
281


chenych's avatar
chenych committed
282
@pytest.mark.parametrize("use_fast", [True, False])
mashun1's avatar
mashun1 committed
283
284
285
@pytest.mark.parametrize("cot_messages", [True, False])
def test_qwen3_template(use_fast: bool, cot_messages: bool):
    messages = MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES
chenych's avatar
chenych committed
286
    prompt_str = (
mashun1's avatar
mashun1 committed
287
288
289
        f"<|im_start|>user\n{messages[0]['content']}<|im_end|>\n"
        f"<|im_start|>assistant\n{MESSAGES[1]['content']}<|im_end|>\n"
        f"<|im_start|>user\n{messages[2]['content']}<|im_end|>\n"
chenych's avatar
chenych committed
290
291
        "<|im_start|>assistant\n"
    )
mashun1's avatar
mashun1 committed
292
293
294
295
296
    answer_str = f"{messages[3]['content']}<|im_end|>\n"
    if not cot_messages:
        answer_str = "<think>\n\n</think>\n\n" + answer_str

    _check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, use_fast, messages=messages)
chenych's avatar
chenych committed
297
298


chenych's avatar
chenych committed
299
300
def test_parse_llama3_template():
    tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, token=HF_TOKEN)
chenych's avatar
chenych committed
301
302
303
304
305
306
307
308
309
310
311
312
    template = parse_template(tokenizer)
    assert template.format_user.slots == [
        "<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
        "<|start_header_id|>assistant<|end_header_id|>\n\n"
    ]
    assert template.format_assistant.slots == ["{{content}}<|eot_id|>"]
    assert template.format_system.slots == ["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]
    assert template.format_prefix.slots == ["<|begin_of_text|>"]
    assert template.default_system == ""


def test_parse_qwen_template():
chenych's avatar
chenych committed
313
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", token=HF_TOKEN)
chenych's avatar
chenych committed
314
    template = parse_template(tokenizer)
mashun1's avatar
mashun1 committed
315
    assert template.__class__.__name__ == "Template"
chenych's avatar
chenych committed
316
317
318
319
    assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
    assert template.format_assistant.slots == ["{{content}}<|im_end|>\n"]
    assert template.format_system.slots == ["<|im_start|>system\n{{content}}<|im_end|>\n"]
    assert template.format_prefix.slots == []
chenych's avatar
chenych committed
320
    assert template.default_system == "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
chenych's avatar
chenych committed
321
322
323
324
325


def test_parse_qwen3_template():
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", token=HF_TOKEN)
    template = parse_template(tokenizer)
mashun1's avatar
mashun1 committed
326
    assert template.__class__.__name__ == "ReasoningTemplate"
chenych's avatar
chenych committed
327
328
329
330
331
    assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
    assert template.format_assistant.slots == ["{{content}}<|im_end|>\n"]
    assert template.format_system.slots == ["<|im_start|>system\n{{content}}<|im_end|>\n"]
    assert template.format_prefix.slots == []
    assert template.default_system == ""