Commit ca625f43 authored by shihm's avatar shihm
Browse files

uodata

parent 7164651d
...@@ -56,10 +56,17 @@ TEXT_MESSAGES = [ ...@@ -56,10 +56,17 @@ TEXT_MESSAGES = [
{"role": "assistant", "content": "I am fine!"}, {"role": "assistant", "content": "I am fine!"},
] ]
VIDEO_MESSAGES = [
{"role": "user", "content": "<video>What is in this viode?"},
{"role": "assistant", "content": "A cat."},
]
AUDIOS = [np.zeros(1600)] AUDIOS = [np.zeros(1600)]
IMAGES = [Image.new("RGB", (32, 32), (255, 255, 255))] IMAGES = [Image.new("RGB", (32, 32), (255, 255, 255))]
VIDEOS = [[Image.new("RGB", (32, 32), (255, 255, 255))] * 4]
NO_IMAGES = [] NO_IMAGES = []
NO_VIDEOS = [] NO_VIDEOS = []
...@@ -145,6 +152,8 @@ def _check_plugin( ...@@ -145,6 +152,8 @@ def _check_plugin(
plugin.get_mm_inputs(IMAGES, NO_VIDEOS, AUDIOS, IMGLENS, NO_VIDLENS, AUDLENS, BATCH_IDS, processor), plugin.get_mm_inputs(IMAGES, NO_VIDEOS, AUDIOS, IMGLENS, NO_VIDLENS, AUDLENS, BATCH_IDS, processor),
expected_mm_inputs, expected_mm_inputs,
) )
elif plugin.__class__.__name__ == "Qwen3VLPlugin": # only check replacement
assert plugin.process_messages(VIDEO_MESSAGES, NO_IMAGES, VIDEOS, NO_AUDIOS, processor) == expected_mm_messages
elif plugin.__class__.__name__ != "BasePlugin": # test mm_messages elif plugin.__class__.__name__ != "BasePlugin": # test mm_messages
assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == expected_mm_messages assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == expected_mm_messages
assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, NO_AUDIOS, tokenizer, processor) == ( assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, NO_AUDIOS, tokenizer, processor) == (
...@@ -170,6 +179,7 @@ def _check_plugin( ...@@ -170,6 +179,7 @@ def _check_plugin(
) )
@pytest.mark.runs_on(["cpu", "mps"])
def test_base_plugin(): def test_base_plugin():
tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA3) tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA3)
base_plugin = get_mm_plugin(name="base") base_plugin = get_mm_plugin(name="base")
...@@ -177,6 +187,7 @@ def test_base_plugin(): ...@@ -177,6 +187,7 @@ def test_base_plugin():
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") @pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
@pytest.mark.skipif(not is_transformers_version_greater_than("4.50.0"), reason="Requires transformers>=4.50.0") @pytest.mark.skipif(not is_transformers_version_greater_than("4.50.0"), reason="Requires transformers>=4.50.0")
def test_gemma3_plugin(): def test_gemma3_plugin():
...@@ -199,6 +210,7 @@ def test_gemma3_plugin(): ...@@ -199,6 +210,7 @@ def test_gemma3_plugin():
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not is_transformers_version_greater_than("4.52.0"), reason="Requires transformers>=4.52.0") @pytest.mark.skipif(not is_transformers_version_greater_than("4.52.0"), reason="Requires transformers>=4.52.0")
def test_internvl_plugin(): def test_internvl_plugin():
image_seqlen = 256 image_seqlen = 256
...@@ -217,6 +229,7 @@ def test_internvl_plugin(): ...@@ -217,6 +229,7 @@ def test_internvl_plugin():
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not is_transformers_version_greater_than("4.51.0"), reason="Requires transformers>=4.51.0") @pytest.mark.skipif(not is_transformers_version_greater_than("4.51.0"), reason="Requires transformers>=4.51.0")
def test_llama4_plugin(): def test_llama4_plugin():
tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA4) tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA4)
...@@ -238,6 +251,7 @@ def test_llama4_plugin(): ...@@ -238,6 +251,7 @@ def test_llama4_plugin():
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu", "mps"])
def test_llava_plugin(): def test_llava_plugin():
image_seqlen = 576 image_seqlen = 576
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf") tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")
...@@ -251,6 +265,7 @@ def test_llava_plugin(): ...@@ -251,6 +265,7 @@ def test_llava_plugin():
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu", "mps"])
def test_llava_next_plugin(): def test_llava_next_plugin():
image_seqlen = 1176 image_seqlen = 1176
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-v1.6-vicuna-7b-hf") tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-v1.6-vicuna-7b-hf")
...@@ -264,6 +279,7 @@ def test_llava_next_plugin(): ...@@ -264,6 +279,7 @@ def test_llava_next_plugin():
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu", "mps"])
def test_llava_next_video_plugin(): def test_llava_next_video_plugin():
image_seqlen = 1176 image_seqlen = 1176
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/LLaVA-NeXT-Video-7B-hf") tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/LLaVA-NeXT-Video-7B-hf")
...@@ -277,6 +293,7 @@ def test_llava_next_video_plugin(): ...@@ -277,6 +293,7 @@ def test_llava_next_video_plugin():
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") @pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
def test_paligemma_plugin(): def test_paligemma_plugin():
image_seqlen = 256 image_seqlen = 256
...@@ -296,6 +313,7 @@ def test_paligemma_plugin(): ...@@ -296,6 +313,7 @@ def test_paligemma_plugin():
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not is_transformers_version_greater_than("4.50.0"), reason="Requires transformers>=4.50.0") @pytest.mark.skipif(not is_transformers_version_greater_than("4.50.0"), reason="Requires transformers>=4.50.0")
def test_pixtral_plugin(): def test_pixtral_plugin():
image_slice_height, image_slice_width = 2, 2 image_slice_height, image_slice_width = 2, 2
...@@ -318,12 +336,20 @@ def test_pixtral_plugin(): ...@@ -318,12 +336,20 @@ def test_pixtral_plugin():
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not is_transformers_version_greater_than("4.52.0"), reason="Requires transformers>=4.52.0") @pytest.mark.skipif(not is_transformers_version_greater_than("4.52.0"), reason="Requires transformers>=4.52.0")
def test_qwen2_omni_plugin(): def test_qwen2_omni_plugin():
image_seqlen, audio_seqlen = 4, 2 image_seqlen, audio_seqlen = 4, 2
tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2.5-Omni-7B") tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2.5-Omni-7B")
qwen2_omni_plugin = get_mm_plugin( qwen2_omni_plugin = get_mm_plugin(
name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>" name="qwen2_omni",
image_token="<|IMAGE|>",
video_token="<|VIDEO|>",
audio_token="<|AUDIO|>",
vision_bos_token="<|vision_bos|>",
vision_eos_token="<|vision_eos|>",
audio_bos_token="<|audio_bos|>",
audio_eos_token="<|audio_eos|>",
) )
check_inputs = {"plugin": qwen2_omni_plugin, **tokenizer_module} check_inputs = {"plugin": qwen2_omni_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [ check_inputs["expected_mm_messages"] = [
...@@ -341,6 +367,7 @@ def test_qwen2_omni_plugin(): ...@@ -341,6 +367,7 @@ def test_qwen2_omni_plugin():
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu", "mps"])
def test_qwen2_vl_plugin(): def test_qwen2_vl_plugin():
image_seqlen = 4 image_seqlen = 4
tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct") tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
...@@ -357,6 +384,29 @@ def test_qwen2_vl_plugin(): ...@@ -357,6 +384,29 @@ def test_qwen2_vl_plugin():
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not is_transformers_version_greater_than("4.57.0"), reason="Requires transformers>=4.57.0")
def test_qwen3_vl_plugin():
frame_seqlen = 1
tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen3-VL-30B-A3B-Instruct")
qwen3_vl_plugin = get_mm_plugin(name="qwen3_vl", video_token="<|video_pad|>")
check_inputs = {"plugin": qwen3_vl_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [
{
key: value.replace(
"<video>", # little different with original processor for default `fps=2` in our repo
"<0.2 seconds><|vision_start|>{}<|vision_end|><1.2 seconds><|vision_start|>{}<|vision_end|>".format(
"<|video_pad|>" * frame_seqlen, "<|video_pad|>" * frame_seqlen
),
)
for key, value in message.items()
}
for message in VIDEO_MESSAGES
]
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not is_transformers_version_greater_than("4.47.0"), reason="Requires transformers>=4.47.0") @pytest.mark.skipif(not is_transformers_version_greater_than("4.47.0"), reason="Requires transformers>=4.47.0")
def test_video_llava_plugin(): def test_video_llava_plugin():
image_seqlen = 256 image_seqlen = 256
......
...@@ -89,6 +89,7 @@ def _check_template( ...@@ -89,6 +89,7 @@ def _check_template(
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str)) _check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False]) @pytest.mark.parametrize("use_fast", [True, False])
def test_encode_oneturn(use_fast: bool): def test_encode_oneturn(use_fast: bool):
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast) tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
...@@ -104,6 +105,7 @@ def test_encode_oneturn(use_fast: bool): ...@@ -104,6 +105,7 @@ def test_encode_oneturn(use_fast: bool):
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str)) _check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False]) @pytest.mark.parametrize("use_fast", [True, False])
def test_encode_multiturn(use_fast: bool): def test_encode_multiturn(use_fast: bool):
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast) tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
...@@ -125,6 +127,7 @@ def test_encode_multiturn(use_fast: bool): ...@@ -125,6 +127,7 @@ def test_encode_multiturn(use_fast: bool):
) )
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False]) @pytest.mark.parametrize("use_fast", [True, False])
@pytest.mark.parametrize("cot_messages", [True, False]) @pytest.mark.parametrize("cot_messages", [True, False])
@pytest.mark.parametrize("enable_thinking", [True, False, None]) @pytest.mark.parametrize("enable_thinking", [True, False, None])
...@@ -151,6 +154,7 @@ def test_reasoning_encode_oneturn(use_fast: bool, cot_messages: bool, enable_thi ...@@ -151,6 +154,7 @@ def test_reasoning_encode_oneturn(use_fast: bool, cot_messages: bool, enable_thi
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str)) _check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False]) @pytest.mark.parametrize("use_fast", [True, False])
@pytest.mark.parametrize("cot_messages", [True, False]) @pytest.mark.parametrize("cot_messages", [True, False])
@pytest.mark.parametrize("enable_thinking", [True, False, None]) @pytest.mark.parametrize("enable_thinking", [True, False, None])
...@@ -180,6 +184,7 @@ def test_reasoning_encode_multiturn(use_fast: bool, cot_messages: bool, enable_t ...@@ -180,6 +184,7 @@ def test_reasoning_encode_multiturn(use_fast: bool, cot_messages: bool, enable_t
) )
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False]) @pytest.mark.parametrize("use_fast", [True, False])
def test_jinja_template(use_fast: bool): def test_jinja_template(use_fast: bool):
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast) tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, use_fast=use_fast)
...@@ -190,6 +195,7 @@ def test_jinja_template(use_fast: bool): ...@@ -190,6 +195,7 @@ def test_jinja_template(use_fast: bool):
assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES) assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES)
@pytest.mark.runs_on(["cpu", "mps"])
def test_ollama_modelfile(): def test_ollama_modelfile():
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3) tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3")) template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
...@@ -207,12 +213,14 @@ def test_ollama_modelfile(): ...@@ -207,12 +213,14 @@ def test_ollama_modelfile():
) )
@pytest.mark.runs_on(["cpu", "mps"])
def test_get_stop_token_ids(): def test_get_stop_token_ids():
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3) tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3")) template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
assert set(template.get_stop_token_ids(tokenizer)) == {128008, 128009} assert set(template.get_stop_token_ids(tokenizer)) == {128008, 128009}
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") @pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
@pytest.mark.parametrize("use_fast", [True, False]) @pytest.mark.parametrize("use_fast", [True, False])
def test_gemma_template(use_fast: bool): def test_gemma_template(use_fast: bool):
...@@ -226,6 +234,7 @@ def test_gemma_template(use_fast: bool): ...@@ -226,6 +234,7 @@ def test_gemma_template(use_fast: bool):
_check_template("google/gemma-3-4b-it", "gemma", prompt_str, answer_str, use_fast) _check_template("google/gemma-3-4b-it", "gemma", prompt_str, answer_str, use_fast)
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") @pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
@pytest.mark.parametrize("use_fast", [True, False]) @pytest.mark.parametrize("use_fast", [True, False])
def test_gemma2_template(use_fast: bool): def test_gemma2_template(use_fast: bool):
...@@ -239,6 +248,7 @@ def test_gemma2_template(use_fast: bool): ...@@ -239,6 +248,7 @@ def test_gemma2_template(use_fast: bool):
_check_template("google/gemma-2-2b-it", "gemma2", prompt_str, answer_str, use_fast) _check_template("google/gemma-2-2b-it", "gemma2", prompt_str, answer_str, use_fast)
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") @pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
@pytest.mark.parametrize("use_fast", [True, False]) @pytest.mark.parametrize("use_fast", [True, False])
def test_llama3_template(use_fast: bool): def test_llama3_template(use_fast: bool):
...@@ -252,6 +262,7 @@ def test_llama3_template(use_fast: bool): ...@@ -252,6 +262,7 @@ def test_llama3_template(use_fast: bool):
_check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str, use_fast) _check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str, use_fast)
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"use_fast", [True, pytest.param(False, marks=pytest.mark.xfail(reason="Llama 4 has no slow tokenizer."))] "use_fast", [True, pytest.param(False, marks=pytest.mark.xfail(reason="Llama 4 has no slow tokenizer."))]
) )
...@@ -273,6 +284,7 @@ def test_llama4_template(use_fast: bool): ...@@ -273,6 +284,7 @@ def test_llama4_template(use_fast: bool):
pytest.param(False, marks=pytest.mark.xfail(reason="Phi-4 slow tokenizer is broken.")), pytest.param(False, marks=pytest.mark.xfail(reason="Phi-4 slow tokenizer is broken.")),
], ],
) )
@pytest.mark.runs_on(["cpu", "mps"])
def test_phi4_template(use_fast: bool): def test_phi4_template(use_fast: bool):
prompt_str = ( prompt_str = (
f"<|im_start|>user<|im_sep|>{MESSAGES[0]['content']}<|im_end|>" f"<|im_start|>user<|im_sep|>{MESSAGES[0]['content']}<|im_end|>"
...@@ -284,6 +296,7 @@ def test_phi4_template(use_fast: bool): ...@@ -284,6 +296,7 @@ def test_phi4_template(use_fast: bool):
_check_template("microsoft/phi-4", "phi4", prompt_str, answer_str, use_fast) _check_template("microsoft/phi-4", "phi4", prompt_str, answer_str, use_fast)
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.") @pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
@pytest.mark.parametrize("use_fast", [True, False]) @pytest.mark.parametrize("use_fast", [True, False])
def test_qwen2_5_template(use_fast: bool): def test_qwen2_5_template(use_fast: bool):
...@@ -298,6 +311,7 @@ def test_qwen2_5_template(use_fast: bool): ...@@ -298,6 +311,7 @@ def test_qwen2_5_template(use_fast: bool):
_check_template("Qwen/Qwen2.5-7B-Instruct", "qwen", prompt_str, answer_str, use_fast) _check_template("Qwen/Qwen2.5-7B-Instruct", "qwen", prompt_str, answer_str, use_fast)
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("use_fast", [True, False]) @pytest.mark.parametrize("use_fast", [True, False])
@pytest.mark.parametrize("cot_messages", [True, False]) @pytest.mark.parametrize("cot_messages", [True, False])
def test_qwen3_template(use_fast: bool, cot_messages: bool): def test_qwen3_template(use_fast: bool, cot_messages: bool):
...@@ -317,6 +331,7 @@ def test_qwen3_template(use_fast: bool, cot_messages: bool): ...@@ -317,6 +331,7 @@ def test_qwen3_template(use_fast: bool, cot_messages: bool):
_check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, use_fast, messages=messages) _check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, use_fast, messages=messages)
@pytest.mark.runs_on(["cpu", "mps"])
def test_parse_llama3_template(): def test_parse_llama3_template():
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, token=HF_TOKEN) tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, token=HF_TOKEN)
template = parse_template(tokenizer) template = parse_template(tokenizer)
...@@ -330,6 +345,7 @@ def test_parse_llama3_template(): ...@@ -330,6 +345,7 @@ def test_parse_llama3_template():
assert template.default_system == "" assert template.default_system == ""
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.") @pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
def test_parse_qwen_template(): def test_parse_qwen_template():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", token=HF_TOKEN) tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", token=HF_TOKEN)
...@@ -342,6 +358,7 @@ def test_parse_qwen_template(): ...@@ -342,6 +358,7 @@ def test_parse_qwen_template():
assert template.default_system == "You are Qwen, created by Alibaba Cloud. You are a helpful assistant." assert template.default_system == "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.xfail(not HF_TOKEN, reason="Authorization.") @pytest.mark.xfail(not HF_TOKEN, reason="Authorization.")
def test_parse_qwen3_template(): def test_parse_qwen3_template():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", token=HF_TOKEN) tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", token=HF_TOKEN)
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
import os import os
import pytest
from llamafactory.chat import ChatModel from llamafactory.chat import ChatModel
...@@ -35,11 +37,13 @@ MESSAGES = [ ...@@ -35,11 +37,13 @@ MESSAGES = [
EXPECTED_RESPONSE = "_rho" EXPECTED_RESPONSE = "_rho"
@pytest.mark.runs_on(["cpu", "mps"])
def test_chat(): def test_chat():
chat_model = ChatModel(INFER_ARGS) chat_model = ChatModel(INFER_ARGS)
assert chat_model.chat(MESSAGES)[0].response_text == EXPECTED_RESPONSE assert chat_model.chat(MESSAGES)[0].response_text == EXPECTED_RESPONSE
@pytest.mark.runs_on(["cpu", "mps"])
def test_stream_chat(): def test_stream_chat():
chat_model = ChatModel(INFER_ARGS) chat_model = ChatModel(INFER_ARGS)
response = "" response = ""
......
...@@ -39,6 +39,7 @@ MESSAGES = [ ...@@ -39,6 +39,7 @@ MESSAGES = [
] ]
@pytest.mark.runs_on(["cuda"])
@pytest.mark.skipif(not is_sglang_available(), reason="SGLang is not installed") @pytest.mark.skipif(not is_sglang_available(), reason="SGLang is not installed")
def test_chat(): def test_chat():
r"""Test the SGLang engine's basic chat functionality.""" r"""Test the SGLang engine's basic chat functionality."""
...@@ -48,6 +49,7 @@ def test_chat(): ...@@ -48,6 +49,7 @@ def test_chat():
print(response.response_text) print(response.response_text)
@pytest.mark.runs_on(["cuda"])
@pytest.mark.skipif(not is_sglang_available(), reason="SGLang is not installed") @pytest.mark.skipif(not is_sglang_available(), reason="SGLang is not installed")
def test_stream_chat(): def test_stream_chat():
r"""Test the SGLang engine's streaming chat functionality.""" r"""Test the SGLang engine's streaming chat functionality."""
......
...@@ -49,6 +49,7 @@ INFER_ARGS = { ...@@ -49,6 +49,7 @@ INFER_ARGS = {
OS_NAME = os.getenv("OS_NAME", "") OS_NAME = os.getenv("OS_NAME", "")
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"stage,dataset", "stage,dataset",
[ [
...@@ -65,6 +66,7 @@ def test_run_exp(stage: str, dataset: str): ...@@ -65,6 +66,7 @@ def test_run_exp(stage: str, dataset: str):
assert os.path.exists(output_dir) assert os.path.exists(output_dir)
@pytest.mark.runs_on(["cpu", "mps"])
def test_export(): def test_export():
export_dir = os.path.join("output", "llama3_export") export_dir = os.path.join("output", "llama3_export")
export_model({"export_dir": export_dir, **INFER_ARGS}) export_model({"export_dir": export_dir, **INFER_ARGS})
......
...@@ -12,9 +12,12 @@ ...@@ -12,9 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import pytest
from llamafactory.eval.template import get_eval_template from llamafactory.eval.template import get_eval_template
@pytest.mark.runs_on(["cpu", "mps"])
def test_eval_template_en(): def test_eval_template_en():
support_set = [ support_set = [
{ {
...@@ -53,6 +56,7 @@ def test_eval_template_en(): ...@@ -53,6 +56,7 @@ def test_eval_template_en():
] ]
@pytest.mark.runs_on(["cpu", "mps"])
def test_eval_template_zh(): def test_eval_template_zh():
support_set = [ support_set = [
{ {
......
...@@ -15,7 +15,17 @@ ...@@ -15,7 +15,17 @@
import os import os
import pytest import pytest
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available from transformers.utils import is_flash_attn_2_available
# Compatible with Transformers v4 and Transformers v5
try:
from transformers.utils import is_torch_sdpa_available
except ImportError:
def is_torch_sdpa_available():
return True
from llamafactory.extras.packages import is_transformers_version_greater_than from llamafactory.extras.packages import is_transformers_version_greater_than
from llamafactory.train.test_utils import load_infer_model from llamafactory.train.test_utils import load_infer_model
......
...@@ -16,7 +16,7 @@ import os ...@@ -16,7 +16,7 @@ import os
import pytest import pytest
from llamafactory.train.test_utils import compare_model, load_infer_model, load_reference_model, patch_valuehead_model from llamafactory.train.test_utils import compare_model, load_infer_model, load_reference_model
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3") TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
...@@ -30,11 +30,6 @@ INFER_ARGS = { ...@@ -30,11 +30,6 @@ INFER_ARGS = {
} }
@pytest.fixture
def fix_valuehead_cpu_loading():
patch_valuehead_model()
def test_base(): def test_base():
model = load_infer_model(**INFER_ARGS) model = load_infer_model(**INFER_ARGS)
ref_model = load_reference_model(TINY_LLAMA3) ref_model = load_reference_model(TINY_LLAMA3)
......
...@@ -23,7 +23,6 @@ from llamafactory.train.test_utils import ( ...@@ -23,7 +23,6 @@ from llamafactory.train.test_utils import (
load_infer_model, load_infer_model,
load_reference_model, load_reference_model,
load_train_model, load_train_model,
patch_valuehead_model,
) )
...@@ -56,11 +55,6 @@ INFER_ARGS = { ...@@ -56,11 +55,6 @@ INFER_ARGS = {
} }
@pytest.fixture
def fix_valuehead_cpu_loading():
patch_valuehead_model()
def test_lora_train_qv_modules(): def test_lora_train_qv_modules():
model = load_train_model(lora_target="q_proj,v_proj", **TRAIN_ARGS) model = load_train_model(lora_target="q_proj,v_proj", **TRAIN_ARGS)
linear_modules, _ = check_lora_model(model) linear_modules, _ = check_lora_model(model)
......
# change if test fails or cache is outdated # change if test fails or cache is outdated
0.9.4.100 0.9.4.105
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
import torch.multiprocessing as mp
from llamafactory.v1.accelerator.helper import ReduceOp
from llamafactory.v1.accelerator.interface import DistributedInterface
from llamafactory.v1.utils.env import find_available_port
from llamafactory.v1.utils.pytest import dist_env
def _all_reduce_tests(local_rank: int, world_size: int, master_port: int):
with dist_env(local_rank, world_size, master_port):
rank = DistributedInterface().get_rank()
world_size = DistributedInterface().get_world_size()
assert world_size == 2
y_sum = DistributedInterface().all_reduce(rank + 1.0, op=ReduceOp.SUM)
assert y_sum == pytest.approx(3.0)
y_mean = DistributedInterface().all_reduce(rank + 1.0, op=ReduceOp.MEAN)
assert y_mean == pytest.approx(1.5)
y_max = DistributedInterface().all_reduce(rank + 1.0, op=ReduceOp.MAX)
assert y_max == pytest.approx(2.0)
z = DistributedInterface().all_gather(rank + 1.0)
assert z == pytest.approx([1.0, 2.0])
z = DistributedInterface().broadcast(rank + 1.0)
assert z == pytest.approx(1.0)
def test_all_device():
assert DistributedInterface().get_rank() == int(os.getenv("RANK", "0"))
assert DistributedInterface().get_world_size() == int(os.getenv("WORLD_SIZE", "1"))
assert DistributedInterface().get_local_rank() == int(os.getenv("LOCAL_RANK", "0"))
assert DistributedInterface().get_local_world_size() == int(os.getenv("LOCAL_WORLD_SIZE", "1"))
@pytest.mark.runs_on(["cuda", "npu"])
@pytest.mark.require_distributed(2)
def test_multi_device():
master_port = find_available_port()
mp.spawn(_all_reduce_tests, args=(2, master_port), nprocs=2)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pathlib
import sys
from unittest.mock import patch
from llamafactory.v1.config.arg_parser import get_args
def test_get_args_from_yaml(tmp_path: pathlib.Path):
config_yaml = """
### model
model: "llamafactory/tiny-random-qwen2.5"
trust_remote_code: true
use_fast_processor: true
model_class: "llm"
kernel_config:
name: "auto"
include_kernels: "auto" # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
peft_config:
name: "lora"
lora_rank: 0.8
quant_config: null
### data
dataset: "llamafactory/tiny-supervised-dataset"
cutoff_len: 2048
### training
output_dir: "outputs/test_run"
micro_batch_size: 1
global_batch_size: 1
learning_rate: 1.0e-4
bf16: false
dist_config: null
### sample
sample_backend: "hf"
max_new_tokens: 128
"""
config_file = tmp_path / "config.yaml"
config_file.write_text(config_yaml, encoding="utf-8")
test_argv = ["test_args_parser.py", str(config_file)]
with patch.object(sys, "argv", test_argv):
data_args, model_args, training_args, sample_args = get_args()
assert training_args.output_dir == "outputs/test_run"
assert training_args.micro_batch_size == 1
assert training_args.global_batch_size == 1
assert training_args.learning_rate == 1.0e-4
assert training_args.bf16 is False
assert training_args.dist_config is None
assert model_args.model == "llamafactory/tiny-random-qwen2.5"
assert model_args.kernel_config.name == "auto"
assert model_args.kernel_config.get("include_kernels") == "auto"
assert model_args.peft_config.name == "lora"
assert model_args.peft_config.get("lora_rank") == 0.8
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""LLaMA-Factory test configuration.
Contains shared fixtures, pytest configuration, and custom markers.
"""
import os
import sys
import pytest
import torch
from pytest import Config, FixtureRequest, Item, MonkeyPatch
from llamafactory.v1.accelerator.helper import get_current_accelerator, get_device_count
from llamafactory.v1.utils.env import is_env_enabled
from llamafactory.v1.utils.packages import is_transformers_version_greater_than
CURRENT_DEVICE = get_current_accelerator().type
def pytest_configure(config: Config):
"""Register custom pytest markers."""
config.addinivalue_line(
"markers",
"slow: marks tests as slow (deselect with '-m \"not slow\"' or set RUN_SLOW=1 to run)",
)
config.addinivalue_line(
"markers",
"runs_on: test requires specific device type, e.g., @pytest.mark.runs_on(['cuda'])",
)
config.addinivalue_line(
"markers",
"require_distributed(num_devices): allow multi-device execution (default: 2)",
)
def _handle_runs_on(items: list[Item]):
"""Skip tests on specified device TYPES (cpu/cuda/npu)."""
for item in items:
marker = item.get_closest_marker("runs_on")
if not marker:
continue
devices = marker.args[0]
if isinstance(devices, str):
devices = [devices]
if CURRENT_DEVICE not in devices:
item.add_marker(pytest.mark.skip(reason=f"test requires one of {devices} (current: {CURRENT_DEVICE})"))
def _handle_slow_tests(items: list[Item]):
"""Skip slow tests unless RUN_SLOW is enabled."""
if not is_env_enabled("RUN_SLOW"):
skip_slow = pytest.mark.skip(reason="slow test (set RUN_SLOW=1 to run)")
for item in items:
if "slow" in item.keywords:
item.add_marker(skip_slow)
def _get_visible_devices_env() -> str | None:
"""Return device visibility env var name."""
if CURRENT_DEVICE == "cuda":
return "CUDA_VISIBLE_DEVICES"
elif CURRENT_DEVICE == "npu":
return "ASCEND_RT_VISIBLE_DEVICES"
else:
return None
def _handle_device_visibility(items: list[Item]):
"""Handle device visibility based on test markers."""
env_key = _get_visible_devices_env()
if env_key is None or CURRENT_DEVICE in ("cpu", "mps"):
return
# Parse visible devices
visible_devices_env = os.environ.get(env_key)
if visible_devices_env is None:
available = get_device_count()
else:
visible_devices = [v for v in visible_devices_env.split(",") if v != ""]
available = len(visible_devices)
for item in items:
marker = item.get_closest_marker("require_distributed")
if not marker:
continue
required = marker.args[0] if marker.args else 2
if available < required:
item.add_marker(pytest.mark.skip(reason=f"test requires {required} devices, but only {available} visible"))
def pytest_collection_modifyitems(config: Config, items: list[Item]):
"""Modify test collection based on markers and environment."""
# Handle version compatibility (from HEAD)
if not is_transformers_version_greater_than("4.57.0"):
skip_bc = pytest.mark.skip(reason="Skip backward compatibility tests")
for item in items:
if "tests_v1" in str(item.fspath):
item.add_marker(skip_bc)
_handle_slow_tests(items)
_handle_runs_on(items)
_handle_device_visibility(items)
@pytest.fixture(autouse=True)
def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -> None:
"""Set environment variables for distributed tests if specific devices are requested."""
env_key = _get_visible_devices_env()
if not env_key:
return
# Save old environment for logic checks, monkeypatch handles restoration
old_value = os.environ.get(env_key)
marker = request.node.get_closest_marker("require_distributed")
if marker: # distributed test
required = marker.args[0] if marker.args else 2
specific_devices = marker.args[1] if len(marker.args) > 1 else None
if specific_devices:
devices_str = ",".join(map(str, specific_devices))
else:
devices_str = ",".join(str(i) for i in range(required))
monkeypatch.setenv(env_key, devices_str)
# add project root dir to path for mp run
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if project_root not in sys.path:
sys.path.insert(0, project_root)
os.environ["PYTHONPATH"] = project_root + os.pathsep + os.environ.get("PYTHONPATH", "")
else: # non-distributed test
if old_value:
visible_devices = [v for v in old_value.split(",") if v != ""]
monkeypatch.setenv(env_key, visible_devices[0] if visible_devices else "0")
else:
monkeypatch.setenv(env_key, "0")
if CURRENT_DEVICE == "cuda":
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
elif CURRENT_DEVICE == "npu":
monkeypatch.setattr(torch.npu, "device_count", lambda: 1)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import pytest
from datasets import load_dataset
from llamafactory.v1.config.data_args import DataArguments
from llamafactory.v1.core.data_engine import DataEngine
@pytest.mark.parametrize("num_samples", [16])
def test_map_dataset(num_samples: int):
data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
data_engine = DataEngine(data_args)
original_data = load_dataset("llamafactory/v1-sft-demo", split="train")
indexes = random.choices(range(len(data_engine)), k=num_samples)
for index in indexes:
print(data_engine[index])
assert data_engine[index] == {"_dataset_name": "default", **original_data[index]}
if __name__ == "__main__":
test_map_dataset(1)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Integration tests for DataLoader with different combinations of packing and dynamic batching.
Tests the 4 scenarios:
a) non pack + non dynamic.
b) non pack + dynamic.
c) pack + non dynamic.
d) pack + dynamic.
"""
import torch
from torch.utils.data import DataLoader as TorchDataLoader
from torch.utils.data import Dataset
from transformers import AutoTokenizer
from llamafactory.v1.config.data_args import DataArguments
from llamafactory.v1.core.data_engine import DataEngine
from llamafactory.v1.core.trainer_utils.data_collator import (
DefaultCollator,
)
from llamafactory.v1.core.trainer_utils.data_loader import DataLoader
from llamafactory.v1.plugins.data_plugins.template import QwenTemplate
from llamafactory.v1.utils.batching_queue import TextBatchingQueue
class TensorDataset(Dataset):
"""Wrapper dataset that converts DataEngine samples to tensor format."""
def __init__(self, data_engine: DataEngine, processor, template, max_samples: int = None):
self.data_engine = data_engine
self.processor = processor
self.template = template
self.max_samples = max_samples or len(data_engine)
self.tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
def __len__(self):
return min(self.max_samples, len(self.data_engine))
def __getitem__(self, idx):
# Get sample from DataEngine
sample = self.data_engine[idx]
# Extract messages from sample
# DataEngine returns samples with format like {"messages": [...], ...}
# For llamafactory/v1-sft-demo, the format should have "messages" field
messages = None
if "messages" in sample:
messages = sample["messages"]
elif "conversations" in sample:
messages = sample["conversations"]
elif "conversation" in sample:
messages = sample["conversation"]
else:
# Try to find message-like fields (skip _dataset_name)
for key, value in sample.items():
if key.startswith("_"):
continue
if isinstance(value, list) and len(value) > 0:
# Check if it looks like a message list
if isinstance(value[0], dict) and "role" in value[0]:
messages = value
break
if messages is None:
raise ValueError(f"Could not find messages in sample: {list(sample.keys())}")
# Encode messages using template
encoded = self.template.encode_messages(self.tokenizer, messages)
# Convert to tensors
return {
"input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long),
"attention_mask": torch.tensor(encoded["attention_mask"], dtype=torch.long),
"labels": torch.tensor(encoded["labels"], dtype=torch.long),
}
def create_real_dataset(max_samples: int = 20, batch_size: int = 4):
"""Create a real dataset using DataEngine."""
data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
data_engine = DataEngine(data_args)
# Create processor and template
processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen2.5")
template = QwenTemplate()
# Create tensor dataset
raw_data_dataset = TensorDataset(data_engine, processor, template, max_samples=max_samples)
# Create torch DataLoader
torch_dataloader = TorchDataLoader(
raw_data_dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=lambda x: x,
)
return torch_dataloader, processor, template
class TestDataLoaderNonPackNonDynamic:
"""Test case a) non pack + non dynamic."""
def test_basic_functionality(self):
"""Test DataLoader without packing and without dynamic batching."""
# Create real dataset
torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8)
# Create collator (non-packing)
collator = DefaultCollator(processor=processor, template=template)
# Create DataLoader without batching_queue (non-dynamic)
data_loader = DataLoader(
dataloader=torch_dataloader,
collate_fn=collator,
num_micro_batch=1,
batching_queue=None,
)
# Iterate and check results
batches = list(iter(data_loader))
assert len(batches) > 0
# Check first batch
one_batch = batches[0]
micro_batches = one_batch[0]
assert "input_ids" in micro_batches
assert "attention_mask" in micro_batches
assert "labels" in micro_batches
assert micro_batches["input_ids"].shape[0] == 1 # batch_size=1
assert micro_batches["input_ids"].ndim == 2 # [batch_size, seq_len]
class TestDataLoaderNonPackDynamic:
"""Test case b) non pack + dynamic."""
def test_basic_functionality(self):
"""Test DataLoader without packing but with dynamic batching."""
# Create real dataset
torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8)
collator = DefaultCollator(processor=processor, template=template)
# Create batching queue for dynamic batching
batching_queue = TextBatchingQueue(
token_micro_bsz=120,
buffer_size=8,
)
data_loader = DataLoader(
dataloader=torch_dataloader,
collate_fn=collator,
num_micro_batch=4,
batching_queue=batching_queue,
)
# Iterate and check
batches = list(iter(data_loader))
micro_batch_tokens_first = [micro_batch["attention_mask"].sum() for micro_batch in batches[0]]
assert all(num_tokens <= 120 for num_tokens in micro_batch_tokens_first)
assert len(batches) > 0
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from llamafactory.v1.config.model_args import ModelArguments, PluginConfig
from llamafactory.v1.core.model_loader import ModelLoader
def test_tiny_qwen():
from transformers import Qwen2Config, Qwen2ForCausalLM, Qwen2TokenizerFast
model_args = ModelArguments(model="llamafactory/tiny-random-qwen2.5")
model_loader = ModelLoader(model_args)
assert isinstance(model_loader.processor, Qwen2TokenizerFast)
assert isinstance(model_loader.model.config, Qwen2Config)
assert isinstance(model_loader.model, Qwen2ForCausalLM)
assert model_loader.model.dtype == torch.bfloat16
def test_tiny_qwen_with_kernel_plugin():
from transformers import Qwen2ForCausalLM
from llamafactory.v1.plugins.model_plugins.kernels.ops.rms_norm.npu_rms_norm import npu_rms_norm_forward
model_args = ModelArguments(
model="llamafactory/tiny-random-qwen2.5", kernel_config=PluginConfig(name="auto", include_kernels="auto")
)
model_loader = ModelLoader(model_args)
# test enable apply kernel plugin
if hasattr(torch, "npu"):
assert model_loader.model.model.layers[0].input_layernorm.forward.__code__ == npu_rms_norm_forward.__code__
else:
assert model_loader.model.model.layers[0].input_layernorm.forward.__code__ != npu_rms_norm_forward.__code__
assert isinstance(model_loader.model, Qwen2ForCausalLM)
if __name__ == "__main__":
test_tiny_qwen()
test_tiny_qwen_with_kernel_plugin()
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from llamafactory.v1.config import DataArguments, ModelArguments, TrainingArguments
from llamafactory.v1.core.data_engine import DataEngine
from llamafactory.v1.core.model_engine import ModelEngine
from llamafactory.v1.core.utils.batching import BatchGenerator
def test_normal_batching():
data_args = DataArguments(train_dataset="llamafactory/v1-sft-demo")
data_engine = DataEngine(data_args.train_dataset)
model_args = ModelArguments(model="llamafactory/tiny-random-qwen3")
model_engine = ModelEngine(model_args=model_args)
training_args = TrainingArguments(
micro_batch_size=4,
global_batch_size=8,
cutoff_len=10,
batching_workers=0,
batching_strategy="normal",
)
batch_generator = BatchGenerator(
data_engine,
model_engine.renderer,
micro_batch_size=training_args.micro_batch_size,
global_batch_size=training_args.global_batch_size,
cutoff_len=training_args.cutoff_len,
batching_workers=training_args.batching_workers,
batching_strategy=training_args.batching_strategy,
)
assert len(batch_generator) == len(data_engine) // training_args.global_batch_size
batch = next(iter(batch_generator))
assert len(batch) == 2
assert batch[0]["input_ids"].shape == (4, 10)
if __name__ == "__main__":
"""
python -m tests_v1.core.utils.test_batching
"""
test_normal_batching()
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import pytest
from transformers import AutoTokenizer
from llamafactory.v1.config import DataArguments
from llamafactory.v1.core.data_engine import DataEngine
from llamafactory.v1.core.utils.rendering import Renderer
from llamafactory.v1.utils.types import Processor
def _get_input_ids(inputs: list | dict) -> list:
if not isinstance(inputs, list):
return inputs["input_ids"]
else:
return inputs
HF_MESSAGES = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is LLM?"},
{"role": "assistant", "content": "LLM stands for Large Language Model."},
]
V1_MESSAGES = [
{"role": "system", "content": [{"type": "text", "value": "You are a helpful assistant."}]},
{"role": "user", "content": [{"type": "text", "value": "What is LLM?"}]},
{"role": "assistant", "content": [{"type": "text", "value": "LLM stands for Large Language Model."}]},
]
HF_MESSAGES_WITH_TOOLS = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is 6*8?"},
{
"role": "assistant",
"tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 6, "b": 8}}}],
},
{"role": "tool", "content": "48."},
{"role": "assistant", "content": "The result of 6*8 is 48."},
]
V1_MESSAGES_WITH_TOOLS = [
{"role": "system", "content": [{"type": "text", "value": "You are a helpful assistant."}]},
{"role": "user", "content": [{"type": "text", "value": "What is 6*8?"}]},
{
"role": "assistant",
"content": [{"type": "tool_call", "value": json.dumps({"name": "multiply", "arguments": {"a": 6, "b": 8}})}],
"loss_weight": 0.0,
},
{"role": "tool", "content": [{"type": "text", "value": "48."}]},
{"role": "assistant", "content": [{"type": "text", "value": "The result of 6*8 is 48."}]},
]
V1_TOOLS = [
{
"type": "function",
"function": {
"name": "multiply",
"description": "A function that multiplies two numbers",
"parameters": {
"type": "object",
"properties": {
"a": {"type": "number", "description": "The first number to multiply"},
"b": {"type": "number", "description": "The second number to multiply"},
},
"required": ["a", "b"],
},
},
}
]
def test_chatml_rendering():
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
renderer = Renderer(template="chatml", processor=tokenizer)
hf_inputs = _get_input_ids(tokenizer.apply_chat_template(HF_MESSAGES[:-1], add_generation_prompt=True))
v1_inputs = renderer.render_messages(V1_MESSAGES[:-1], is_generate=True)
assert v1_inputs["input_ids"] == hf_inputs
assert v1_inputs["attention_mask"] == [1] * len(hf_inputs)
assert v1_inputs["labels"] == [-100] * len(hf_inputs)
assert v1_inputs["loss_weights"] == [0.0] * len(hf_inputs)
hf_inputs_part = _get_input_ids(tokenizer.apply_chat_template(HF_MESSAGES[:-1], add_generation_prompt=False))
hf_inputs_full = _get_input_ids(tokenizer.apply_chat_template(HF_MESSAGES, add_generation_prompt=False))
v1_inputs_full = renderer.render_messages(V1_MESSAGES, is_generate=False)
assert v1_inputs_full["input_ids"] == hf_inputs_full
assert v1_inputs_full["attention_mask"] == [1] * len(hf_inputs_full)
assert v1_inputs_full["labels"] == [-100] * len(hf_inputs_part) + hf_inputs_full[len(hf_inputs_part) :]
assert v1_inputs_full["loss_weights"] == [0.0] * len(hf_inputs_part) + [1.0] * (
len(hf_inputs_full) - len(hf_inputs_part)
)
def test_chatml_parse():
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
renderer = Renderer(template="chatml", processor=tokenizer)
generated_text = "LLM stands for Large Language Model."
parsed_message = renderer.parse_message(generated_text)
assert parsed_message == V1_MESSAGES[-1]
@pytest.mark.parametrize("num_samples", [16])
def test_chatml_rendering_remote(num_samples: int):
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
renderer = Renderer(template="chatml", processor=tokenizer)
data_args = DataArguments(train_dataset="llamafactory/v1-sft-demo")
data_engine = DataEngine(data_args.train_dataset)
for index in range(num_samples):
v1_inputs = renderer.render_messages(data_engine[index]["messages"], is_generate=True)
prefix = tokenizer.encode("<|im_start|>user\n", add_special_tokens=False)
print(tokenizer.decode(v1_inputs["input_ids"][: len(prefix)]))
assert v1_inputs["input_ids"][: len(prefix)] == prefix
def test_qwen3_nothink_rendering():
tokenizer: Processor = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507")
renderer = Renderer(template="qwen3_nothink", processor=tokenizer)
hf_inputs = _get_input_ids(
tokenizer.apply_chat_template(HF_MESSAGES_WITH_TOOLS[:-1], tools=V1_TOOLS, add_generation_prompt=True)
)
v1_inputs = renderer.render_messages(V1_MESSAGES_WITH_TOOLS[:-1], tools=json.dumps(V1_TOOLS), is_generate=True)
assert v1_inputs["input_ids"] == hf_inputs
assert v1_inputs["attention_mask"] == [1] * len(hf_inputs)
assert v1_inputs["labels"] == [-100] * len(hf_inputs)
assert v1_inputs["loss_weights"] == [0.0] * len(hf_inputs)
hf_inputs_part = _get_input_ids(
tokenizer.apply_chat_template(HF_MESSAGES_WITH_TOOLS[:-1], tools=V1_TOOLS, add_generation_prompt=False)
)
hf_inputs_full = _get_input_ids(
tokenizer.apply_chat_template(HF_MESSAGES_WITH_TOOLS, tools=V1_TOOLS, add_generation_prompt=False)
)
v1_inputs_full = renderer.render_messages(V1_MESSAGES_WITH_TOOLS, tools=json.dumps(V1_TOOLS), is_generate=False)
assert v1_inputs_full["input_ids"] == hf_inputs_full
assert v1_inputs_full["attention_mask"] == [1] * len(hf_inputs_full)
assert v1_inputs_full["labels"] == [-100] * len(hf_inputs_part) + hf_inputs_full[len(hf_inputs_part) :]
assert v1_inputs_full["loss_weights"] == [0.0] * len(hf_inputs_part) + [1.0] * (
len(hf_inputs_full) - len(hf_inputs_part)
)
def test_qwen3_nothink_parse():
tokenizer: Processor = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507")
renderer = Renderer(template="qwen3_nothink", processor=tokenizer)
generated_text = (
"<thinking>I need to use the multiply function to calculate 6*8.</thinking>"
"Let me call the multiply function."
'<tool_call>{"name": "multiply", "arguments": {"a": 6, "b": 8}}</tool_call>'
)
parsed_message = renderer.parse_message(generated_text)
assert parsed_message == {
"role": "assistant",
"content": [
{"type": "reasoning", "value": "I need to use the multiply function to calculate 6*8."},
{"type": "text", "value": "Let me call the multiply function."},
{"type": "tool_call", "value": json.dumps({"name": "multiply", "arguments": {"a": 6, "b": 8}})},
],
}
@pytest.mark.parametrize("num_samples", [8])
def test_qwen3_nothink_rendering_remote(num_samples: int):
tokenizer: Processor = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507")
renderer = Renderer(template="qwen3_nothink", processor=tokenizer)
data_args = DataArguments(train_dataset="llamafactory/reason-tool-use-demo-1500")
data_engine = DataEngine(data_args.train_dataset)
for index in range(num_samples):
v1_inputs = renderer.render_messages(data_engine[index]["messages"], tools=data_engine[index]["tools"])
prefix_text = (
"<|im_start|>system\nYou are a methodical and expert assistant. "
"Your primary goal is to solve user requests by leveraging a set of available tools. "
"You must reason for the best course of action in a structured manner before responding.\n\n"
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>\n"
'{"type": "function", "function": {"name":'
)
prefix = tokenizer.encode(prefix_text, add_special_tokens=False)
print(tokenizer.decode(v1_inputs["input_ids"][: len(prefix)]))
assert v1_inputs["input_ids"][: len(prefix)] == prefix
def test_process_sft_samples():
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
renderer = Renderer(template="chatml", processor=tokenizer)
hf_inputs = _get_input_ids(tokenizer.apply_chat_template(HF_MESSAGES))
samples = [{"messages": V1_MESSAGES, "extra_info": "test", "_dataset_name": "default"}]
model_inputs = renderer.process_samples(samples)
assert len(model_inputs) == 1
assert model_inputs[0]["input_ids"] == hf_inputs
assert model_inputs[0]["extra_info"] == "test"
assert model_inputs[0]["_dataset_name"] == "default"
def test_process_dpo_samples():
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
renderer = Renderer(template="chatml", processor=tokenizer)
hf_inputs = _get_input_ids(tokenizer.apply_chat_template(HF_MESSAGES))
samples = [
{
"chosen_messages": V1_MESSAGES,
"rejected_messages": V1_MESSAGES,
"extra_info": "test",
"_dataset_name": "default",
}
]
model_inputs = renderer.process_samples(samples)
assert len(model_inputs) == 1
assert model_inputs[0]["input_ids"] == hf_inputs * 2
assert model_inputs[0]["token_type_ids"] == [1] * len(hf_inputs) + [2] * len(hf_inputs)
assert model_inputs[0]["extra_info"] == "test"
assert model_inputs[0]["_dataset_name"] == "default"
if __name__ == "__main__":
"""
python -m tests_v1.core.utils.test_rendering
"""
test_chatml_rendering()
test_chatml_parse()
test_chatml_rendering_remote(16)
test_qwen3_nothink_rendering()
test_qwen3_nothink_parse()
test_qwen3_nothink_rendering_remote(16)
test_process_sft_samples()
test_process_dpo_samples()
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import pytest
from datasets import load_dataset
from llamafactory.v1.config.data_args import DataArguments
from llamafactory.v1.core.data_engine import DataEngine
from llamafactory.v1.plugins.data_plugins.converter import DataConverterPlugin
@pytest.mark.parametrize("num_samples", [16])
def test_alpaca_converter(num_samples: int):
data_args = DataArguments(dataset="llamafactory/v1-dataset-info/tiny-supervised-dataset.yaml")
data_engine = DataEngine(data_args)
original_data = load_dataset("llamafactory/tiny-supervised-dataset", split="train")
indexes = random.choices(range(len(data_engine)), k=num_samples)
for index in indexes:
print(data_engine[index])
expected_data = {
"messages": [
{
"role": "user",
"content": [
{"type": "text", "value": original_data[index]["instruction"] + original_data[index]["input"]}
],
"loss_weight": 0.0,
},
{
"role": "assistant",
"content": [{"type": "text", "value": original_data[index]["output"]}],
"loss_weight": 1.0,
},
]
}
assert data_engine[index] == {"_dataset_name": "tiny_dataset", **expected_data}
def test_sharegpt_converter():
example = {
"conversations": [
{"from": "system", "value": "System"},
{"from": "human", "value": "User"},
{"from": "function_call", "value": "Tool"},
{"from": "observation", "value": "Observation"},
{"from": "gpt", "value": "Assistant"},
]
}
expected_data = {
"messages": [
{"content": [{"type": "text", "value": "System"}], "loss_weight": 0.0, "role": "system"},
{"content": [{"type": "text", "value": "User"}], "loss_weight": 0.0, "role": "user"},
{"content": [{"type": "tool_calls", "value": "Tool"}], "loss_weight": 1.0, "role": "assistant"},
{"content": [{"type": "text", "value": "Observation"}], "loss_weight": 0.0, "role": "tool"},
{"content": [{"type": "text", "value": "Assistant"}], "loss_weight": 1.0, "role": "assistant"},
]
}
assert DataConverterPlugin("sharegpt")(example) == expected_data
@pytest.mark.parametrize("num_samples", [16])
def test_pair_converter(num_samples: int):
data_args = DataArguments(dataset="llamafactory/v1-dataset-info/orca-dpo-pairs.yaml")
data_engine = DataEngine(data_args)
original_data = load_dataset("HuggingFaceH4/orca_dpo_pairs", split="train_prefs")
indexes = random.choices(range(len(data_engine)), k=num_samples)
for index in indexes:
print(data_engine[index])
print(original_data[index])
expected_data = {
"chosen_messages": [
{
"role": "system",
"content": [{"type": "text", "value": original_data[index]["chosen"][0]["content"]}],
"loss_weight": 0.0,
},
{
"role": "user",
"content": [{"type": "text", "value": original_data[index]["chosen"][1]["content"]}],
"loss_weight": 0.0,
},
{
"role": "assistant",
"content": [{"type": "text", "value": original_data[index]["chosen"][2]["content"]}],
"loss_weight": 1.0,
},
],
"rejected_messages": [
{
"role": "system",
"content": [{"type": "text", "value": original_data[index]["rejected"][0]["content"]}],
"loss_weight": 0.0,
},
{
"role": "user",
"content": [{"type": "text", "value": original_data[index]["rejected"][1]["content"]}],
"loss_weight": 0.0,
},
{
"role": "assistant",
"content": [{"type": "text", "value": original_data[index]["rejected"][2]["content"]}],
"loss_weight": 1.0,
},
],
}
assert data_engine[index] == {"_dataset_name": "tiny_dataset", **expected_data}
if __name__ == "__main__":
test_alpaca_converter(1)
test_sharegpt_converter()
test_pair_converter(1)
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from llamafactory.v1.accelerator.interface import DistributedInterface
from llamafactory.v1.config.arg_parser import get_args
from llamafactory.v1.core.model_engine import ModelEngine
def test_init_on_meta():
model_args, *_ = get_args(
dict(
model="llamafactory/tiny-random-qwen3",
init_config={"name": "init_on_meta"},
)
)
model_engine = ModelEngine(model_args=model_args)
assert model_engine.model.device.type == "meta"
def test_init_on_rank0():
model_args, *_ = get_args(
dict(
model="llamafactory/tiny-random-qwen3",
init_config={"name": "init_on_rank0"},
)
)
model_engine = ModelEngine(model_args=model_args)
if DistributedInterface().get_rank() == 0:
assert model_engine.model.device.type == "cpu"
else:
assert model_engine.model.device.type == "meta"
def test_init_on_default():
model_args, *_ = get_args(
dict(
model="llamafactory/tiny-random-qwen3",
init_config={"name": "init_on_default"},
)
)
model_engine = ModelEngine(model_args=model_args)
assert model_engine.model.device == DistributedInterface().current_device
if __name__ == "__main__":
"""
python tests_v1/plugins/model_plugins/test_init_plugin.py
"""
test_init_on_meta()
test_init_on_rank0()
test_init_on_default()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment