test_template.py 15.3 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
chenych's avatar
chenych committed
16
from typing import TYPE_CHECKING
chenych's avatar
chenych committed
17
18
19
20
21

import pytest
from transformers import AutoTokenizer

from llamafactory.data import get_template_and_fix_tokenizer
chenych's avatar
chenych committed
22
from llamafactory.data.template import parse_template
luopl's avatar
luopl committed
23
from llamafactory.hparams import DataArguments
chenych's avatar
chenych committed
24
25
26
27
28
29


if TYPE_CHECKING:
    from transformers import PreTrainedTokenizer


luopl's avatar
luopl committed
30
HF_TOKEN = os.getenv("HF_TOKEN")
chenych's avatar
chenych committed
31

chenych's avatar
chenych committed
32
33
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
TINY_LLAMA4 = os.getenv("TINY_LLAMA4", "llamafactory/tiny-random-Llama-4")
chenych's avatar
chenych committed
34
35
36
37
38
39
40
41

MESSAGES = [
    {"role": "user", "content": "How are you"},
    {"role": "assistant", "content": "I am fine!"},
    {"role": "user", "content": "你好"},
    {"role": "assistant", "content": "很高兴认识你!"},
]

chenych's avatar
chenych committed
42
43
44
45
46
47
48
MESSAGES_WITH_THOUGHT = [
    {"role": "user", "content": "How are you"},
    {"role": "assistant", "content": "<think>\nModel thought here\n</think>\n\nI am fine!"},
    {"role": "user", "content": "你好"},
    {"role": "assistant", "content": "<think>\n模型思考内容\n</think>\n\n很高兴认识你!"},
]

chenych's avatar
chenych committed
49
50

def _check_tokenization(
chenych's avatar
chenych committed
51
    tokenizer: "PreTrainedTokenizer", batch_input_ids: list[list[int]], batch_text: list[str]
chenych's avatar
chenych committed
52
) -> None:
chenych's avatar
chenych committed
53
    r"""Check token ids and texts.
luopl's avatar
luopl committed
54
55
56
57

    encode(text) == token_ids
    decode(token_ids) == text
    """
chenych's avatar
chenych committed
58
    for input_ids, text in zip(batch_input_ids, batch_text):
luopl's avatar
luopl committed
59
        assert tokenizer.encode(text, add_special_tokens=False) == input_ids
chenych's avatar
chenych committed
60
61
62
        assert tokenizer.decode(input_ids) == text


chenych's avatar
chenych committed
63
64
65
66
67
68
69
70
def _check_template(
    model_id: str,
    template_name: str,
    prompt_str: str,
    answer_str: str,
    use_fast: bool,
    messages: list[dict[str, str]] = MESSAGES,
) -> None:
chenych's avatar
chenych committed
71
    r"""Check template.
chenych's avatar
chenych committed
72
73
74
75
76
77

    Args:
        model_id: the model id on hugging face hub.
        template_name: the template name.
        prompt_str: the string corresponding to the prompt part.
        answer_str: the string corresponding to the answer part.
luopl's avatar
luopl committed
78
        use_fast: whether to use fast tokenizer.
chenych's avatar
chenych committed
79
        messages: the list of messages.
chenych's avatar
chenych committed
80

chenych's avatar
chenych committed
81
    """
luopl's avatar
luopl committed
82
    tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=HF_TOKEN)
chenych's avatar
chenych committed
83
84
    content_str = tokenizer.apply_chat_template(messages, tokenize=False)
    content_ids = tokenizer.apply_chat_template(messages, tokenize=True)
luopl's avatar
luopl committed
85
    template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template=template_name))
chenych's avatar
chenych committed
86
    prompt_ids, answer_ids = template.encode_oneturn(tokenizer, messages)
luopl's avatar
luopl committed
87
88
89
    assert content_str == prompt_str + answer_str
    assert content_ids == prompt_ids + answer_ids
    _check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
chenych's avatar
chenych committed
90
91
92
93


@pytest.mark.parametrize("use_fast", [True, False])
def test_encode_oneturn(use_fast: bool):
chenych's avatar
chenych committed
94
    tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
luopl's avatar
luopl committed
95
    template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
chenych's avatar
chenych committed
96
97
98
99
100
101
102
103
104
105
106
107
108
    prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
    prompt_str = (
        "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
        "<|start_header_id|>assistant<|end_header_id|>\n\nI am fine!<|eot_id|>"
        "<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|>"
        "<|start_header_id|>assistant<|end_header_id|>\n\n"
    )
    answer_str = "很高兴认识你!<|eot_id|>"
    _check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))


@pytest.mark.parametrize("use_fast", [True, False])
def test_encode_multiturn(use_fast: bool):
chenych's avatar
chenych committed
109
    tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
luopl's avatar
luopl committed
110
    template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
chenych's avatar
chenych committed
111
112
113
114
115
116
117
    encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES)
    prompt_str_1 = (
        "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
        "<|start_header_id|>assistant<|end_header_id|>\n\n"
    )
    answer_str_1 = "I am fine!<|eot_id|>"
    prompt_str_2 = (
luopl's avatar
luopl committed
118
        "<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
chenych's avatar
chenych committed
119
120
121
122
123
124
125
126
127
    )
    answer_str_2 = "很高兴认识你!<|eot_id|>"
    _check_tokenization(
        tokenizer,
        (encoded_pairs[0][0], encoded_pairs[0][1], encoded_pairs[1][0], encoded_pairs[1][1]),
        (prompt_str_1, answer_str_1, prompt_str_2, answer_str_2),
    )


chenych's avatar
chenych committed
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
@pytest.mark.parametrize("use_fast", [True, False])
@pytest.mark.parametrize("cot_messages", [True, False])
@pytest.mark.parametrize("enable_thinking", [True, False, None])
def test_reasoning_encode_oneturn(use_fast: bool, cot_messages: bool, enable_thinking: bool):
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast)
    data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking)
    template = get_template_and_fix_tokenizer(tokenizer, data_args)
    prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES)

    prompt_str = (
        f"<|im_start|>user\n{MESSAGES[0]['content']}<|im_end|>\n<|im_start|>assistant\n"
        f"{MESSAGES[1]['content']}<|im_end|>\n"
        f"<|im_start|>user\n{MESSAGES[2]['content']}<|im_end|>\n<|im_start|>assistant\n"
    )
    if not cot_messages or enable_thinking is False:
        answer_str = f"{MESSAGES[3]['content']}<|im_end|>\n"
        if enable_thinking:
            answer_str = "<think>\n\n</think>\n\n" + answer_str
        else:
            prompt_str = prompt_str + "<think>\n\n</think>\n\n"
    else:
        answer_str = f"{MESSAGES_WITH_THOUGHT[3]['content']}<|im_end|>\n"

    _check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))


@pytest.mark.parametrize("use_fast", [True, False])
@pytest.mark.parametrize("cot_messages", [True, False])
@pytest.mark.parametrize("enable_thinking", [True, False, None])
def test_reasoning_encode_multiturn(use_fast: bool, cot_messages: bool, enable_thinking: bool):
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast)
    data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking)
    template = get_template_and_fix_tokenizer(tokenizer, data_args)
    encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES)

    messages = MESSAGES if not cot_messages or enable_thinking is False else MESSAGES_WITH_THOUGHT
    prompt_str_1 = f"<|im_start|>user\n{MESSAGES[0]['content']}<|im_end|>\n<|im_start|>assistant\n"
    answer_str_1 = f"{messages[1]['content']}<|im_end|>\n"
    prompt_str_2 = f"<|im_start|>user\n{MESSAGES[2]['content']}<|im_end|>\n<|im_start|>assistant\n"
    answer_str_2 = f"{messages[3]['content']}<|im_end|>\n"
    if not cot_messages or enable_thinking is False:
        if enable_thinking:
            answer_str_1 = "<think>\n\n</think>\n\n" + answer_str_1
            answer_str_2 = "<think>\n\n</think>\n\n" + answer_str_2
        else:
            prompt_str_1 = prompt_str_1 + "<think>\n\n</think>\n\n"
            prompt_str_2 = prompt_str_2 + "<think>\n\n</think>\n\n"

    _check_tokenization(
        tokenizer,
        (encoded_pairs[0][0], encoded_pairs[0][1], encoded_pairs[1][0], encoded_pairs[1][1]),
        (prompt_str_1, answer_str_1, prompt_str_2, answer_str_2),
    )


chenych's avatar
chenych committed
183
184
@pytest.mark.parametrize("use_fast", [True, False])
def test_jinja_template(use_fast: bool):
chenych's avatar
chenych committed
185
186
    tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
    ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
luopl's avatar
luopl committed
187
    template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
chenych's avatar
chenych committed
188
    tokenizer.chat_template = template._get_jinja_template(tokenizer)  # llama3 template no replace
chenych's avatar
chenych committed
189
190
191
192
    assert tokenizer.chat_template != ref_tokenizer.chat_template
    assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES)


chenych's avatar
chenych committed
193
def test_ollama_modelfile():
chenych's avatar
chenych committed
194
    tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
chenych's avatar
chenych committed
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
    template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
    assert template.get_ollama_modelfile(tokenizer) == (
        "# ollama modelfile auto-generated by llamafactory\n\n"
        "FROM .\n\n"
        'TEMPLATE """<|begin_of_text|>'
        "{{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}"
        '{{ range .Messages }}{{ if eq .Role "user" }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Content }}'
        "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
        '{{ else if eq .Role "assistant" }}{{ .Content }}<|eot_id|>{{ end }}{{ end }}"""\n\n'
        'PARAMETER stop "<|eom_id|>"\n'
        'PARAMETER stop "<|eot_id|>"\n'
        "PARAMETER num_ctx 4096\n"
    )


luopl's avatar
luopl committed
210
def test_get_stop_token_ids():
chenych's avatar
chenych committed
211
    tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
luopl's avatar
luopl committed
212
213
214
215
    template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
    assert set(template.get_stop_token_ids(tokenizer)) == {128008, 128009}


chenych's avatar
chenych committed
216
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
luopl's avatar
luopl committed
217
218
@pytest.mark.parametrize("use_fast", [True, False])
def test_gemma_template(use_fast: bool):
chenych's avatar
chenych committed
219
    prompt_str = (
chenych's avatar
chenych committed
220
221
222
        f"<bos><start_of_turn>user\n{MESSAGES[0]['content']}<end_of_turn>\n"
        f"<start_of_turn>model\n{MESSAGES[1]['content']}<end_of_turn>\n"
        f"<start_of_turn>user\n{MESSAGES[2]['content']}<end_of_turn>\n"
chenych's avatar
chenych committed
223
224
        "<start_of_turn>model\n"
    )
chenych's avatar
chenych committed
225
    answer_str = f"{MESSAGES[3]['content']}<end_of_turn>\n"
chenych's avatar
chenych committed
226
    _check_template("google/gemma-3-4b-it", "gemma", prompt_str, answer_str, use_fast)
chenych's avatar
chenych committed
227
228
229


@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
luopl's avatar
luopl committed
230
231
@pytest.mark.parametrize("use_fast", [True, False])
def test_llama3_template(use_fast: bool):
chenych's avatar
chenych committed
232
    prompt_str = (
chenych's avatar
chenych committed
233
234
235
        f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{MESSAGES[0]['content']}<|eot_id|>"
        f"<|start_header_id|>assistant<|end_header_id|>\n\n{MESSAGES[1]['content']}<|eot_id|>"
        f"<|start_header_id|>user<|end_header_id|>\n\n{MESSAGES[2]['content']}<|eot_id|>"
chenych's avatar
chenych committed
236
237
        "<|start_header_id|>assistant<|end_header_id|>\n\n"
    )
chenych's avatar
chenych committed
238
    answer_str = f"{MESSAGES[3]['content']}<|eot_id|>"
luopl's avatar
luopl committed
239
240
241
    _check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str, use_fast)


chenych's avatar
chenych committed
242
243
244
245
246
@pytest.mark.parametrize(
    "use_fast", [True, pytest.param(False, marks=pytest.mark.xfail(reason="Llama 4 has no slow tokenizer."))]
)
def test_llama4_template(use_fast: bool):
    prompt_str = (
chenych's avatar
chenych committed
247
248
249
        f"<|begin_of_text|><|header_start|>user<|header_end|>\n\n{MESSAGES[0]['content']}<|eot|>"
        f"<|header_start|>assistant<|header_end|>\n\n{MESSAGES[1]['content']}<|eot|>"
        f"<|header_start|>user<|header_end|>\n\n{MESSAGES[2]['content']}<|eot|>"
chenych's avatar
chenych committed
250
251
        "<|header_start|>assistant<|header_end|>\n\n"
    )
chenych's avatar
chenych committed
252
    answer_str = f"{MESSAGES[3]['content']}<|eot|>"
chenych's avatar
chenych committed
253
254
255
    _check_template(TINY_LLAMA4, "llama4", prompt_str, answer_str, use_fast)


luopl's avatar
luopl committed
256
@pytest.mark.parametrize(
chenych's avatar
chenych committed
257
258
259
260
261
    "use_fast",
    [
        pytest.param(True, marks=pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")),
        pytest.param(False, marks=pytest.mark.xfail(reason="Phi-4 slow tokenizer is broken.")),
    ],
luopl's avatar
luopl committed
262
263
264
)
def test_phi4_template(use_fast: bool):
    prompt_str = (
chenych's avatar
chenych committed
265
266
267
        f"<|im_start|>user<|im_sep|>{MESSAGES[0]['content']}<|im_end|>"
        f"<|im_start|>assistant<|im_sep|>{MESSAGES[1]['content']}<|im_end|>"
        f"<|im_start|>user<|im_sep|>{MESSAGES[2]['content']}<|im_end|>"
luopl's avatar
luopl committed
268
269
        "<|im_start|>assistant<|im_sep|>"
    )
chenych's avatar
chenych committed
270
    answer_str = f"{MESSAGES[3]['content']}<|im_end|>"
luopl's avatar
luopl committed
271
    _check_template("microsoft/phi-4", "phi4", prompt_str, answer_str, use_fast)
chenych's avatar
chenych committed
272
273


chenych's avatar
chenych committed
274
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
luopl's avatar
luopl committed
275
@pytest.mark.parametrize("use_fast", [True, False])
chenych's avatar
chenych committed
276
def test_qwen2_5_template(use_fast: bool):
chenych's avatar
chenych committed
277
    prompt_str = (
chenych's avatar
chenych committed
278
        "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n"
chenych's avatar
chenych committed
279
280
281
        f"<|im_start|>user\n{MESSAGES[0]['content']}<|im_end|>\n"
        f"<|im_start|>assistant\n{MESSAGES[1]['content']}<|im_end|>\n"
        f"<|im_start|>user\n{MESSAGES[2]['content']}<|im_end|>\n"
chenych's avatar
chenych committed
282
283
        "<|im_start|>assistant\n"
    )
chenych's avatar
chenych committed
284
    answer_str = f"{MESSAGES[3]['content']}<|im_end|>\n"
chenych's avatar
chenych committed
285
    _check_template("Qwen/Qwen2.5-7B-Instruct", "qwen", prompt_str, answer_str, use_fast)
chenych's avatar
chenych committed
286
287


chenych's avatar
chenych committed
288
@pytest.mark.parametrize("use_fast", [True, False])
chenych's avatar
chenych committed
289
290
@pytest.mark.parametrize("cot_messages", [True, False])
def test_qwen3_template(use_fast: bool, cot_messages: bool):
chenych's avatar
chenych committed
291
    prompt_str = (
chenych's avatar
chenych committed
292
293
294
        f"<|im_start|>user\n{MESSAGES[0]['content']}<|im_end|>\n"
        f"<|im_start|>assistant\n{MESSAGES[1]['content']}<|im_end|>\n"
        f"<|im_start|>user\n{MESSAGES[2]['content']}<|im_end|>\n"
chenych's avatar
chenych committed
295
296
        "<|im_start|>assistant\n"
    )
chenych's avatar
chenych committed
297
298
299
300
301
302
303
304
    if not cot_messages:
        answer_str = f"<think>\n\n</think>\n\n{MESSAGES[3]['content']}<|im_end|>\n"
        messages = MESSAGES
    else:
        answer_str = f"{MESSAGES_WITH_THOUGHT[3]['content']}<|im_end|>\n"
        messages = MESSAGES_WITH_THOUGHT

    _check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, use_fast, messages=messages)
chenych's avatar
chenych committed
305
306


chenych's avatar
chenych committed
307
308
def test_parse_llama3_template():
    tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, token=HF_TOKEN)
chenych's avatar
chenych committed
309
310
311
312
313
314
315
316
317
318
319
    template = parse_template(tokenizer)
    assert template.format_user.slots == [
        "<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
        "<|start_header_id|>assistant<|end_header_id|>\n\n"
    ]
    assert template.format_assistant.slots == ["{{content}}<|eot_id|>"]
    assert template.format_system.slots == ["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]
    assert template.format_prefix.slots == ["<|begin_of_text|>"]
    assert template.default_system == ""


chenych's avatar
chenych committed
320
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
chenych's avatar
chenych committed
321
def test_parse_qwen_template():
chenych's avatar
chenych committed
322
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", token=HF_TOKEN)
chenych's avatar
chenych committed
323
    template = parse_template(tokenizer)
chenych's avatar
chenych committed
324
    assert template.__class__.__name__ == "Template"
chenych's avatar
chenych committed
325
326
327
328
    assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
    assert template.format_assistant.slots == ["{{content}}<|im_end|>\n"]
    assert template.format_system.slots == ["<|im_start|>system\n{{content}}<|im_end|>\n"]
    assert template.format_prefix.slots == []
chenych's avatar
chenych committed
329
    assert template.default_system == "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
chenych's avatar
chenych committed
330
331


chenych's avatar
chenych committed
332
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
chenych's avatar
chenych committed
333
334
335
def test_parse_qwen3_template():
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", token=HF_TOKEN)
    template = parse_template(tokenizer)
chenych's avatar
chenych committed
336
    assert template.__class__.__name__ == "ReasoningTemplate"
chenych's avatar
chenych committed
337
338
339
340
341
    assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
    assert template.format_assistant.slots == ["{{content}}<|im_end|>\n"]
    assert template.format_system.slots == ["<|im_start|>system\n{{content}}<|im_end|>\n"]
    assert template.format_prefix.slots == []
    assert template.default_system == ""