"test/test_scatter.py" did not exist on "52f2ad25f85f469a675dc5c7ba20e8790a6e6677"
Commit c7c477c7 authored by chenych's avatar chenych
Browse files

add grpo

parents
Pipeline #2942 failed with stages
in 0 seconds
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from llamafactory.data import Role
from llamafactory.data.converter import get_dataset_converter
from llamafactory.data.parser import DatasetAttr
from llamafactory.hparams import DataArguments
def test_alpaca_converter():
dataset_attr = DatasetAttr("hf_hub", "llamafactory/tiny-supervised-dataset")
data_args = DataArguments()
example = {
"instruction": "Solve the math problem.",
"input": "3 + 4",
"output": "The answer is 7.",
}
dataset_converter = get_dataset_converter("alpaca", dataset_attr, data_args)
assert dataset_converter(example) == {
"_prompt": [{"role": Role.USER.value, "content": "Solve the math problem.\n3 + 4"}],
"_response": [{"role": Role.ASSISTANT.value, "content": "The answer is 7."}],
"_system": "",
"_tools": "",
"_images": None,
"_videos": None,
"_audios": None,
}
def test_sharegpt_converter():
dataset_attr = DatasetAttr("hf_hub", "llamafactory/tiny-supervised-dataset")
data_args = DataArguments()
example = {
"conversations": [
{"from": "system", "value": "You are a helpful assistant."},
{"from": "human", "value": "Solve the math problem.\n3 + 4"},
{"from": "gpt", "value": "The answer is 7."},
]
}
dataset_converter = get_dataset_converter("sharegpt", dataset_attr, data_args)
assert dataset_converter(example) == {
"_prompt": [{"role": Role.USER.value, "content": "Solve the math problem.\n3 + 4"}],
"_response": [{"role": Role.ASSISTANT.value, "content": "The answer is 7."}],
"_system": "You are a helpful assistant.",
"_tools": "",
"_images": None,
"_videos": None,
"_audios": None,
}
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from datetime import datetime
from llamafactory.data.formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
FUNCTION = {"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}
TOOLS = [
{
"name": "test_tool",
"description": "tool_desc",
"parameters": {
"type": "object",
"properties": {
"foo": {"type": "string", "description": "foo_desc"},
"bar": {"type": "number", "description": "bar_desc"},
},
"required": ["foo"],
},
}
]
def test_empty_formatter():
formatter = EmptyFormatter(slots=["\n"])
assert formatter.apply() == ["\n"]
def test_string_formatter():
formatter = StringFormatter(slots=["<s>", "Human: {{content}}\nAssistant:"])
assert formatter.apply(content="Hi") == ["<s>", "Human: Hi\nAssistant:"]
def test_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default")
tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == [
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}""",
"</s>",
]
def test_multi_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default")
tool_calls = json.dumps([FUNCTION] * 2)
assert formatter.apply(content=tool_calls) == [
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n"""
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}""",
"</s>",
]
def test_default_tool_formatter():
formatter = ToolFormatter(tool_format="default")
assert formatter.apply(content=json.dumps(TOOLS)) == [
"You have access to the following tools:\n"
"> Tool Name: test_tool\n"
"Tool Description: tool_desc\n"
"Tool Args:\n"
" - foo (string, required): foo_desc\n"
" - bar (number): bar_desc\n\n"
"Use the following format if using a tool:\n"
"```\n"
"Action: tool name (one of [test_tool])\n"
"Action Input: the input to the tool, in a JSON format representing the kwargs "
"""(e.g. ```{"input": "hello world", "num_beams": 5}```)\n"""
"```\n"
]
def test_default_tool_extractor():
formatter = ToolFormatter(tool_format="default")
result = """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
def test_default_multi_tool_extractor():
formatter = ToolFormatter(tool_format="default")
result = (
"""Action: test_tool\nAction Input: {"foo": "bar", "size": 10}\n"""
"""Action: another_tool\nAction Input: {"foo": "job", "size": 2}"""
)
assert formatter.extract(result) == [
("test_tool", """{"foo": "bar", "size": 10}"""),
("another_tool", """{"foo": "job", "size": 2}"""),
]
def test_glm4_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}"], tool_format="glm4")
tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == ["""tool_name\n{"foo": "bar", "size": 10}"""]
def test_glm4_tool_formatter():
formatter = ToolFormatter(tool_format="glm4")
assert formatter.apply(content=json.dumps(TOOLS)) == [
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱 AI 公司训练的语言模型 GLM-4 模型开发的,"
"你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具\n\n"
f"## test_tool\n\n{json.dumps(TOOLS[0], indent=4, ensure_ascii=False)}\n"
"在调用上述函数时,请使用 Json 格式表示调用的参数。"
]
def test_glm4_tool_extractor():
formatter = ToolFormatter(tool_format="glm4")
result = """test_tool\n{"foo": "bar", "size": 10}\n"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
def test_llama3_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3")
tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == [
"""{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}<|eot_id|>"""
]
def test_llama3_multi_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3")
tool_calls = json.dumps([FUNCTION] * 2)
assert formatter.apply(content=tool_calls) == [
"""[{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}, """
"""{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}]"""
"""<|eot_id|>"""
]
def test_llama3_tool_formatter():
formatter = ToolFormatter(tool_format="llama3")
date = datetime.now().strftime("%d %b %Y")
wrapped_tool = {"type": "function", "function": TOOLS[0]}
assert formatter.apply(content=json.dumps(TOOLS)) == [
f"Cutting Knowledge Date: December 2023\nToday Date: {date}\n\n"
"You have access to the following functions. "
"To call a function, please respond with JSON for a function call. "
"""Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. """
f"Do not use variables.\n\n{json.dumps(wrapped_tool, indent=4, ensure_ascii=False)}\n\n"
]
def test_llama3_tool_extractor():
formatter = ToolFormatter(tool_format="llama3")
result = """{"name": "test_tool", "parameters": {"foo": "bar", "size": 10}}\n"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
def test_llama3_multi_tool_extractor():
formatter = ToolFormatter(tool_format="llama3")
result = (
"""[{"name": "test_tool", "parameters": {"foo": "bar", "size": 10}}, """
"""{"name": "another_tool", "parameters": {"foo": "job", "size": 2}}]"""
)
assert formatter.extract(result) == [
("test_tool", """{"foo": "bar", "size": 10}"""),
("another_tool", """{"foo": "job", "size": 2}"""),
]
def test_mistral_function_formatter():
formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", "</s>"], tool_format="mistral")
tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == [
"[TOOL_CALLS] " """[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]""",
"</s>",
]
def test_mistral_multi_function_formatter():
formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", "</s>"], tool_format="mistral")
tool_calls = json.dumps([FUNCTION] * 2)
assert formatter.apply(content=tool_calls) == [
"[TOOL_CALLS] "
"""[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}, """
"""{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]""",
"</s>",
]
def test_mistral_tool_formatter():
formatter = ToolFormatter(tool_format="mistral")
wrapped_tool = {"type": "function", "function": TOOLS[0]}
assert formatter.apply(content=json.dumps(TOOLS)) == [
"[AVAILABLE_TOOLS] " + json.dumps([wrapped_tool], ensure_ascii=False) + "[/AVAILABLE_TOOLS]"
]
def test_mistral_tool_extractor():
formatter = ToolFormatter(tool_format="mistral")
result = """{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
def test_mistral_multi_tool_extractor():
formatter = ToolFormatter(tool_format="mistral")
result = (
"""[{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}, """
"""{"name": "another_tool", "arguments": {"foo": "job", "size": 2}}]"""
)
assert formatter.extract(result) == [
("test_tool", """{"foo": "bar", "size": 10}"""),
("another_tool", """{"foo": "job", "size": 2}"""),
]
def test_qwen_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen")
tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == [
"""<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call><|im_end|>\n"""
]
def test_qwen_multi_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen")
tool_calls = json.dumps([FUNCTION] * 2)
assert formatter.apply(content=tool_calls) == [
"""<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>\n"""
"""<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>"""
"<|im_end|>\n"
]
def test_qwen_tool_formatter():
formatter = ToolFormatter(tool_format="qwen")
wrapped_tool = {"type": "function", "function": TOOLS[0]}
assert formatter.apply(content=json.dumps(TOOLS)) == [
"\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
f"\n{json.dumps(wrapped_tool, ensure_ascii=False)}"
"\n</tools>\n\nFor each function call, return a json object with function name and arguments within "
"""<tool_call></tool_call> XML tags:\n<tool_call>\n{"name": <function-name>, """
""""arguments": <args-json-object>}\n</tool_call>"""
]
def test_qwen_tool_extractor():
formatter = ToolFormatter(tool_format="qwen")
result = """<tool_call>\n{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
def test_qwen_multi_tool_extractor():
formatter = ToolFormatter(tool_format="qwen")
result = (
"""<tool_call>\n{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>\n"""
"""<tool_call>\n{"name": "another_tool", "arguments": {"foo": "job", "size": 2}}\n</tool_call>"""
)
assert formatter.extract(result) == [
("test_tool", """{"foo": "bar", "size": 10}"""),
("another_tool", """{"foo": "job", "size": 2}"""),
]
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from llamafactory.train.test_utils import load_dataset_module
DEMO_DATA = os.getenv("DEMO_DATA", "llamafactory/demo_data")
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
TINY_DATA = os.getenv("TINY_DATA", "llamafactory/tiny-supervised-dataset")
TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA3,
"stage": "sft",
"do_train": True,
"finetuning_type": "full",
"template": "llama3",
"dataset": TINY_DATA,
"dataset_dir": "ONLINE",
"cutoff_len": 8192,
"output_dir": "dummy_dir",
"overwrite_output_dir": True,
"fp16": True,
}
def test_load_train_only():
dataset_module = load_dataset_module(**TRAIN_ARGS)
assert dataset_module.get("train_dataset") is not None
assert dataset_module.get("eval_dataset") is None
def test_load_val_size():
dataset_module = load_dataset_module(val_size=0.1, **TRAIN_ARGS)
assert dataset_module.get("train_dataset") is not None
assert dataset_module.get("eval_dataset") is not None
def test_load_eval_data():
dataset_module = load_dataset_module(eval_dataset=TINY_DATA, **TRAIN_ARGS)
assert dataset_module.get("train_dataset") is not None
assert dataset_module.get("eval_dataset") is not None
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import TYPE_CHECKING, Any
import numpy as np
import pytest
import torch
from PIL import Image
from llamafactory.data.mm_plugin import get_mm_plugin
from llamafactory.extras.packages import is_transformers_version_greater_than
from llamafactory.hparams import get_infer_args
from llamafactory.model import load_tokenizer
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor
from llamafactory.data.mm_plugin import BasePlugin
from llamafactory.model.loader import TokenizerModule
HF_TOKEN = os.getenv("HF_TOKEN")
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
TINY_LLAMA4 = os.getenv("TINY_LLAMA4", "llamafactory/tiny-random-Llama-4")
MM_MESSAGES = [
{"role": "user", "content": "<image>What is in this image?"},
{"role": "assistant", "content": "A cat."},
]
OMNI_MESSAGES = [
{"role": "user", "content": "<image>What is in this image?"},
{"role": "assistant", "content": "A cat."},
{"role": "user", "content": "<audio>What is in this audio?"},
{"role": "assistant", "content": "Nothing."},
]
TEXT_MESSAGES = [
{"role": "user", "content": "How are you"},
{"role": "assistant", "content": "I am fine!"},
]
AUDIOS = [np.zeros(1600)]
IMAGES = [Image.new("RGB", (32, 32), (255, 255, 255))]
NO_IMAGES = []
NO_VIDEOS = []
NO_AUDIOS = []
IMGLENS = [1]
AUDLENS = [1]
NO_IMGLENS = [0]
NO_VIDLENS = [0]
NO_AUDLENS = [0]
INPUT_IDS = [0, 1, 2, 3, 4]
LABELS = [0, 1, 2, 3, 4]
BATCH_IDS = [[1] * 1024]
def _get_mm_inputs(processor: "ProcessorMixin") -> dict[str, "torch.Tensor"]:
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
return image_processor(images=IMAGES, return_tensors="pt")
def _get_omni_inputs(processor: "ProcessorMixin") -> dict[str, "torch.Tensor"]:
mm_inputs = {}
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
feature_extractor = getattr(processor, "feature_extractor", None)
mm_inputs.update(image_processor(IMAGES, return_tensors="pt"))
mm_inputs.update(
feature_extractor(
AUDIOS,
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
return_attention_mask=True,
padding="max_length",
return_tensors="pt",
)
)
mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask")
return mm_inputs
def _is_close(batch_a: dict[str, Any], batch_b: dict[str, Any]) -> None:
assert batch_a.keys() == batch_b.keys()
for key in batch_a.keys():
if isinstance(batch_a[key], torch.Tensor):
assert torch.allclose(batch_a[key], batch_b[key], rtol=1e-4, atol=1e-5)
elif isinstance(batch_a[key], list) and all(isinstance(item, torch.Tensor) for item in batch_a[key]):
assert len(batch_a[key]) == len(batch_b[key])
for tensor_a, tensor_b in zip(batch_a[key], batch_b[key]):
assert torch.allclose(tensor_a, tensor_b, rtol=1e-4, atol=1e-5)
else:
assert batch_a[key] == batch_b[key]
def _load_tokenizer_module(model_name_or_path: str) -> "TokenizerModule":
model_args, *_ = get_infer_args({"model_name_or_path": model_name_or_path, "template": "default"})
return load_tokenizer(model_args)
def _check_plugin(
plugin: "BasePlugin",
tokenizer: "PreTrainedTokenizer",
processor: "ProcessorMixin",
expected_mm_messages: list[dict[str, str]] = MM_MESSAGES,
expected_input_ids: list[int] = INPUT_IDS,
expected_labels: list[int] = LABELS,
expected_mm_inputs: dict[str, Any] = {},
expected_no_mm_inputs: dict[str, Any] = {},
) -> None:
if plugin.__class__.__name__ == "Qwen2OmniPlugin": # test omni_messages
assert plugin.process_messages(OMNI_MESSAGES, IMAGES, NO_VIDEOS, AUDIOS, processor) == expected_mm_messages
assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, AUDIOS, tokenizer, processor) == (
expected_input_ids,
expected_labels,
)
_is_close(
plugin.get_mm_inputs(IMAGES, NO_VIDEOS, AUDIOS, IMGLENS, NO_VIDLENS, AUDLENS, BATCH_IDS, processor),
expected_mm_inputs,
)
elif plugin.__class__.__name__ != "BasePlugin": # test mm_messages
assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == expected_mm_messages
assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, NO_AUDIOS, tokenizer, processor) == (
expected_input_ids,
expected_labels,
)
_is_close(
plugin.get_mm_inputs(IMAGES, NO_VIDEOS, NO_AUDIOS, IMGLENS, NO_VIDLENS, NO_AUDLENS, BATCH_IDS, processor),
expected_mm_inputs,
)
# test text_messages
assert plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == TEXT_MESSAGES
assert plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, NO_VIDEOS, NO_AUDIOS, tokenizer, processor) == (
INPUT_IDS,
LABELS,
)
_is_close(
plugin.get_mm_inputs(
NO_IMAGES, NO_VIDEOS, NO_AUDIOS, NO_IMGLENS, NO_VIDLENS, NO_AUDLENS, BATCH_IDS, processor
),
expected_no_mm_inputs,
)
def test_base_plugin():
tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA3)
base_plugin = get_mm_plugin(name="base")
check_inputs = {"plugin": base_plugin, **tokenizer_module}
_check_plugin(**check_inputs)
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
@pytest.mark.skipif(not is_transformers_version_greater_than("4.50.0"), reason="Requires transformers>=4.50.0")
def test_gemma3_plugin():
image_seqlen = 256
tokenizer_module = _load_tokenizer_module(model_name_or_path="google/gemma-3-4b-it")
gemma3_plugin = get_mm_plugin(name="gemma3", image_token="<image_soft_token>")
image_tokens_expanded = "<image_soft_token>" * image_seqlen
check_inputs = {"plugin": gemma3_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [
{
key: value.replace("<image>", f"\n\n<start_of_image>{image_tokens_expanded}<end_of_image>\n\n")
for key, value in message.items()
}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
check_inputs["expected_mm_inputs"].pop("num_crops")
check_inputs["expected_mm_inputs"]["token_type_ids"] = [[0] * 1024]
check_inputs["expected_no_mm_inputs"] = {"token_type_ids": [[0] * 1024]}
_check_plugin(**check_inputs)
@pytest.mark.skipif(not is_transformers_version_greater_than("4.52.0"), reason="Requires transformers>=4.52.0")
def test_internvl_plugin():
image_seqlen = 256
tokenizer_module = _load_tokenizer_module(model_name_or_path="OpenGVLab/InternVL3-1B-hf")
internvl_plugin = get_mm_plugin("intern_vl", image_token="<image>", video_token="<video>")
check_inputs = {"plugin": internvl_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [
{
key: value.replace("<image>", f"<img>{'<IMG_CONTEXT>' * image_seqlen * 1}</img>")
for key, value in message.items()
}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
check_inputs["expected_mm_inputs"].pop("num_patches", None)
_check_plugin(**check_inputs)
@pytest.mark.skipif(not is_transformers_version_greater_than("4.51.0"), reason="Requires transformers>=4.51.0")
def test_llama4_plugin():
tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA4)
processor = tokenizer_module["processor"]
llama4_plugin = get_mm_plugin(name="llama4", image_token="<|image|>")
check_inputs = {"plugin": llama4_plugin, **tokenizer_module}
mm_inputs = _get_mm_inputs(tokenizer_module["processor"])
image_height, image_width = mm_inputs["pixel_values"][0].shape[-2:]
num_patches_per_chunk = int(
(image_height // processor.patch_size) * (image_width // processor.patch_size) // processor.downsample_ratio
)
aspect_ratios = mm_inputs.pop("aspect_ratios")
tokens_for_this_image = processor._prompt_split_image(aspect_ratios[0], num_patches_per_chunk)
check_inputs["expected_mm_messages"] = [
{key: value.replace("<image>", tokens_for_this_image) for key, value in message.items()}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = mm_inputs
_check_plugin(**check_inputs)
def test_llava_plugin():
image_seqlen = 576
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")
llava_plugin = get_mm_plugin(name="llava", image_token="<image>")
check_inputs = {"plugin": llava_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
_check_plugin(**check_inputs)
def test_llava_next_plugin():
image_seqlen = 1176
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-v1.6-vicuna-7b-hf")
llava_next_plugin = get_mm_plugin(name="llava_next", image_token="<image>")
check_inputs = {"plugin": llava_next_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
_check_plugin(**check_inputs)
def test_llava_next_video_plugin():
image_seqlen = 1176
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/LLaVA-NeXT-Video-7B-hf")
llava_next_video_plugin = get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>")
check_inputs = {"plugin": llava_next_video_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
_check_plugin(**check_inputs)
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
def test_paligemma_plugin():
image_seqlen = 256
tokenizer_module = _load_tokenizer_module(model_name_or_path="google/paligemma-3b-pt-224")
paligemma_plugin = get_mm_plugin(name="paligemma", image_token="<image>")
check_inputs = {"plugin": paligemma_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [
{key: value.replace("<image>", "") for key, value in message.items()} for message in MM_MESSAGES
]
check_inputs["expected_input_ids"] = [
tokenizer_module["tokenizer"].convert_tokens_to_ids(paligemma_plugin.image_token)
] * image_seqlen + INPUT_IDS
check_inputs["expected_labels"] = [-100] * image_seqlen + LABELS
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
check_inputs["expected_mm_inputs"]["token_type_ids"] = [[0] * image_seqlen + [1] * (1024 - image_seqlen)]
check_inputs["expected_no_mm_inputs"] = {"token_type_ids": [[1] * 1024]}
_check_plugin(**check_inputs)
@pytest.mark.skipif(not is_transformers_version_greater_than("4.50.0"), reason="Requires transformers>=4.50.0")
def test_pixtral_plugin():
image_slice_height, image_slice_width = 2, 2
tokenizer_module = _load_tokenizer_module(model_name_or_path="mistral-community/pixtral-12b")
pixtral_plugin = get_mm_plugin(name="pixtral", image_token="[IMG]")
check_inputs = {"plugin": pixtral_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [
{
key: value.replace(
"<image>",
("{}[IMG_BREAK]".format("[IMG]" * image_slice_width) * image_slice_height).rsplit("[IMG_BREAK]", 1)[0]
+ "[IMG_END]",
)
for key, value in message.items()
}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
check_inputs["expected_mm_inputs"]["pixel_values"] = check_inputs["expected_mm_inputs"]["pixel_values"][0]
_check_plugin(**check_inputs)
@pytest.mark.skipif(not is_transformers_version_greater_than("4.52.0"), reason="Requires transformers>=4.52.0")
def test_qwen2_omni_plugin():
image_seqlen, audio_seqlen = 4, 2
tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2.5-Omni-7B")
qwen2_omni_plugin = get_mm_plugin(
name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>"
)
check_inputs = {"plugin": qwen2_omni_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [
{
key: (
value.replace("<image>", f"<|vision_bos|>{'<|IMAGE|>' * image_seqlen}<|vision_eos|>").replace(
"<audio>", f"<|audio_bos|>{'<|AUDIO|>' * audio_seqlen}<|audio_eos|>"
)
)
for key, value in message.items()
}
for message in OMNI_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_omni_inputs(tokenizer_module["processor"])
_check_plugin(**check_inputs)
def test_qwen2_vl_plugin():
image_seqlen = 4
tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
qwen2_vl_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>")
check_inputs = {"plugin": qwen2_vl_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [
{
key: value.replace("<image>", "<|vision_start|>{}<|vision_end|>".format("<|image_pad|>" * image_seqlen))
for key, value in message.items()
}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
_check_plugin(**check_inputs)
@pytest.mark.skipif(not is_transformers_version_greater_than("4.47.0"), reason="Requires transformers>=4.47.0")
def test_video_llava_plugin():
image_seqlen = 256
tokenizer_module = _load_tokenizer_module(model_name_or_path="LanguageBind/Video-LLaVA-7B-hf")
video_llava_plugin = get_mm_plugin(name="video_llava", image_token="<image>", video_token="<video>")
check_inputs = {"plugin": video_llava_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
_check_plugin(**check_inputs)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import TYPE_CHECKING
import pytest
from transformers import AutoTokenizer
from llamafactory.data import get_template_and_fix_tokenizer
from llamafactory.data.template import parse_template
from llamafactory.hparams import DataArguments
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
HF_TOKEN = os.getenv("HF_TOKEN")
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
TINY_LLAMA4 = os.getenv("TINY_LLAMA4", "llamafactory/tiny-random-Llama-4")
MESSAGES = [
{"role": "user", "content": "How are you"},
{"role": "assistant", "content": "I am fine!"},
{"role": "user", "content": "你好"},
{"role": "assistant", "content": "很高兴认识你!"},
]
MESSAGES_WITH_THOUGHT = [
{"role": "user", "content": "How are you"},
{"role": "assistant", "content": "<think>\nModel thought here\n</think>\n\nI am fine!"},
{"role": "user", "content": "你好"},
{"role": "assistant", "content": "<think>\n模型思考内容\n</think>\n\n很高兴认识你!"},
]
def _check_tokenization(
tokenizer: "PreTrainedTokenizer", batch_input_ids: list[list[int]], batch_text: list[str]
) -> None:
r"""Check token ids and texts.
encode(text) == token_ids
decode(token_ids) == text
"""
for input_ids, text in zip(batch_input_ids, batch_text):
assert tokenizer.encode(text, add_special_tokens=False) == input_ids
assert tokenizer.decode(input_ids) == text
def _check_template(
model_id: str,
template_name: str,
prompt_str: str,
answer_str: str,
use_fast: bool,
messages: list[dict[str, str]] = MESSAGES,
) -> None:
r"""Check template.
Args:
model_id: the model id on hugging face hub.
template_name: the template name.
prompt_str: the string corresponding to the prompt part.
answer_str: the string corresponding to the answer part.
use_fast: whether to use fast tokenizer.
messages: the list of messages.
"""
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=HF_TOKEN)
content_str = tokenizer.apply_chat_template(messages, tokenize=False)
content_ids = tokenizer.apply_chat_template(messages, tokenize=True)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template=template_name))
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, messages)
assert content_str == prompt_str + answer_str
assert content_ids == prompt_ids + answer_ids
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
@pytest.mark.parametrize("use_fast", [True, False])
def test_encode_oneturn(use_fast: bool):
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
prompt_str = (
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\nI am fine!<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
answer_str = "很高兴认识你!<|eot_id|>"
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
@pytest.mark.parametrize("use_fast", [True, False])
def test_encode_multiturn(use_fast: bool):
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES)
prompt_str_1 = (
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
answer_str_1 = "I am fine!<|eot_id|>"
prompt_str_2 = (
"<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
answer_str_2 = "很高兴认识你!<|eot_id|>"
_check_tokenization(
tokenizer,
(encoded_pairs[0][0], encoded_pairs[0][1], encoded_pairs[1][0], encoded_pairs[1][1]),
(prompt_str_1, answer_str_1, prompt_str_2, answer_str_2),
)
@pytest.mark.parametrize("use_fast", [True, False])
@pytest.mark.parametrize("cot_messages", [True, False])
@pytest.mark.parametrize("enable_thinking", [True, False, None])
def test_reasoning_encode_oneturn(use_fast: bool, cot_messages: bool, enable_thinking: bool):
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast)
data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking)
template = get_template_and_fix_tokenizer(tokenizer, data_args)
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES)
prompt_str = (
f"<|im_start|>user\n{MESSAGES[0]['content']}<|im_end|>\n<|im_start|>assistant\n"
f"{MESSAGES[1]['content']}<|im_end|>\n"
f"<|im_start|>user\n{MESSAGES[2]['content']}<|im_end|>\n<|im_start|>assistant\n"
)
if not cot_messages or enable_thinking is False:
answer_str = f"{MESSAGES[3]['content']}<|im_end|>\n"
if enable_thinking:
answer_str = "<think>\n\n</think>\n\n" + answer_str
else:
prompt_str = prompt_str + "<think>\n\n</think>\n\n"
else:
answer_str = f"{MESSAGES_WITH_THOUGHT[3]['content']}<|im_end|>\n"
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
@pytest.mark.parametrize("use_fast", [True, False])
@pytest.mark.parametrize("cot_messages", [True, False])
@pytest.mark.parametrize("enable_thinking", [True, False, None])
def test_reasoning_encode_multiturn(use_fast: bool, cot_messages: bool, enable_thinking: bool):
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast)
data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking)
template = get_template_and_fix_tokenizer(tokenizer, data_args)
encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES)
messages = MESSAGES if not cot_messages or enable_thinking is False else MESSAGES_WITH_THOUGHT
prompt_str_1 = f"<|im_start|>user\n{MESSAGES[0]['content']}<|im_end|>\n<|im_start|>assistant\n"
answer_str_1 = f"{messages[1]['content']}<|im_end|>\n"
prompt_str_2 = f"<|im_start|>user\n{MESSAGES[2]['content']}<|im_end|>\n<|im_start|>assistant\n"
answer_str_2 = f"{messages[3]['content']}<|im_end|>\n"
if not cot_messages or enable_thinking is False:
if enable_thinking:
answer_str_1 = "<think>\n\n</think>\n\n" + answer_str_1
answer_str_2 = "<think>\n\n</think>\n\n" + answer_str_2
else:
prompt_str_1 = prompt_str_1 + "<think>\n\n</think>\n\n"
prompt_str_2 = prompt_str_2 + "<think>\n\n</think>\n\n"
_check_tokenization(
tokenizer,
(encoded_pairs[0][0], encoded_pairs[0][1], encoded_pairs[1][0], encoded_pairs[1][1]),
(prompt_str_1, answer_str_1, prompt_str_2, answer_str_2),
)
@pytest.mark.parametrize("use_fast", [True, False])
def test_jinja_template(use_fast: bool):
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
tokenizer.chat_template = template._get_jinja_template(tokenizer) # llama3 template no replace
assert tokenizer.chat_template != ref_tokenizer.chat_template
assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES)
def test_ollama_modelfile():
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
assert template.get_ollama_modelfile(tokenizer) == (
"# ollama modelfile auto-generated by llamafactory\n\n"
"FROM .\n\n"
'TEMPLATE """<|begin_of_text|>'
"{{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}"
'{{ range .Messages }}{{ if eq .Role "user" }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Content }}'
"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
'{{ else if eq .Role "assistant" }}{{ .Content }}<|eot_id|>{{ end }}{{ end }}"""\n\n'
'PARAMETER stop "<|eom_id|>"\n'
'PARAMETER stop "<|eot_id|>"\n'
"PARAMETER num_ctx 4096\n"
)
def test_get_stop_token_ids():
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
assert set(template.get_stop_token_ids(tokenizer)) == {128008, 128009}
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
@pytest.mark.parametrize("use_fast", [True, False])
def test_gemma_template(use_fast: bool):
prompt_str = (
f"<bos><start_of_turn>user\n{MESSAGES[0]['content']}<end_of_turn>\n"
f"<start_of_turn>model\n{MESSAGES[1]['content']}<end_of_turn>\n"
f"<start_of_turn>user\n{MESSAGES[2]['content']}<end_of_turn>\n"
"<start_of_turn>model\n"
)
answer_str = f"{MESSAGES[3]['content']}<end_of_turn>\n"
_check_template("google/gemma-3-4b-it", "gemma", prompt_str, answer_str, use_fast)
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
@pytest.mark.parametrize("use_fast", [True, False])
def test_gemma2_template(use_fast: bool):
prompt_str = (
f"<bos><start_of_turn>user\n{MESSAGES[0]['content']}<end_of_turn>\n"
f"<start_of_turn>model\n{MESSAGES[1]['content']}<end_of_turn>\n"
f"<start_of_turn>user\n{MESSAGES[2]['content']}<end_of_turn>\n"
"<start_of_turn>model\n"
)
answer_str = f"{MESSAGES[3]['content']}<end_of_turn>\n"
_check_template("google/gemma-2-2b-it", "gemma2", prompt_str, answer_str, use_fast)
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
@pytest.mark.parametrize("use_fast", [True, False])
def test_llama3_template(use_fast: bool):
prompt_str = (
f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{MESSAGES[0]['content']}<|eot_id|>"
f"<|start_header_id|>assistant<|end_header_id|>\n\n{MESSAGES[1]['content']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{MESSAGES[2]['content']}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
answer_str = f"{MESSAGES[3]['content']}<|eot_id|>"
_check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str, use_fast)
@pytest.mark.parametrize(
"use_fast", [True, pytest.param(False, marks=pytest.mark.xfail(reason="Llama 4 has no slow tokenizer."))]
)
def test_llama4_template(use_fast: bool):
prompt_str = (
f"<|begin_of_text|><|header_start|>user<|header_end|>\n\n{MESSAGES[0]['content']}<|eot|>"
f"<|header_start|>assistant<|header_end|>\n\n{MESSAGES[1]['content']}<|eot|>"
f"<|header_start|>user<|header_end|>\n\n{MESSAGES[2]['content']}<|eot|>"
"<|header_start|>assistant<|header_end|>\n\n"
)
answer_str = f"{MESSAGES[3]['content']}<|eot|>"
_check_template(TINY_LLAMA4, "llama4", prompt_str, answer_str, use_fast)
@pytest.mark.parametrize(
"use_fast",
[
pytest.param(True, marks=pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")),
pytest.param(False, marks=pytest.mark.xfail(reason="Phi-4 slow tokenizer is broken.")),
],
)
def test_phi4_template(use_fast: bool):
prompt_str = (
f"<|im_start|>user<|im_sep|>{MESSAGES[0]['content']}<|im_end|>"
f"<|im_start|>assistant<|im_sep|>{MESSAGES[1]['content']}<|im_end|>"
f"<|im_start|>user<|im_sep|>{MESSAGES[2]['content']}<|im_end|>"
"<|im_start|>assistant<|im_sep|>"
)
answer_str = f"{MESSAGES[3]['content']}<|im_end|>"
_check_template("microsoft/phi-4", "phi4", prompt_str, answer_str, use_fast)
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
@pytest.mark.parametrize("use_fast", [True, False])
def test_qwen2_5_template(use_fast: bool):
prompt_str = (
"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n"
f"<|im_start|>user\n{MESSAGES[0]['content']}<|im_end|>\n"
f"<|im_start|>assistant\n{MESSAGES[1]['content']}<|im_end|>\n"
f"<|im_start|>user\n{MESSAGES[2]['content']}<|im_end|>\n"
"<|im_start|>assistant\n"
)
answer_str = f"{MESSAGES[3]['content']}<|im_end|>\n"
_check_template("Qwen/Qwen2.5-7B-Instruct", "qwen", prompt_str, answer_str, use_fast)
@pytest.mark.parametrize("use_fast", [True, False])
@pytest.mark.parametrize("cot_messages", [True, False])
def test_qwen3_template(use_fast: bool, cot_messages: bool):
prompt_str = (
f"<|im_start|>user\n{MESSAGES[0]['content']}<|im_end|>\n"
f"<|im_start|>assistant\n{MESSAGES[1]['content']}<|im_end|>\n"
f"<|im_start|>user\n{MESSAGES[2]['content']}<|im_end|>\n"
"<|im_start|>assistant\n"
)
if not cot_messages:
answer_str = f"<think>\n\n</think>\n\n{MESSAGES[3]['content']}<|im_end|>\n"
messages = MESSAGES
else:
answer_str = f"{MESSAGES_WITH_THOUGHT[3]['content']}<|im_end|>\n"
messages = MESSAGES_WITH_THOUGHT
_check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, use_fast, messages=messages)
def test_parse_llama3_template():
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, token=HF_TOKEN)
template = parse_template(tokenizer)
assert template.format_user.slots == [
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
]
assert template.format_assistant.slots == ["{{content}}<|eot_id|>"]
assert template.format_system.slots == ["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]
assert template.format_prefix.slots == ["<|begin_of_text|>"]
assert template.default_system == ""
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
def test_parse_qwen_template():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", token=HF_TOKEN)
template = parse_template(tokenizer)
assert template.__class__.__name__ == "Template"
assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
assert template.format_assistant.slots == ["{{content}}<|im_end|>\n"]
assert template.format_system.slots == ["<|im_start|>system\n{{content}}<|im_end|>\n"]
assert template.format_prefix.slots == []
assert template.default_system == "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
def test_parse_qwen3_template():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", token=HF_TOKEN)
template = parse_template(tokenizer)
assert template.__class__.__name__ == "ReasoningTemplate"
assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
assert template.format_assistant.slots == ["{{content}}<|im_end|>\n"]
assert template.format_system.slots == ["<|im_start|>system\n{{content}}<|im_end|>\n"]
assert template.format_prefix.slots == []
assert template.default_system == ""
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from llamafactory.chat import ChatModel
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
INFER_ARGS = {
"model_name_or_path": TINY_LLAMA3,
"finetuning_type": "lora",
"template": "llama3",
"infer_dtype": "float16",
"do_sample": False,
"max_new_tokens": 1,
}
MESSAGES = [
{"role": "user", "content": "Hi"},
]
EXPECTED_RESPONSE = "_rho"
def test_chat():
chat_model = ChatModel(INFER_ARGS)
assert chat_model.chat(MESSAGES)[0].response_text == EXPECTED_RESPONSE
def test_stream_chat():
chat_model = ChatModel(INFER_ARGS)
response = ""
for token in chat_model.stream_chat(MESSAGES):
response += token
assert response == EXPECTED_RESPONSE
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import pytest
from llamafactory.chat import ChatModel
from llamafactory.extras.packages import is_sglang_available
MODEL_NAME = "Qwen/Qwen2.5-0.5B"
INFER_ARGS = {
"model_name_or_path": MODEL_NAME,
"finetuning_type": "lora",
"template": "llama3",
"infer_dtype": "float16",
"infer_backend": "sglang",
"do_sample": False,
"max_new_tokens": 1,
}
MESSAGES = [
{"role": "user", "content": "Hi"},
]
@pytest.mark.skipif(not is_sglang_available(), reason="SGLang is not installed")
def test_chat():
r"""Test the SGLang engine's basic chat functionality."""
chat_model = ChatModel(INFER_ARGS)
response = chat_model.chat(MESSAGES)[0]
# TODO: Change to EXPECTED_RESPONSE
print(response.response_text)
@pytest.mark.skipif(not is_sglang_available(), reason="SGLang is not installed")
def test_stream_chat():
r"""Test the SGLang engine's streaming chat functionality."""
chat_model = ChatModel(INFER_ARGS)
response = ""
for token in chat_model.stream_chat(MESSAGES):
response += token
print("Complete response:", response)
assert response, "Should receive a non-empty response"
# Run tests if executed directly
if __name__ == "__main__":
if not is_sglang_available():
print("SGLang is not available. Please install it.")
sys.exit(1)
test_chat()
test_stream_chat()
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
from llamafactory.train.tuner import export_model, run_exp
DEMO_DATA = os.getenv("DEMO_DATA", "llamafactory/demo_data")
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
TINY_LLAMA_ADAPTER = os.getenv("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-lora")
TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA3,
"do_train": True,
"finetuning_type": "lora",
"dataset_dir": "REMOTE:" + DEMO_DATA,
"template": "llama3",
"cutoff_len": 1,
"overwrite_output_dir": True,
"per_device_train_batch_size": 1,
"max_steps": 1,
"report_to": "none",
}
INFER_ARGS = {
"model_name_or_path": TINY_LLAMA3,
"adapter_name_or_path": TINY_LLAMA_ADAPTER,
"finetuning_type": "lora",
"template": "llama3",
"infer_dtype": "float16",
}
OS_NAME = os.getenv("OS_NAME", "")
@pytest.mark.parametrize(
"stage,dataset",
[
("pt", "c4_demo"),
("sft", "alpaca_en_demo"),
("dpo", "dpo_en_demo"),
("kto", "kto_en_demo"),
pytest.param("rm", "dpo_en_demo", marks=pytest.mark.xfail(OS_NAME.startswith("windows"), reason="OS error.")),
],
)
def test_run_exp(stage: str, dataset: str):
output_dir = os.path.join("output", f"train_{stage}")
run_exp({"stage": stage, "dataset": dataset, "output_dir": output_dir, **TRAIN_ARGS})
assert os.path.exists(output_dir)
def test_export():
export_dir = os.path.join("output", "llama3_export")
export_model({"export_dir": export_dir, **INFER_ARGS})
assert os.path.exists(export_dir)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from llamafactory.eval.template import get_eval_template
def test_eval_template_en():
support_set = [
{
"question": "Fewshot question",
"A": "Fewshot1",
"B": "Fewshot2",
"C": "Fewshot3",
"D": "Fewshot4",
"answer": "B",
}
]
example = {
"question": "Target question",
"A": "Target1",
"B": "Target2",
"C": "Target3",
"D": "Target4",
"answer": "C",
}
template = get_eval_template(name="en")
messages = template.format_example(example, support_set=support_set, subject_name="SubName")
assert messages == [
{
"role": "user",
"content": (
"The following are multiple choice questions (with answers) about SubName.\n\n"
"Fewshot question\nA. Fewshot1\nB. Fewshot2\nC. Fewshot3\nD. Fewshot4\nAnswer:"
),
},
{"role": "assistant", "content": "B"},
{
"role": "user",
"content": "Target question\nA. Target1\nB. Target2\nC. Target3\nD. Target4\nAnswer:",
},
{"role": "assistant", "content": "C"},
]
def test_eval_template_zh():
support_set = [
{
"question": "示例问题",
"A": "示例答案1",
"B": "示例答案2",
"C": "示例答案3",
"D": "示例答案4",
"answer": "B",
}
]
example = {
"question": "目标问题",
"A": "目标答案1",
"B": "目标答案2",
"C": "目标答案3",
"D": "目标答案4",
"answer": "C",
}
template = get_eval_template(name="zh")
messages = template.format_example(example, support_set=support_set, subject_name="主题")
assert messages == [
{
"role": "user",
"content": (
"以下是中国关于主题考试的单项选择题,请选出其中的正确答案。\n\n"
"示例问题\nA. 示例答案1\nB. 示例答案2\nC. 示例答案3\nD. 示例答案4\n答案:"
),
},
{"role": "assistant", "content": "B"},
{
"role": "user",
"content": "目标问题\nA. 目标答案1\nB. 目标答案2\nC. 目标答案3\nD. 目标答案4\n答案:",
},
{"role": "assistant", "content": "C"},
]
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
from llamafactory.hparams import ModelArguments
from llamafactory.model import load_tokenizer
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
UNUSED_TOKEN = "<|UNUSED_TOKEN|>"
@pytest.mark.parametrize("special_tokens", [False, True])
def test_add_tokens(special_tokens: bool):
if special_tokens:
model_args = ModelArguments(model_name_or_path=TINY_LLAMA3, add_special_tokens=UNUSED_TOKEN)
else:
model_args = ModelArguments(model_name_or_path=TINY_LLAMA3, add_tokens=UNUSED_TOKEN)
tokenizer = load_tokenizer(model_args)["tokenizer"]
encoded_ids = tokenizer.encode(UNUSED_TOKEN, add_special_tokens=False)
assert len(encoded_ids) == 1
decoded_str = tokenizer.decode(encoded_ids, skip_special_tokens=True)
if special_tokens:
assert decoded_str == ""
else:
assert decoded_str == UNUSED_TOKEN
if __name__ == "__main__":
pytest.main([__file__])
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
from llamafactory.extras.packages import is_transformers_version_greater_than
from llamafactory.train.test_utils import load_infer_model
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
INFER_ARGS = {
"model_name_or_path": TINY_LLAMA3,
"template": "llama3",
}
@pytest.mark.xfail(is_transformers_version_greater_than("4.48"), reason="Attention refactor.")
def test_attention():
attention_available = ["disabled"]
if is_torch_sdpa_available():
attention_available.append("sdpa")
if is_flash_attn_2_available():
attention_available.append("fa2")
llama_attention_classes = {
"disabled": "LlamaAttention",
"sdpa": "LlamaSdpaAttention",
"fa2": "LlamaFlashAttention2",
}
for requested_attention in attention_available:
model = load_infer_model(flash_attn=requested_attention, **INFER_ARGS)
for module in model.modules():
if "Attention" in module.__class__.__name__:
assert module.__class__.__name__ == llama_attention_classes[requested_attention]
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
import torch
from llamafactory.extras.misc import get_current_device
from llamafactory.train.test_utils import load_train_model
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA3,
"stage": "sft",
"do_train": True,
"finetuning_type": "lora",
"lora_target": "all",
"dataset": "llamafactory/tiny-supervised-dataset",
"dataset_dir": "ONLINE",
"template": "llama3",
"cutoff_len": 1024,
"output_dir": "dummy_dir",
"overwrite_output_dir": True,
"fp16": True,
}
@pytest.mark.parametrize("disable_gradient_checkpointing", [False, True])
def test_vanilla_checkpointing(disable_gradient_checkpointing: bool):
model = load_train_model(disable_gradient_checkpointing=disable_gradient_checkpointing, **TRAIN_ARGS)
for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()):
assert getattr(module, "gradient_checkpointing") != disable_gradient_checkpointing
def test_unsloth_gradient_checkpointing():
model = load_train_model(use_unsloth_gc=True, **TRAIN_ARGS)
for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()):
assert module._gradient_checkpointing_func.__self__.__name__ == "UnslothGradientCheckpointing"
def test_upcast_layernorm():
model = load_train_model(upcast_layernorm=True, **TRAIN_ARGS)
for name, param in model.named_parameters():
if param.ndim == 1 and "norm" in name:
assert param.dtype == torch.float32
def test_upcast_lmhead_output():
model = load_train_model(upcast_lmhead_output=True, **TRAIN_ARGS)
inputs = torch.randn((1, 16), dtype=torch.float16, device=get_current_device())
outputs: torch.Tensor = model.get_output_embeddings()(inputs)
assert outputs.dtype == torch.float32
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
import torch
from transformers import AutoConfig, AutoModelForCausalLM
from llamafactory.model.model_utils.misc import find_expanded_modules
HF_TOKEN = os.getenv("HF_TOKEN")
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
def test_expanded_modules():
config = AutoConfig.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
with torch.device("meta"):
model = AutoModelForCausalLM.from_config(config)
expanded_modules = find_expanded_modules(model, ["q_proj", "v_proj"], num_layer_trainable=4)
assert expanded_modules == [
"model.layers.7.self_attn.q_proj",
"model.layers.7.self_attn.v_proj",
"model.layers.15.self_attn.q_proj",
"model.layers.15.self_attn.v_proj",
"model.layers.23.self_attn.q_proj",
"model.layers.23.self_attn.v_proj",
"model.layers.31.self_attn.q_proj",
"model.layers.31.self_attn.v_proj",
]
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from llamafactory.model.model_utils.packing import get_seqlens_in_batch, get_unpad_data
@pytest.mark.parametrize(
"attention_mask,golden_seq_lens",
[
(
[
[1, 1, 2, 2, 2, 0],
[1, 2, 2, 3, 3, 3],
],
[2, 3, 1, 2, 3],
),
(
[[1]],
[1],
),
],
)
def test_get_seqlens_in_batch(attention_mask, golden_seq_lens):
attention_mask_with_indices = torch.tensor(attention_mask)
seqlens_in_batch = get_seqlens_in_batch(attention_mask_with_indices)
assert torch.all(seqlens_in_batch == torch.tensor(golden_seq_lens))
@pytest.mark.parametrize(
"attention_mask,golden_indices,golden_cu_seqlens,golden_max_seqlen",
[
(
[
[1, 1, 2, 2, 2, 0],
[1, 2, 2, 3, 3, 3],
],
[0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11],
[0, 2, 5, 6, 8, 11],
3,
),
(
[[1]],
[0],
[0, 1],
1,
),
],
)
def test_get_unpad_data(attention_mask, golden_indices, golden_cu_seqlens, golden_max_seqlen):
attention_mask_with_indices = torch.tensor(attention_mask)
indices, cu_seqlens, max_seqlen_in_batch = get_unpad_data(attention_mask_with_indices)
assert torch.all(indices == torch.tensor(golden_indices))
assert torch.all(cu_seqlens == torch.tensor(golden_cu_seqlens, dtype=torch.int32))
assert max_seqlen_in_batch == golden_max_seqlen
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
import torch
from transformers import AutoConfig, AutoModelForVision2Seq
from llamafactory.extras.packages import is_transformers_version_greater_than
from llamafactory.hparams import FinetuningArguments, ModelArguments
from llamafactory.model.adapter import init_adapter
@pytest.mark.parametrize("freeze_vision_tower", (False, True))
@pytest.mark.parametrize("freeze_multi_modal_projector", (False, True))
@pytest.mark.parametrize("freeze_language_model", (False, True))
def test_visual_full(freeze_vision_tower: bool, freeze_multi_modal_projector: bool, freeze_language_model: bool):
model_args = ModelArguments(model_name_or_path="Qwen/Qwen2-VL-2B-Instruct")
finetuning_args = FinetuningArguments(
finetuning_type="full",
freeze_vision_tower=freeze_vision_tower,
freeze_multi_modal_projector=freeze_multi_modal_projector,
freeze_language_model=freeze_language_model,
)
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
with torch.device("meta"):
model = AutoModelForVision2Seq.from_config(config)
model = init_adapter(config, model, model_args, finetuning_args, is_trainable=True)
for name, param in model.named_parameters():
if any(key in name for key in ["visual.patch_embed", "visual.blocks"]):
assert param.requires_grad != freeze_vision_tower
elif "visual.merger" in name:
assert param.requires_grad != freeze_multi_modal_projector
else:
assert param.requires_grad != freeze_language_model
@pytest.mark.parametrize("freeze_vision_tower,freeze_language_model", ((False, False), (False, True), (True, False)))
def test_visual_lora(freeze_vision_tower: bool, freeze_language_model: bool):
model_args = ModelArguments(model_name_or_path="Qwen/Qwen2-VL-2B-Instruct")
finetuning_args = FinetuningArguments(
finetuning_type="lora", freeze_vision_tower=freeze_vision_tower, freeze_language_model=freeze_language_model
)
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
with torch.device("meta"):
model = AutoModelForVision2Seq.from_config(config)
model = init_adapter(config, model, model_args, finetuning_args, is_trainable=True)
trainable_params, frozen_params = set(), set()
for name, param in model.named_parameters():
if param.requires_grad:
trainable_params.add(name)
else:
frozen_params.add(name)
if is_transformers_version_greater_than("4.52.0"):
visual_param_name = "base_model.model.model.visual.blocks.0.attn.qkv.lora_A.default.weight"
language_param_name = "base_model.model.model.language_model.layers.0.self_attn.q_proj.lora_A.default.weight"
merger_param_name = "base_model.model.model.visual.merger.lora_A.default.weight"
else:
visual_param_name = "base_model.model.visual.blocks.0.attn.qkv.lora_A.default.weight"
language_param_name = "base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight"
merger_param_name = "base_model.model.visual.merger.lora_A.default.weight"
assert (visual_param_name in trainable_params) != freeze_vision_tower
assert (language_param_name in trainable_params) != freeze_language_model
assert (merger_param_name in trainable_params) is False
def test_visual_model_save_load():
# check VLM's state dict: https://github.com/huggingface/transformers/pull/38385
model_args = ModelArguments(model_name_or_path="Qwen/Qwen2-VL-2B-Instruct")
finetuning_args = FinetuningArguments(finetuning_type="full")
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
with torch.device("meta"):
model = AutoModelForVision2Seq.from_config(config)
model = init_adapter(config, model, model_args, finetuning_args, is_trainable=False)
loaded_model_weight = dict(model.named_parameters())
model.save_pretrained(os.path.join("output", "qwen2_vl"), max_shard_size="10GB", safe_serialization=False)
saved_model_weight = torch.load(os.path.join("output", "qwen2_vl", "pytorch_model.bin"), weights_only=False)
if is_transformers_version_greater_than("4.52.0"):
assert "model.language_model.layers.0.self_attn.q_proj.weight" in loaded_model_weight
else:
assert "model.layers.0.self_attn.q_proj.weight" in loaded_model_weight
assert "model.layers.0.self_attn.q_proj.weight" in saved_model_weight
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
from llamafactory.train.test_utils import compare_model, load_infer_model, load_reference_model, patch_valuehead_model
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
TINY_LLAMA_VALUEHEAD = os.getenv("TINY_LLAMA_VALUEHEAD", "llamafactory/tiny-random-Llama-3-valuehead")
INFER_ARGS = {
"model_name_or_path": TINY_LLAMA3,
"template": "llama3",
"infer_dtype": "float16",
}
@pytest.fixture
def fix_valuehead_cpu_loading():
patch_valuehead_model()
def test_base():
model = load_infer_model(**INFER_ARGS)
ref_model = load_reference_model(TINY_LLAMA3)
compare_model(model, ref_model)
@pytest.mark.usefixtures("fix_valuehead_cpu_loading")
def test_valuehead():
model = load_infer_model(add_valuehead=True, **INFER_ARGS)
ref_model = load_reference_model(TINY_LLAMA_VALUEHEAD, add_valuehead=True)
compare_model(model, ref_model)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
from llamafactory.train.test_utils import load_infer_model, load_train_model
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA3,
"stage": "sft",
"do_train": True,
"finetuning_type": "freeze",
"dataset": "llamafactory/tiny-supervised-dataset",
"dataset_dir": "ONLINE",
"template": "llama3",
"cutoff_len": 1024,
"output_dir": "dummy_dir",
"overwrite_output_dir": True,
"fp16": True,
}
INFER_ARGS = {
"model_name_or_path": TINY_LLAMA3,
"finetuning_type": "freeze",
"template": "llama3",
"infer_dtype": "float16",
}
def test_freeze_train_all_modules():
model = load_train_model(freeze_trainable_layers=1, **TRAIN_ARGS)
for name, param in model.named_parameters():
if name.startswith("model.layers.1."):
assert param.requires_grad is True
assert param.dtype == torch.float32
else:
assert param.requires_grad is False
assert param.dtype == torch.float16
def test_freeze_train_extra_modules():
model = load_train_model(freeze_trainable_layers=1, freeze_extra_modules="embed_tokens,lm_head", **TRAIN_ARGS)
for name, param in model.named_parameters():
if name.startswith("model.layers.1.") or any(module in name for module in ["embed_tokens", "lm_head"]):
assert param.requires_grad is True
assert param.dtype == torch.float32
else:
assert param.requires_grad is False
assert param.dtype == torch.float16
def test_freeze_inference():
model = load_infer_model(**INFER_ARGS)
for param in model.parameters():
assert param.requires_grad is False
assert param.dtype == torch.float16
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
from llamafactory.train.test_utils import load_infer_model, load_train_model
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA3,
"stage": "sft",
"do_train": True,
"finetuning_type": "full",
"dataset": "llamafactory/tiny-supervised-dataset",
"dataset_dir": "ONLINE",
"template": "llama3",
"cutoff_len": 1024,
"output_dir": "dummy_dir",
"overwrite_output_dir": True,
"fp16": True,
}
INFER_ARGS = {
"model_name_or_path": TINY_LLAMA3,
"finetuning_type": "full",
"template": "llama3",
"infer_dtype": "float16",
}
def test_full_train():
model = load_train_model(**TRAIN_ARGS)
for param in model.parameters():
assert param.requires_grad is True
assert param.dtype == torch.float32
def test_full_inference():
model = load_infer_model(**INFER_ARGS)
for param in model.parameters():
assert param.requires_grad is False
assert param.dtype == torch.float16
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
import torch
from llamafactory.train.test_utils import (
check_lora_model,
compare_model,
load_infer_model,
load_reference_model,
load_train_model,
patch_valuehead_model,
)
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
TINY_LLAMA_ADAPTER = os.getenv("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-lora")
TINY_LLAMA_VALUEHEAD = os.getenv("TINY_LLAMA_VALUEHEAD", "llamafactory/tiny-random-Llama-3-valuehead")
TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA3,
"stage": "sft",
"do_train": True,
"finetuning_type": "lora",
"dataset": "llamafactory/tiny-supervised-dataset",
"dataset_dir": "ONLINE",
"template": "llama3",
"cutoff_len": 1024,
"output_dir": "dummy_dir",
"overwrite_output_dir": True,
"fp16": True,
}
INFER_ARGS = {
"model_name_or_path": TINY_LLAMA3,
"adapter_name_or_path": TINY_LLAMA_ADAPTER,
"finetuning_type": "lora",
"template": "llama3",
"infer_dtype": "float16",
}
@pytest.fixture
def fix_valuehead_cpu_loading():
patch_valuehead_model()
def test_lora_train_qv_modules():
model = load_train_model(lora_target="q_proj,v_proj", **TRAIN_ARGS)
linear_modules, _ = check_lora_model(model)
assert linear_modules == {"q_proj", "v_proj"}
def test_lora_train_all_modules():
model = load_train_model(lora_target="all", **TRAIN_ARGS)
linear_modules, _ = check_lora_model(model)
assert linear_modules == {"q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"}
def test_lora_train_extra_modules():
model = load_train_model(additional_target="embed_tokens,lm_head", **TRAIN_ARGS)
_, extra_modules = check_lora_model(model)
assert extra_modules == {"embed_tokens", "lm_head"}
def test_lora_train_old_adapters():
model = load_train_model(adapter_name_or_path=TINY_LLAMA_ADAPTER, create_new_adapter=False, **TRAIN_ARGS)
ref_model = load_reference_model(TINY_LLAMA3, TINY_LLAMA_ADAPTER, use_lora=True, is_trainable=True)
compare_model(model, ref_model)
def test_lora_train_new_adapters():
model = load_train_model(adapter_name_or_path=TINY_LLAMA_ADAPTER, create_new_adapter=True, **TRAIN_ARGS)
ref_model = load_reference_model(TINY_LLAMA3, TINY_LLAMA_ADAPTER, use_lora=True, is_trainable=True)
compare_model(
model, ref_model, diff_keys=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"]
)
@pytest.mark.usefixtures("fix_valuehead_cpu_loading")
def test_lora_train_valuehead():
model = load_train_model(add_valuehead=True, **TRAIN_ARGS)
ref_model = load_reference_model(TINY_LLAMA_VALUEHEAD, is_trainable=True, add_valuehead=True)
state_dict = model.state_dict()
ref_state_dict = ref_model.state_dict()
assert torch.allclose(state_dict["v_head.summary.weight"], ref_state_dict["v_head.summary.weight"])
assert torch.allclose(state_dict["v_head.summary.bias"], ref_state_dict["v_head.summary.bias"])
def test_lora_inference():
model = load_infer_model(**INFER_ARGS)
ref_model = load_reference_model(TINY_LLAMA3, TINY_LLAMA_ADAPTER, use_lora=True).merge_and_unload()
compare_model(model, ref_model)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
from llamafactory.train.test_utils import compare_model, load_infer_model, load_reference_model, load_train_model
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
TINY_LLAMA_PISSA = os.getenv("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-pissa")
TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA3,
"stage": "sft",
"do_train": True,
"finetuning_type": "lora",
"pissa_init": True,
"pissa_iter": -1,
"dataset": "llamafactory/tiny-supervised-dataset",
"dataset_dir": "ONLINE",
"template": "llama3",
"cutoff_len": 1024,
"output_dir": "dummy_dir",
"overwrite_output_dir": True,
"fp16": True,
}
INFER_ARGS = {
"model_name_or_path": TINY_LLAMA_PISSA,
"adapter_name_or_path": TINY_LLAMA_PISSA,
"adapter_folder": "pissa_init",
"finetuning_type": "lora",
"template": "llama3",
"infer_dtype": "float16",
}
@pytest.mark.xfail(reason="PiSSA initialization is not stable in different platform.")
def test_pissa_train():
model = load_train_model(**TRAIN_ARGS)
ref_model = load_reference_model(TINY_LLAMA_PISSA, TINY_LLAMA_PISSA, use_pissa=True, is_trainable=True)
compare_model(model, ref_model)
@pytest.mark.xfail(reason="Known connection error.")
def test_pissa_inference():
model = load_infer_model(**INFER_ARGS)
ref_model = load_reference_model(TINY_LLAMA_PISSA, TINY_LLAMA_PISSA, use_pissa=True, is_trainable=False)
ref_model = ref_model.merge_and_unload()
compare_model(model, ref_model)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment