Commit 27a7ad86 authored by luopl's avatar luopl
Browse files

update to v0.9.1

parent 731cf9b8
......@@ -17,7 +17,7 @@
from typing import TYPE_CHECKING, List, Optional
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX
from ...extras.misc import get_logits_processor
from ...extras.ploting import plot_loss
......@@ -43,25 +43,27 @@ def run_sft(
):
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
dataset_module = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
template = get_template_and_fix_tokenizer(tokenizer, data_args)
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
if getattr(model, "is_quantized", False) and not training_args.do_train:
setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction
data_collator = SFTDataCollatorWith4DAttentionMask(
tokenizer=tokenizer,
template=template,
pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
block_diag_attn=model_args.block_diag_attn,
attn_implementation=getattr(model.config, "_attn_implementation", None),
compute_dtype=model_args.compute_dtype,
**tokenizer_module,
)
# Override the decoding parameters of Seq2SeqTrainer
training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
training_args.remove_unused_columns = False if model_args.visual_inputs else training_args.remove_unused_columns
training_args.remove_unused_columns = False # important for multimodal dataset
# Metric utils
metric_module = {}
......
......@@ -19,7 +19,7 @@ from peft import PeftModel
from transformers import AutoModelForCausalLM
from trl import AutoModelForCausalLMWithValueHead
from ..data import get_dataset
from ..data import get_dataset, get_template_and_fix_tokenizer
from ..extras.misc import get_current_device
from ..hparams import get_infer_args, get_train_args
from ..model import load_model, load_tokenizer
......@@ -37,9 +37,9 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_k
assert set(state_dict_a.keys()) == set(state_dict_b.keys())
for name in state_dict_a.keys():
if any(key in name for key in diff_keys):
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is False
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-3, atol=1e-4) is False
else:
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is True
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-3, atol=1e-4) is True
def check_lora_model(model: "LoraModel") -> Tuple[Set[str], Set[str]]:
......@@ -105,7 +105,8 @@ def load_reference_model(
def load_train_dataset(**kwargs) -> "Dataset":
model_args, data_args, training_args, _, _ = get_train_args(kwargs)
tokenizer_module = load_tokenizer(model_args)
dataset_module = get_dataset(model_args, data_args, training_args, stage=kwargs["stage"], **tokenizer_module)
template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
dataset_module = get_dataset(template, model_args, data_args, training_args, kwargs["stage"], **tokenizer_module)
return dataset_module["train_dataset"]
......
......@@ -26,6 +26,7 @@ from transformers.modeling_utils import is_fsdp_enabled
from transformers.optimization import get_scheduler
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer_pt_utils import get_parameter_names
from typing_extensions import override
from ..extras.constants import IGNORE_INDEX
from ..extras.logging import get_logger
......@@ -60,9 +61,11 @@ class DummyOptimizer(torch.optim.Optimizer):
self.optimizer_dict = optimizer_dict
super().__init__([dummy_tensor], {"lr": lr})
@override
def zero_grad(self, set_to_none: bool = True) -> None:
pass
@override
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
pass
......
......@@ -72,7 +72,7 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
processor = tokenizer_module["processor"]
get_template_and_fix_tokenizer(tokenizer, data_args.template)
get_template_and_fix_tokenizer(tokenizer, data_args)
model = load_model(tokenizer, model_args, finetuning_args) # must after fixing tokenizer to resize vocab
if getattr(model, "quantization_method", None) is not None and model_args.adapter_name_or_path is not None:
......@@ -132,12 +132,12 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
if model_args.export_hub_model_id is not None:
tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
if model_args.visual_inputs and processor is not None:
if processor is not None:
getattr(processor, "image_processor").save_pretrained(model_args.export_dir)
if model_args.export_hub_model_id is not None:
getattr(processor, "image_processor").push_to_hub(
model_args.export_hub_model_id, token=model_args.hf_hub_token
)
except Exception:
logger.warning("Cannot save tokenizer, please copy the files manually.")
except Exception as e:
logger.warning("Cannot save tokenizer, please copy the files manually: {}.".format(e))
......@@ -14,9 +14,7 @@
import json
import os
from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Tuple
from numpy.typing import NDArray
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple
from ..chat import ChatModel
from ..data import Role
......@@ -90,7 +88,6 @@ class WebChatModel(ChatModel):
template=get("top.template"),
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
use_unsloth=(get("top.booster") == "unsloth"),
visual_inputs=get("top.visual_inputs"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
infer_backend=get("infer.infer_backend"),
infer_dtype=get("infer.infer_dtype"),
......@@ -135,7 +132,8 @@ class WebChatModel(ChatModel):
messages: Sequence[Dict[str, str]],
system: str,
tools: str,
image: Optional[NDArray],
image: Optional[Any],
video: Optional[Any],
max_new_tokens: int,
top_p: float,
temperature: float,
......@@ -143,7 +141,7 @@ class WebChatModel(ChatModel):
chatbot[-1][1] = ""
response = ""
for new_text in self.stream_chat(
messages, system, tools, image, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
messages, system, tools, image, video, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
):
response += new_text
if tools:
......
......@@ -115,39 +115,29 @@ def get_model_path(model_name: str) -> str:
return model_path
def get_prefix(model_name: str) -> str:
r"""
Gets the prefix of the model name to obtain the model family.
"""
return model_name.split("-")[0]
def get_model_info(model_name: str) -> Tuple[str, str, bool]:
def get_model_info(model_name: str) -> Tuple[str, str]:
r"""
Gets the necessary information of this model.
Returns:
model_path (str)
template (str)
visual (bool)
"""
return get_model_path(model_name), get_template(model_name), get_visual(model_name)
return get_model_path(model_name), get_template(model_name)
def get_template(model_name: str) -> str:
r"""
Gets the template name if the model is a chat model.
"""
if model_name and model_name.endswith("Chat") and get_prefix(model_name) in DEFAULT_TEMPLATE:
return DEFAULT_TEMPLATE[get_prefix(model_name)]
return "default"
return DEFAULT_TEMPLATE.get(model_name, "default")
def get_visual(model_name: str) -> bool:
r"""
Judges if the model is a vision language model.
"""
return get_prefix(model_name) in VISION_MODELS
return model_name in VISION_MODELS
def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown":
......
......@@ -43,8 +43,12 @@ def create_chat_box(
system = gr.Textbox(show_label=False)
tools = gr.Textbox(show_label=False, lines=3)
with gr.Column() as image_box:
image = gr.Image(sources=["upload"], type="numpy")
with gr.Column() as mm_box:
with gr.Tab("Image"):
image = gr.Image(sources=["upload"], type="pil")
with gr.Tab("Video"):
video = gr.Video(sources=["upload"])
query = gr.Textbox(show_label=False, lines=8)
submit_btn = gr.Button(variant="primary")
......@@ -63,7 +67,7 @@ def create_chat_box(
[chatbot, messages, query],
).then(
engine.chatter.stream,
[chatbot, messages, system, tools, image, max_new_tokens, top_p, temperature],
[chatbot, messages, system, tools, image, video, max_new_tokens, top_p, temperature],
[chatbot, messages],
)
clear_btn.click(lambda: ([], []), outputs=[chatbot, messages])
......@@ -76,8 +80,9 @@ def create_chat_box(
role=role,
system=system,
tools=tools,
image_box=image_box,
mm_box=mm_box,
image=image,
video=video,
query=query,
submit_btn=submit_btn,
max_new_tokens=max_new_tokens,
......
......@@ -46,7 +46,6 @@ def save_model(
finetuning_type: str,
checkpoint_path: Union[str, List[str]],
template: str,
visual_inputs: bool,
export_size: int,
export_quantization_bit: str,
export_quantization_dataset: str,
......@@ -78,7 +77,6 @@ def save_model(
model_name_or_path=model_path,
finetuning_type=finetuning_type,
template=template,
visual_inputs=visual_inputs,
export_dir=export_dir,
export_hub_model_id=export_hub_model_id or None,
export_size=export_size,
......@@ -129,7 +127,6 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
engine.manager.get_elem_by_id("top.finetuning_type"),
engine.manager.get_elem_by_id("top.checkpoint_path"),
engine.manager.get_elem_by_id("top.template"),
engine.manager.get_elem_by_id("top.visual_inputs"),
export_size,
export_quantization_bit,
export_quantization_dataset,
......
......@@ -15,6 +15,7 @@
from typing import TYPE_CHECKING, Dict
from ...extras.packages import is_gradio_available
from ..common import get_visual
from .chatbot import create_chat_box
......@@ -64,10 +65,10 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
lambda: ([], []), outputs=[chatbot, messages]
).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]])
engine.manager.get_elem_by_id("top.visual_inputs").change(
lambda enabled: gr.Column(visible=enabled),
[engine.manager.get_elem_by_id("top.visual_inputs")],
[chat_elems["image_box"]],
engine.manager.get_elem_by_id("top.model_name").change(
lambda model_name: gr.Column(visible=get_visual(model_name)),
[engine.manager.get_elem_by_id("top.model_name")],
[chat_elems["mm_box"]],
)
return elem_dict
......@@ -43,14 +43,13 @@ def create_top() -> Dict[str, "Component"]:
with gr.Accordion(open=False) as advanced_tab:
with gr.Row():
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", allow_custom_value=True, scale=1)
quantization_method = gr.Dropdown(choices=["bitsandbytes", "hqq", "eetq"], value="bitsandbytes", scale=1)
template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=1)
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=2)
booster = gr.Radio(choices=["auto", "flashattn2", "unsloth"], value="auto", scale=2)
visual_inputs = gr.Checkbox(scale=1)
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", allow_custom_value=True, scale=2)
quantization_method = gr.Dropdown(choices=["bitsandbytes", "hqq", "eetq"], value="bitsandbytes", scale=2)
template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=2)
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=3)
booster = gr.Radio(choices=["auto", "flashattn2", "unsloth", "liger_kernel"], value="auto", scale=5)
model_name.change(get_model_info, [model_name], [model_path, template, visual_inputs], queue=False).then(
model_name.change(get_model_info, [model_name], [model_path, template], queue=False).then(
list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False
)
model_name.input(save_config, inputs=[lang, model_name], queue=False)
......@@ -73,5 +72,4 @@ def create_top() -> Dict[str, "Component"]:
template=template,
rope_scaling=rope_scaling,
booster=booster,
visual_inputs=visual_inputs,
)
......@@ -59,7 +59,7 @@ class Engine:
init_dict["train.output_dir"] = {"value": "train_{}".format(current_time)}
init_dict["train.config_path"] = {"value": "{}.yaml".format(current_time)}
init_dict["eval.output_dir"] = {"value": "eval_{}".format(current_time)}
init_dict["infer.image_box"] = {"visible": False}
init_dict["infer.mm_box"] = {"visible": False}
if user_config.get("last_model", None):
init_dict["top.model_name"] = {"value": user_config["last_model"]}
......
......@@ -148,7 +148,7 @@ LOCALES = {
},
"zh": {
"label": "提示模板",
"info": "构建提示词时使用的模板",
"info": "构建提示词时使用的模板",
},
"ko": {
"label": "프롬프트 템플릿",
......@@ -183,20 +183,6 @@ LOCALES = {
"label": "부스터",
},
},
"visual_inputs": {
"en": {
"label": "Visual inputs",
},
"ru": {
"label": "визуальные входы",
},
"zh": {
"label": "图像输入",
},
"ko": {
"label": "시각적 입력",
},
},
"training_stage": {
"en": {
"label": "Stage",
......@@ -1705,6 +1691,20 @@ LOCALES = {
"label": "이미지 (선택 사항)",
},
},
"video": {
"en": {
"label": "Video (optional)",
},
"ru": {
"label": "Видео (по желанию)",
},
"zh": {
"label": "视频(非必填)",
},
"ko": {
"label": "비디오 (선택 사항)",
},
},
"query": {
"en": {
"placeholder": "Input...",
......
......@@ -75,5 +75,4 @@ class Manager:
self._id_to_elem["top.template"],
self._id_to_elem["top.rope_scaling"],
self._id_to_elem["top.booster"],
self._id_to_elem["top.visual_inputs"],
}
......@@ -115,7 +115,7 @@ class Runner:
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
use_unsloth=(get("top.booster") == "unsloth"),
visual_inputs=get("top.visual_inputs"),
enable_liger_kernel=(get("top.booster") == "liger_kernel"),
dataset_dir=get("train.dataset_dir"),
dataset=",".join(get("train.dataset")),
cutoff_len=get("train.cutoff_len"),
......@@ -251,7 +251,6 @@ class Runner:
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
use_unsloth=(get("top.booster") == "unsloth"),
visual_inputs=get("top.visual_inputs"),
dataset_dir=get("eval.dataset_dir"),
eval_dataset=",".join(get("eval.dataset")),
cutoff_len=get("eval.cutoff_len"),
......
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple
import pytest
import torch
from PIL import Image
from llamafactory.data.mm_plugin import get_mm_plugin
from llamafactory.hparams import ModelArguments
from llamafactory.model import load_tokenizer
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor
from llamafactory.data.mm_plugin import BasePlugin
HF_TOKEN = os.environ.get("HF_TOKEN", None)
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
MM_MESSAGES = [
{"role": "user", "content": "<image>What is in this image?"},
{"role": "assistant", "content": "A cat."},
]
TEXT_MESSAGES = [
{"role": "user", "content": "How are you"},
{"role": "assistant", "content": "I am fine!"},
]
IMAGES = [Image.new("RGB", (32, 32), (255, 255, 255))]
NO_IMAGES = []
NO_VIDEOS = []
IMGLENS = [1]
NO_IMGLENS = [0]
NO_VIDLENS = [0]
INPUT_IDS = [0, 1, 2, 3, 4]
LABELS = [0, 1, 2, 3, 4]
SEQLENS = [1024]
def _get_mm_inputs(processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
return image_processor(images=IMAGES, return_tensors="pt")
def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]) -> None:
assert batch_a.keys() == batch_b.keys()
for key in batch_a.keys():
if isinstance(batch_a[key], torch.Tensor):
assert torch.allclose(batch_a[key], batch_b[key], rtol=1e-4, atol=1e-5)
else:
assert batch_a[key] == batch_b[key]
def _load_tokenizer_module(model_name_or_path: str) -> Tuple["PreTrainedTokenizer", "ProcessorMixin"]:
model_args = ModelArguments(model_name_or_path=model_name_or_path)
tokenizer_module = load_tokenizer(model_args)
return tokenizer_module["tokenizer"], tokenizer_module["processor"]
def _check_plugin(
plugin: "BasePlugin",
tokenizer: "PreTrainedTokenizer",
processor: "ProcessorMixin",
expected_mm_messages: Sequence[Dict[str, str]] = MM_MESSAGES,
expected_input_ids: List[int] = INPUT_IDS,
expected_labels: List[int] = LABELS,
expected_mm_inputs: Dict[str, Any] = {},
expected_no_mm_inputs: Dict[str, Any] = {},
) -> None:
# test mm_messages
assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, processor) == expected_mm_messages
assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, tokenizer, processor) == (
expected_input_ids,
expected_labels,
)
_is_close(
plugin.get_mm_inputs(IMAGES, NO_VIDEOS, IMGLENS, NO_VIDLENS, SEQLENS, processor),
expected_mm_inputs,
)
# test text_messages
assert plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, NO_VIDEOS, processor) == TEXT_MESSAGES
assert plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, NO_VIDEOS, tokenizer, processor) == (
INPUT_IDS,
LABELS,
)
_is_close(
plugin.get_mm_inputs(NO_IMAGES, NO_VIDEOS, NO_IMGLENS, NO_VIDLENS, SEQLENS, processor),
expected_no_mm_inputs,
)
def test_base_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path=TINY_LLAMA)
base_plugin = get_mm_plugin(name="base", image_token="<image>")
check_inputs = {"plugin": base_plugin, "tokenizer": tokenizer, "processor": processor}
_check_plugin(**check_inputs)
def test_llava_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")
llava_plugin = get_mm_plugin(name="llava", image_token="<image>")
image_seqlen = 576
check_inputs = {"plugin": llava_plugin, "tokenizer": tokenizer, "processor": processor}
check_inputs["expected_mm_messages"] = [
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
_check_plugin(**check_inputs)
def test_llava_next_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path="llava-hf/llava-v1.6-vicuna-7b-hf")
llava_next_plugin = get_mm_plugin(name="llava_next", image_token="<image>")
check_inputs = {"plugin": llava_next_plugin, "tokenizer": tokenizer, "processor": processor}
image_seqlen = 1176
check_inputs["expected_mm_messages"] = [
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
_check_plugin(**check_inputs)
def test_llava_next_video_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path="llava-hf/LLaVA-NeXT-Video-7B-hf")
llava_next_video_plugin = get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>")
check_inputs = {"plugin": llava_next_video_plugin, "tokenizer": tokenizer, "processor": processor}
image_seqlen = 1176
check_inputs["expected_mm_messages"] = [
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
_check_plugin(**check_inputs)
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
def test_paligemma_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path="google/paligemma-3b-pt-224")
paligemma_plugin = get_mm_plugin(name="paligemma", image_token="<image>")
image_seqlen = 256
check_inputs = {"plugin": paligemma_plugin, "tokenizer": tokenizer, "processor": processor}
check_inputs["expected_mm_messages"] = [
{key: value.replace("<image>", "") for key, value in message.items()} for message in MM_MESSAGES
]
check_inputs["expected_input_ids"] = [tokenizer.convert_tokens_to_ids("<image>")] * image_seqlen + INPUT_IDS
check_inputs["expected_labels"] = [-100] * image_seqlen + LABELS
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
check_inputs["expected_mm_inputs"]["token_type_ids"] = [[0] * image_seqlen + [1] * (1024 - image_seqlen)]
check_inputs["expected_no_mm_inputs"] = {"token_type_ids": [[1] * 1024]}
_check_plugin(**check_inputs)
def test_qwen2_vl_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
qwen2_vl_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>")
image_seqlen = 4
check_inputs = {"plugin": qwen2_vl_plugin, "tokenizer": tokenizer, "processor": processor}
check_inputs["expected_mm_messages"] = [
{
key: value.replace("<image>", "<|vision_start|>{}<|vision_end|>".format("<|image_pad|>" * image_seqlen))
for key, value in message.items()
}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
_check_plugin(**check_inputs)
def test_video_llava_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path="LanguageBind/Video-LLaVA-7B-hf")
video_llava_plugin = get_mm_plugin(name="video_llava", image_token="<image>", video_token="<video>")
check_inputs = {"plugin": video_llava_plugin, "tokenizer": tokenizer, "processor": processor}
image_seqlen = 256
check_inputs["expected_mm_messages"] = [
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
_check_plugin(**check_inputs)
......@@ -19,6 +19,8 @@ import pytest
from transformers import AutoTokenizer
from llamafactory.data import get_template_and_fix_tokenizer
from llamafactory.data.template import _get_jinja_template
from llamafactory.hparams import DataArguments
if TYPE_CHECKING:
......@@ -51,7 +53,7 @@ def _check_single_template(
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=HF_TOKEN)
content_str = tokenizer.apply_chat_template(MESSAGES, tokenize=False)
content_ids = tokenizer.apply_chat_template(MESSAGES, tokenize=True)
template = get_template_and_fix_tokenizer(tokenizer, name=template_name)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template=template_name))
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
assert content_str == prompt_str + answer_str + extra_str
assert content_ids == prompt_ids + answer_ids + tokenizer.encode(extra_str, add_special_tokens=False)
......@@ -78,7 +80,7 @@ def _check_template(model_id: str, template_name: str, prompt_str: str, answer_s
@pytest.mark.parametrize("use_fast", [True, False])
def test_encode_oneturn(use_fast: bool):
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
template = get_template_and_fix_tokenizer(tokenizer, name="llama3")
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
prompt_str = (
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
......@@ -93,7 +95,7 @@ def test_encode_oneturn(use_fast: bool):
@pytest.mark.parametrize("use_fast", [True, False])
def test_encode_multiturn(use_fast: bool):
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
template = get_template_and_fix_tokenizer(tokenizer, name="llama3")
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES)
prompt_str_1 = (
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
......@@ -116,7 +118,8 @@ def test_encode_multiturn(use_fast: bool):
def test_jinja_template(use_fast: bool):
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast)
get_template_and_fix_tokenizer(tokenizer, name="llama3")
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
tokenizer.chat_template = _get_jinja_template(template, tokenizer) # llama3 template no replace
assert tokenizer.chat_template != ref_tokenizer.chat_template
assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES)
......@@ -157,7 +160,7 @@ def test_qwen_template():
_check_template("Qwen/Qwen2-7B-Instruct", "qwen", prompt_str, answer_str, extra_str="\n")
@pytest.mark.skip(reason="The fast tokenizer of Yi model is corrupted.")
@pytest.mark.xfail(reason="The fast tokenizer of Yi model is corrupted.")
def test_yi_template():
prompt_str = (
"<|im_start|>user\nHow are you<|im_end|>\n"
......
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from llamafactory.chat import ChatModel
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
INFER_ARGS = {
"model_name_or_path": TINY_LLAMA,
"finetuning_type": "lora",
"template": "llama3",
"infer_dtype": "float16",
"do_sample": False,
"max_new_tokens": 1,
}
MESSAGES = [
{"role": "user", "content": "Hi"},
]
EXPECTED_RESPONSE = "_rho"
def test_chat():
chat_model = ChatModel(INFER_ARGS)
assert chat_model.chat(MESSAGES)[0].response_text == EXPECTED_RESPONSE
def test_stream_chat():
chat_model = ChatModel(INFER_ARGS)
response = ""
for token in chat_model.stream_chat(MESSAGES):
response += token
assert response == EXPECTED_RESPONSE
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
from llamafactory.train.tuner import export_model, run_exp
DEMO_DATA = os.environ.get("DEMO_DATA", "llamafactory/demo_data")
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
TINY_LLAMA_ADAPTER = os.environ.get("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-lora")
TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA,
"do_train": True,
"finetuning_type": "lora",
"dataset_dir": "REMOTE:" + DEMO_DATA,
"template": "llama3",
"cutoff_len": 1,
"overwrite_cache": False,
"overwrite_output_dir": True,
"per_device_train_batch_size": 1,
"max_steps": 1,
}
INFER_ARGS = {
"model_name_or_path": TINY_LLAMA,
"adapter_name_or_path": TINY_LLAMA_ADAPTER,
"finetuning_type": "lora",
"template": "llama3",
"infer_dtype": "float16",
"export_dir": "llama3_export",
}
OS_NAME = os.environ.get("OS_NAME", "")
@pytest.mark.parametrize(
"stage,dataset",
[
("pt", "c4_demo"),
("sft", "alpaca_en_demo"),
("dpo", "dpo_en_demo"),
("kto", "kto_en_demo"),
pytest.param("rm", "dpo_en_demo", marks=pytest.mark.xfail(OS_NAME.startswith("windows"), reason="OS error.")),
],
)
def test_run_exp(stage: str, dataset: str):
output_dir = "train_{}".format(stage)
run_exp({"stage": stage, "dataset": dataset, "output_dir": output_dir, **TRAIN_ARGS})
assert os.path.exists(output_dir)
def test_export():
export_model(INFER_ARGS)
assert os.path.exists("llama3_export")
......@@ -51,6 +51,12 @@ def test_checkpointing_disable():
assert getattr(module, "gradient_checkpointing") is False
def test_unsloth_gradient_checkpointing():
model = load_train_model(use_unsloth_gc=True, **TRAIN_ARGS)
for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()):
assert module._gradient_checkpointing_func.__self__.__name__ == "UnslothGradientCheckpointing" # classmethod
def test_upcast_layernorm():
model = load_train_model(upcast_layernorm=True, **TRAIN_ARGS)
for name, param in model.named_parameters():
......
......@@ -14,6 +14,8 @@
import os
import pytest
from llamafactory.train.test_utils import compare_model, load_infer_model, load_reference_model, load_train_model
......@@ -47,13 +49,17 @@ INFER_ARGS = {
"infer_dtype": "float16",
}
OS_NAME = os.environ.get("OS_NAME", "")
@pytest.mark.xfail(OS_NAME.startswith("windows"), reason="Known connection error on Windows.")
def test_pissa_train():
model = load_train_model(**TRAIN_ARGS)
ref_model = load_reference_model(TINY_LLAMA_PISSA, TINY_LLAMA_PISSA, use_pissa=True, is_trainable=True)
compare_model(model, ref_model)
@pytest.mark.xfail(OS_NAME.startswith("windows"), reason="Known connection error on Windows.")
def test_pissa_inference():
model = load_infer_model(**INFER_ARGS)
ref_model = load_reference_model(TINY_LLAMA_PISSA, TINY_LLAMA_PISSA, use_pissa=True, is_trainable=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment