Commit c7d1b209 authored by chenych's avatar chenych
Browse files

Update 0429

parent c8d12c06
...@@ -15,11 +15,13 @@ ...@@ -15,11 +15,13 @@
import os import os
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
import numpy as np
import pytest import pytest
import torch import torch
from PIL import Image from PIL import Image
from llamafactory.data.mm_plugin import get_mm_plugin from llamafactory.data.mm_plugin import get_mm_plugin
from llamafactory.extras.packages import is_transformers_version_greater_than
from llamafactory.hparams import get_infer_args from llamafactory.hparams import get_infer_args
from llamafactory.model import load_tokenizer from llamafactory.model import load_tokenizer
...@@ -42,11 +44,20 @@ MM_MESSAGES = [ ...@@ -42,11 +44,20 @@ MM_MESSAGES = [
{"role": "assistant", "content": "A cat."}, {"role": "assistant", "content": "A cat."},
] ]
OMNI_MESSAGES = [
{"role": "user", "content": "<image>What is in this image?"},
{"role": "assistant", "content": "A cat."},
{"role": "user", "content": "<audio>What is in this audio?"},
{"role": "assistant", "content": "Nothing."},
]
TEXT_MESSAGES = [ TEXT_MESSAGES = [
{"role": "user", "content": "How are you"}, {"role": "user", "content": "How are you"},
{"role": "assistant", "content": "I am fine!"}, {"role": "assistant", "content": "I am fine!"},
] ]
AUDIOS = [np.zeros(1600)]
IMAGES = [Image.new("RGB", (32, 32), (255, 255, 255))] IMAGES = [Image.new("RGB", (32, 32), (255, 255, 255))]
NO_IMAGES = [] NO_IMAGES = []
...@@ -57,6 +68,8 @@ NO_AUDIOS = [] ...@@ -57,6 +68,8 @@ NO_AUDIOS = []
IMGLENS = [1] IMGLENS = [1]
AUDLENS = [1]
NO_IMGLENS = [0] NO_IMGLENS = [0]
NO_VIDLENS = [0] NO_VIDLENS = [0]
...@@ -75,6 +88,25 @@ def _get_mm_inputs(processor: "ProcessorMixin") -> dict[str, "torch.Tensor"]: ...@@ -75,6 +88,25 @@ def _get_mm_inputs(processor: "ProcessorMixin") -> dict[str, "torch.Tensor"]:
return image_processor(images=IMAGES, return_tensors="pt") return image_processor(images=IMAGES, return_tensors="pt")
def _get_omni_inputs(processor: "ProcessorMixin") -> dict[str, "torch.Tensor"]:
mm_inputs = {}
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
feature_extractor = getattr(processor, "feature_extractor", None)
mm_inputs.update(image_processor(IMAGES, return_tensors="pt"))
mm_inputs.update(
feature_extractor(
AUDIOS,
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
return_attention_mask=True,
padding="max_length",
return_tensors="pt",
)
)
mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask")
return mm_inputs
def _is_close(batch_a: dict[str, Any], batch_b: dict[str, Any]) -> None: def _is_close(batch_a: dict[str, Any], batch_b: dict[str, Any]) -> None:
assert batch_a.keys() == batch_b.keys() assert batch_a.keys() == batch_b.keys()
for key in batch_a.keys(): for key in batch_a.keys():
...@@ -103,6 +135,17 @@ def _check_plugin( ...@@ -103,6 +135,17 @@ def _check_plugin(
expected_mm_inputs: dict[str, Any] = {}, expected_mm_inputs: dict[str, Any] = {},
expected_no_mm_inputs: dict[str, Any] = {}, expected_no_mm_inputs: dict[str, Any] = {},
) -> None: ) -> None:
# test omni_messages
if plugin.__class__.__name__ == "Qwen2OmniPlugin":
assert plugin.process_messages(OMNI_MESSAGES, IMAGES, NO_VIDEOS, AUDIOS, processor) == expected_mm_messages
assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, AUDIOS, tokenizer, processor) == (
expected_input_ids,
expected_labels,
)
_is_close(
plugin.get_mm_inputs(IMAGES, NO_VIDEOS, AUDIOS, IMGLENS, NO_VIDLENS, AUDLENS, BATCH_IDS, processor),
expected_mm_inputs,
)
# test mm_messages # test mm_messages
if plugin.__class__.__name__ != "BasePlugin": if plugin.__class__.__name__ != "BasePlugin":
assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == expected_mm_messages assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == expected_mm_messages
...@@ -137,6 +180,7 @@ def test_base_plugin(): ...@@ -137,6 +180,7 @@ def test_base_plugin():
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") @pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
@pytest.mark.skipif(not is_transformers_version_greater_than("4.50.0"), reason="Requires transformers>=4.50.0")
def test_gemma3_plugin(): def test_gemma3_plugin():
image_seqlen = 256 image_seqlen = 256
tokenizer_module = _load_tokenizer_module(model_name_or_path="google/gemma-3-4b-it") tokenizer_module = _load_tokenizer_module(model_name_or_path="google/gemma-3-4b-it")
...@@ -157,6 +201,24 @@ def test_gemma3_plugin(): ...@@ -157,6 +201,24 @@ def test_gemma3_plugin():
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
@pytest.mark.xfail(reason="Unknown error.")
def test_internvl_plugin():
image_seqlen = 256
tokenizer_module = _load_tokenizer_module(model_name_or_path="OpenGVLab/InternVL3-1B-hf")
internvl_plugin = get_mm_plugin("intern_vl", image_token="<image>", video_token="<video>")
check_inputs = {"plugin": internvl_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [
{
key: value.replace("<image>", f"<img>{'<IMG_CONTEXT>' * image_seqlen * 1}</img>")
for key, value in message.items()
}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
check_inputs["expected_mm_inputs"].pop("num_patches", None)
_check_plugin(**check_inputs)
@pytest.mark.xfail(reason="Unknown error.") @pytest.mark.xfail(reason="Unknown error.")
def test_llama4_plugin(): def test_llama4_plugin():
tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA4) tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA4)
...@@ -178,6 +240,7 @@ def test_llama4_plugin(): ...@@ -178,6 +240,7 @@ def test_llama4_plugin():
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
@pytest.mark.skipif(not is_transformers_version_greater_than("4.47.0"), reason="Requires transformers>=4.47.0")
def test_llava_plugin(): def test_llava_plugin():
image_seqlen = 576 image_seqlen = 576
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf") tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")
...@@ -236,6 +299,7 @@ def test_paligemma_plugin(): ...@@ -236,6 +299,7 @@ def test_paligemma_plugin():
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
@pytest.mark.skipif(not is_transformers_version_greater_than("4.50.0"), reason="Requires transformers>=4.50.0")
def test_pixtral_plugin(): def test_pixtral_plugin():
image_slice_height, image_slice_width = 2, 2 image_slice_height, image_slice_width = 2, 2
tokenizer_module = _load_tokenizer_module(model_name_or_path="mistral-community/pixtral-12b") tokenizer_module = _load_tokenizer_module(model_name_or_path="mistral-community/pixtral-12b")
...@@ -257,6 +321,30 @@ def test_pixtral_plugin(): ...@@ -257,6 +321,30 @@ def test_pixtral_plugin():
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
@pytest.mark.xfail(reason="Unknown error.")
def test_qwen2_omni_plugin():
image_seqlen = 4
audio_seqlen = 2
tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2.5-Omni-7B")
qwen2_omni_plugin = get_mm_plugin(
name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>"
)
check_inputs = {"plugin": qwen2_omni_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [
{
key: (
value.replace("<image>", f"<|vision_bos|>{'<|IMAGE|>' * image_seqlen}<|vision_eos|>").replace(
"<audio>", f"<|audio_bos|>{'<|AUDIO|>' * audio_seqlen}<|audio_eos|>"
)
)
for key, value in message.items()
}
for message in OMNI_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_omni_inputs(tokenizer_module["processor"])
_check_plugin(**check_inputs)
def test_qwen2_vl_plugin(): def test_qwen2_vl_plugin():
image_seqlen = 4 image_seqlen = 4
tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct") tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
...@@ -273,6 +361,7 @@ def test_qwen2_vl_plugin(): ...@@ -273,6 +361,7 @@ def test_qwen2_vl_plugin():
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
@pytest.mark.skipif(not is_transformers_version_greater_than("4.47.0"), reason="Requires transformers>=4.47.0")
def test_video_llava_plugin(): def test_video_llava_plugin():
image_seqlen = 256 image_seqlen = 256
tokenizer_module = _load_tokenizer_module(model_name_or_path="LanguageBind/Video-LLaVA-7B-hf") tokenizer_module = _load_tokenizer_module(model_name_or_path="LanguageBind/Video-LLaVA-7B-hf")
......
...@@ -39,6 +39,13 @@ MESSAGES = [ ...@@ -39,6 +39,13 @@ MESSAGES = [
{"role": "assistant", "content": "很高兴认识你!"}, {"role": "assistant", "content": "很高兴认识你!"},
] ]
MESSAGES_WITH_THOUGHT = [
{"role": "user", "content": "How are you"},
{"role": "assistant", "content": "<think>\nModel thought here\n</think>\n\nI am fine!"},
{"role": "user", "content": "你好"},
{"role": "assistant", "content": "<think>\n模型思考内容\n</think>\n\n很高兴认识你!"},
]
def _check_tokenization( def _check_tokenization(
tokenizer: "PreTrainedTokenizer", batch_input_ids: list[list[int]], batch_text: list[str] tokenizer: "PreTrainedTokenizer", batch_input_ids: list[list[int]], batch_text: list[str]
...@@ -53,7 +60,14 @@ def _check_tokenization( ...@@ -53,7 +60,14 @@ def _check_tokenization(
assert tokenizer.decode(input_ids) == text assert tokenizer.decode(input_ids) == text
def _check_template(model_id: str, template_name: str, prompt_str: str, answer_str: str, use_fast: bool) -> None: def _check_template(
model_id: str,
template_name: str,
prompt_str: str,
answer_str: str,
use_fast: bool,
messages: list[dict[str, str]] = MESSAGES,
) -> None:
r"""Check template. r"""Check template.
Args: Args:
...@@ -62,13 +76,14 @@ def _check_template(model_id: str, template_name: str, prompt_str: str, answer_s ...@@ -62,13 +76,14 @@ def _check_template(model_id: str, template_name: str, prompt_str: str, answer_s
prompt_str: the string corresponding to the prompt part. prompt_str: the string corresponding to the prompt part.
answer_str: the string corresponding to the answer part. answer_str: the string corresponding to the answer part.
use_fast: whether to use fast tokenizer. use_fast: whether to use fast tokenizer.
messages: the list of messages.
""" """
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=HF_TOKEN) tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=HF_TOKEN)
content_str = tokenizer.apply_chat_template(MESSAGES, tokenize=False) content_str = tokenizer.apply_chat_template(messages, tokenize=False)
content_ids = tokenizer.apply_chat_template(MESSAGES, tokenize=True) content_ids = tokenizer.apply_chat_template(messages, tokenize=True)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template=template_name)) template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template=template_name))
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES) prompt_ids, answer_ids = template.encode_oneturn(tokenizer, messages)
assert content_str == prompt_str + answer_str assert content_str == prompt_str + answer_str
assert content_ids == prompt_ids + answer_ids assert content_ids == prompt_ids + answer_ids
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str)) _check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
...@@ -198,7 +213,7 @@ def test_phi4_template(use_fast: bool): ...@@ -198,7 +213,7 @@ def test_phi4_template(use_fast: bool):
@pytest.mark.parametrize("use_fast", [True, False]) @pytest.mark.parametrize("use_fast", [True, False])
def test_qwen_template(use_fast: bool): def test_qwen2_5_template(use_fast: bool):
prompt_str = ( prompt_str = (
"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n" "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\nHow are you<|im_end|>\n" "<|im_start|>user\nHow are you<|im_end|>\n"
...@@ -210,6 +225,18 @@ def test_qwen_template(use_fast: bool): ...@@ -210,6 +225,18 @@ def test_qwen_template(use_fast: bool):
_check_template("Qwen/Qwen2.5-7B-Instruct", "qwen", prompt_str, answer_str, use_fast) _check_template("Qwen/Qwen2.5-7B-Instruct", "qwen", prompt_str, answer_str, use_fast)
@pytest.mark.parametrize("use_fast", [True, False])
def test_qwen3_template(use_fast: bool):
prompt_str = (
"<|im_start|>user\nHow are you<|im_end|>\n"
"<|im_start|>assistant\nI am fine!<|im_end|>\n"
"<|im_start|>user\n你好<|im_end|>\n"
"<|im_start|>assistant\n"
)
answer_str = "<think>\n模型思考内容\n</think>\n\n很高兴认识你!<|im_end|>\n"
_check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, use_fast, messages=MESSAGES_WITH_THOUGHT)
def test_parse_llama3_template(): def test_parse_llama3_template():
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, token=HF_TOKEN) tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA3, token=HF_TOKEN)
template = parse_template(tokenizer) template = parse_template(tokenizer)
...@@ -231,3 +258,13 @@ def test_parse_qwen_template(): ...@@ -231,3 +258,13 @@ def test_parse_qwen_template():
assert template.format_system.slots == ["<|im_start|>system\n{{content}}<|im_end|>\n"] assert template.format_system.slots == ["<|im_start|>system\n{{content}}<|im_end|>\n"]
assert template.format_prefix.slots == [] assert template.format_prefix.slots == []
assert template.default_system == "You are Qwen, created by Alibaba Cloud. You are a helpful assistant." assert template.default_system == "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
def test_parse_qwen3_template():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", token=HF_TOKEN)
template = parse_template(tokenizer)
assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
assert template.format_assistant.slots == ["{{content}}<|im_end|>\n"]
assert template.format_system.slots == ["<|im_start|>system\n{{content}}<|im_end|>\n"]
assert template.format_prefix.slots == []
assert template.default_system == ""
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
from llamafactory.hparams import ModelArguments
from llamafactory.model import load_tokenizer
TINY_LLAMA3 = os.getenv("TINY_LLAMA3", "llamafactory/tiny-random-Llama-3")
UNUSED_TOKEN = "<|UNUSED_TOKEN|>"
@pytest.mark.parametrize("special_tokens", [False, True])
def test_add_tokens(special_tokens: bool):
if special_tokens:
model_args = ModelArguments(model_name_or_path=TINY_LLAMA3, add_special_tokens=UNUSED_TOKEN)
else:
model_args = ModelArguments(model_name_or_path=TINY_LLAMA3, add_tokens=UNUSED_TOKEN)
tokenizer = load_tokenizer(model_args)["tokenizer"]
encoded_ids = tokenizer.encode(UNUSED_TOKEN, add_special_tokens=False)
assert len(encoded_ids) == 1
decoded_str = tokenizer.decode(encoded_ids, skip_special_tokens=True)
if special_tokens:
assert decoded_str == ""
else:
assert decoded_str == UNUSED_TOKEN
if __name__ == "__main__":
pytest.main([__file__])
...@@ -50,6 +50,10 @@ class DataCollatorWithVerbose(DataCollatorWithPadding): ...@@ -50,6 +50,10 @@ class DataCollatorWithVerbose(DataCollatorWithPadding):
verbose_list: list[dict[str, Any]] = field(default_factory=list) verbose_list: list[dict[str, Any]] = field(default_factory=list)
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
features = [
{k: v for k, v in feature.items() if k in ["input_ids", "attention_mask", "labels"]}
for feature in features
]
self.verbose_list.extend(features) self.verbose_list.extend(features)
batch = super().__call__(features) batch = super().__call__(features)
return {k: v[:, :1] for k, v in batch.items()} # truncate input length return {k: v[:, :1] for k, v in batch.items()} # truncate input length
......
# change if test fails # change if test fails or cache is outdated
0.9.3.102 0.9.3.106
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment