Commit 7ea81099 authored by chenych's avatar chenych
Browse files

update llama4

parent 84987715
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING
from ...data import TEMPLATES from ...data import TEMPLATES
from ...extras.constants import METHODS, SUPPORTED_MODELS from ...extras.constants import METHODS, SUPPORTED_MODELS
...@@ -29,7 +29,7 @@ if TYPE_CHECKING: ...@@ -29,7 +29,7 @@ if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component
def create_top() -> Dict[str, "Component"]: def create_top() -> dict[str, "Component"]:
with gr.Row(): with gr.Row():
lang = gr.Dropdown(choices=["en", "ru", "zh", "ko", "ja"], value=None, scale=1) lang = gr.Dropdown(choices=["en", "ru", "zh", "ko", "ja"], value=None, scale=1)
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"] available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING
from transformers.trainer_utils import SchedulerType from transformers.trainer_utils import SchedulerType
...@@ -34,7 +34,7 @@ if TYPE_CHECKING: ...@@ -34,7 +34,7 @@ if TYPE_CHECKING:
from ..engine import Engine from ..engine import Engine
def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
input_elems = engine.manager.get_base_elems() input_elems = engine.manager.get_base_elems()
elem_dict = dict() elem_dict = dict()
...@@ -382,8 +382,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: ...@@ -382,8 +382,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None) resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
lang = engine.manager.get_elem_by_id("top.lang") lang = engine.manager.get_elem_by_id("top.lang")
model_name: "gr.Dropdown" = engine.manager.get_elem_by_id("top.model_name") model_name: gr.Dropdown = engine.manager.get_elem_by_id("top.model_name")
finetuning_type: "gr.Dropdown" = engine.manager.get_elem_by_id("top.finetuning_type") finetuning_type: gr.Dropdown = engine.manager.get_elem_by_id("top.finetuning_type")
arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None) arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None)
arg_load_btn.click( arg_load_btn.click(
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import json import json
import os import os
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Optional
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
...@@ -39,8 +39,7 @@ if is_gradio_available(): ...@@ -39,8 +39,7 @@ if is_gradio_available():
def can_quantize(finetuning_type: str) -> "gr.Dropdown": def can_quantize(finetuning_type: str) -> "gr.Dropdown":
r""" r"""Judge if the quantization is available in this finetuning type.
Judges if the quantization is available in this finetuning type.
Inputs: top.finetuning_type Inputs: top.finetuning_type
Outputs: top.quantization_bit Outputs: top.quantization_bit
...@@ -52,8 +51,7 @@ def can_quantize(finetuning_type: str) -> "gr.Dropdown": ...@@ -52,8 +51,7 @@ def can_quantize(finetuning_type: str) -> "gr.Dropdown":
def can_quantize_to(quantization_method: str) -> "gr.Dropdown": def can_quantize_to(quantization_method: str) -> "gr.Dropdown":
r""" r"""Get the available quantization bits.
Gets the available quantization bits.
Inputs: top.quantization_method Inputs: top.quantization_method
Outputs: top.quantization_bit Outputs: top.quantization_bit
...@@ -68,9 +66,8 @@ def can_quantize_to(quantization_method: str) -> "gr.Dropdown": ...@@ -68,9 +66,8 @@ def can_quantize_to(quantization_method: str) -> "gr.Dropdown":
return gr.Dropdown(choices=available_bits) return gr.Dropdown(choices=available_bits)
def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Tuple[List[str], bool]: def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> tuple[list[str], bool]:
r""" r"""Modify states after changing the training stage.
Modifys states after changing the training stage.
Inputs: train.training_stage Inputs: train.training_stage
Outputs: train.dataset, train.packing Outputs: train.dataset, train.packing
...@@ -78,9 +75,8 @@ def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Tuple ...@@ -78,9 +75,8 @@ def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Tuple
return [], TRAINING_STAGES[training_stage] == "pt" return [], TRAINING_STAGES[training_stage] == "pt"
def get_model_info(model_name: str) -> Tuple[str, str]: def get_model_info(model_name: str) -> tuple[str, str]:
r""" r"""Get the necessary information of this model.
Gets the necessary information of this model.
Inputs: top.model_name Inputs: top.model_name
Outputs: top.model_path, top.template Outputs: top.model_path, top.template
...@@ -88,9 +84,8 @@ def get_model_info(model_name: str) -> Tuple[str, str]: ...@@ -88,9 +84,8 @@ def get_model_info(model_name: str) -> Tuple[str, str]:
return get_model_path(model_name), get_template(model_name) return get_model_path(model_name), get_template(model_name)
def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Dict[str, Any]]: def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> tuple[str, "gr.Slider", dict[str, Any]]:
r""" r"""Get training infomation for monitor.
Gets training infomation for monitor.
If do_train is True: If do_train is True:
Inputs: top.lang, train.output_path Inputs: top.lang, train.output_path
...@@ -110,7 +105,7 @@ def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> Tup ...@@ -110,7 +105,7 @@ def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> Tup
trainer_log_path = os.path.join(output_path, TRAINER_LOG) trainer_log_path = os.path.join(output_path, TRAINER_LOG)
if os.path.isfile(trainer_log_path): if os.path.isfile(trainer_log_path):
trainer_log: List[Dict[str, Any]] = [] trainer_log: list[dict[str, Any]] = []
with open(trainer_log_path, encoding="utf-8") as f: with open(trainer_log_path, encoding="utf-8") as f:
for line in f: for line in f:
trainer_log.append(json.loads(line)) trainer_log.append(json.loads(line))
...@@ -143,8 +138,7 @@ def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> Tup ...@@ -143,8 +138,7 @@ def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> Tup
def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown": def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown":
r""" r"""List all available checkpoints.
Lists all available checkpoints.
Inputs: top.model_name, top.finetuning_type Inputs: top.model_name, top.finetuning_type
Outputs: top.checkpoint_path Outputs: top.checkpoint_path
...@@ -166,8 +160,7 @@ def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown": ...@@ -166,8 +160,7 @@ def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown":
def list_config_paths(current_time: str) -> "gr.Dropdown": def list_config_paths(current_time: str) -> "gr.Dropdown":
r""" r"""List all the saved configuration files.
Lists all the saved configuration files.
Inputs: train.current_time Inputs: train.current_time
Outputs: train.config_path Outputs: train.config_path
...@@ -182,8 +175,7 @@ def list_config_paths(current_time: str) -> "gr.Dropdown": ...@@ -182,8 +175,7 @@ def list_config_paths(current_time: str) -> "gr.Dropdown":
def list_datasets(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Dropdown": def list_datasets(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Dropdown":
r""" r"""List all available datasets in the dataset dir for the training stage.
Lists all available datasets in the dataset dir for the training stage.
Inputs: *.dataset_dir, *.training_stage Inputs: *.dataset_dir, *.training_stage
Outputs: *.dataset Outputs: *.dataset
...@@ -195,8 +187,7 @@ def list_datasets(dataset_dir: str = None, training_stage: str = list(TRAINING_S ...@@ -195,8 +187,7 @@ def list_datasets(dataset_dir: str = None, training_stage: str = list(TRAINING_S
def list_output_dirs(model_name: Optional[str], finetuning_type: str, current_time: str) -> "gr.Dropdown": def list_output_dirs(model_name: Optional[str], finetuning_type: str, current_time: str) -> "gr.Dropdown":
r""" r"""List all the directories that can resume from.
Lists all the directories that can resume from.
Inputs: top.model_name, top.finetuning_type, train.current_time Inputs: top.model_name, top.finetuning_type, train.current_time
Outputs: train.output_dir Outputs: train.output_dir
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, Any, Dict from typing import TYPE_CHECKING, Any
from .chatter import WebChatModel from .chatter import WebChatModel
from .common import create_ds_config, get_time, load_config from .common import create_ds_config, get_time, load_config
...@@ -26,9 +26,7 @@ if TYPE_CHECKING: ...@@ -26,9 +26,7 @@ if TYPE_CHECKING:
class Engine: class Engine:
r""" r"""A general engine to control the behaviors of Web UI."""
A general engine to control the behaviors of Web UI.
"""
def __init__(self, demo_mode: bool = False, pure_chat: bool = False) -> None: def __init__(self, demo_mode: bool = False, pure_chat: bool = False) -> None:
self.demo_mode = demo_mode self.demo_mode = demo_mode
...@@ -39,11 +37,9 @@ class Engine: ...@@ -39,11 +37,9 @@ class Engine:
if not demo_mode: if not demo_mode:
create_ds_config() create_ds_config()
def _update_component(self, input_dict: Dict[str, Dict[str, Any]]) -> Dict["Component", "Component"]: def _update_component(self, input_dict: dict[str, dict[str, Any]]) -> dict["Component", "Component"]:
r""" r"""Update gradio components according to the (elem_id, properties) mapping."""
Updates gradio components according to the (elem_id, properties) mapping. output_dict: dict[Component, Component] = {}
"""
output_dict: Dict["Component", "Component"] = {}
for elem_id, elem_attr in input_dict.items(): for elem_id, elem_attr in input_dict.items():
elem = self.manager.get_elem_by_id(elem_id) elem = self.manager.get_elem_by_id(elem_id)
output_dict[elem] = elem.__class__(**elem_attr) output_dict[elem] = elem.__class__(**elem_attr)
...@@ -51,9 +47,7 @@ class Engine: ...@@ -51,9 +47,7 @@ class Engine:
return output_dict return output_dict
def resume(self): def resume(self):
r""" r"""Get the initial value of gradio components and restores training status if necessary."""
Gets the initial value of gradio components and restores training status if necessary.
"""
user_config = load_config() if not self.demo_mode else {} # do not use config in demo mode user_config = load_config() if not self.demo_mode else {} # do not use config in demo mode
lang = user_config.get("lang", None) or "en" lang = user_config.get("lang", None) or "en"
init_dict = {"top.lang": {"value": lang}, "infer.chat_box": {"visible": self.chatter.loaded}} init_dict = {"top.lang": {"value": lang}, "infer.chat_box": {"visible": self.chatter.loaded}}
...@@ -79,9 +73,7 @@ class Engine: ...@@ -79,9 +73,7 @@ class Engine:
yield self._update_component({"eval.resume_btn": {"value": True}}) yield self._update_component({"eval.resume_btn": {"value": True}})
def change_lang(self, lang: str): def change_lang(self, lang: str):
r""" r"""Update the displayed language of gradio components."""
Updates the displayed language of gradio components.
"""
return { return {
elem: elem.__class__(**LOCALES[elem_name][lang]) elem: elem.__class__(**LOCALES[elem_name][lang])
for elem_name, elem in self.manager.get_elem_iter() for elem_name, elem in self.manager.get_elem_iter()
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import os import os
import platform import platform
from ..extras.misc import is_env_enabled from ..extras.misc import fix_proxy, is_env_enabled
from ..extras.packages import is_gradio_available from ..extras.packages import is_gradio_available
from .common import save_config from .common import save_config
from .components import ( from .components import (
...@@ -48,7 +48,7 @@ def create_ui(demo_mode: bool = False) -> "gr.Blocks": ...@@ -48,7 +48,7 @@ def create_ui(demo_mode: bool = False) -> "gr.Blocks":
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button") gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
engine.manager.add_elems("top", create_top()) engine.manager.add_elems("top", create_top())
lang: "gr.Dropdown" = engine.manager.get_elem_by_id("top.lang") lang: gr.Dropdown = engine.manager.get_elem_by_id("top.lang")
with gr.Tab("Train"): with gr.Tab("Train"):
engine.manager.add_elems("train", create_train_tab(engine)) engine.manager.add_elems("train", create_train_tab(engine))
...@@ -72,8 +72,9 @@ def create_ui(demo_mode: bool = False) -> "gr.Blocks": ...@@ -72,8 +72,9 @@ def create_ui(demo_mode: bool = False) -> "gr.Blocks":
def create_web_demo() -> "gr.Blocks": def create_web_demo() -> "gr.Blocks":
engine = Engine(pure_chat=True) engine = Engine(pure_chat=True)
hostname = os.getenv("HOSTNAME", os.getenv("COMPUTERNAME", platform.node())).split(".")[0]
with gr.Blocks(title="Web Demo", css=CSS) as demo: with gr.Blocks(title=f"LLaMA Factory Web Demo ({hostname})", css=CSS) as demo:
lang = gr.Dropdown(choices=["en", "ru", "zh", "ko", "ja"], scale=1) lang = gr.Dropdown(choices=["en", "ru", "zh", "ko", "ja"], scale=1)
engine.manager.add_elems("top", dict(lang=lang)) engine.manager.add_elems("top", dict(lang=lang))
...@@ -91,6 +92,8 @@ def run_web_ui() -> None: ...@@ -91,6 +92,8 @@ def run_web_ui() -> None:
gradio_ipv6 = is_env_enabled("GRADIO_IPV6") gradio_ipv6 = is_env_enabled("GRADIO_IPV6")
gradio_share = is_env_enabled("GRADIO_SHARE") gradio_share = is_env_enabled("GRADIO_SHARE")
server_name = os.getenv("GRADIO_SERVER_NAME", "[::]" if gradio_ipv6 else "0.0.0.0") server_name = os.getenv("GRADIO_SERVER_NAME", "[::]" if gradio_ipv6 else "0.0.0.0")
print("Visit http://ip:port for Web UI, e.g., http://127.0.0.1:7860")
fix_proxy(ipv6_enabled=gradio_ipv6)
create_ui().queue().launch(share=gradio_share, server_name=server_name, inbrowser=True) create_ui().queue().launch(share=gradio_share, server_name=server_name, inbrowser=True)
...@@ -98,4 +101,6 @@ def run_web_demo() -> None: ...@@ -98,4 +101,6 @@ def run_web_demo() -> None:
gradio_ipv6 = is_env_enabled("GRADIO_IPV6") gradio_ipv6 = is_env_enabled("GRADIO_IPV6")
gradio_share = is_env_enabled("GRADIO_SHARE") gradio_share = is_env_enabled("GRADIO_SHARE")
server_name = os.getenv("GRADIO_SERVER_NAME", "[::]" if gradio_ipv6 else "0.0.0.0") server_name = os.getenv("GRADIO_SERVER_NAME", "[::]" if gradio_ipv6 else "0.0.0.0")
print("Visit http://ip:port for Web UI, e.g., http://127.0.0.1:7860")
fix_proxy(ipv6_enabled=gradio_ipv6)
create_web_demo().queue().launch(share=gradio_share, server_name=server_name, inbrowser=True) create_web_demo().queue().launch(share=gradio_share, server_name=server_name, inbrowser=True)
...@@ -12,7 +12,8 @@ ...@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, Dict, Generator, List, Set, Tuple from collections.abc import Generator
from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -20,54 +21,41 @@ if TYPE_CHECKING: ...@@ -20,54 +21,41 @@ if TYPE_CHECKING:
class Manager: class Manager:
r""" r"""A class to manage all the gradio components in Web UI."""
A class to manage all the gradio components in Web UI.
"""
def __init__(self) -> None: def __init__(self) -> None:
self._id_to_elem: Dict[str, "Component"] = {} self._id_to_elem: dict[str, Component] = {}
self._elem_to_id: Dict["Component", str] = {} self._elem_to_id: dict[Component, str] = {}
def add_elems(self, tab_name: str, elem_dict: Dict[str, "Component"]) -> None: def add_elems(self, tab_name: str, elem_dict: dict[str, "Component"]) -> None:
r""" r"""Add elements to manager."""
Adds elements to manager.
"""
for elem_name, elem in elem_dict.items(): for elem_name, elem in elem_dict.items():
elem_id = f"{tab_name}.{elem_name}" elem_id = f"{tab_name}.{elem_name}"
self._id_to_elem[elem_id] = elem self._id_to_elem[elem_id] = elem
self._elem_to_id[elem] = elem_id self._elem_to_id[elem] = elem_id
def get_elem_list(self) -> List["Component"]: def get_elem_list(self) -> list["Component"]:
r""" r"""Return the list of all elements."""
Returns the list of all elements.
"""
return list(self._id_to_elem.values()) return list(self._id_to_elem.values())
def get_elem_iter(self) -> Generator[Tuple[str, "Component"], None, None]: def get_elem_iter(self) -> Generator[tuple[str, "Component"], None, None]:
r""" r"""Return an iterator over all elements with their names."""
Returns an iterator over all elements with their names.
"""
for elem_id, elem in self._id_to_elem.items(): for elem_id, elem in self._id_to_elem.items():
yield elem_id.split(".")[-1], elem yield elem_id.split(".")[-1], elem
def get_elem_by_id(self, elem_id: str) -> "Component": def get_elem_by_id(self, elem_id: str) -> "Component":
r""" r"""Get element by id.
Gets element by id.
Example: top.lang, train.dataset Example: top.lang, train.dataset
""" """
return self._id_to_elem[elem_id] return self._id_to_elem[elem_id]
def get_id_by_elem(self, elem: "Component") -> str: def get_id_by_elem(self, elem: "Component") -> str:
r""" r"""Get id by element."""
Gets id by element.
"""
return self._elem_to_id[elem] return self._elem_to_id[elem]
def get_base_elems(self) -> Set["Component"]: def get_base_elems(self) -> set["Component"]:
r""" r"""Get the base elements that are commonly used."""
Gets the base elements that are commonly used.
"""
return { return {
self._id_to_elem["top.lang"], self._id_to_elem["top.lang"],
self._id_to_elem["top.model_name"], self._id_to_elem["top.model_name"],
......
...@@ -14,9 +14,10 @@ ...@@ -14,9 +14,10 @@
import json import json
import os import os
from collections.abc import Generator
from copy import deepcopy from copy import deepcopy
from subprocess import Popen, TimeoutExpired from subprocess import Popen, TimeoutExpired
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional from typing import TYPE_CHECKING, Any, Optional
from transformers.trainer import TRAINING_ARGS_NAME from transformers.trainer import TRAINING_ARGS_NAME
from transformers.utils import is_torch_npu_available from transformers.utils import is_torch_npu_available
...@@ -51,17 +52,16 @@ if TYPE_CHECKING: ...@@ -51,17 +52,16 @@ if TYPE_CHECKING:
class Runner: class Runner:
r""" r"""A class to manage the running status of the trainers."""
A class to manage the running status of the trainers.
"""
def __init__(self, manager: "Manager", demo_mode: bool = False) -> None: def __init__(self, manager: "Manager", demo_mode: bool = False) -> None:
r"""Init a runner."""
self.manager = manager self.manager = manager
self.demo_mode = demo_mode self.demo_mode = demo_mode
""" Resume """ """ Resume """
self.trainer: Optional["Popen"] = None self.trainer: Optional[Popen] = None
self.do_train = True self.do_train = True
self.running_data: Dict["Component", Any] = None self.running_data: dict[Component, Any] = None
""" State """ """ State """
self.aborted = False self.aborted = False
self.running = False self.running = False
...@@ -71,10 +71,8 @@ class Runner: ...@@ -71,10 +71,8 @@ class Runner:
if self.trainer is not None: if self.trainer is not None:
abort_process(self.trainer.pid) abort_process(self.trainer.pid)
def _initialize(self, data: Dict["Component", Any], do_train: bool, from_preview: bool) -> str: def _initialize(self, data: dict["Component", Any], do_train: bool, from_preview: bool) -> str:
r""" r"""Validate the configuration."""
Validates the configuration.
"""
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)] get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path") lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
dataset = get("train.dataset") if do_train else get("eval.dataset") dataset = get("train.dataset") if do_train else get("eval.dataset")
...@@ -116,9 +114,7 @@ class Runner: ...@@ -116,9 +114,7 @@ class Runner:
return "" return ""
def _finalize(self, lang: str, finish_info: str) -> str: def _finalize(self, lang: str, finish_info: str) -> str:
r""" r"""Clean the cached memory and resets the runner."""
Cleans the cached memory and resets the runner.
"""
finish_info = ALERTS["info_aborted"][lang] if self.aborted else finish_info finish_info = ALERTS["info_aborted"][lang] if self.aborted else finish_info
gr.Info(finish_info) gr.Info(finish_info)
self.trainer = None self.trainer = None
...@@ -128,10 +124,8 @@ class Runner: ...@@ -128,10 +124,8 @@ class Runner:
torch_gc() torch_gc()
return finish_info return finish_info
def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]: def _parse_train_args(self, data: dict["Component", Any]) -> dict[str, Any]:
r""" r"""Build and validate the training arguments."""
Builds and validates the training arguments.
"""
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)] get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type") model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
user_config = load_config() user_config = load_config()
...@@ -291,10 +285,8 @@ class Runner: ...@@ -291,10 +285,8 @@ class Runner:
return args return args
def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]: def _parse_eval_args(self, data: dict["Component", Any]) -> dict[str, Any]:
r""" r"""Build and validate the evaluation arguments."""
Builds and validates the evaluation arguments.
"""
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)] get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type") model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
user_config = load_config() user_config = load_config()
...@@ -345,10 +337,8 @@ class Runner: ...@@ -345,10 +337,8 @@ class Runner:
return args return args
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", str], None, None]: def _preview(self, data: dict["Component", Any], do_train: bool) -> Generator[dict["Component", str], None, None]:
r""" r"""Preview the training commands."""
Previews the training commands.
"""
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval")) output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
error = self._initialize(data, do_train, from_preview=True) error = self._initialize(data, do_train, from_preview=True)
if error: if error:
...@@ -358,10 +348,8 @@ class Runner: ...@@ -358,10 +348,8 @@ class Runner:
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data) args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
yield {output_box: gen_cmd(args)} yield {output_box: gen_cmd(args)}
def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", Any], None, None]: def _launch(self, data: dict["Component", Any], do_train: bool) -> Generator[dict["Component", Any], None, None]:
r""" r"""Start the training process."""
Starts the training process.
"""
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval")) output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
error = self._initialize(data, do_train, from_preview=False) error = self._initialize(data, do_train, from_preview=False)
if error: if error:
...@@ -383,10 +371,8 @@ class Runner: ...@@ -383,10 +371,8 @@ class Runner:
self.trainer = Popen(["llamafactory-cli", "train", save_cmd(args)], env=env) self.trainer = Popen(["llamafactory-cli", "train", save_cmd(args)], env=env)
yield from self.monitor() yield from self.monitor()
def _build_config_dict(self, data: Dict["Component", Any]) -> Dict[str, Any]: def _build_config_dict(self, data: dict["Component", Any]) -> dict[str, Any]:
r""" r"""Build a dictionary containing the current training configuration."""
Builds a dictionary containing the current training configuration.
"""
config_dict = {} config_dict = {}
skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path"] skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path"]
for elem, value in data.items(): for elem, value in data.items():
...@@ -409,9 +395,7 @@ class Runner: ...@@ -409,9 +395,7 @@ class Runner:
yield from self._launch(data, do_train=False) yield from self._launch(data, do_train=False)
def monitor(self): def monitor(self):
r""" r"""Monitorgit the training progress and logs."""
Monitors the training progress and logs.
"""
self.aborted = False self.aborted = False
self.running = True self.running = True
...@@ -469,9 +453,7 @@ class Runner: ...@@ -469,9 +453,7 @@ class Runner:
yield return_dict yield return_dict
def save_args(self, data): def save_args(self, data):
r""" r"""Save the training configuration to config path."""
Saves the training configuration to config path.
"""
output_box = self.manager.get_elem_by_id("train.output_box") output_box = self.manager.get_elem_by_id("train.output_box")
error = self._initialize(data, do_train=True, from_preview=True) error = self._initialize(data, do_train=True, from_preview=True)
if error: if error:
...@@ -487,27 +469,23 @@ class Runner: ...@@ -487,27 +469,23 @@ class Runner:
return {output_box: ALERTS["info_config_saved"][lang] + save_path} return {output_box: ALERTS["info_config_saved"][lang] + save_path}
def load_args(self, lang: str, config_path: str): def load_args(self, lang: str, config_path: str):
r""" r"""Load the training configuration from config path."""
Loads the training configuration from config path.
"""
output_box = self.manager.get_elem_by_id("train.output_box") output_box = self.manager.get_elem_by_id("train.output_box")
config_dict = load_args(os.path.join(DEFAULT_CONFIG_DIR, config_path)) config_dict = load_args(os.path.join(DEFAULT_CONFIG_DIR, config_path))
if config_dict is None: if config_dict is None:
gr.Warning(ALERTS["err_config_not_found"][lang]) gr.Warning(ALERTS["err_config_not_found"][lang])
return {output_box: ALERTS["err_config_not_found"][lang]} return {output_box: ALERTS["err_config_not_found"][lang]}
output_dict: Dict["Component", Any] = {output_box: ALERTS["info_config_loaded"][lang]} output_dict: dict[Component, Any] = {output_box: ALERTS["info_config_loaded"][lang]}
for elem_id, value in config_dict.items(): for elem_id, value in config_dict.items():
output_dict[self.manager.get_elem_by_id(elem_id)] = value output_dict[self.manager.get_elem_by_id(elem_id)] = value
return output_dict return output_dict
def check_output_dir(self, lang: str, model_name: str, finetuning_type: str, output_dir: str): def check_output_dir(self, lang: str, model_name: str, finetuning_type: str, output_dir: str):
r""" r"""Restore the training status if output_dir exists."""
Restore the training status if output_dir exists.
"""
output_box = self.manager.get_elem_by_id("train.output_box") output_box = self.manager.get_elem_by_id("train.output_box")
output_dict: Dict["Component", Any] = {output_box: LOCALES["output_box"][lang]["value"]} output_dict: dict[Component, Any] = {output_box: LOCALES["output_box"][lang]["value"]}
if model_name and output_dir and os.path.isdir(get_save_dir(model_name, finetuning_type, output_dir)): if model_name and output_dir and os.path.isdir(get_save_dir(model_name, finetuning_type, output_dir)):
gr.Warning(ALERTS["warn_output_dir_exists"][lang]) gr.Warning(ALERTS["warn_output_dir_exists"][lang])
output_dict[output_box] = ALERTS["warn_output_dir_exists"][lang] output_dict[output_box] = ALERTS["warn_output_dir_exists"][lang]
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import os import os
from llamafactory.extras.misc import is_env_enabled from llamafactory.extras.misc import fix_proxy, is_env_enabled
from llamafactory.webui.interface import create_ui from llamafactory.webui.interface import create_ui
...@@ -22,6 +22,8 @@ def main(): ...@@ -22,6 +22,8 @@ def main():
gradio_ipv6 = is_env_enabled("GRADIO_IPV6") gradio_ipv6 = is_env_enabled("GRADIO_IPV6")
gradio_share = is_env_enabled("GRADIO_SHARE") gradio_share = is_env_enabled("GRADIO_SHARE")
server_name = os.getenv("GRADIO_SERVER_NAME", "[::]" if gradio_ipv6 else "0.0.0.0") server_name = os.getenv("GRADIO_SERVER_NAME", "[::]" if gradio_ipv6 else "0.0.0.0")
print("Visit http://ip:port for Web UI, e.g., http://127.0.0.1:7860")
fix_proxy(ipv6_enabled=gradio_ipv6)
create_ui().queue().launch(share=gradio_share, server_name=server_name, inbrowser=True) create_ui().queue().launch(share=gradio_share, server_name=server_name, inbrowser=True)
......
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from pathlib import Path
KEYWORDS = ("Copyright", "2025", "LlamaFactory")
def main():
path_list: list[Path] = []
for check_dir in sys.argv[1:]:
path_list.extend(Path(check_dir).glob("**/*.py"))
for path in path_list:
with open(path.absolute(), encoding="utf-8") as f:
file_content = f.read().strip().split("\n")
if not file_content[0]:
continue
print(f"Check license: {path}")
assert all(keyword in file_content[0] for keyword in KEYWORDS), f"File {path} does not contain license."
if __name__ == "__main__":
main()
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
import os import os
import random import random
from typing import Dict, List
import pytest import pytest
from datasets import load_dataset from datasets import load_dataset
...@@ -43,7 +42,7 @@ TRAIN_ARGS = { ...@@ -43,7 +42,7 @@ TRAIN_ARGS = {
} }
def _convert_sharegpt_to_openai(messages: List[Dict[str, str]]) -> List[Dict[str, str]]: def _convert_sharegpt_to_openai(messages: list[dict[str, str]]) -> list[dict[str, str]]:
role_mapping = {"human": "user", "gpt": "assistant", "system": "system"} role_mapping = {"human": "user", "gpt": "assistant", "system": "system"}
new_messages = [] new_messages = []
for message in messages: for message in messages:
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Tuple
import pytest import pytest
...@@ -31,5 +30,5 @@ from llamafactory.data.processor.processor_utils import infer_seqlen ...@@ -31,5 +30,5 @@ from llamafactory.data.processor.processor_utils import infer_seqlen
((10, 10, 1000), (10, 10)), ((10, 10, 1000), (10, 10)),
], ],
) )
def test_infer_seqlen(test_input: Tuple[int, int, int], test_output: Tuple[int, int]): def test_infer_seqlen(test_input: tuple[int, int, int], test_output: tuple[int, int]):
assert test_output == infer_seqlen(*test_input) assert test_output == infer_seqlen(*test_input)
...@@ -112,7 +112,8 @@ def test_glm4_tool_formatter(): ...@@ -112,7 +112,8 @@ def test_glm4_tool_formatter():
assert formatter.apply(content=json.dumps(TOOLS)) == [ assert formatter.apply(content=json.dumps(TOOLS)) == [
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的," "你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具\n\n" "你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具\n\n"
f"## test_tool\n\n{json.dumps(TOOLS[0], indent=4, ensure_ascii=False)}\n在调用上述函数时,请使用 Json 格式表示调用的参数。" f"## test_tool\n\n{json.dumps(TOOLS[0], indent=4, ensure_ascii=False)}\n"
"在调用上述函数时,请使用 Json 格式表示调用的参数。"
] ]
...@@ -136,7 +137,8 @@ def test_llama3_tool_formatter(): ...@@ -136,7 +137,8 @@ def test_llama3_tool_formatter():
wrapped_tool = {"type": "function", "function": TOOLS[0]} wrapped_tool = {"type": "function", "function": TOOLS[0]}
assert formatter.apply(content=json.dumps(TOOLS)) == [ assert formatter.apply(content=json.dumps(TOOLS)) == [
f"Cutting Knowledge Date: December 2023\nToday Date: {date}\n\n" f"Cutting Knowledge Date: December 2023\nToday Date: {date}\n\n"
"You have access to the following functions. To call a function, please respond with JSON for a function call. " "You have access to the following functions. "
"To call a function, please respond with JSON for a function call. "
"""Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. """ """Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. """
f"Do not use variables.\n\n{json.dumps(wrapped_tool, indent=4, ensure_ascii=False)}\n\n" f"Do not use variables.\n\n{json.dumps(wrapped_tool, indent=4, ensure_ascii=False)}\n\n"
] ]
......
...@@ -13,13 +13,14 @@ ...@@ -13,13 +13,14 @@
# limitations under the License. # limitations under the License.
import os import os
from typing import TYPE_CHECKING, Any, Dict, List, Sequence from typing import TYPE_CHECKING, Any
import pytest import pytest
import torch import torch
from PIL import Image from PIL import Image
from llamafactory.data.mm_plugin import get_mm_plugin from llamafactory.data.mm_plugin import get_mm_plugin
from llamafactory.extras.packages import is_transformers_version_greater_than
from llamafactory.hparams import get_infer_args from llamafactory.hparams import get_infer_args
from llamafactory.model import load_tokenizer from llamafactory.model import load_tokenizer
...@@ -69,12 +70,12 @@ LABELS = [0, 1, 2, 3, 4] ...@@ -69,12 +70,12 @@ LABELS = [0, 1, 2, 3, 4]
BATCH_IDS = [[1] * 1024] BATCH_IDS = [[1] * 1024]
def _get_mm_inputs(processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]: def _get_mm_inputs(processor: "ProcessorMixin") -> dict[str, "torch.Tensor"]:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") image_processor: BaseImageProcessor = getattr(processor, "image_processor")
return image_processor(images=IMAGES, return_tensors="pt") return image_processor(images=IMAGES, return_tensors="pt")
def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]) -> None: def _is_close(batch_a: dict[str, Any], batch_b: dict[str, Any]) -> None:
assert batch_a.keys() == batch_b.keys() assert batch_a.keys() == batch_b.keys()
for key in batch_a.keys(): for key in batch_a.keys():
if isinstance(batch_a[key], torch.Tensor): if isinstance(batch_a[key], torch.Tensor):
...@@ -96,11 +97,11 @@ def _check_plugin( ...@@ -96,11 +97,11 @@ def _check_plugin(
plugin: "BasePlugin", plugin: "BasePlugin",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: "ProcessorMixin", processor: "ProcessorMixin",
expected_mm_messages: Sequence[Dict[str, str]] = MM_MESSAGES, expected_mm_messages: list[dict[str, str]] = MM_MESSAGES,
expected_input_ids: List[int] = INPUT_IDS, expected_input_ids: list[int] = INPUT_IDS,
expected_labels: List[int] = LABELS, expected_labels: list[int] = LABELS,
expected_mm_inputs: Dict[str, Any] = {}, expected_mm_inputs: dict[str, Any] = {},
expected_no_mm_inputs: Dict[str, Any] = {}, expected_no_mm_inputs: dict[str, Any] = {},
) -> None: ) -> None:
# test mm_messages # test mm_messages
if plugin.__class__.__name__ != "BasePlugin": if plugin.__class__.__name__ != "BasePlugin":
...@@ -135,6 +136,27 @@ def test_base_plugin(): ...@@ -135,6 +136,27 @@ def test_base_plugin():
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
@pytest.mark.skipif(not HF_TOKEN or not is_transformers_version_greater_than("4.50.0"), reason="Gated model.")
def test_gemma3_plugin():
image_seqlen = 256
tokenizer_module = _load_tokenizer_module(model_name_or_path="google/gemma-3-4b-it")
gemma3_plugin = get_mm_plugin(name="gemma3", image_token="<image_soft_token>")
image_tokens_expanded = "<image_soft_token>" * image_seqlen
check_inputs = {"plugin": gemma3_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [
{
key: value.replace("<image>", f"\n\n<start_of_image>{image_tokens_expanded}<end_of_image>\n\n")
for key, value in message.items()
}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
check_inputs["expected_mm_inputs"].pop("num_crops")
check_inputs["expected_mm_inputs"]["token_type_ids"] = [[0] * 1024]
check_inputs["expected_no_mm_inputs"] = {"token_type_ids": [[0] * 1024]}
_check_plugin(**check_inputs)
def test_llava_plugin(): def test_llava_plugin():
image_seqlen = 576 image_seqlen = 576
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf") tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")
...@@ -210,7 +232,6 @@ def test_pixtral_plugin(): ...@@ -210,7 +232,6 @@ def test_pixtral_plugin():
for message in MM_MESSAGES for message in MM_MESSAGES
] ]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"]) check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
check_inputs["expected_mm_inputs"].pop("image_sizes")
check_inputs["expected_mm_inputs"]["pixel_values"] = check_inputs["expected_mm_inputs"]["pixel_values"][0] check_inputs["expected_mm_inputs"]["pixel_values"] = check_inputs["expected_mm_inputs"]["pixel_values"][0]
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import os import os
from typing import TYPE_CHECKING, Sequence from typing import TYPE_CHECKING
import pytest import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
...@@ -40,10 +40,9 @@ MESSAGES = [ ...@@ -40,10 +40,9 @@ MESSAGES = [
def _check_tokenization( def _check_tokenization(
tokenizer: "PreTrainedTokenizer", batch_input_ids: Sequence[Sequence[int]], batch_text: Sequence[str] tokenizer: "PreTrainedTokenizer", batch_input_ids: list[list[int]], batch_text: list[str]
) -> None: ) -> None:
r""" r"""Check token ids and texts.
Checks token ids and texts.
encode(text) == token_ids encode(text) == token_ids
decode(token_ids) == text decode(token_ids) == text
...@@ -54,8 +53,7 @@ def _check_tokenization( ...@@ -54,8 +53,7 @@ def _check_tokenization(
def _check_template(model_id: str, template_name: str, prompt_str: str, answer_str: str, use_fast: bool) -> None: def _check_template(model_id: str, template_name: str, prompt_str: str, answer_str: str, use_fast: bool) -> None:
r""" r"""Check template.
Checks template.
Args: Args:
model_id: the model id on hugging face hub. model_id: the model id on hugging face hub.
...@@ -63,6 +61,7 @@ def _check_template(model_id: str, template_name: str, prompt_str: str, answer_s ...@@ -63,6 +61,7 @@ def _check_template(model_id: str, template_name: str, prompt_str: str, answer_s
prompt_str: the string corresponding to the prompt part. prompt_str: the string corresponding to the prompt part.
answer_str: the string corresponding to the answer part. answer_str: the string corresponding to the answer part.
use_fast: whether to use fast tokenizer. use_fast: whether to use fast tokenizer.
""" """
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=HF_TOKEN) tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=HF_TOKEN)
content_str = tokenizer.apply_chat_template(MESSAGES, tokenize=False) content_str = tokenizer.apply_chat_template(MESSAGES, tokenize=False)
......
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import pytest
from llamafactory.chat import ChatModel
from llamafactory.extras.packages import is_sglang_available
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
INFER_ARGS = {
"model_name_or_path": MODEL_NAME,
"finetuning_type": "lora",
"template": "llama3",
"infer_dtype": "float16",
"infer_backend": "sglang",
"do_sample": False,
"max_new_tokens": 1,
}
MESSAGES = [
{"role": "user", "content": "Hi"},
]
@pytest.mark.skipif(not is_sglang_available(), reason="SGLang is not installed")
def test_chat():
r"""Test the SGLang engine's basic chat functionality."""
chat_model = ChatModel(INFER_ARGS)
response = chat_model.chat(MESSAGES)[0]
# TODO: Change to EXPECTED_RESPONSE
print(response.response_text)
@pytest.mark.skipif(not is_sglang_available(), reason="SGLang is not installed")
def test_stream_chat():
r"""Test the SGLang engine's streaming chat functionality."""
chat_model = ChatModel(INFER_ARGS)
response = ""
for token in chat_model.stream_chat(MESSAGES):
response += token
print("Complete response:", response)
assert response, "Should receive a non-empty response"
# Run tests if executed directly
if __name__ == "__main__":
if not is_sglang_available():
print("SGLang is not available. Please install it.")
sys.exit(1)
test_chat()
test_stream_chat()
...@@ -62,5 +62,5 @@ def test_upcast_layernorm(): ...@@ -62,5 +62,5 @@ def test_upcast_layernorm():
def test_upcast_lmhead_output(): def test_upcast_lmhead_output():
model = load_train_model(upcast_lmhead_output=True, **TRAIN_ARGS) model = load_train_model(upcast_lmhead_output=True, **TRAIN_ARGS)
inputs = torch.randn((1, 16), dtype=torch.float16, device=get_current_device()) inputs = torch.randn((1, 16), dtype=torch.float16, device=get_current_device())
outputs: "torch.Tensor" = model.get_output_embeddings()(inputs) outputs: torch.Tensor = model.get_output_embeddings()(inputs)
assert outputs.dtype == torch.float32 assert outputs.dtype == torch.float32
...@@ -48,8 +48,6 @@ INFER_ARGS = { ...@@ -48,8 +48,6 @@ INFER_ARGS = {
"infer_dtype": "float16", "infer_dtype": "float16",
} }
OS_NAME = os.getenv("OS_NAME", "")
@pytest.mark.xfail(reason="PiSSA initialization is not stable in different platform.") @pytest.mark.xfail(reason="PiSSA initialization is not stable in different platform.")
def test_pissa_train(): def test_pissa_train():
...@@ -58,7 +56,7 @@ def test_pissa_train(): ...@@ -58,7 +56,7 @@ def test_pissa_train():
compare_model(model, ref_model) compare_model(model, ref_model)
@pytest.mark.xfail(OS_NAME.startswith("windows"), reason="Known connection error on Windows.") @pytest.mark.xfail(reason="Known connection error.")
def test_pissa_inference(): def test_pissa_inference():
model = load_infer_model(**INFER_ARGS) model = load_infer_model(**INFER_ARGS)
ref_model = load_reference_model(TINY_LLAMA_PISSA, TINY_LLAMA_PISSA, use_pissa=True, is_trainable=False) ref_model = load_reference_model(TINY_LLAMA_PISSA, TINY_LLAMA_PISSA, use_pissa=True, is_trainable=False)
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import os import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, List from typing import Any
import pytest import pytest
from transformers import DataCollatorWithPadding from transformers import DataCollatorWithPadding
...@@ -46,9 +46,9 @@ TRAIN_ARGS = { ...@@ -46,9 +46,9 @@ TRAIN_ARGS = {
@dataclass @dataclass
class DataCollatorWithVerbose(DataCollatorWithPadding): class DataCollatorWithVerbose(DataCollatorWithPadding):
verbose_list: List[Dict[str, Any]] = field(default_factory=list) verbose_list: list[dict[str, Any]] = field(default_factory=list)
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
self.verbose_list.extend(features) self.verbose_list.extend(features)
batch = super().__call__(features) batch = super().__call__(features)
return {k: v[:, :1] for k, v in batch.items()} # truncate input length return {k: v[:, :1] for k, v in batch.items()} # truncate input length
......
# change if test fails
0.9.3.101
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment