Commit 0722acf1 authored by chenych's avatar chenych
Browse files

Update 0604

parent c4ba4563
...@@ -99,27 +99,29 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None: ...@@ -99,27 +99,29 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.moe_aux_loss_coef:
return
model_type = getattr(config, "model_type", None) model_type = getattr(config, "model_type", None)
if model_args.moe_aux_loss_coef is not None: if model_type in [
if model_type in [ "dbrx",
"dbrx", "granitemoe",
"granitemoe", "jamba",
"jamba", "jetmoe",
"jetmoe", "llama4",
"llama4", "mixtral",
"mixtral", "olmoe",
"olmoe", "phimoe",
"phimoe", "qwen2_moe",
"qwen2_moe", "qwen3_moe",
"qwen3_moe", ]:
]: setattr(config, "output_router_logits", True)
setattr(config, "output_router_logits", is_trainable)
if model_type in ["granitemoe", "jamba", "llama4", "mixtral", "olmoe", "phimoe", "qwen2_moe", "qwen3_moe"]:
if model_type in ["granitemoe", "jamba", "llama4", "mixtral", "olmoe", "phimoe", "qwen2_moe", "qwen3_moe"]: setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
elif model_type == "deepseek":
elif model_type == "deepseek": setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef)
setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef)
elif model_type == "jetmoe":
elif model_type == "jetmoe": setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef)
setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef)
...@@ -97,7 +97,7 @@ def configure_quantization( ...@@ -97,7 +97,7 @@ def configure_quantization(
quant_method = quantization_config.get("quant_method", "") quant_method = quantization_config.get("quant_method", "")
if quant_method == QuantizationMethod.GPTQ: if quant_method == QuantizationMethod.GPTQ:
check_version("auto_gptq>=0.5.0", mandatory=True) check_version("gptqmodel>=2.0.0", mandatory=True)
quantization_config.pop("disable_exllama", None) # remove deprecated args quantization_config.pop("disable_exllama", None) # remove deprecated args
quantization_config["use_exllama"] = False # disable exllama quantization_config["use_exllama"] = False # disable exllama
...@@ -111,12 +111,12 @@ def configure_quantization( ...@@ -111,12 +111,12 @@ def configure_quantization(
quant_bits = quantization_config.get("bits", "?") quant_bits = quantization_config.get("bits", "?")
logger.info_rank0(f"Loading {quant_bits}-bit {quant_method.upper()}-quantized model.") logger.info_rank0(f"Loading {quant_bits}-bit {quant_method.upper()}-quantized model.")
elif model_args.export_quantization_bit is not None: # auto-gptq elif model_args.export_quantization_bit is not None: # gptqmodel
if model_args.export_quantization_bit not in [8, 4, 3, 2]: if model_args.export_quantization_bit not in [8, 4, 3, 2]:
raise ValueError("AutoGPTQ only accepts 2/3/4/8-bit quantization.") raise ValueError("AutoGPTQ only accepts 2/3/4/8-bit quantization.")
check_version("optimum>=1.17.0", mandatory=True) check_version("optimum>=1.24.0", mandatory=True)
check_version("auto_gptq>=0.5.0", mandatory=True) check_version("gptqmodel>=2.0.0", mandatory=True)
from accelerate.utils import get_max_memory from accelerate.utils import get_max_memory
if getattr(config, "model_type", None) == "chatglm": if getattr(config, "model_type", None) == "chatglm":
...@@ -142,7 +142,8 @@ def configure_quantization( ...@@ -142,7 +142,8 @@ def configure_quantization(
) )
init_kwargs["device_map"] = "auto" init_kwargs["device_map"] = "auto"
init_kwargs["max_memory"] = get_max_memory() init_kwargs["max_memory"] = get_max_memory()
logger.info_rank0(f"Quantizing model to {model_args.export_quantization_bit} bit with AutoGPTQ.") model_args.compute_dtype = torch.float16 # force fp16 for gptqmodel
logger.info_rank0(f"Quantizing model to {model_args.export_quantization_bit} bit with GPTQModel.")
elif model_args.quantization_bit is not None: # on-the-fly elif model_args.quantization_bit is not None: # on-the-fly
if model_args.quantization_method == QuantizationMethod.BNB: if model_args.quantization_method == QuantizationMethod.BNB:
......
...@@ -32,7 +32,7 @@ if TYPE_CHECKING: ...@@ -32,7 +32,7 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
if model_args.rope_scaling is None: if model_args.rope_scaling is None:
return return
...@@ -40,30 +40,40 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_ ...@@ -40,30 +40,40 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
logger.warning_rank0("Current model does not support RoPE scaling.") logger.warning_rank0("Current model does not support RoPE scaling.")
return return
rope_kwargs = {"rope_type": getattr(model_args.rope_scaling, "value", model_args.rope_scaling)} # handle enum if hasattr(config, "max_position_embeddings"):
if model_args.model_max_length is not None: old_max_length = getattr(config, "max_position_embeddings", None)
if is_trainable and model_args.rope_scaling == RopeScaling.DYNAMIC: else:
logger.warning_rank0("Cannot find the max position embeddings in the config.")
return
if model_args.model_max_length is not None: # training
if model_args.model_max_length <= old_max_length:
logger.warning_rank0("Input length is smaller than max length. Disabling rope scaling.")
return
if model_args.rope_scaling == RopeScaling.DYNAMIC:
logger.warning_rank0( logger.warning_rank0(
"Dynamic NTK scaling may not work well with fine-tuning. " "Dynamic NTK scaling may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653" "See: https://github.com/huggingface/transformers/pull/24653"
) )
current_max_length = getattr(config, "max_position_embeddings", None) rope_factor = float(math.ceil(model_args.model_max_length / old_max_length))
if (not current_max_length) or model_args.model_max_length <= current_max_length: else: # inference
logger.warning_rank0("Input length is smaller than max length. Disabling rope scaling.") rope_factor = 2.0
return
logger.info_rank0(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.") rope_kwargs = {
setattr(config, "max_position_embeddings", model_args.model_max_length) "rope_type": getattr(model_args.rope_scaling, "value", model_args.rope_scaling), # handle enum
rope_kwargs["factor"] = float(math.ceil(model_args.model_max_length / current_max_length)) "factor": rope_factor,
if model_args.rope_scaling == RopeScaling.DYNAMIC: }
rope_kwargs["original_max_position_embeddings"] = current_max_length setattr(config, "max_position_embeddings", old_max_length * rope_factor)
elif model_args.rope_scaling == RopeScaling.LLAMA3: logger.info_rank0(f"Enlarge max model length from {old_max_length} to {old_max_length * rope_factor}.")
rope_kwargs["original_max_position_embeddings"] = current_max_length
rope_kwargs["low_freq_factor"] = 1.0 if model_args.rope_scaling in [RopeScaling.DYNAMIC, RopeScaling.YARN]:
rope_kwargs["high_freq_factor"] = 4.0 rope_kwargs["original_max_position_embeddings"] = old_max_length
else: elif model_args.rope_scaling == RopeScaling.LLAMA3:
rope_kwargs["factor"] = 2.0 rope_kwargs["original_max_position_embeddings"] = old_max_length
rope_kwargs["low_freq_factor"] = 1.0
rope_kwargs["high_freq_factor"] = 4.0
setattr(config, "rope_scaling", rope_kwargs) setattr(config, "rope_scaling", rope_kwargs)
logger.info_rank0( logger.info_rank0(
......
...@@ -24,6 +24,7 @@ import transformers.models ...@@ -24,6 +24,7 @@ import transformers.models
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from ...extras import logging from ...extras import logging
from ...extras.packages import is_transformers_version_greater_than
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -281,7 +282,7 @@ _register_composite_model( ...@@ -281,7 +282,7 @@ _register_composite_model(
model_type="qwen2_vl", model_type="qwen2_vl",
projector_key="visual.merger", projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks"], vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["model", "lm_head"], language_model_keys=["language_model"] if is_transformers_version_greater_than("4.52.0") else ["model", "lm_head"],
lora_conflict_keys=["patch_embed"], lora_conflict_keys=["patch_embed"],
) )
...@@ -290,6 +291,6 @@ _register_composite_model( ...@@ -290,6 +291,6 @@ _register_composite_model(
model_type="qwen2_5_vl", model_type="qwen2_5_vl",
projector_key="visual.merger", projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks"], vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["model", "lm_head"], language_model_keys=["language_model"] if is_transformers_version_greater_than("4.52.0") else ["model", "lm_head"],
lora_conflict_keys=["patch_embed"], lora_conflict_keys=["patch_embed"],
) )
...@@ -85,8 +85,8 @@ def patch_processor( ...@@ -85,8 +85,8 @@ def patch_processor(
setattr(processor, "video_min_pixels", model_args.video_min_pixels) setattr(processor, "video_min_pixels", model_args.video_min_pixels)
setattr(processor, "video_fps", model_args.video_fps) setattr(processor, "video_fps", model_args.video_fps)
setattr(processor, "video_maxlen", model_args.video_maxlen) setattr(processor, "video_maxlen", model_args.video_maxlen)
setattr(processor, "audio_sampling_rate", model_args.audio_sampling_rate)
setattr(processor, "use_audio_in_video", model_args.use_audio_in_video) setattr(processor, "use_audio_in_video", model_args.use_audio_in_video)
setattr(processor, "audio_sampling_rate", model_args.audio_sampling_rate)
def patch_config( def patch_config(
...@@ -102,8 +102,8 @@ def patch_config( ...@@ -102,8 +102,8 @@ def patch_config(
else: else:
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
configure_attn_implementation(config, model_args, is_trainable) configure_attn_implementation(config, model_args)
configure_rope(config, model_args, is_trainable) configure_rope(config, model_args)
configure_longlora(config, model_args, is_trainable) configure_longlora(config, model_args, is_trainable)
configure_quantization(config, tokenizer, model_args, init_kwargs) configure_quantization(config, tokenizer, model_args, init_kwargs)
configure_moe(config, model_args, is_trainable) configure_moe(config, model_args, is_trainable)
......
...@@ -121,11 +121,11 @@ class CustomDPOTrainer(DPOTrainer): ...@@ -121,11 +121,11 @@ class CustomDPOTrainer(DPOTrainer):
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
@override @override
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]: def _get_train_sampler(self, *args, **kwargs) -> Optional["torch.utils.data.Sampler"]:
if self.finetuning_args.disable_shuffling: if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset) return torch.utils.data.SequentialSampler(self.train_dataset)
return super()._get_train_sampler() return super()._get_train_sampler(*args, **kwargs)
@override @override
def get_batch_samples(self, *args, **kwargs): def get_batch_samples(self, *args, **kwargs):
......
...@@ -34,7 +34,6 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, ge ...@@ -34,7 +34,6 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, ge
if TYPE_CHECKING: if TYPE_CHECKING:
import torch.utils.data
from transformers import PreTrainedModel, ProcessorMixin from transformers import PreTrainedModel, ProcessorMixin
from ...hparams import FinetuningArguments from ...hparams import FinetuningArguments
...@@ -119,12 +118,12 @@ class CustomKTOTrainer(KTOTrainer): ...@@ -119,12 +118,12 @@ class CustomKTOTrainer(KTOTrainer):
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
@override @override
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]: def _get_train_sampler(self, *args, **kwargs) -> Optional["torch.utils.data.Sampler"]:
r"""Replace the sequential sampler of KTO Trainer created by trl with the random sampler.""" r"""Replace the sequential sampler of KTO Trainer created by trl with the random sampler."""
if self.finetuning_args.disable_shuffling: if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset) return torch.utils.data.SequentialSampler(self.train_dataset)
return Trainer._get_train_sampler(self) return Trainer._get_train_sampler(self, *args, **kwargs)
@override @override
def get_batch_samples(self, *args, **kwargs): def get_batch_samples(self, *args, **kwargs):
......
...@@ -70,11 +70,11 @@ class CustomTrainer(Trainer): ...@@ -70,11 +70,11 @@ class CustomTrainer(Trainer):
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
@override @override
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]: def _get_train_sampler(self, *args, **kwargs) -> Optional["torch.utils.data.Sampler"]:
if self.finetuning_args.disable_shuffling: if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset) return torch.utils.data.SequentialSampler(self.train_dataset)
return super()._get_train_sampler() return super()._get_train_sampler(*args, **kwargs)
@override @override
def compute_loss(self, model, inputs, *args, **kwargs): def compute_loss(self, model, inputs, *args, **kwargs):
......
...@@ -77,12 +77,23 @@ def run_pt( ...@@ -77,12 +77,23 @@ def run_pt(
# Evaluation # Evaluation
if training_args.do_eval: if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval") metrics = trainer.evaluate(metric_key_prefix="eval")
try:
perplexity = math.exp(metrics["eval_loss"])
except OverflowError:
perplexity = float("inf")
metrics["perplexity"] = perplexity if isinstance(dataset_module.get("eval_dataset"), dict):
for key in dataset_module["eval_dataset"].keys():
try:
perplexity = math.exp(metrics[f"eval_{key}_loss"])
except OverflowError:
perplexity = float("inf")
metrics[f"eval_{key}_perplexity"] = perplexity
else:
try:
perplexity = math.exp(metrics["eval_loss"])
except OverflowError:
perplexity = float("inf")
metrics["eval_perplexity"] = perplexity
trainer.log_metrics("eval", metrics) trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics) trainer.save_metrics("eval", metrics)
......
...@@ -78,11 +78,11 @@ class PairwiseTrainer(Trainer): ...@@ -78,11 +78,11 @@ class PairwiseTrainer(Trainer):
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
@override @override
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]: def _get_train_sampler(self, *args, **kwargs) -> Optional["torch.utils.data.Sampler"]:
if self.finetuning_args.disable_shuffling: if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset) return torch.utils.data.SequentialSampler(self.train_dataset)
return super()._get_train_sampler() return super()._get_train_sampler(*args, **kwargs)
@override @override
def compute_loss( def compute_loss(
......
...@@ -92,11 +92,11 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): ...@@ -92,11 +92,11 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
@override @override
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]: def _get_train_sampler(self, *args, **kwargs) -> Optional["torch.utils.data.Sampler"]:
if self.finetuning_args.disable_shuffling: if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset) return torch.utils.data.SequentialSampler(self.train_dataset)
return super()._get_train_sampler() return super()._get_train_sampler(*args, **kwargs)
@override @override
def compute_loss(self, model, inputs, *args, **kwargs): def compute_loss(self, model, inputs, *args, **kwargs):
......
...@@ -665,6 +665,7 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall ...@@ -665,6 +665,7 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
mode=finetuning_args.swanlab_mode, mode=finetuning_args.swanlab_mode,
config={"Framework": "🦙LlamaFactory"}, config={"Framework": "🦙LlamaFactory"},
logdir=finetuning_args.swanlab_logdir, logdir=finetuning_args.swanlab_logdir,
tags=["🦙LlamaFactory"],
) )
return swanlab_callback return swanlab_callback
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import json import json
import os import os
from collections.abc import Generator from collections.abc import Generator
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
from transformers.utils import is_torch_npu_available from transformers.utils import is_torch_npu_available
...@@ -68,6 +69,14 @@ def _format_response(text: str, lang: str, escape_html: bool, thought_words: tup ...@@ -68,6 +69,14 @@ def _format_response(text: str, lang: str, escape_html: bool, thought_words: tup
) )
@contextmanager
def update_attr(obj: Any, name: str, value: Any):
old_value = getattr(obj, name, None)
setattr(obj, name, value)
yield
setattr(obj, name, old_value)
class WebChatModel(ChatModel): class WebChatModel(ChatModel):
def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None: def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None:
self.manager = manager self.manager = manager
...@@ -105,6 +114,11 @@ class WebChatModel(ChatModel): ...@@ -105,6 +114,11 @@ class WebChatModel(ChatModel):
elif self.demo_mode: elif self.demo_mode:
error = ALERTS["err_demo"][lang] error = ALERTS["err_demo"][lang]
try:
json.loads(get("infer.extra_args"))
except json.JSONDecodeError:
error = ALERTS["err_json_schema"][lang]
if error: if error:
gr.Warning(error) gr.Warning(error)
yield error yield error
...@@ -122,9 +136,9 @@ class WebChatModel(ChatModel): ...@@ -122,9 +136,9 @@ class WebChatModel(ChatModel):
enable_liger_kernel=(get("top.booster") == "liger_kernel"), enable_liger_kernel=(get("top.booster") == "liger_kernel"),
infer_backend=get("infer.infer_backend"), infer_backend=get("infer.infer_backend"),
infer_dtype=get("infer.infer_dtype"), infer_dtype=get("infer.infer_dtype"),
vllm_enforce_eager=True,
trust_remote_code=True, trust_remote_code=True,
) )
args.update(json.loads(get("infer.extra_args")))
# checkpoints # checkpoints
if checkpoint_path: if checkpoint_path:
...@@ -191,40 +205,42 @@ class WebChatModel(ChatModel): ...@@ -191,40 +205,42 @@ class WebChatModel(ChatModel):
temperature: float, temperature: float,
skip_special_tokens: bool, skip_special_tokens: bool,
escape_html: bool, escape_html: bool,
enable_thinking: bool,
) -> Generator[tuple[list[dict[str, str]], list[dict[str, str]]], None, None]: ) -> Generator[tuple[list[dict[str, str]], list[dict[str, str]]], None, None]:
r"""Generate output text in stream. r"""Generate output text in stream.
Inputs: infer.chatbot, infer.messages, infer.system, infer.tools, infer.image, infer.video, ... Inputs: infer.chatbot, infer.messages, infer.system, infer.tools, infer.image, infer.video, ...
Output: infer.chatbot, infer.messages Output: infer.chatbot, infer.messages
""" """
chatbot.append({"role": "assistant", "content": ""}) with update_attr(self.engine.template, "enable_thinking", enable_thinking):
response = "" chatbot.append({"role": "assistant", "content": ""})
for new_text in self.stream_chat( response = ""
messages, for new_text in self.stream_chat(
system, messages,
tools, system,
images=[image] if image else None, tools,
videos=[video] if video else None, images=[image] if image else None,
audios=[audio] if audio else None, videos=[video] if video else None,
max_new_tokens=max_new_tokens, audios=[audio] if audio else None,
top_p=top_p, max_new_tokens=max_new_tokens,
temperature=temperature, top_p=top_p,
skip_special_tokens=skip_special_tokens, temperature=temperature,
): skip_special_tokens=skip_special_tokens,
response += new_text ):
if tools: response += new_text
result = self.engine.template.extract_tool(response) if tools:
else: result = self.engine.template.extract_tool(response)
result = response else:
result = response
if isinstance(result, list):
tool_calls = [{"name": tool.name, "arguments": json.loads(tool.arguments)} for tool in result] if isinstance(result, list):
tool_calls = json.dumps(tool_calls, ensure_ascii=False) tool_calls = [{"name": tool.name, "arguments": json.loads(tool.arguments)} for tool in result]
output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_calls}] tool_calls = json.dumps(tool_calls, ensure_ascii=False)
bot_text = "```json\n" + tool_calls + "\n```" output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_calls}]
else: bot_text = "```json\n" + tool_calls + "\n```"
output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}] else:
bot_text = _format_response(result, lang, escape_html, self.engine.template.thought_words) output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}]
bot_text = _format_response(result, lang, escape_html, self.engine.template.thought_words)
chatbot[-1] = {"role": "assistant", "content": bot_text}
yield chatbot, output_messages chatbot[-1] = {"role": "assistant", "content": bot_text}
yield chatbot, output_messages
...@@ -163,7 +163,14 @@ def save_args(config_path: str, config_dict: dict[str, Any]) -> None: ...@@ -163,7 +163,14 @@ def save_args(config_path: str, config_dict: dict[str, Any]) -> None:
def _clean_cmd(args: dict[str, Any]) -> dict[str, Any]: def _clean_cmd(args: dict[str, Any]) -> dict[str, Any]:
r"""Remove args with NoneType or False or empty string value.""" r"""Remove args with NoneType or False or empty string value."""
no_skip_keys = ["packing"] no_skip_keys = [
"packing",
"enable_thinking",
"use_reentrant_gc",
"double_quantization",
"freeze_vision_tower",
"freeze_multi_modal_projector",
]
return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")} return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")}
...@@ -205,6 +212,14 @@ def load_eval_results(path: os.PathLike) -> str: ...@@ -205,6 +212,14 @@ def load_eval_results(path: os.PathLike) -> str:
return f"```json\n{result}\n```\n" return f"```json\n{result}\n```\n"
def calculate_pixels(pixels: str) -> int:
r"""Calculate the number of pixels from the expression."""
if "*" in pixels:
return int(pixels.split("*")[0]) * int(pixels.split("*")[1])
else:
return int(pixels)
def create_ds_config() -> None: def create_ds_config() -> None:
r"""Create deepspeed config in the current directory.""" r"""Create deepspeed config in the current directory."""
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True) os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
......
...@@ -79,6 +79,7 @@ def create_chat_box( ...@@ -79,6 +79,7 @@ def create_chat_box(
temperature = gr.Slider(minimum=0.01, maximum=1.5, value=0.95, step=0.01) temperature = gr.Slider(minimum=0.01, maximum=1.5, value=0.95, step=0.01)
skip_special_tokens = gr.Checkbox(value=True) skip_special_tokens = gr.Checkbox(value=True)
escape_html = gr.Checkbox(value=True) escape_html = gr.Checkbox(value=True)
enable_thinking = gr.Checkbox(value=True)
clear_btn = gr.Button() clear_btn = gr.Button()
tools.input(check_json_schema, inputs=[tools, engine.manager.get_elem_by_id("top.lang")]) tools.input(check_json_schema, inputs=[tools, engine.manager.get_elem_by_id("top.lang")])
...@@ -103,6 +104,7 @@ def create_chat_box( ...@@ -103,6 +104,7 @@ def create_chat_box(
temperature, temperature,
skip_special_tokens, skip_special_tokens,
escape_html, escape_html,
enable_thinking,
], ],
[chatbot, messages], [chatbot, messages],
) )
...@@ -127,6 +129,7 @@ def create_chat_box( ...@@ -127,6 +129,7 @@ def create_chat_box(
temperature=temperature, temperature=temperature,
skip_special_tokens=skip_special_tokens, skip_special_tokens=skip_special_tokens,
escape_html=escape_html, escape_html=escape_html,
enable_thinking=enable_thinking,
clear_btn=clear_btn, clear_btn=clear_btn,
), ),
) )
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
from collections.abc import Generator from collections.abc import Generator
from typing import TYPE_CHECKING, Union from typing import TYPE_CHECKING, Union
...@@ -57,6 +58,7 @@ def save_model( ...@@ -57,6 +58,7 @@ def save_model(
export_legacy_format: bool, export_legacy_format: bool,
export_dir: str, export_dir: str,
export_hub_model_id: str, export_hub_model_id: str,
extra_args: str,
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
user_config = load_config() user_config = load_config()
error = "" error = ""
...@@ -73,6 +75,11 @@ def save_model( ...@@ -73,6 +75,11 @@ def save_model(
elif export_quantization_bit in GPTQ_BITS and checkpoint_path and isinstance(checkpoint_path, list): elif export_quantization_bit in GPTQ_BITS and checkpoint_path and isinstance(checkpoint_path, list):
error = ALERTS["err_gptq_lora"][lang] error = ALERTS["err_gptq_lora"][lang]
try:
json.loads(extra_args)
except json.JSONDecodeError:
error = ALERTS["err_json_schema"][lang]
if error: if error:
gr.Warning(error) gr.Warning(error)
yield error yield error
...@@ -92,6 +99,7 @@ def save_model( ...@@ -92,6 +99,7 @@ def save_model(
export_legacy_format=export_legacy_format, export_legacy_format=export_legacy_format,
trust_remote_code=True, trust_remote_code=True,
) )
args.update(json.loads(extra_args))
if checkpoint_path: if checkpoint_path:
if finetuning_type in PEFT_METHODS: # list if finetuning_type in PEFT_METHODS: # list
...@@ -118,6 +126,7 @@ def create_export_tab(engine: "Engine") -> dict[str, "Component"]: ...@@ -118,6 +126,7 @@ def create_export_tab(engine: "Engine") -> dict[str, "Component"]:
with gr.Row(): with gr.Row():
export_dir = gr.Textbox() export_dir = gr.Textbox()
export_hub_model_id = gr.Textbox() export_hub_model_id = gr.Textbox()
extra_args = gr.Textbox(value="{}")
checkpoint_path: gr.Dropdown = engine.manager.get_elem_by_id("top.checkpoint_path") checkpoint_path: gr.Dropdown = engine.manager.get_elem_by_id("top.checkpoint_path")
checkpoint_path.change(can_quantize, [checkpoint_path], [export_quantization_bit], queue=False) checkpoint_path.change(can_quantize, [checkpoint_path], [export_quantization_bit], queue=False)
...@@ -141,6 +150,7 @@ def create_export_tab(engine: "Engine") -> dict[str, "Component"]: ...@@ -141,6 +150,7 @@ def create_export_tab(engine: "Engine") -> dict[str, "Component"]:
export_legacy_format, export_legacy_format,
export_dir, export_dir,
export_hub_model_id, export_hub_model_id,
extra_args,
], ],
[info_box], [info_box],
) )
...@@ -153,6 +163,7 @@ def create_export_tab(engine: "Engine") -> dict[str, "Component"]: ...@@ -153,6 +163,7 @@ def create_export_tab(engine: "Engine") -> dict[str, "Component"]:
export_legacy_format=export_legacy_format, export_legacy_format=export_legacy_format,
export_dir=export_dir, export_dir=export_dir,
export_hub_model_id=export_hub_model_id, export_hub_model_id=export_hub_model_id,
extra_args=extra_args,
export_btn=export_btn, export_btn=export_btn,
info_box=info_box, info_box=info_box,
) )
...@@ -36,6 +36,7 @@ def create_infer_tab(engine: "Engine") -> dict[str, "Component"]: ...@@ -36,6 +36,7 @@ def create_infer_tab(engine: "Engine") -> dict[str, "Component"]:
with gr.Row(): with gr.Row():
infer_backend = gr.Dropdown(choices=["huggingface", "vllm", "sglang"], value="huggingface") infer_backend = gr.Dropdown(choices=["huggingface", "vllm", "sglang"], value="huggingface")
infer_dtype = gr.Dropdown(choices=["auto", "float16", "bfloat16", "float32"], value="auto") infer_dtype = gr.Dropdown(choices=["auto", "float16", "bfloat16", "float32"], value="auto")
extra_args = gr.Textbox(value='{"vllm_enforce_eager": true}')
with gr.Row(): with gr.Row():
load_btn = gr.Button() load_btn = gr.Button()
...@@ -43,11 +44,12 @@ def create_infer_tab(engine: "Engine") -> dict[str, "Component"]: ...@@ -43,11 +44,12 @@ def create_infer_tab(engine: "Engine") -> dict[str, "Component"]:
info_box = gr.Textbox(show_label=False, interactive=False) info_box = gr.Textbox(show_label=False, interactive=False)
input_elems.update({infer_backend, infer_dtype}) input_elems.update({infer_backend, infer_dtype, extra_args})
elem_dict.update( elem_dict.update(
dict( dict(
infer_backend=infer_backend, infer_backend=infer_backend,
infer_dtype=infer_dtype, infer_dtype=infer_dtype,
extra_args=extra_args,
load_btn=load_btn, load_btn=load_btn,
unload_btn=unload_btn, unload_btn=unload_btn,
info_box=info_box, info_box=info_box,
......
...@@ -18,7 +18,7 @@ from ...data import TEMPLATES ...@@ -18,7 +18,7 @@ from ...data import TEMPLATES
from ...extras.constants import METHODS, SUPPORTED_MODELS from ...extras.constants import METHODS, SUPPORTED_MODELS
from ...extras.packages import is_gradio_available from ...extras.packages import is_gradio_available
from ..common import save_config from ..common import save_config
from ..control import can_quantize, can_quantize_to, get_model_info, list_checkpoints from ..control import can_quantize, can_quantize_to, check_template, get_model_info, list_checkpoints
if is_gradio_available(): if is_gradio_available():
...@@ -49,7 +49,7 @@ def create_top() -> dict[str, "Component"]: ...@@ -49,7 +49,7 @@ def create_top() -> dict[str, "Component"]:
model_name.change(get_model_info, [model_name], [model_path, template], queue=False).then( model_name.change(get_model_info, [model_name], [model_path, template], queue=False).then(
list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False
) ).then(check_template, [lang, template])
model_name.input(save_config, inputs=[lang, model_name], queue=False) model_name.input(save_config, inputs=[lang, model_name], queue=False)
model_path.input(save_config, inputs=[lang, model_name, model_path], queue=False) model_path.input(save_config, inputs=[lang, model_name, model_path], queue=False)
finetuning_type.change(can_quantize, [finetuning_type], [quantization_bit], queue=False).then( finetuning_type.change(can_quantize, [finetuning_type], [quantization_bit], queue=False).then(
......
...@@ -106,11 +106,11 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]: ...@@ -106,11 +106,11 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
use_llama_pro = gr.Checkbox() use_llama_pro = gr.Checkbox()
with gr.Column(): with gr.Column():
enable_thinking = gr.Checkbox(value=True)
report_to = gr.Dropdown( report_to = gr.Dropdown(
choices=["none", "all", "wandb", "mlflow", "neptune", "tensorboard"], choices=["none", "wandb", "mlflow", "neptune", "tensorboard", "all"],
value=["none"], value="none",
allow_custom_value=True, allow_custom_value=True,
multiselect=True,
) )
input_elems.update( input_elems.update(
...@@ -126,6 +126,7 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]: ...@@ -126,6 +126,7 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
mask_history, mask_history,
resize_vocab, resize_vocab,
use_llama_pro, use_llama_pro,
enable_thinking,
report_to, report_to,
} }
) )
...@@ -143,6 +144,7 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]: ...@@ -143,6 +144,7 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
mask_history=mask_history, mask_history=mask_history,
resize_vocab=resize_vocab, resize_vocab=resize_vocab,
use_llama_pro=use_llama_pro, use_llama_pro=use_llama_pro,
enable_thinking=enable_thinking,
report_to=report_to, report_to=report_to,
) )
) )
...@@ -231,6 +233,42 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]: ...@@ -231,6 +233,42 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
) )
) )
with gr.Accordion(open=False) as mm_tab:
with gr.Row():
freeze_vision_tower = gr.Checkbox(value=True)
freeze_multi_modal_projector = gr.Checkbox(value=True)
freeze_language_model = gr.Checkbox(value=False)
with gr.Row():
image_max_pixels = gr.Textbox(value="768*768")
image_min_pixels = gr.Textbox(value="32*32")
video_max_pixels = gr.Textbox(value="256*256")
video_min_pixels = gr.Textbox(value="16*16")
input_elems.update(
{
freeze_vision_tower,
freeze_multi_modal_projector,
freeze_language_model,
image_max_pixels,
image_min_pixels,
video_max_pixels,
video_min_pixels,
}
)
elem_dict.update(
dict(
mm_tab=mm_tab,
freeze_vision_tower=freeze_vision_tower,
freeze_multi_modal_projector=freeze_multi_modal_projector,
freeze_language_model=freeze_language_model,
image_max_pixels=image_max_pixels,
image_min_pixels=image_min_pixels,
video_max_pixels=video_max_pixels,
video_min_pixels=video_min_pixels,
)
)
with gr.Accordion(open=False) as galore_tab: with gr.Accordion(open=False) as galore_tab:
with gr.Row(): with gr.Row():
use_galore = gr.Checkbox() use_galore = gr.Checkbox()
......
...@@ -84,6 +84,17 @@ def get_model_info(model_name: str) -> tuple[str, str]: ...@@ -84,6 +84,17 @@ def get_model_info(model_name: str) -> tuple[str, str]:
return get_model_path(model_name), get_template(model_name) return get_model_path(model_name), get_template(model_name)
def check_template(lang: str, template: str) -> None:
r"""Check if an instruct model is used.
Please use queue=True to show the warning message.
Inputs: top.lang, top.template
"""
if template == "default":
gr.Warning(ALERTS["warn_no_instruct"][lang])
def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> tuple[str, "gr.Slider", dict[str, Any]]: def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> tuple[str, "gr.Slider", dict[str, Any]]:
r"""Get training infomation for monitor. r"""Get training infomation for monitor.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment