Commit d1588ee7 authored by chenych's avatar chenych
Browse files

update 0718

parent 358bd2a0
...@@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Literal, Union ...@@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Literal, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import transformers.dynamic_module_utils import transformers.dynamic_module_utils
from huggingface_hub.utils import WeakFileLock
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
from transformers.dynamic_module_utils import get_relative_imports from transformers.dynamic_module_utils import get_relative_imports
from transformers.utils import ( from transformers.utils import (
...@@ -35,7 +36,6 @@ from transformers.utils import ( ...@@ -35,7 +36,6 @@ from transformers.utils import (
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from . import logging from . import logging
from .packages import is_transformers_version_greater_than
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available() _is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
...@@ -94,15 +94,11 @@ def check_version(requirement: str, mandatory: bool = False) -> None: ...@@ -94,15 +94,11 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
def check_dependencies() -> None: def check_dependencies() -> None:
r"""Check the version of the required packages.""" r"""Check the version of the required packages."""
check_version( check_version("transformers>=4.49.0,<=4.52.4,!=4.52.0")
"transformers>=4.45.0,<=4.52.4,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0,!=4.52.0"
)
check_version("datasets>=2.16.0,<=3.6.0") check_version("datasets>=2.16.0,<=3.6.0")
check_version("accelerate>=0.34.0,<=1.7.0") check_version("accelerate>=1.3.0,<=1.7.0")
check_version("peft>=0.14.0,<=0.15.2") check_version("peft>=0.14.0,<=0.15.2")
check_version("trl>=0.8.6,<=0.9.6") check_version("trl>=0.8.6,<=0.9.6")
if is_transformers_version_greater_than("4.46.0") and not is_transformers_version_greater_than("4.48.1"):
logger.warning_rank0_once("There are known bugs in transformers v4.46.0-v4.48.0, please use other versions.")
def calculate_tps(dataset: list[dict[str, Any]], metrics: dict[str, float], stage: Literal["sft", "rm"]) -> float: def calculate_tps(dataset: list[dict[str, Any]], metrics: dict[str, float], stage: Literal["sft", "rm"]) -> float:
...@@ -182,8 +178,22 @@ def get_logits_processor() -> "LogitsProcessorList": ...@@ -182,8 +178,22 @@ def get_logits_processor() -> "LogitsProcessorList":
return logits_processor return logits_processor
def get_current_memory() -> tuple[int, int]:
r"""Get the available and total memory for the current device (in Bytes)."""
if is_torch_xpu_available():
return torch.xpu.mem_get_info()
elif is_torch_npu_available():
return torch.npu.mem_get_info()
elif is_torch_mps_available():
return torch.mps.current_allocated_memory(), torch.mps.recommended_max_memory()
elif is_torch_cuda_available():
return torch.cuda.mem_get_info()
else:
return 0, -1
def get_peak_memory() -> tuple[int, int]: def get_peak_memory() -> tuple[int, int]:
r"""Get the peak memory usage for the current device (in Bytes).""" r"""Get the peak memory usage (allocated, reserved) for the current device (in Bytes)."""
if is_torch_xpu_available(): if is_torch_xpu_available():
return torch.xpu.max_memory_allocated(), torch.xpu.max_memory_reserved() return torch.xpu.max_memory_allocated(), torch.xpu.max_memory_reserved()
elif is_torch_npu_available(): elif is_torch_npu_available():
...@@ -193,7 +203,7 @@ def get_peak_memory() -> tuple[int, int]: ...@@ -193,7 +203,7 @@ def get_peak_memory() -> tuple[int, int]:
elif is_torch_cuda_available(): elif is_torch_cuda_available():
return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved() return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved()
else: else:
return 0, 0 return 0, -1
def has_tokenized_data(path: "os.PathLike") -> bool: def has_tokenized_data(path: "os.PathLike") -> bool:
...@@ -259,26 +269,37 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str: ...@@ -259,26 +269,37 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
return model_args.model_name_or_path return model_args.model_name_or_path
if use_modelscope(): if use_modelscope():
check_version("modelscope>=1.11.0", mandatory=True) check_version("modelscope>=1.14.0", mandatory=True)
from modelscope import snapshot_download # type: ignore from modelscope import snapshot_download # type: ignore
from modelscope.hub.api import HubApi # type: ignore
if model_args.ms_hub_token:
api = HubApi()
api.login(model_args.ms_hub_token)
revision = "master" if model_args.model_revision == "main" else model_args.model_revision revision = "master" if model_args.model_revision == "main" else model_args.model_revision
return snapshot_download( with WeakFileLock(os.path.abspath(os.path.expanduser("~/.cache/llamafactory/modelscope.lock"))):
model_path = snapshot_download(
model_args.model_name_or_path, model_args.model_name_or_path,
revision=revision, revision=revision,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
) )
return model_path
if use_openmind(): if use_openmind():
check_version("openmind>=0.8.0", mandatory=True) check_version("openmind>=0.8.0", mandatory=True)
from openmind.utils.hub import snapshot_download # type: ignore from openmind.utils.hub import snapshot_download # type: ignore
return snapshot_download( with WeakFileLock(os.path.abspath(os.path.expanduser("~/.cache/llamafactory/openmind.lock"))):
model_path = snapshot_download(
model_args.model_name_or_path, model_args.model_name_or_path,
revision=model_args.model_revision, revision=model_args.model_revision,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
) )
return model_path
def use_modelscope() -> bool: def use_modelscope() -> bool:
return is_env_enabled("USE_MODELSCOPE_HUB") return is_env_enabled("USE_MODELSCOPE_HUB")
...@@ -305,5 +326,5 @@ def fix_proxy(ipv6_enabled: bool = False) -> None: ...@@ -305,5 +326,5 @@ def fix_proxy(ipv6_enabled: bool = False) -> None:
r"""Fix proxy settings for gradio ui.""" r"""Fix proxy settings for gradio ui."""
os.environ["no_proxy"] = "localhost,127.0.0.1,0.0.0.0" os.environ["no_proxy"] = "localhost,127.0.0.1,0.0.0.0"
if ipv6_enabled: if ipv6_enabled:
for name in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"): os.environ.pop("http_proxy", None)
os.environ.pop(name, None) os.environ.pop("HTTP_PROXY", None)
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
import os import os
import sys import sys
from pathlib import Path from pathlib import Path
...@@ -23,7 +22,6 @@ from typing import Any, Optional, Union ...@@ -23,7 +22,6 @@ from typing import Any, Optional, Union
import torch import torch
import transformers import transformers
import yaml
from omegaconf import OmegaConf from omegaconf import OmegaConf
from transformers import HfArgumentParser from transformers import HfArgumentParser
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
...@@ -62,11 +60,11 @@ def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[ ...@@ -62,11 +60,11 @@ def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[
if sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml"): if sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml"):
override_config = OmegaConf.from_cli(sys.argv[2:]) override_config = OmegaConf.from_cli(sys.argv[2:])
dict_config = yaml.safe_load(Path(sys.argv[1]).absolute().read_text()) dict_config = OmegaConf.load(Path(sys.argv[1]).absolute())
return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config)) return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
elif sys.argv[1].endswith(".json"): elif sys.argv[1].endswith(".json"):
override_config = OmegaConf.from_cli(sys.argv[2:]) override_config = OmegaConf.from_cli(sys.argv[2:])
dict_config = json.loads(Path(sys.argv[1]).absolute().read_text()) dict_config = OmegaConf.load(Path(sys.argv[1]).absolute())
return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config)) return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
else: else:
return sys.argv[1:] return sys.argv[1:]
...@@ -166,6 +164,9 @@ def _check_extra_dependencies( ...@@ -166,6 +164,9 @@ def _check_extra_dependencies(
if finetuning_args.use_adam_mini: if finetuning_args.use_adam_mini:
check_version("adam-mini", mandatory=True) check_version("adam-mini", mandatory=True)
if finetuning_args.use_swanlab:
check_version("swanlab", mandatory=True)
if finetuning_args.plot_loss: if finetuning_args.plot_loss:
check_version("matplotlib", mandatory=True) check_version("matplotlib", mandatory=True)
...@@ -348,6 +349,9 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _ ...@@ -348,6 +349,9 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
# https://github.com/huggingface/transformers/blob/v4.50.0/src/transformers/trainer.py#L782 # https://github.com/huggingface/transformers/blob/v4.50.0/src/transformers/trainer.py#L782
training_args.label_names = training_args.label_names or ["labels"] training_args.label_names = training_args.label_names or ["labels"]
if "swanlab" in training_args.report_to and finetuning_args.use_swanlab:
training_args.report_to.remove("swanlab")
if ( if (
training_args.parallel_mode == ParallelMode.DISTRIBUTED training_args.parallel_mode == ParallelMode.DISTRIBUTED
and training_args.ddp_find_unused_parameters is None and training_args.ddp_find_unused_parameters is None
......
...@@ -188,7 +188,7 @@ def _setup_lora_tuning( ...@@ -188,7 +188,7 @@ def _setup_lora_tuning(
if adapter_to_resume is not None: # resume lora training if adapter_to_resume is not None: # resume lora training
if model_args.use_unsloth: if model_args.use_unsloth:
model = load_unsloth_peft_model(config, model_args, is_trainable=is_trainable) model = load_unsloth_peft_model(config, model_args, finetuning_args, is_trainable=is_trainable)
else: else:
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs) model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
......
...@@ -19,6 +19,7 @@ import torch ...@@ -19,6 +19,7 @@ import torch
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoModelForImageTextToText,
AutoModelForSeq2SeqLM, AutoModelForSeq2SeqLM,
AutoModelForTextToWaveform, AutoModelForTextToWaveform,
AutoModelForVision2Seq, AutoModelForVision2Seq,
...@@ -29,7 +30,6 @@ from trl import AutoModelForCausalLMWithValueHead ...@@ -29,7 +30,6 @@ from trl import AutoModelForCausalLMWithValueHead
from ..extras import logging from ..extras import logging
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
from ..extras.packages import is_transformers_version_greater_than
from .adapter import init_adapter from .adapter import init_adapter
from .model_utils.liger_kernel import apply_liger_kernel from .model_utils.liger_kernel import apply_liger_kernel
from .model_utils.misc import register_autoclass from .model_utils.misc import register_autoclass
...@@ -39,10 +39,6 @@ from .model_utils.valuehead import load_valuehead_params ...@@ -39,10 +39,6 @@ from .model_utils.valuehead import load_valuehead_params
from .patcher import patch_config, patch_model, patch_processor, patch_tokenizer, patch_valuehead_model from .patcher import patch_config, patch_model, patch_processor, patch_tokenizer, patch_valuehead_model
if is_transformers_version_greater_than("4.46.0"):
from transformers import AutoModelForImageTextToText
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
...@@ -111,9 +107,8 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule": ...@@ -111,9 +107,8 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
**init_kwargs, **init_kwargs,
) )
except Exception as e: except Exception as e:
raise OSError("Failed to load processor.") from e logger.info_rank0(f"Failed to load processor: {e}.")
processor = None
patch_processor(processor, tokenizer, model_args)
# Avoid load tokenizer, see: # Avoid load tokenizer, see:
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324 # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324
...@@ -121,6 +116,9 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule": ...@@ -121,6 +116,9 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
logger.debug("The loaded processor is not an instance of Processor. Dropping it.") logger.debug("The loaded processor is not an instance of Processor. Dropping it.")
processor = None processor = None
if processor is not None:
patch_processor(processor, tokenizer, model_args)
return {"tokenizer": tokenizer, "processor": processor} return {"tokenizer": tokenizer, "processor": processor}
...@@ -160,10 +158,7 @@ def load_model( ...@@ -160,10 +158,7 @@ def load_model(
else: else:
if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # image-text if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # image-text
load_class = AutoModelForVision2Seq load_class = AutoModelForVision2Seq
elif ( elif type(config) in AutoModelForImageTextToText._model_mapping.keys(): # image-text
is_transformers_version_greater_than("4.46.0")
and type(config) in AutoModelForImageTextToText._model_mapping.keys()
): # image-text
load_class = AutoModelForImageTextToText load_class = AutoModelForImageTextToText
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): # audio-text elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): # audio-text
load_class = AutoModelForSeq2SeqLM load_class = AutoModelForSeq2SeqLM
......
...@@ -80,12 +80,15 @@ def get_unsloth_peft_model( ...@@ -80,12 +80,15 @@ def get_unsloth_peft_model(
def load_unsloth_peft_model( def load_unsloth_peft_model(
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool config: "PretrainedConfig",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool,
) -> "PreTrainedModel": ) -> "PreTrainedModel":
r"""Load peft model with unsloth. Used in both training and inference.""" r"""Load peft model with unsloth. Used in both training and inference."""
from unsloth import FastLanguageModel # type: ignore from unsloth import FastLanguageModel # type: ignore
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args) unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args, finetuning_args)
try: try:
if not is_trainable: if not is_trainable:
unsloth_kwargs["use_gradient_checkpointing"] = False unsloth_kwargs["use_gradient_checkpointing"] = False
......
...@@ -49,7 +49,7 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> ...@@ -49,7 +49,7 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") ->
try: try:
vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs) vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs)
return torch.load(vhead_file, map_location="cpu") return torch.load(vhead_file, map_location="cpu", weights_only=True)
except Exception as err: except Exception as err:
err_text = str(err) err_text = str(err)
......
...@@ -204,6 +204,23 @@ _register_composite_model( ...@@ -204,6 +204,23 @@ _register_composite_model(
) )
_register_composite_model(
model_type="gemma3n",
vision_model_keys=["vision_tower", "audio_tower"],
lora_conflict_keys=["timm_model", "subsample_conv_projection"],
)
# copied from qwen2vl
_register_composite_model(
model_type="glm4v",
projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],
)
_register_composite_model( _register_composite_model(
model_type="internvl", model_type="internvl",
) )
......
...@@ -178,6 +178,9 @@ def patch_model( ...@@ -178,6 +178,9 @@ def patch_model(
resize_embedding_layer(model, tokenizer) resize_embedding_layer(model, tokenizer)
if is_trainable: if is_trainable:
if getattr(model.config, "model_type", None) == "gemma3n":
setattr(model_args, "disable_gradient_checkpointing", True)
prepare_model_for_training(model, model_args) prepare_model_for_training(model, model_args)
autocast_projector_dtype(model, model_args) autocast_projector_dtype(model, model_args)
add_z3_leaf_module(model) add_z3_leaf_module(model)
......
...@@ -76,7 +76,7 @@ def fix_valuehead_checkpoint( ...@@ -76,7 +76,7 @@ def fix_valuehead_checkpoint(
state_dict: dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()} state_dict: dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
else: else:
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME) path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
state_dict: dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu") state_dict: dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu", weights_only=True)
os.remove(path_to_checkpoint) os.remove(path_to_checkpoint)
decoder_state_dict, v_head_state_dict = {}, {} decoder_state_dict, v_head_state_dict = {}, {}
......
...@@ -77,14 +77,19 @@ def load_config() -> dict[str, Union[str, dict[str, Any]]]: ...@@ -77,14 +77,19 @@ def load_config() -> dict[str, Union[str, dict[str, Any]]]:
with open(_get_config_path(), encoding="utf-8") as f: with open(_get_config_path(), encoding="utf-8") as f:
return safe_load(f) return safe_load(f)
except Exception: except Exception:
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None} return {"lang": None, "hub_name": None, "last_model": None, "path_dict": {}, "cache_dir": None}
def save_config(lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None) -> None: def save_config(
lang: str, hub_name: Optional[str] = None, model_name: Optional[str] = None, model_path: Optional[str] = None
) -> None:
r"""Save user config.""" r"""Save user config."""
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True) os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
user_config = load_config() user_config = load_config()
user_config["lang"] = lang or user_config["lang"] user_config["lang"] = lang or user_config["lang"]
if hub_name:
user_config["hub_name"] = hub_name
if model_name: if model_name:
user_config["last_model"] = model_name user_config["last_model"] = model_name
...@@ -247,7 +252,7 @@ def create_ds_config() -> None: ...@@ -247,7 +252,7 @@ def create_ds_config() -> None:
"stage": 2, "stage": 2,
"allgather_partitions": True, "allgather_partitions": True,
"allgather_bucket_size": 5e8, "allgather_bucket_size": 5e8,
"overlap_comm": True, "overlap_comm": False,
"reduce_scatter": True, "reduce_scatter": True,
"reduce_bucket_size": 5e8, "reduce_bucket_size": 5e8,
"contiguous_gradients": True, "contiguous_gradients": True,
...@@ -262,7 +267,7 @@ def create_ds_config() -> None: ...@@ -262,7 +267,7 @@ def create_ds_config() -> None:
ds_config["zero_optimization"] = { ds_config["zero_optimization"] = {
"stage": 3, "stage": 3,
"overlap_comm": True, "overlap_comm": False,
"contiguous_gradients": True, "contiguous_gradients": True,
"sub_group_size": 1e9, "sub_group_size": 1e9,
"reduce_bucket_size": "auto", "reduce_bucket_size": "auto",
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from .chatbot import create_chat_box from .chatbot import create_chat_box
from .eval import create_eval_tab from .eval import create_eval_tab
from .export import create_export_tab from .export import create_export_tab
from .footer import create_footer
from .infer import create_infer_tab from .infer import create_infer_tab
from .top import create_top from .top import create_top
from .train import create_train_tab from .train import create_train_tab
...@@ -24,6 +25,7 @@ __all__ = [ ...@@ -24,6 +25,7 @@ __all__ = [
"create_chat_box", "create_chat_box",
"create_eval_tab", "create_eval_tab",
"create_export_tab", "create_export_tab",
"create_footer",
"create_infer_tab", "create_infer_tab",
"create_top", "create_top",
"create_train_tab", "create_train_tab",
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
import json import json
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
...@@ -50,7 +51,14 @@ def create_chat_box( ...@@ -50,7 +51,14 @@ def create_chat_box(
) -> tuple["Component", "Component", dict[str, "Component"]]: ) -> tuple["Component", "Component", dict[str, "Component"]]:
lang = engine.manager.get_elem_by_id("top.lang") lang = engine.manager.get_elem_by_id("top.lang")
with gr.Column(visible=visible) as chat_box: with gr.Column(visible=visible) as chat_box:
chatbot = gr.Chatbot(type="messages", show_copy_button=True) kwargs = {}
if "show_copy_button" in inspect.signature(gr.Chatbot.__init__).parameters:
kwargs["show_copy_button"] = True
if "resizable" in inspect.signature(gr.Chatbot.__init__).parameters:
kwargs["resizable"] = True
chatbot = gr.Chatbot(type="messages", **kwargs)
messages = gr.State([]) messages = gr.State([])
with gr.Row(): with gr.Row():
with gr.Column(scale=4): with gr.Column(scale=4):
......
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...extras.misc import get_current_memory
from ...extras.packages import is_gradio_available
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING:
from gradio.components import Component
def get_device_memory() -> "gr.Slider":
free, total = get_current_memory()
if total != -1:
used = round((total - free) / (1024**3), 2)
total = round(total / (1024**3), 2)
return gr.Slider(minimum=0, maximum=total, value=used, step=0.01, visible=True)
else:
return gr.Slider(visible=False)
def create_footer() -> dict[str, "Component"]:
with gr.Row():
device_memory = gr.Slider(visible=False, interactive=False)
timer = gr.Timer(value=5)
timer.tick(get_device_memory, outputs=[device_memory], queue=False)
return dict(device_memory=device_memory)
...@@ -16,9 +16,10 @@ from typing import TYPE_CHECKING ...@@ -16,9 +16,10 @@ from typing import TYPE_CHECKING
from ...data import TEMPLATES from ...data import TEMPLATES
from ...extras.constants import METHODS, SUPPORTED_MODELS from ...extras.constants import METHODS, SUPPORTED_MODELS
from ...extras.misc import use_modelscope, use_openmind
from ...extras.packages import is_gradio_available from ...extras.packages import is_gradio_available
from ..common import save_config from ..common import save_config
from ..control import can_quantize, can_quantize_to, check_template, get_model_info, list_checkpoints from ..control import can_quantize, can_quantize_to, check_template, get_model_info, list_checkpoints, switch_hub
if is_gradio_available(): if is_gradio_available():
...@@ -33,8 +34,10 @@ def create_top() -> dict[str, "Component"]: ...@@ -33,8 +34,10 @@ def create_top() -> dict[str, "Component"]:
with gr.Row(): with gr.Row():
lang = gr.Dropdown(choices=["en", "ru", "zh", "ko", "ja"], value=None, scale=1) lang = gr.Dropdown(choices=["en", "ru", "zh", "ko", "ja"], value=None, scale=1)
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"] available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
model_name = gr.Dropdown(choices=available_models, value=None, scale=3) model_name = gr.Dropdown(choices=available_models, value=None, scale=2)
model_path = gr.Textbox(scale=3) model_path = gr.Textbox(scale=2)
default_hub = "modelscope" if use_modelscope() else "openmind" if use_openmind() else "huggingface"
hub_name = gr.Dropdown(choices=["huggingface", "modelscope", "openmind"], value=default_hub, scale=2)
with gr.Row(): with gr.Row():
finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1) finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1)
...@@ -50,18 +53,25 @@ def create_top() -> dict[str, "Component"]: ...@@ -50,18 +53,25 @@ def create_top() -> dict[str, "Component"]:
model_name.change(get_model_info, [model_name], [model_path, template], queue=False).then( model_name.change(get_model_info, [model_name], [model_path, template], queue=False).then(
list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False
).then(check_template, [lang, template]) ).then(check_template, [lang, template])
model_name.input(save_config, inputs=[lang, model_name], queue=False) model_name.input(save_config, inputs=[lang, hub_name, model_name], queue=False)
model_path.input(save_config, inputs=[lang, model_name, model_path], queue=False) model_path.input(save_config, inputs=[lang, hub_name, model_name, model_path], queue=False)
finetuning_type.change(can_quantize, [finetuning_type], [quantization_bit], queue=False).then( finetuning_type.change(can_quantize, [finetuning_type], [quantization_bit], queue=False).then(
list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False
) )
checkpoint_path.focus(list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False) checkpoint_path.focus(list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False)
quantization_method.change(can_quantize_to, [quantization_method], [quantization_bit], queue=False) quantization_method.change(can_quantize_to, [quantization_method], [quantization_bit], queue=False)
hub_name.change(switch_hub, inputs=[hub_name], queue=False).then(
get_model_info, [model_name], [model_path, template], queue=False
).then(list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False).then(
check_template, [lang, template]
)
hub_name.input(save_config, inputs=[lang, hub_name], queue=False)
return dict( return dict(
lang=lang, lang=lang,
model_name=model_name, model_name=model_name,
model_path=model_path, model_path=model_path,
hub_name=hub_name,
finetuning_type=finetuning_type, finetuning_type=finetuning_type,
checkpoint_path=checkpoint_path, checkpoint_path=checkpoint_path,
quantization_bit=quantization_bit, quantization_bit=quantization_bit,
......
...@@ -38,6 +38,15 @@ if is_gradio_available(): ...@@ -38,6 +38,15 @@ if is_gradio_available():
import gradio as gr import gradio as gr
def switch_hub(hub_name: str) -> None:
r"""Switch model hub.
Inputs: top.hub_name
"""
os.environ["USE_MODELSCOPE_HUB"] = "1" if hub_name == "modelscope" else "0"
os.environ["USE_OPENMIND_HUB"] = "1" if hub_name == "openmind" else "0"
def can_quantize(finetuning_type: str) -> "gr.Dropdown": def can_quantize(finetuning_type: str) -> "gr.Dropdown":
r"""Judge if the quantization is available in this finetuning type. r"""Judge if the quantization is available in this finetuning type.
...@@ -112,7 +121,7 @@ def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> tup ...@@ -112,7 +121,7 @@ def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> tup
running_log_path = os.path.join(output_path, RUNNING_LOG) running_log_path = os.path.join(output_path, RUNNING_LOG)
if os.path.isfile(running_log_path): if os.path.isfile(running_log_path):
with open(running_log_path, encoding="utf-8") as f: with open(running_log_path, encoding="utf-8") as f:
running_log = f.read()[-20000:] # avoid lengthy log running_log = "```\n" + f.read()[-20000:] + "\n```\n" # avoid lengthy log
trainer_log_path = os.path.join(output_path, TRAINER_LOG) trainer_log_path = os.path.join(output_path, TRAINER_LOG)
if os.path.isfile(trainer_log_path): if os.path.isfile(trainer_log_path):
......
...@@ -49,11 +49,13 @@ class Engine: ...@@ -49,11 +49,13 @@ class Engine:
def resume(self): def resume(self):
r"""Get the initial value of gradio components and restores training status if necessary.""" r"""Get the initial value of gradio components and restores training status if necessary."""
user_config = load_config() if not self.demo_mode else {} # do not use config in demo mode user_config = load_config() if not self.demo_mode else {} # do not use config in demo mode
lang = user_config.get("lang", None) or "en" lang = user_config.get("lang") or "en"
init_dict = {"top.lang": {"value": lang}, "infer.chat_box": {"visible": self.chatter.loaded}} init_dict = {"top.lang": {"value": lang}, "infer.chat_box": {"visible": self.chatter.loaded}}
if not self.pure_chat: if not self.pure_chat:
current_time = get_time() current_time = get_time()
hub_name = user_config.get("hub_name") or "huggingface"
init_dict["top.hub_name"] = {"value": hub_name}
init_dict["train.current_time"] = {"value": current_time} init_dict["train.current_time"] = {"value": current_time}
init_dict["train.output_dir"] = {"value": f"train_{current_time}"} init_dict["train.output_dir"] = {"value": f"train_{current_time}"}
init_dict["train.config_path"] = {"value": f"{current_time}.yaml"} init_dict["train.config_path"] = {"value": f"{current_time}.yaml"}
......
...@@ -22,6 +22,7 @@ from .components import ( ...@@ -22,6 +22,7 @@ from .components import (
create_chat_box, create_chat_box,
create_eval_tab, create_eval_tab,
create_export_tab, create_export_tab,
create_footer,
create_infer_tab, create_infer_tab,
create_top, create_top,
create_train_tab, create_train_tab,
...@@ -38,15 +39,13 @@ def create_ui(demo_mode: bool = False) -> "gr.Blocks": ...@@ -38,15 +39,13 @@ def create_ui(demo_mode: bool = False) -> "gr.Blocks":
engine = Engine(demo_mode=demo_mode, pure_chat=False) engine = Engine(demo_mode=demo_mode, pure_chat=False)
hostname = os.getenv("HOSTNAME", os.getenv("COMPUTERNAME", platform.node())).split(".")[0] hostname = os.getenv("HOSTNAME", os.getenv("COMPUTERNAME", platform.node())).split(".")[0]
with gr.Blocks(title=f"LLaMA Board ({hostname})", css=CSS) as demo: with gr.Blocks(title=f"LLaMA Factory ({hostname})", css=CSS) as demo:
title = gr.HTML()
subtitle = gr.HTML()
if demo_mode: if demo_mode:
gr.HTML("<h1><center>LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory</center></h1>")
gr.HTML(
'<h3><center>Visit <a href="https://github.com/hiyouga/LLaMA-Factory" target="_blank">'
"LLaMA Factory</a> for details.</center></h3>"
)
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button") gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
engine.manager.add_elems("head", {"title": title, "subtitle": subtitle})
engine.manager.add_elems("top", create_top()) engine.manager.add_elems("top", create_top())
lang: gr.Dropdown = engine.manager.get_elem_by_id("top.lang") lang: gr.Dropdown = engine.manager.get_elem_by_id("top.lang")
...@@ -63,6 +62,7 @@ def create_ui(demo_mode: bool = False) -> "gr.Blocks": ...@@ -63,6 +62,7 @@ def create_ui(demo_mode: bool = False) -> "gr.Blocks":
with gr.Tab("Export"): with gr.Tab("Export"):
engine.manager.add_elems("export", create_export_tab(engine)) engine.manager.add_elems("export", create_export_tab(engine))
engine.manager.add_elems("footer", create_footer())
demo.load(engine.resume, outputs=engine.manager.get_elem_list(), concurrency_limit=None) demo.load(engine.resume, outputs=engine.manager.get_elem_list(), concurrency_limit=None)
lang.change(engine.change_lang, [lang], engine.manager.get_elem_list(), queue=False) lang.change(engine.change_lang, [lang], engine.manager.get_elem_list(), queue=False)
lang.input(save_config, inputs=[lang], queue=False) lang.input(save_config, inputs=[lang], queue=False)
......
...@@ -13,6 +13,55 @@ ...@@ -13,6 +13,55 @@
# limitations under the License. # limitations under the License.
LOCALES = { LOCALES = {
"title": {
"en": {
"value": "<h1><center>🦙🏭LLaMA Factory: Unified Efficient Fine-Tuning of 100+ LLMs</center></h1>",
},
"ru": {
"value": "<h1><center>🦙🏭LLaMA Factory: Унифицированная эффективная тонкая настройка 100+ LLMs</center></h1>",
},
"zh": {
"value": "<h1><center>🦙🏭LLaMA Factory: 一站式大模型高效微调平台</center></h1>",
},
"ko": {
"value": "<h1><center>🦙🏭LLaMA Factory: 100+ LLMs를 위한 통합 효율적인 튜닝</center></h1>",
},
"ja": {
"value": "<h1><center>🦙🏭LLaMA Factory: 100+ LLMs の統合効率的なチューニング</center></h1>",
},
},
"subtitle": {
"en": {
"value": (
"<h3><center>Visit <a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
"GitHub Page</a></center></h3>"
),
},
"ru": {
"value": (
"<h3><center>Посетить <a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
"страницу GitHub</a></center></h3>"
),
},
"zh": {
"value": (
"<h3><center>访问 <a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
"GitHub 主页</a></center></h3>"
),
},
"ko": {
"value": (
"<h3><center><a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
"GitHub 페이지</a>를 방문하세요.</center></h3>"
),
},
"ja": {
"value": (
"<h3><center><a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
"GitHub ページ</a>にアクセスする</center></h3>"
),
},
},
"lang": { "lang": {
"en": { "en": {
"label": "Language", "label": "Language",
...@@ -74,6 +123,28 @@ LOCALES = { ...@@ -74,6 +123,28 @@ LOCALES = {
"info": "事前学習済みモデルへのパス、または Hugging Face のモデル識別子。", "info": "事前学習済みモデルへのパス、または Hugging Face のモデル識別子。",
}, },
}, },
"hub_name": {
"en": {
"label": "Hub name",
"info": "Choose the model download source.",
},
"ru": {
"label": "Имя хаба",
"info": "Выберите источник загрузки модели.",
},
"zh": {
"label": "模型下载源",
"info": "选择模型下载源。(网络受限环境推荐使用 ModelScope)",
},
"ko": {
"label": "모델 다운로드 소스",
"info": "모델 다운로드 소스를 선택하세요.",
},
"ja": {
"label": "モデルダウンロードソース",
"info": "モデルをダウンロードするためのソースを選択してください。",
},
},
"finetuning_type": { "finetuning_type": {
"en": { "en": {
"label": "Finetuning method", "label": "Finetuning method",
...@@ -2849,6 +2920,28 @@ LOCALES = { ...@@ -2849,6 +2920,28 @@ LOCALES = {
"value": "エクスポート", "value": "エクスポート",
}, },
}, },
"device_memory": {
"en": {
"label": "Device memory",
"info": "Current memory usage of the device (GB).",
},
"ru": {
"label": "Память устройства",
"info": "Текущая память на устройстве (GB).",
},
"zh": {
"label": "设备显存",
"info": "当前设备的显存(GB)。",
},
"ko": {
"label": "디바이스 메모리",
"info": "지금 사용 중인 기기 메모리 (GB).",
},
"ja": {
"label": "デバイスメモリ",
"info": "現在のデバイスのメモリ(GB)。",
},
},
} }
......
...@@ -16,14 +16,13 @@ import json ...@@ -16,14 +16,13 @@ import json
import os import os
from collections.abc import Generator from collections.abc import Generator
from copy import deepcopy from copy import deepcopy
from subprocess import Popen, TimeoutExpired from subprocess import PIPE, Popen, TimeoutExpired
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
from transformers.trainer import TRAINING_ARGS_NAME
from transformers.utils import is_torch_npu_available from transformers.utils import is_torch_npu_available
from ..extras.constants import LLAMABOARD_CONFIG, MULTIMODAL_SUPPORTED_MODELS, PEFT_METHODS, TRAINING_STAGES from ..extras.constants import LLAMABOARD_CONFIG, MULTIMODAL_SUPPORTED_MODELS, PEFT_METHODS, TRAINING_STAGES
from ..extras.misc import is_accelerator_available, torch_gc, use_ray from ..extras.misc import is_accelerator_available, torch_gc
from ..extras.packages import is_gradio_available from ..extras.packages import is_gradio_available
from .common import ( from .common import (
DEFAULT_CACHE_DIR, DEFAULT_CACHE_DIR,
...@@ -114,7 +113,7 @@ class Runner: ...@@ -114,7 +113,7 @@ class Runner:
return "" return ""
def _finalize(self, lang: str, finish_info: str) -> str: def _finalize(self, lang: str, finish_info: str) -> None:
r"""Clean the cached memory and resets the runner.""" r"""Clean the cached memory and resets the runner."""
finish_info = ALERTS["info_aborted"][lang] if self.aborted else finish_info finish_info = ALERTS["info_aborted"][lang] if self.aborted else finish_info
gr.Info(finish_info) gr.Info(finish_info)
...@@ -123,7 +122,6 @@ class Runner: ...@@ -123,7 +122,6 @@ class Runner:
self.running = False self.running = False
self.running_data = None self.running_data = None
torch_gc() torch_gc()
return finish_info
def _parse_train_args(self, data: dict["Component", Any]) -> dict[str, Any]: def _parse_train_args(self, data: dict["Component", Any]) -> dict[str, Any]:
r"""Build and validate the training arguments.""" r"""Build and validate the training arguments."""
...@@ -314,11 +312,13 @@ class Runner: ...@@ -314,11 +312,13 @@ class Runner:
max_samples=int(get("eval.max_samples")), max_samples=int(get("eval.max_samples")),
per_device_eval_batch_size=get("eval.batch_size"), per_device_eval_batch_size=get("eval.batch_size"),
predict_with_generate=True, predict_with_generate=True,
report_to="none",
max_new_tokens=get("eval.max_new_tokens"), max_new_tokens=get("eval.max_new_tokens"),
top_p=get("eval.top_p"), top_p=get("eval.top_p"),
temperature=get("eval.temperature"), temperature=get("eval.temperature"),
output_dir=get_save_dir(model_name, finetuning_type, get("eval.output_dir")), output_dir=get_save_dir(model_name, finetuning_type, get("eval.output_dir")),
trust_remote_code=True, trust_remote_code=True,
ddp_timeout=180000000,
) )
if get("eval.predict"): if get("eval.predict"):
...@@ -375,7 +375,7 @@ class Runner: ...@@ -375,7 +375,7 @@ class Runner:
env["FORCE_TORCHRUN"] = "1" env["FORCE_TORCHRUN"] = "1"
# NOTE: DO NOT USE shell=True to avoid security risk # NOTE: DO NOT USE shell=True to avoid security risk
self.trainer = Popen(["llamafactory-cli", "train", save_cmd(args)], env=env) self.trainer = Popen(["llamafactory-cli", "train", save_cmd(args)], env=env, stderr=PIPE, text=True)
yield from self.monitor() yield from self.monitor()
def _build_config_dict(self, data: dict["Component", Any]) -> dict[str, Any]: def _build_config_dict(self, data: dict["Component", Any]) -> dict[str, Any]:
...@@ -417,7 +417,8 @@ class Runner: ...@@ -417,7 +417,8 @@ class Runner:
swanlab_link = self.manager.get_elem_by_id("train.swanlab_link") if self.do_train else None swanlab_link = self.manager.get_elem_by_id("train.swanlab_link") if self.do_train else None
running_log = "" running_log = ""
while self.trainer is not None: return_code = -1
while return_code == -1:
if self.aborted: if self.aborted:
yield { yield {
output_box: ALERTS["info_aborting"][lang], output_box: ALERTS["info_aborting"][lang],
...@@ -436,27 +437,26 @@ class Runner: ...@@ -436,27 +437,26 @@ class Runner:
return_dict[swanlab_link] = running_info["swanlab_link"] return_dict[swanlab_link] = running_info["swanlab_link"]
yield return_dict yield return_dict
try: try:
self.trainer.wait(2) stderr = self.trainer.communicate(timeout=2)[1]
self.trainer = None return_code = self.trainer.returncode
except TimeoutExpired: except TimeoutExpired:
continue continue
if self.do_train: if return_code == 0 or self.aborted:
if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)) or use_ray():
finish_info = ALERTS["info_finished"][lang] finish_info = ALERTS["info_finished"][lang]
if self.do_train:
finish_log = ALERTS["info_finished"][lang] + "\n\n" + running_log
else: else:
finish_info = ALERTS["err_failed"][lang] finish_log = load_eval_results(os.path.join(output_path, "all_results.json")) + "\n\n" + running_log
else:
if os.path.exists(os.path.join(output_path, "all_results.json")) or use_ray():
finish_info = load_eval_results(os.path.join(output_path, "all_results.json"))
else: else:
print(stderr)
finish_info = ALERTS["err_failed"][lang] finish_info = ALERTS["err_failed"][lang]
finish_log = ALERTS["err_failed"][lang] + f" Exit code: {return_code}\n\n```\n{stderr}\n```\n"
return_dict = { self._finalize(lang, finish_info)
output_box: self._finalize(lang, finish_info) + "\n\n" + running_log, return_dict = {output_box: finish_log, progress_bar: gr.Slider(visible=False)}
progress_bar: gr.Slider(visible=False),
}
yield return_dict yield return_dict
def save_args(self, data): def save_args(self, data):
......
...@@ -110,8 +110,8 @@ def test_glm4_function_formatter(): ...@@ -110,8 +110,8 @@ def test_glm4_function_formatter():
def test_glm4_tool_formatter(): def test_glm4_tool_formatter():
formatter = ToolFormatter(tool_format="glm4") formatter = ToolFormatter(tool_format="glm4")
assert formatter.apply(content=json.dumps(TOOLS)) == [ assert formatter.apply(content=json.dumps(TOOLS)) == [
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的," "你是一个名为 ChatGLM 的人工智能助手。你是基于智谱 AI 公司训练的语言模型 GLM-4 模型开发的,"
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具\n\n" "你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具\n\n"
f"## test_tool\n\n{json.dumps(TOOLS[0], indent=4, ensure_ascii=False)}\n" f"## test_tool\n\n{json.dumps(TOOLS[0], indent=4, ensure_ascii=False)}\n"
"在调用上述函数时,请使用 Json 格式表示调用的参数。" "在调用上述函数时,请使用 Json 格式表示调用的参数。"
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment