Commit 7ea81099 authored by chenych's avatar chenych
Browse files

update llama4

parent 84987715
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING
from ...data import TEMPLATES
from ...extras.constants import METHODS, SUPPORTED_MODELS
......@@ -29,7 +29,7 @@ if TYPE_CHECKING:
from gradio.components import Component
def create_top() -> Dict[str, "Component"]:
def create_top() -> dict[str, "Component"]:
with gr.Row():
lang = gr.Dropdown(choices=["en", "ru", "zh", "ko", "ja"], value=None, scale=1)
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING
from transformers.trainer_utils import SchedulerType
......@@ -34,7 +34,7 @@ if TYPE_CHECKING:
from ..engine import Engine
def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
input_elems = engine.manager.get_base_elems()
elem_dict = dict()
......@@ -382,8 +382,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
lang = engine.manager.get_elem_by_id("top.lang")
model_name: "gr.Dropdown" = engine.manager.get_elem_by_id("top.model_name")
finetuning_type: "gr.Dropdown" = engine.manager.get_elem_by_id("top.finetuning_type")
model_name: gr.Dropdown = engine.manager.get_elem_by_id("top.model_name")
finetuning_type: gr.Dropdown = engine.manager.get_elem_by_id("top.finetuning_type")
arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None)
arg_load_btn.click(
......
......@@ -14,7 +14,7 @@
import json
import os
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Optional
from transformers.trainer_utils import get_last_checkpoint
......@@ -39,8 +39,7 @@ if is_gradio_available():
def can_quantize(finetuning_type: str) -> "gr.Dropdown":
r"""
Judges if the quantization is available in this finetuning type.
r"""Judge if the quantization is available in this finetuning type.
Inputs: top.finetuning_type
Outputs: top.quantization_bit
......@@ -52,8 +51,7 @@ def can_quantize(finetuning_type: str) -> "gr.Dropdown":
def can_quantize_to(quantization_method: str) -> "gr.Dropdown":
r"""
Gets the available quantization bits.
r"""Get the available quantization bits.
Inputs: top.quantization_method
Outputs: top.quantization_bit
......@@ -68,9 +66,8 @@ def can_quantize_to(quantization_method: str) -> "gr.Dropdown":
return gr.Dropdown(choices=available_bits)
def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Tuple[List[str], bool]:
r"""
Modifys states after changing the training stage.
def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> tuple[list[str], bool]:
r"""Modify states after changing the training stage.
Inputs: train.training_stage
Outputs: train.dataset, train.packing
......@@ -78,9 +75,8 @@ def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Tuple
return [], TRAINING_STAGES[training_stage] == "pt"
def get_model_info(model_name: str) -> Tuple[str, str]:
r"""
Gets the necessary information of this model.
def get_model_info(model_name: str) -> tuple[str, str]:
r"""Get the necessary information of this model.
Inputs: top.model_name
Outputs: top.model_path, top.template
......@@ -88,9 +84,8 @@ def get_model_info(model_name: str) -> Tuple[str, str]:
return get_model_path(model_name), get_template(model_name)
def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Dict[str, Any]]:
r"""
Gets training infomation for monitor.
def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> tuple[str, "gr.Slider", dict[str, Any]]:
r"""Get training infomation for monitor.
If do_train is True:
Inputs: top.lang, train.output_path
......@@ -110,7 +105,7 @@ def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> Tup
trainer_log_path = os.path.join(output_path, TRAINER_LOG)
if os.path.isfile(trainer_log_path):
trainer_log: List[Dict[str, Any]] = []
trainer_log: list[dict[str, Any]] = []
with open(trainer_log_path, encoding="utf-8") as f:
for line in f:
trainer_log.append(json.loads(line))
......@@ -143,8 +138,7 @@ def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> Tup
def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown":
r"""
Lists all available checkpoints.
r"""List all available checkpoints.
Inputs: top.model_name, top.finetuning_type
Outputs: top.checkpoint_path
......@@ -166,8 +160,7 @@ def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown":
def list_config_paths(current_time: str) -> "gr.Dropdown":
r"""
Lists all the saved configuration files.
r"""List all the saved configuration files.
Inputs: train.current_time
Outputs: train.config_path
......@@ -182,8 +175,7 @@ def list_config_paths(current_time: str) -> "gr.Dropdown":
def list_datasets(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Dropdown":
r"""
Lists all available datasets in the dataset dir for the training stage.
r"""List all available datasets in the dataset dir for the training stage.
Inputs: *.dataset_dir, *.training_stage
Outputs: *.dataset
......@@ -195,8 +187,7 @@ def list_datasets(dataset_dir: str = None, training_stage: str = list(TRAINING_S
def list_output_dirs(model_name: Optional[str], finetuning_type: str, current_time: str) -> "gr.Dropdown":
r"""
Lists all the directories that can resume from.
r"""List all the directories that can resume from.
Inputs: top.model_name, top.finetuning_type, train.current_time
Outputs: train.output_dir
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict
from typing import TYPE_CHECKING, Any
from .chatter import WebChatModel
from .common import create_ds_config, get_time, load_config
......@@ -26,9 +26,7 @@ if TYPE_CHECKING:
class Engine:
r"""
A general engine to control the behaviors of Web UI.
"""
r"""A general engine to control the behaviors of Web UI."""
def __init__(self, demo_mode: bool = False, pure_chat: bool = False) -> None:
self.demo_mode = demo_mode
......@@ -39,11 +37,9 @@ class Engine:
if not demo_mode:
create_ds_config()
def _update_component(self, input_dict: Dict[str, Dict[str, Any]]) -> Dict["Component", "Component"]:
r"""
Updates gradio components according to the (elem_id, properties) mapping.
"""
output_dict: Dict["Component", "Component"] = {}
def _update_component(self, input_dict: dict[str, dict[str, Any]]) -> dict["Component", "Component"]:
r"""Update gradio components according to the (elem_id, properties) mapping."""
output_dict: dict[Component, Component] = {}
for elem_id, elem_attr in input_dict.items():
elem = self.manager.get_elem_by_id(elem_id)
output_dict[elem] = elem.__class__(**elem_attr)
......@@ -51,9 +47,7 @@ class Engine:
return output_dict
def resume(self):
r"""
Gets the initial value of gradio components and restores training status if necessary.
"""
r"""Get the initial value of gradio components and restores training status if necessary."""
user_config = load_config() if not self.demo_mode else {} # do not use config in demo mode
lang = user_config.get("lang", None) or "en"
init_dict = {"top.lang": {"value": lang}, "infer.chat_box": {"visible": self.chatter.loaded}}
......@@ -79,9 +73,7 @@ class Engine:
yield self._update_component({"eval.resume_btn": {"value": True}})
def change_lang(self, lang: str):
r"""
Updates the displayed language of gradio components.
"""
r"""Update the displayed language of gradio components."""
return {
elem: elem.__class__(**LOCALES[elem_name][lang])
for elem_name, elem in self.manager.get_elem_iter()
......
......@@ -15,7 +15,7 @@
import os
import platform
from ..extras.misc import is_env_enabled
from ..extras.misc import fix_proxy, is_env_enabled
from ..extras.packages import is_gradio_available
from .common import save_config
from .components import (
......@@ -48,7 +48,7 @@ def create_ui(demo_mode: bool = False) -> "gr.Blocks":
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
engine.manager.add_elems("top", create_top())
lang: "gr.Dropdown" = engine.manager.get_elem_by_id("top.lang")
lang: gr.Dropdown = engine.manager.get_elem_by_id("top.lang")
with gr.Tab("Train"):
engine.manager.add_elems("train", create_train_tab(engine))
......@@ -72,8 +72,9 @@ def create_ui(demo_mode: bool = False) -> "gr.Blocks":
def create_web_demo() -> "gr.Blocks":
engine = Engine(pure_chat=True)
hostname = os.getenv("HOSTNAME", os.getenv("COMPUTERNAME", platform.node())).split(".")[0]
with gr.Blocks(title="Web Demo", css=CSS) as demo:
with gr.Blocks(title=f"LLaMA Factory Web Demo ({hostname})", css=CSS) as demo:
lang = gr.Dropdown(choices=["en", "ru", "zh", "ko", "ja"], scale=1)
engine.manager.add_elems("top", dict(lang=lang))
......@@ -91,6 +92,8 @@ def run_web_ui() -> None:
gradio_ipv6 = is_env_enabled("GRADIO_IPV6")
gradio_share = is_env_enabled("GRADIO_SHARE")
server_name = os.getenv("GRADIO_SERVER_NAME", "[::]" if gradio_ipv6 else "0.0.0.0")
print("Visit http://ip:port for Web UI, e.g., http://127.0.0.1:7860")
fix_proxy(ipv6_enabled=gradio_ipv6)
create_ui().queue().launch(share=gradio_share, server_name=server_name, inbrowser=True)
......@@ -98,4 +101,6 @@ def run_web_demo() -> None:
gradio_ipv6 = is_env_enabled("GRADIO_IPV6")
gradio_share = is_env_enabled("GRADIO_SHARE")
server_name = os.getenv("GRADIO_SERVER_NAME", "[::]" if gradio_ipv6 else "0.0.0.0")
print("Visit http://ip:port for Web UI, e.g., http://127.0.0.1:7860")
fix_proxy(ipv6_enabled=gradio_ipv6)
create_web_demo().queue().launch(share=gradio_share, server_name=server_name, inbrowser=True)
......@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict, Generator, List, Set, Tuple
from collections.abc import Generator
from typing import TYPE_CHECKING
if TYPE_CHECKING:
......@@ -20,54 +21,41 @@ if TYPE_CHECKING:
class Manager:
r"""
A class to manage all the gradio components in Web UI.
"""
r"""A class to manage all the gradio components in Web UI."""
def __init__(self) -> None:
self._id_to_elem: Dict[str, "Component"] = {}
self._elem_to_id: Dict["Component", str] = {}
self._id_to_elem: dict[str, Component] = {}
self._elem_to_id: dict[Component, str] = {}
def add_elems(self, tab_name: str, elem_dict: Dict[str, "Component"]) -> None:
r"""
Adds elements to manager.
"""
def add_elems(self, tab_name: str, elem_dict: dict[str, "Component"]) -> None:
r"""Add elements to manager."""
for elem_name, elem in elem_dict.items():
elem_id = f"{tab_name}.{elem_name}"
self._id_to_elem[elem_id] = elem
self._elem_to_id[elem] = elem_id
def get_elem_list(self) -> List["Component"]:
r"""
Returns the list of all elements.
"""
def get_elem_list(self) -> list["Component"]:
r"""Return the list of all elements."""
return list(self._id_to_elem.values())
def get_elem_iter(self) -> Generator[Tuple[str, "Component"], None, None]:
r"""
Returns an iterator over all elements with their names.
"""
def get_elem_iter(self) -> Generator[tuple[str, "Component"], None, None]:
r"""Return an iterator over all elements with their names."""
for elem_id, elem in self._id_to_elem.items():
yield elem_id.split(".")[-1], elem
def get_elem_by_id(self, elem_id: str) -> "Component":
r"""
Gets element by id.
r"""Get element by id.
Example: top.lang, train.dataset
"""
return self._id_to_elem[elem_id]
def get_id_by_elem(self, elem: "Component") -> str:
r"""
Gets id by element.
"""
r"""Get id by element."""
return self._elem_to_id[elem]
def get_base_elems(self) -> Set["Component"]:
r"""
Gets the base elements that are commonly used.
"""
def get_base_elems(self) -> set["Component"]:
r"""Get the base elements that are commonly used."""
return {
self._id_to_elem["top.lang"],
self._id_to_elem["top.model_name"],
......
......@@ -14,9 +14,10 @@
import json
import os
from collections.abc import Generator
from copy import deepcopy
from subprocess import Popen, TimeoutExpired
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
from typing import TYPE_CHECKING, Any, Optional
from transformers.trainer import TRAINING_ARGS_NAME
from transformers.utils import is_torch_npu_available
......@@ -51,17 +52,16 @@ if TYPE_CHECKING:
class Runner:
r"""
A class to manage the running status of the trainers.
"""
r"""A class to manage the running status of the trainers."""
def __init__(self, manager: "Manager", demo_mode: bool = False) -> None:
r"""Init a runner."""
self.manager = manager
self.demo_mode = demo_mode
""" Resume """
self.trainer: Optional["Popen"] = None
self.trainer: Optional[Popen] = None
self.do_train = True
self.running_data: Dict["Component", Any] = None
self.running_data: dict[Component, Any] = None
""" State """
self.aborted = False
self.running = False
......@@ -71,10 +71,8 @@ class Runner:
if self.trainer is not None:
abort_process(self.trainer.pid)
def _initialize(self, data: Dict["Component", Any], do_train: bool, from_preview: bool) -> str:
r"""
Validates the configuration.
"""
def _initialize(self, data: dict["Component", Any], do_train: bool, from_preview: bool) -> str:
r"""Validate the configuration."""
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
dataset = get("train.dataset") if do_train else get("eval.dataset")
......@@ -116,9 +114,7 @@ class Runner:
return ""
def _finalize(self, lang: str, finish_info: str) -> str:
r"""
Cleans the cached memory and resets the runner.
"""
r"""Clean the cached memory and resets the runner."""
finish_info = ALERTS["info_aborted"][lang] if self.aborted else finish_info
gr.Info(finish_info)
self.trainer = None
......@@ -128,10 +124,8 @@ class Runner:
torch_gc()
return finish_info
def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
r"""
Builds and validates the training arguments.
"""
def _parse_train_args(self, data: dict["Component", Any]) -> dict[str, Any]:
r"""Build and validate the training arguments."""
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
user_config = load_config()
......@@ -291,10 +285,8 @@ class Runner:
return args
def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
r"""
Builds and validates the evaluation arguments.
"""
def _parse_eval_args(self, data: dict["Component", Any]) -> dict[str, Any]:
r"""Build and validate the evaluation arguments."""
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
user_config = load_config()
......@@ -345,10 +337,8 @@ class Runner:
return args
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", str], None, None]:
r"""
Previews the training commands.
"""
def _preview(self, data: dict["Component", Any], do_train: bool) -> Generator[dict["Component", str], None, None]:
r"""Preview the training commands."""
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
error = self._initialize(data, do_train, from_preview=True)
if error:
......@@ -358,10 +348,8 @@ class Runner:
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
yield {output_box: gen_cmd(args)}
def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", Any], None, None]:
r"""
Starts the training process.
"""
def _launch(self, data: dict["Component", Any], do_train: bool) -> Generator[dict["Component", Any], None, None]:
r"""Start the training process."""
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
error = self._initialize(data, do_train, from_preview=False)
if error:
......@@ -383,10 +371,8 @@ class Runner:
self.trainer = Popen(["llamafactory-cli", "train", save_cmd(args)], env=env)
yield from self.monitor()
def _build_config_dict(self, data: Dict["Component", Any]) -> Dict[str, Any]:
r"""
Builds a dictionary containing the current training configuration.
"""
def _build_config_dict(self, data: dict["Component", Any]) -> dict[str, Any]:
r"""Build a dictionary containing the current training configuration."""
config_dict = {}
skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path"]
for elem, value in data.items():
......@@ -409,9 +395,7 @@ class Runner:
yield from self._launch(data, do_train=False)
def monitor(self):
r"""
Monitors the training progress and logs.
"""
r"""Monitorgit the training progress and logs."""
self.aborted = False
self.running = True
......@@ -469,9 +453,7 @@ class Runner:
yield return_dict
def save_args(self, data):
r"""
Saves the training configuration to config path.
"""
r"""Save the training configuration to config path."""
output_box = self.manager.get_elem_by_id("train.output_box")
error = self._initialize(data, do_train=True, from_preview=True)
if error:
......@@ -487,27 +469,23 @@ class Runner:
return {output_box: ALERTS["info_config_saved"][lang] + save_path}
def load_args(self, lang: str, config_path: str):
r"""
Loads the training configuration from config path.
"""
r"""Load the training configuration from config path."""
output_box = self.manager.get_elem_by_id("train.output_box")
config_dict = load_args(os.path.join(DEFAULT_CONFIG_DIR, config_path))
if config_dict is None:
gr.Warning(ALERTS["err_config_not_found"][lang])
return {output_box: ALERTS["err_config_not_found"][lang]}
output_dict: Dict["Component", Any] = {output_box: ALERTS["info_config_loaded"][lang]}
output_dict: dict[Component, Any] = {output_box: ALERTS["info_config_loaded"][lang]}
for elem_id, value in config_dict.items():
output_dict[self.manager.get_elem_by_id(elem_id)] = value
return output_dict
def check_output_dir(self, lang: str, model_name: str, finetuning_type: str, output_dir: str):
r"""
Restore the training status if output_dir exists.
"""
r"""Restore the training status if output_dir exists."""
output_box = self.manager.get_elem_by_id("train.output_box")
output_dict: Dict["Component", Any] = {output_box: LOCALES["output_box"][lang]["value"]}
output_dict: dict[Component, Any] = {output_box: LOCALES["output_box"][lang]["value"]}
if model_name and output_dir and os.path.isdir(get_save_dir(model_name, finetuning_type, output_dir)):
gr.Warning(ALERTS["warn_output_dir_exists"][lang])
output_dict[output_box] = ALERTS["warn_output_dir_exists"][lang]
......
......@@ -14,7 +14,7 @@
import os
from llamafactory.extras.misc import is_env_enabled
from llamafactory.extras.misc import fix_proxy, is_env_enabled
from llamafactory.webui.interface import create_ui
......@@ -22,6 +22,8 @@ def main():
gradio_ipv6 = is_env_enabled("GRADIO_IPV6")
gradio_share = is_env_enabled("GRADIO_SHARE")
server_name = os.getenv("GRADIO_SERVER_NAME", "[::]" if gradio_ipv6 else "0.0.0.0")
print("Visit http://ip:port for Web UI, e.g., http://127.0.0.1:7860")
fix_proxy(ipv6_enabled=gradio_ipv6)
create_ui().queue().launch(share=gradio_share, server_name=server_name, inbrowser=True)
......
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from pathlib import Path
KEYWORDS = ("Copyright", "2025", "LlamaFactory")
def main():
path_list: list[Path] = []
for check_dir in sys.argv[1:]:
path_list.extend(Path(check_dir).glob("**/*.py"))
for path in path_list:
with open(path.absolute(), encoding="utf-8") as f:
file_content = f.read().strip().split("\n")
if not file_content[0]:
continue
print(f"Check license: {path}")
assert all(keyword in file_content[0] for keyword in KEYWORDS), f"File {path} does not contain license."
if __name__ == "__main__":
main()
......@@ -14,7 +14,6 @@
import os
import random
from typing import Dict, List
import pytest
from datasets import load_dataset
......@@ -43,7 +42,7 @@ TRAIN_ARGS = {
}
def _convert_sharegpt_to_openai(messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
def _convert_sharegpt_to_openai(messages: list[dict[str, str]]) -> list[dict[str, str]]:
role_mapping = {"human": "user", "gpt": "assistant", "system": "system"}
new_messages = []
for message in messages:
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import pytest
......@@ -31,5 +30,5 @@ from llamafactory.data.processor.processor_utils import infer_seqlen
((10, 10, 1000), (10, 10)),
],
)
def test_infer_seqlen(test_input: Tuple[int, int, int], test_output: Tuple[int, int]):
def test_infer_seqlen(test_input: tuple[int, int, int], test_output: tuple[int, int]):
assert test_output == infer_seqlen(*test_input)
......@@ -112,7 +112,8 @@ def test_glm4_tool_formatter():
assert formatter.apply(content=json.dumps(TOOLS)) == [
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具\n\n"
f"## test_tool\n\n{json.dumps(TOOLS[0], indent=4, ensure_ascii=False)}\n在调用上述函数时,请使用 Json 格式表示调用的参数。"
f"## test_tool\n\n{json.dumps(TOOLS[0], indent=4, ensure_ascii=False)}\n"
"在调用上述函数时,请使用 Json 格式表示调用的参数。"
]
......@@ -136,7 +137,8 @@ def test_llama3_tool_formatter():
wrapped_tool = {"type": "function", "function": TOOLS[0]}
assert formatter.apply(content=json.dumps(TOOLS)) == [
f"Cutting Knowledge Date: December 2023\nToday Date: {date}\n\n"
"You have access to the following functions. To call a function, please respond with JSON for a function call. "
"You have access to the following functions. "
"To call a function, please respond with JSON for a function call. "
"""Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. """
f"Do not use variables.\n\n{json.dumps(wrapped_tool, indent=4, ensure_ascii=False)}\n\n"
]
......
......@@ -13,13 +13,14 @@
# limitations under the License.
import os
from typing import TYPE_CHECKING, Any, Dict, List, Sequence
from typing import TYPE_CHECKING, Any
import pytest
import torch
from PIL import Image
from llamafactory.data.mm_plugin import get_mm_plugin
from llamafactory.extras.packages import is_transformers_version_greater_than
from llamafactory.hparams import get_infer_args
from llamafactory.model import load_tokenizer
......@@ -69,12 +70,12 @@ LABELS = [0, 1, 2, 3, 4]
BATCH_IDS = [[1] * 1024]
def _get_mm_inputs(processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
def _get_mm_inputs(processor: "ProcessorMixin") -> dict[str, "torch.Tensor"]:
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
return image_processor(images=IMAGES, return_tensors="pt")
def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]) -> None:
def _is_close(batch_a: dict[str, Any], batch_b: dict[str, Any]) -> None:
assert batch_a.keys() == batch_b.keys()
for key in batch_a.keys():
if isinstance(batch_a[key], torch.Tensor):
......@@ -96,11 +97,11 @@ def _check_plugin(
plugin: "BasePlugin",
tokenizer: "PreTrainedTokenizer",
processor: "ProcessorMixin",
expected_mm_messages: Sequence[Dict[str, str]] = MM_MESSAGES,
expected_input_ids: List[int] = INPUT_IDS,
expected_labels: List[int] = LABELS,
expected_mm_inputs: Dict[str, Any] = {},
expected_no_mm_inputs: Dict[str, Any] = {},
expected_mm_messages: list[dict[str, str]] = MM_MESSAGES,
expected_input_ids: list[int] = INPUT_IDS,
expected_labels: list[int] = LABELS,
expected_mm_inputs: dict[str, Any] = {},
expected_no_mm_inputs: dict[str, Any] = {},
) -> None:
# test mm_messages
if plugin.__class__.__name__ != "BasePlugin":
......@@ -135,6 +136,27 @@ def test_base_plugin():
_check_plugin(**check_inputs)
@pytest.mark.skipif(not HF_TOKEN or not is_transformers_version_greater_than("4.50.0"), reason="Gated model.")
def test_gemma3_plugin():
image_seqlen = 256
tokenizer_module = _load_tokenizer_module(model_name_or_path="google/gemma-3-4b-it")
gemma3_plugin = get_mm_plugin(name="gemma3", image_token="<image_soft_token>")
image_tokens_expanded = "<image_soft_token>" * image_seqlen
check_inputs = {"plugin": gemma3_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [
{
key: value.replace("<image>", f"\n\n<start_of_image>{image_tokens_expanded}<end_of_image>\n\n")
for key, value in message.items()
}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
check_inputs["expected_mm_inputs"].pop("num_crops")
check_inputs["expected_mm_inputs"]["token_type_ids"] = [[0] * 1024]
check_inputs["expected_no_mm_inputs"] = {"token_type_ids": [[0] * 1024]}
_check_plugin(**check_inputs)
def test_llava_plugin():
image_seqlen = 576
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")
......@@ -210,7 +232,6 @@ def test_pixtral_plugin():
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
check_inputs["expected_mm_inputs"].pop("image_sizes")
check_inputs["expected_mm_inputs"]["pixel_values"] = check_inputs["expected_mm_inputs"]["pixel_values"][0]
_check_plugin(**check_inputs)
......
......@@ -13,7 +13,7 @@
# limitations under the License.
import os
from typing import TYPE_CHECKING, Sequence
from typing import TYPE_CHECKING
import pytest
from transformers import AutoTokenizer
......@@ -40,10 +40,9 @@ MESSAGES = [
def _check_tokenization(
tokenizer: "PreTrainedTokenizer", batch_input_ids: Sequence[Sequence[int]], batch_text: Sequence[str]
tokenizer: "PreTrainedTokenizer", batch_input_ids: list[list[int]], batch_text: list[str]
) -> None:
r"""
Checks token ids and texts.
r"""Check token ids and texts.
encode(text) == token_ids
decode(token_ids) == text
......@@ -54,8 +53,7 @@ def _check_tokenization(
def _check_template(model_id: str, template_name: str, prompt_str: str, answer_str: str, use_fast: bool) -> None:
r"""
Checks template.
r"""Check template.
Args:
model_id: the model id on hugging face hub.
......@@ -63,6 +61,7 @@ def _check_template(model_id: str, template_name: str, prompt_str: str, answer_s
prompt_str: the string corresponding to the prompt part.
answer_str: the string corresponding to the answer part.
use_fast: whether to use fast tokenizer.
"""
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=HF_TOKEN)
content_str = tokenizer.apply_chat_template(MESSAGES, tokenize=False)
......
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import pytest
from llamafactory.chat import ChatModel
from llamafactory.extras.packages import is_sglang_available
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
INFER_ARGS = {
"model_name_or_path": MODEL_NAME,
"finetuning_type": "lora",
"template": "llama3",
"infer_dtype": "float16",
"infer_backend": "sglang",
"do_sample": False,
"max_new_tokens": 1,
}
MESSAGES = [
{"role": "user", "content": "Hi"},
]
@pytest.mark.skipif(not is_sglang_available(), reason="SGLang is not installed")
def test_chat():
r"""Test the SGLang engine's basic chat functionality."""
chat_model = ChatModel(INFER_ARGS)
response = chat_model.chat(MESSAGES)[0]
# TODO: Change to EXPECTED_RESPONSE
print(response.response_text)
@pytest.mark.skipif(not is_sglang_available(), reason="SGLang is not installed")
def test_stream_chat():
r"""Test the SGLang engine's streaming chat functionality."""
chat_model = ChatModel(INFER_ARGS)
response = ""
for token in chat_model.stream_chat(MESSAGES):
response += token
print("Complete response:", response)
assert response, "Should receive a non-empty response"
# Run tests if executed directly
if __name__ == "__main__":
if not is_sglang_available():
print("SGLang is not available. Please install it.")
sys.exit(1)
test_chat()
test_stream_chat()
......@@ -62,5 +62,5 @@ def test_upcast_layernorm():
def test_upcast_lmhead_output():
model = load_train_model(upcast_lmhead_output=True, **TRAIN_ARGS)
inputs = torch.randn((1, 16), dtype=torch.float16, device=get_current_device())
outputs: "torch.Tensor" = model.get_output_embeddings()(inputs)
outputs: torch.Tensor = model.get_output_embeddings()(inputs)
assert outputs.dtype == torch.float32
......@@ -48,8 +48,6 @@ INFER_ARGS = {
"infer_dtype": "float16",
}
OS_NAME = os.getenv("OS_NAME", "")
@pytest.mark.xfail(reason="PiSSA initialization is not stable in different platform.")
def test_pissa_train():
......@@ -58,7 +56,7 @@ def test_pissa_train():
compare_model(model, ref_model)
@pytest.mark.xfail(OS_NAME.startswith("windows"), reason="Known connection error on Windows.")
@pytest.mark.xfail(reason="Known connection error.")
def test_pissa_inference():
model = load_infer_model(**INFER_ARGS)
ref_model = load_reference_model(TINY_LLAMA_PISSA, TINY_LLAMA_PISSA, use_pissa=True, is_trainable=False)
......
......@@ -14,7 +14,7 @@
import os
from dataclasses import dataclass, field
from typing import Any, Dict, List
from typing import Any
import pytest
from transformers import DataCollatorWithPadding
......@@ -46,9 +46,9 @@ TRAIN_ARGS = {
@dataclass
class DataCollatorWithVerbose(DataCollatorWithPadding):
verbose_list: List[Dict[str, Any]] = field(default_factory=list)
verbose_list: list[dict[str, Any]] = field(default_factory=list)
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
self.verbose_list.extend(features)
batch = super().__call__(features)
return {k: v[:, :1] for k, v in batch.items()} # truncate input length
......
# change if test fails
0.9.3.101
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment