test_template.py 15.9 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
chenych's avatar
chenych committed
16
from typing import TYPE_CHECKING
chenych's avatar
chenych committed
17
18
19
20
21

import pytest
from transformers import AutoTokenizer

from llamafactory.data import get_template_and_fix_tokenizer
chenych's avatar
chenych committed
22
from llamafactory.data.template import parse_template
luopl's avatar
luopl committed
23
from llamafactory.hparams import DataArguments
chenych's avatar
chenych committed
24
25
26
27
28
29


if TYPE_CHECKING:
    from transformers import PreTrainedTokenizer


luopl's avatar
luopl committed
30
HF_TOKEN = os.getenv("HF_TOKEN")
chenych's avatar
chenych committed
31

chenych's avatar
chenych committed
32
33
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
TINY_LLAMA4 = os.getenv("TINY_LLAMA4", "llamafactory/tiny-random-Llama-4")
chenych's avatar
chenych committed
34
35
36
37
38
39
40
41

MESSAGES = [
    {"role": "user", "content": "How are you"},
    {"role": "assistant", "content": "I am fine!"},
    {"role": "user", "content": "你好"},
    {"role": "assistant", "content": "很高兴认识你!"},
]

chenych's avatar
chenych committed
42
43
44
45
46
47
48
MESSAGES_WITH_THOUGHT = [
    {"role": "user", "content": "How are you"},
    {"role": "assistant", "content": "<think>\nModel thought here\n</think>\n\nI am fine!"},
    {"role": "user", "content": "你好"},
    {"role": "assistant", "content": "<think>\n模型思考内容\n</think>\n\n很高兴认识你!"},
]

chenych's avatar
chenych committed
49
50

def _check_tokenization(
chenych's avatar
chenych committed
51
    tokenizer: "PreTrainedTokenizer", batch_input_ids: list[list[int]], batch_text: list[str]
chenych's avatar
chenych committed
52
) -> None:
chenych's avatar
chenych committed
53
    r"""Check token ids and texts.
luopl's avatar
luopl committed
54
55
56
57

    encode(text) == token_ids
    decode(token_ids) == text
    """
chenych's avatar
chenych committed
58
    for input_ids, text in zip(batch_input_ids, batch_text):
luopl's avatar
luopl committed
59
        assert tokenizer.encode(text, add_special_tokens=False) == input_ids
chenych's avatar
chenych committed
60
61
62
        assert tokenizer.decode(input_ids) == text


chenych's avatar
chenych committed
63
64
65
66
67
68
69
70
def _check_template(
    model_id: str,
    template_name: str,
    prompt_str: str,
    answer_str: str,
    use_fast: bool,
    messages: list[dict[str, str]] = MESSAGES,
) -> None:
chenych's avatar
chenych committed
71
    r"""Check template.
chenych's avatar
chenych committed
72
73
74
75
76
77

    Args:
        model_id: the model id on hugging face hub.
        template_name: the template name.
        prompt_str: the string corresponding to the prompt part.
        answer_str: the string corresponding to the answer part.
luopl's avatar
luopl committed
78
        use_fast: whether to use fast tokenizer.
chenych's avatar
chenych committed
79
        messages: the list of messages.
chenych's avatar
chenych committed
80

chenych's avatar
chenych committed
81
    """
luopl's avatar
luopl committed
82
    tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=HF_TOKEN)
chenych's avatar
chenych committed
83
84
    content_str = tokenizer.apply_chat_template(messages, tokenize=False)
    content_ids = tokenizer.apply_chat_template(messages, tokenize=True)
luopl's avatar
luopl committed
85
    template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template=template_name))
chenych's avatar
chenych committed
86
    prompt_ids, answer_ids = template.encode_oneturn(tokenizer, messages)
luopl's avatar
luopl committed
87
88
89
    assert content_str == prompt_str + answer_str
    assert content_ids == prompt_ids + answer_ids
    _check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
chenych's avatar
chenych committed
90
91
92
93


@pytest.mark.parametrize("use_fast", [True, False])
def test_encode_oneturn(use_fast: bool):
chenych's avatar
chenych committed
94
    tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
luopl's avatar
luopl committed
95
    template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
chenych's avatar
chenych committed
96
97
98
99
100
101
102
103
104
105
106
107
108
    prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
    prompt_str = (
        "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
        "<|start_header_id|>assistant<|end_header_id|>\n\nI am fine!<|eot_id|>"
        "<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|>"
        "<|start_header_id|>assistant<|end_header_id|>\n\n"
    )
    answer_str = "很高兴认识你!<|eot_id|>"
    _check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))


@pytest.mark.parametrize("use_fast", [True, False])
def test_encode_multiturn(use_fast: bool):
chenych's avatar
chenych committed
109
    tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
luopl's avatar
luopl committed
110
    template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
chenych's avatar
chenych committed
111
112
113
114
115
116
117
    encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES)
    prompt_str_1 = (
        "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
        "<|start_header_id|>assistant<|end_header_id|>\n\n"
    )
    answer_str_1 = "I am fine!<|eot_id|>"
    prompt_str_2 = (
luopl's avatar
luopl committed
118
        "<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
chenych's avatar
chenych committed
119
120
121
122
123
124
125
126
127
    )
    answer_str_2 = "很高兴认识你!<|eot_id|>"
    _check_tokenization(
        tokenizer,
        (encoded_pairs[0][0], encoded_pairs[0][1], encoded_pairs[1][0], encoded_pairs[1][1]),
        (prompt_str_1, answer_str_1, prompt_str_2, answer_str_2),
    )


chenych's avatar
chenych committed
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
@pytest.mark.parametrize("use_fast", [True, False])
@pytest.mark.parametrize("cot_messages", [True, False])
@pytest.mark.parametrize("enable_thinking", [True, False, None])
def test_reasoning_encode_oneturn(use_fast: bool, cot_messages: bool, enable_thinking: bool):
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast)
    data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking)
    template = get_template_and_fix_tokenizer(tokenizer, data_args)
    prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES)

    prompt_str = (
        f"<|im_start|>user\n{MESSAGES[0]['content']}<|im_end|>\n<|im_start|>assistant\n"
        f"{MESSAGES[1]['content']}<|im_end|>\n"
        f"<|im_start|>user\n{MESSAGES[2]['content']}<|im_end|>\n<|im_start|>assistant\n"
    )
    if not cot_messages or enable_thinking is False:
        answer_str = f"{MESSAGES[3]['content']}<|im_end|>\n"
        if enable_thinking:
            answer_str = "<think>\n\n</think>\n\n" + answer_str
        else:
            prompt_str = prompt_str + "<think>\n\n</think>\n\n"
    else:
        answer_str = f"{MESSAGES_WITH_THOUGHT[3]['content']}<|im_end|>\n"

    _check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))


@pytest.mark.parametrize("use_fast", [True, False])
@pytest.mark.parametrize("cot_messages", [True, False])
@pytest.mark.parametrize("enable_thinking", [True, False, None])
def test_reasoning_encode_multiturn(use_fast: bool, cot_messages: bool, enable_thinking: bool):
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast)
    data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking)
    template = get_template_and_fix_tokenizer(tokenizer, data_args)
    encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES)

    messages = MESSAGES if not cot_messages or enable_thinking is False else MESSAGES_WITH_THOUGHT
    prompt_str_1 = f"<|im_start|>user\n{MESSAGES[0]['content']}<|im_end|>\n<|im_start|>assistant\n"
    answer_str_1 = f"{messages[1]['content']}<|im_end|>\n"
    prompt_str_2 = f"<|im_start|>user\n{MESSAGES[2]['content']}<|im_end|>\n<|im_start|>assistant\n"
    answer_str_2 = f"{messages[3]['content']}<|im_end|>\n"
    if not cot_messages or enable_thinking is False:
        if enable_thinking:
            answer_str_1 = "<think>\n\n</think>\n\n" + answer_str_1
            answer_str_2 = "<think>\n\n</think>\n\n" + answer_str_2
        else:
            prompt_str_1 = prompt_str_1 + "<think>\n\n</think>\n\n"
            prompt_str_2 = prompt_str_2 + "<think>\n\n</think>\n\n"

    _check_tokenization(
        tokenizer,
        (encoded_pairs[0][0], encoded_pairs[0][1], encoded_pairs[1][0], encoded_pairs[1][1]),
        (prompt_str_1, answer_str_1, prompt_str_2, answer_str_2),
    )


chenych's avatar
chenych committed
183
184
@pytest.mark.parametrize("use_fast", [True, False])
def test_jinja_template(use_fast: bool):
chenych's avatar
chenych committed
185
186
    tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
    ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
luopl's avatar
luopl committed
187
    template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
chenych's avatar
chenych committed
188
    tokenizer.chat_template = template._get_jinja_template(tokenizer)  # llama3 template no replace
chenych's avatar
chenych committed
189
190
191
192
    assert tokenizer.chat_template != ref_tokenizer.chat_template
    assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES)


chenych's avatar
chenych committed
193
def test_ollama_modelfile():
chenych's avatar
chenych committed
194
    tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
chenych's avatar
chenych committed
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
    template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
    assert template.get_ollama_modelfile(tokenizer) == (
        "# ollama modelfile auto-generated by llamafactory\n\n"
        "FROM .\n\n"
        'TEMPLATE """<|begin_of_text|>'
        "{{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}"
        '{{ range .Messages }}{{ if eq .Role "user" }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Content }}'
        "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
        '{{ else if eq .Role "assistant" }}{{ .Content }}<|eot_id|>{{ end }}{{ end }}"""\n\n'
        'PARAMETER stop "<|eom_id|>"\n'
        'PARAMETER stop "<|eot_id|>"\n'
        "PARAMETER num_ctx 4096\n"
    )


luopl's avatar
luopl committed
210
def test_get_stop_token_ids():
chenych's avatar
chenych committed
211
    tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
luopl's avatar
luopl committed
212
213
214
215
    template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
    assert set(template.get_stop_token_ids(tokenizer)) == {128008, 128009}


chenych's avatar
chenych committed
216
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
luopl's avatar
luopl committed
217
218
@pytest.mark.parametrize("use_fast", [True, False])
def test_gemma_template(use_fast: bool):
chenych's avatar
chenych committed
219
    prompt_str = (
chenych's avatar
chenych committed
220
221
222
        f"<bos><start_of_turn>user\n{MESSAGES[0]['content']}<end_of_turn>\n"
        f"<start_of_turn>model\n{MESSAGES[1]['content']}<end_of_turn>\n"
        f"<start_of_turn>user\n{MESSAGES[2]['content']}<end_of_turn>\n"
chenych's avatar
chenych committed
223
224
        "<start_of_turn>model\n"
    )
chenych's avatar
chenych committed
225
    answer_str = f"{MESSAGES[3]['content']}<end_of_turn>\n"
chenych's avatar
chenych committed
226
    _check_template("google/gemma-3-4b-it", "gemma", prompt_str, answer_str, use_fast)
chenych's avatar
chenych committed
227
228


chenych's avatar
chenych committed
229
230
231
232
233
234
235
236
237
238
239
240
241
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
@pytest.mark.parametrize("use_fast", [True, False])
def test_gemma2_template(use_fast: bool):
    prompt_str = (
        f"<bos><start_of_turn>user\n{MESSAGES[0]['content']}<end_of_turn>\n"
        f"<start_of_turn>model\n{MESSAGES[1]['content']}<end_of_turn>\n"
        f"<start_of_turn>user\n{MESSAGES[2]['content']}<end_of_turn>\n"
        "<start_of_turn>model\n"
    )
    answer_str = f"{MESSAGES[3]['content']}<end_of_turn>\n"
    _check_template("google/gemma-2-2b-it", "gemma2", prompt_str, answer_str, use_fast)


chenych's avatar
chenych committed
242
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
luopl's avatar
luopl committed
243
244
@pytest.mark.parametrize("use_fast", [True, False])
def test_llama3_template(use_fast: bool):
chenych's avatar
chenych committed
245
    prompt_str = (
chenych's avatar
chenych committed
246
247
248
        f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{MESSAGES[0]['content']}<|eot_id|>"
        f"<|start_header_id|>assistant<|end_header_id|>\n\n{MESSAGES[1]['content']}<|eot_id|>"
        f"<|start_header_id|>user<|end_header_id|>\n\n{MESSAGES[2]['content']}<|eot_id|>"
chenych's avatar
chenych committed
249
250
        "<|start_header_id|>assistant<|end_header_id|>\n\n"
    )
chenych's avatar
chenych committed
251
    answer_str = f"{MESSAGES[3]['content']}<|eot_id|>"
luopl's avatar
luopl committed
252
253
254
    _check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str, use_fast)


chenych's avatar
chenych committed
255
256
257
258
259
@pytest.mark.parametrize(
    "use_fast", [True, pytest.param(False, marks=pytest.mark.xfail(reason="Llama 4 has no slow tokenizer."))]
)
def test_llama4_template(use_fast: bool):
    prompt_str = (
chenych's avatar
chenych committed
260
261
262
        f"<|begin_of_text|><|header_start|>user<|header_end|>\n\n{MESSAGES[0]['content']}<|eot|>"
        f"<|header_start|>assistant<|header_end|>\n\n{MESSAGES[1]['content']}<|eot|>"
        f"<|header_start|>user<|header_end|>\n\n{MESSAGES[2]['content']}<|eot|>"
chenych's avatar
chenych committed
263
264
        "<|header_start|>assistant<|header_end|>\n\n"
    )
chenych's avatar
chenych committed
265
    answer_str = f"{MESSAGES[3]['content']}<|eot|>"
chenych's avatar
chenych committed
266
267
268
    _check_template(TINY_LLAMA4, "llama4", prompt_str, answer_str, use_fast)


luopl's avatar
luopl committed
269
@pytest.mark.parametrize(
chenych's avatar
chenych committed
270
271
272
273
274
    "use_fast",
    [
        pytest.param(True, marks=pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")),
        pytest.param(False, marks=pytest.mark.xfail(reason="Phi-4 slow tokenizer is broken.")),
    ],
luopl's avatar
luopl committed
275
276
277
)
def test_phi4_template(use_fast: bool):
    prompt_str = (
chenych's avatar
chenych committed
278
279
280
        f"<|im_start|>user<|im_sep|>{MESSAGES[0]['content']}<|im_end|>"
        f"<|im_start|>assistant<|im_sep|>{MESSAGES[1]['content']}<|im_end|>"
        f"<|im_start|>user<|im_sep|>{MESSAGES[2]['content']}<|im_end|>"
luopl's avatar
luopl committed
281
282
        "<|im_start|>assistant<|im_sep|>"
    )
chenych's avatar
chenych committed
283
    answer_str = f"{MESSAGES[3]['content']}<|im_end|>"
luopl's avatar
luopl committed
284
    _check_template("microsoft/phi-4", "phi4", prompt_str, answer_str, use_fast)
chenych's avatar
chenych committed
285
286


chenych's avatar
chenych committed
287
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
luopl's avatar
luopl committed
288
@pytest.mark.parametrize("use_fast", [True, False])
chenych's avatar
chenych committed
289
def test_qwen2_5_template(use_fast: bool):
chenych's avatar
chenych committed
290
    prompt_str = (
chenych's avatar
chenych committed
291
        "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n"
chenych's avatar
chenych committed
292
293
294
        f"<|im_start|>user\n{MESSAGES[0]['content']}<|im_end|>\n"
        f"<|im_start|>assistant\n{MESSAGES[1]['content']}<|im_end|>\n"
        f"<|im_start|>user\n{MESSAGES[2]['content']}<|im_end|>\n"
chenych's avatar
chenych committed
295
296
        "<|im_start|>assistant\n"
    )
chenych's avatar
chenych committed
297
    answer_str = f"{MESSAGES[3]['content']}<|im_end|>\n"
chenych's avatar
chenych committed
298
    _check_template("Qwen/Qwen2.5-7B-Instruct", "qwen", prompt_str, answer_str, use_fast)
chenych's avatar
chenych committed
299
300


chenych's avatar
chenych committed
301
@pytest.mark.parametrize("use_fast", [True, False])
chenych's avatar
chenych committed
302
303
@pytest.mark.parametrize("cot_messages", [True, False])
def test_qwen3_template(use_fast: bool, cot_messages: bool):
chenych's avatar
chenych committed
304
    prompt_str = (
chenych's avatar
chenych committed
305
306
307
        f"<|im_start|>user\n{MESSAGES[0]['content']}<|im_end|>\n"
        f"<|im_start|>assistant\n{MESSAGES[1]['content']}<|im_end|>\n"
        f"<|im_start|>user\n{MESSAGES[2]['content']}<|im_end|>\n"
chenych's avatar
chenych committed
308
309
        "<|im_start|>assistant\n"
    )
chenych's avatar
chenych committed
310
311
312
313
314
315
316
317
    if not cot_messages:
        answer_str = f"<think>\n\n</think>\n\n{MESSAGES[3]['content']}<|im_end|>\n"
        messages = MESSAGES
    else:
        answer_str = f"{MESSAGES_WITH_THOUGHT[3]['content']}<|im_end|>\n"
        messages = MESSAGES_WITH_THOUGHT

    _check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, use_fast, messages=messages)
chenych's avatar
chenych committed
318
319


chenych's avatar
chenych committed
320
321
def test_parse_llama3_template():
    tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, token=HF_TOKEN)
chenych's avatar
chenych committed
322
323
324
325
326
327
328
329
330
331
332
    template = parse_template(tokenizer)
    assert template.format_user.slots == [
        "<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
        "<|start_header_id|>assistant<|end_header_id|>\n\n"
    ]
    assert template.format_assistant.slots == ["{{content}}<|eot_id|>"]
    assert template.format_system.slots == ["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]
    assert template.format_prefix.slots == ["<|begin_of_text|>"]
    assert template.default_system == ""


chenych's avatar
chenych committed
333
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
chenych's avatar
chenych committed
334
def test_parse_qwen_template():
chenych's avatar
chenych committed
335
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", token=HF_TOKEN)
chenych's avatar
chenych committed
336
    template = parse_template(tokenizer)
chenych's avatar
chenych committed
337
    assert template.__class__.__name__ == "Template"
chenych's avatar
chenych committed
338
339
340
341
    assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
    assert template.format_assistant.slots == ["{{content}}<|im_end|>\n"]
    assert template.format_system.slots == ["<|im_start|>system\n{{content}}<|im_end|>\n"]
    assert template.format_prefix.slots == []
chenych's avatar
chenych committed
342
    assert template.default_system == "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
chenych's avatar
chenych committed
343
344


chenych's avatar
chenych committed
345
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
chenych's avatar
chenych committed
346
347
348
def test_parse_qwen3_template():
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", token=HF_TOKEN)
    template = parse_template(tokenizer)
chenych's avatar
chenych committed
349
    assert template.__class__.__name__ == "ReasoningTemplate"
chenych's avatar
chenych committed
350
351
352
353
354
    assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
    assert template.format_assistant.slots == ["{{content}}<|im_end|>\n"]
    assert template.format_system.slots == ["<|im_start|>system\n{{content}}<|im_end|>\n"]
    assert template.format_prefix.slots == []
    assert template.default_system == ""