Commit 0722acf1 authored by chenych's avatar chenych
Browse files

Update 0604

parent c4ba4563
......@@ -99,8 +99,10 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.moe_aux_loss_coef:
return
model_type = getattr(config, "model_type", None)
if model_args.moe_aux_loss_coef is not None:
if model_type in [
"dbrx",
"granitemoe",
......@@ -113,7 +115,7 @@ def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_t
"qwen2_moe",
"qwen3_moe",
]:
setattr(config, "output_router_logits", is_trainable)
setattr(config, "output_router_logits", True)
if model_type in ["granitemoe", "jamba", "llama4", "mixtral", "olmoe", "phimoe", "qwen2_moe", "qwen3_moe"]:
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
......
......@@ -97,7 +97,7 @@ def configure_quantization(
quant_method = quantization_config.get("quant_method", "")
if quant_method == QuantizationMethod.GPTQ:
check_version("auto_gptq>=0.5.0", mandatory=True)
check_version("gptqmodel>=2.0.0", mandatory=True)
quantization_config.pop("disable_exllama", None) # remove deprecated args
quantization_config["use_exllama"] = False # disable exllama
......@@ -111,12 +111,12 @@ def configure_quantization(
quant_bits = quantization_config.get("bits", "?")
logger.info_rank0(f"Loading {quant_bits}-bit {quant_method.upper()}-quantized model.")
elif model_args.export_quantization_bit is not None: # auto-gptq
elif model_args.export_quantization_bit is not None: # gptqmodel
if model_args.export_quantization_bit not in [8, 4, 3, 2]:
raise ValueError("AutoGPTQ only accepts 2/3/4/8-bit quantization.")
check_version("optimum>=1.17.0", mandatory=True)
check_version("auto_gptq>=0.5.0", mandatory=True)
check_version("optimum>=1.24.0", mandatory=True)
check_version("gptqmodel>=2.0.0", mandatory=True)
from accelerate.utils import get_max_memory
if getattr(config, "model_type", None) == "chatglm":
......@@ -142,7 +142,8 @@ def configure_quantization(
)
init_kwargs["device_map"] = "auto"
init_kwargs["max_memory"] = get_max_memory()
logger.info_rank0(f"Quantizing model to {model_args.export_quantization_bit} bit with AutoGPTQ.")
model_args.compute_dtype = torch.float16 # force fp16 for gptqmodel
logger.info_rank0(f"Quantizing model to {model_args.export_quantization_bit} bit with GPTQModel.")
elif model_args.quantization_bit is not None: # on-the-fly
if model_args.quantization_method == QuantizationMethod.BNB:
......
......@@ -32,7 +32,7 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
if model_args.rope_scaling is None:
return
......@@ -40,30 +40,40 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
logger.warning_rank0("Current model does not support RoPE scaling.")
return
rope_kwargs = {"rope_type": getattr(model_args.rope_scaling, "value", model_args.rope_scaling)} # handle enum
if model_args.model_max_length is not None:
if is_trainable and model_args.rope_scaling == RopeScaling.DYNAMIC:
if hasattr(config, "max_position_embeddings"):
old_max_length = getattr(config, "max_position_embeddings", None)
else:
logger.warning_rank0("Cannot find the max position embeddings in the config.")
return
if model_args.model_max_length is not None: # training
if model_args.model_max_length <= old_max_length:
logger.warning_rank0("Input length is smaller than max length. Disabling rope scaling.")
return
if model_args.rope_scaling == RopeScaling.DYNAMIC:
logger.warning_rank0(
"Dynamic NTK scaling may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653"
)
current_max_length = getattr(config, "max_position_embeddings", None)
if (not current_max_length) or model_args.model_max_length <= current_max_length:
logger.warning_rank0("Input length is smaller than max length. Disabling rope scaling.")
return
rope_factor = float(math.ceil(model_args.model_max_length / old_max_length))
else: # inference
rope_factor = 2.0
logger.info_rank0(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.")
setattr(config, "max_position_embeddings", model_args.model_max_length)
rope_kwargs["factor"] = float(math.ceil(model_args.model_max_length / current_max_length))
if model_args.rope_scaling == RopeScaling.DYNAMIC:
rope_kwargs["original_max_position_embeddings"] = current_max_length
rope_kwargs = {
"rope_type": getattr(model_args.rope_scaling, "value", model_args.rope_scaling), # handle enum
"factor": rope_factor,
}
setattr(config, "max_position_embeddings", old_max_length * rope_factor)
logger.info_rank0(f"Enlarge max model length from {old_max_length} to {old_max_length * rope_factor}.")
if model_args.rope_scaling in [RopeScaling.DYNAMIC, RopeScaling.YARN]:
rope_kwargs["original_max_position_embeddings"] = old_max_length
elif model_args.rope_scaling == RopeScaling.LLAMA3:
rope_kwargs["original_max_position_embeddings"] = current_max_length
rope_kwargs["original_max_position_embeddings"] = old_max_length
rope_kwargs["low_freq_factor"] = 1.0
rope_kwargs["high_freq_factor"] = 4.0
else:
rope_kwargs["factor"] = 2.0
setattr(config, "rope_scaling", rope_kwargs)
logger.info_rank0(
......
......@@ -24,6 +24,7 @@ import transformers.models
from transformers.activations import ACT2FN
from ...extras import logging
from ...extras.packages import is_transformers_version_greater_than
if TYPE_CHECKING:
......@@ -281,7 +282,7 @@ _register_composite_model(
model_type="qwen2_vl",
projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["model", "lm_head"],
language_model_keys=["language_model"] if is_transformers_version_greater_than("4.52.0") else ["model", "lm_head"],
lora_conflict_keys=["patch_embed"],
)
......@@ -290,6 +291,6 @@ _register_composite_model(
model_type="qwen2_5_vl",
projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["model", "lm_head"],
language_model_keys=["language_model"] if is_transformers_version_greater_than("4.52.0") else ["model", "lm_head"],
lora_conflict_keys=["patch_embed"],
)
......@@ -85,8 +85,8 @@ def patch_processor(
setattr(processor, "video_min_pixels", model_args.video_min_pixels)
setattr(processor, "video_fps", model_args.video_fps)
setattr(processor, "video_maxlen", model_args.video_maxlen)
setattr(processor, "audio_sampling_rate", model_args.audio_sampling_rate)
setattr(processor, "use_audio_in_video", model_args.use_audio_in_video)
setattr(processor, "audio_sampling_rate", model_args.audio_sampling_rate)
def patch_config(
......@@ -102,8 +102,8 @@ def patch_config(
else:
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
configure_attn_implementation(config, model_args, is_trainable)
configure_rope(config, model_args, is_trainable)
configure_attn_implementation(config, model_args)
configure_rope(config, model_args)
configure_longlora(config, model_args, is_trainable)
configure_quantization(config, tokenizer, model_args, init_kwargs)
configure_moe(config, model_args, is_trainable)
......
......@@ -121,11 +121,11 @@ class CustomDPOTrainer(DPOTrainer):
return super().create_scheduler(num_training_steps, optimizer)
@override
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
def _get_train_sampler(self, *args, **kwargs) -> Optional["torch.utils.data.Sampler"]:
if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset)
return super()._get_train_sampler()
return super()._get_train_sampler(*args, **kwargs)
@override
def get_batch_samples(self, *args, **kwargs):
......
......@@ -34,7 +34,6 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, ge
if TYPE_CHECKING:
import torch.utils.data
from transformers import PreTrainedModel, ProcessorMixin
from ...hparams import FinetuningArguments
......@@ -119,12 +118,12 @@ class CustomKTOTrainer(KTOTrainer):
return super().create_scheduler(num_training_steps, optimizer)
@override
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
def _get_train_sampler(self, *args, **kwargs) -> Optional["torch.utils.data.Sampler"]:
r"""Replace the sequential sampler of KTO Trainer created by trl with the random sampler."""
if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset)
return Trainer._get_train_sampler(self)
return Trainer._get_train_sampler(self, *args, **kwargs)
@override
def get_batch_samples(self, *args, **kwargs):
......
......@@ -70,11 +70,11 @@ class CustomTrainer(Trainer):
return super().create_scheduler(num_training_steps, optimizer)
@override
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
def _get_train_sampler(self, *args, **kwargs) -> Optional["torch.utils.data.Sampler"]:
if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset)
return super()._get_train_sampler()
return super()._get_train_sampler(*args, **kwargs)
@override
def compute_loss(self, model, inputs, *args, **kwargs):
......
......@@ -77,12 +77,23 @@ def run_pt(
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval")
if isinstance(dataset_module.get("eval_dataset"), dict):
for key in dataset_module["eval_dataset"].keys():
try:
perplexity = math.exp(metrics[f"eval_{key}_loss"])
except OverflowError:
perplexity = float("inf")
metrics[f"eval_{key}_perplexity"] = perplexity
else:
try:
perplexity = math.exp(metrics["eval_loss"])
except OverflowError:
perplexity = float("inf")
metrics["perplexity"] = perplexity
metrics["eval_perplexity"] = perplexity
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
......
......@@ -78,11 +78,11 @@ class PairwiseTrainer(Trainer):
return super().create_scheduler(num_training_steps, optimizer)
@override
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
def _get_train_sampler(self, *args, **kwargs) -> Optional["torch.utils.data.Sampler"]:
if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset)
return super()._get_train_sampler()
return super()._get_train_sampler(*args, **kwargs)
@override
def compute_loss(
......
......@@ -92,11 +92,11 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
return super().create_scheduler(num_training_steps, optimizer)
@override
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
def _get_train_sampler(self, *args, **kwargs) -> Optional["torch.utils.data.Sampler"]:
if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset)
return super()._get_train_sampler()
return super()._get_train_sampler(*args, **kwargs)
@override
def compute_loss(self, model, inputs, *args, **kwargs):
......
......@@ -665,6 +665,7 @@ def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCall
mode=finetuning_args.swanlab_mode,
config={"Framework": "🦙LlamaFactory"},
logdir=finetuning_args.swanlab_logdir,
tags=["🦙LlamaFactory"],
)
return swanlab_callback
......
......@@ -15,6 +15,7 @@
import json
import os
from collections.abc import Generator
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Optional
from transformers.utils import is_torch_npu_available
......@@ -68,6 +69,14 @@ def _format_response(text: str, lang: str, escape_html: bool, thought_words: tup
)
@contextmanager
def update_attr(obj: Any, name: str, value: Any):
old_value = getattr(obj, name, None)
setattr(obj, name, value)
yield
setattr(obj, name, old_value)
class WebChatModel(ChatModel):
def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None:
self.manager = manager
......@@ -105,6 +114,11 @@ class WebChatModel(ChatModel):
elif self.demo_mode:
error = ALERTS["err_demo"][lang]
try:
json.loads(get("infer.extra_args"))
except json.JSONDecodeError:
error = ALERTS["err_json_schema"][lang]
if error:
gr.Warning(error)
yield error
......@@ -122,9 +136,9 @@ class WebChatModel(ChatModel):
enable_liger_kernel=(get("top.booster") == "liger_kernel"),
infer_backend=get("infer.infer_backend"),
infer_dtype=get("infer.infer_dtype"),
vllm_enforce_eager=True,
trust_remote_code=True,
)
args.update(json.loads(get("infer.extra_args")))
# checkpoints
if checkpoint_path:
......@@ -191,12 +205,14 @@ class WebChatModel(ChatModel):
temperature: float,
skip_special_tokens: bool,
escape_html: bool,
enable_thinking: bool,
) -> Generator[tuple[list[dict[str, str]], list[dict[str, str]]], None, None]:
r"""Generate output text in stream.
Inputs: infer.chatbot, infer.messages, infer.system, infer.tools, infer.image, infer.video, ...
Output: infer.chatbot, infer.messages
"""
with update_attr(self.engine.template, "enable_thinking", enable_thinking):
chatbot.append({"role": "assistant", "content": ""})
response = ""
for new_text in self.stream_chat(
......
......@@ -163,7 +163,14 @@ def save_args(config_path: str, config_dict: dict[str, Any]) -> None:
def _clean_cmd(args: dict[str, Any]) -> dict[str, Any]:
r"""Remove args with NoneType or False or empty string value."""
no_skip_keys = ["packing"]
no_skip_keys = [
"packing",
"enable_thinking",
"use_reentrant_gc",
"double_quantization",
"freeze_vision_tower",
"freeze_multi_modal_projector",
]
return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")}
......@@ -205,6 +212,14 @@ def load_eval_results(path: os.PathLike) -> str:
return f"```json\n{result}\n```\n"
def calculate_pixels(pixels: str) -> int:
r"""Calculate the number of pixels from the expression."""
if "*" in pixels:
return int(pixels.split("*")[0]) * int(pixels.split("*")[1])
else:
return int(pixels)
def create_ds_config() -> None:
r"""Create deepspeed config in the current directory."""
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
......
......@@ -79,6 +79,7 @@ def create_chat_box(
temperature = gr.Slider(minimum=0.01, maximum=1.5, value=0.95, step=0.01)
skip_special_tokens = gr.Checkbox(value=True)
escape_html = gr.Checkbox(value=True)
enable_thinking = gr.Checkbox(value=True)
clear_btn = gr.Button()
tools.input(check_json_schema, inputs=[tools, engine.manager.get_elem_by_id("top.lang")])
......@@ -103,6 +104,7 @@ def create_chat_box(
temperature,
skip_special_tokens,
escape_html,
enable_thinking,
],
[chatbot, messages],
)
......@@ -127,6 +129,7 @@ def create_chat_box(
temperature=temperature,
skip_special_tokens=skip_special_tokens,
escape_html=escape_html,
enable_thinking=enable_thinking,
clear_btn=clear_btn,
),
)
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from collections.abc import Generator
from typing import TYPE_CHECKING, Union
......@@ -57,6 +58,7 @@ def save_model(
export_legacy_format: bool,
export_dir: str,
export_hub_model_id: str,
extra_args: str,
) -> Generator[str, None, None]:
user_config = load_config()
error = ""
......@@ -73,6 +75,11 @@ def save_model(
elif export_quantization_bit in GPTQ_BITS and checkpoint_path and isinstance(checkpoint_path, list):
error = ALERTS["err_gptq_lora"][lang]
try:
json.loads(extra_args)
except json.JSONDecodeError:
error = ALERTS["err_json_schema"][lang]
if error:
gr.Warning(error)
yield error
......@@ -92,6 +99,7 @@ def save_model(
export_legacy_format=export_legacy_format,
trust_remote_code=True,
)
args.update(json.loads(extra_args))
if checkpoint_path:
if finetuning_type in PEFT_METHODS: # list
......@@ -118,6 +126,7 @@ def create_export_tab(engine: "Engine") -> dict[str, "Component"]:
with gr.Row():
export_dir = gr.Textbox()
export_hub_model_id = gr.Textbox()
extra_args = gr.Textbox(value="{}")
checkpoint_path: gr.Dropdown = engine.manager.get_elem_by_id("top.checkpoint_path")
checkpoint_path.change(can_quantize, [checkpoint_path], [export_quantization_bit], queue=False)
......@@ -141,6 +150,7 @@ def create_export_tab(engine: "Engine") -> dict[str, "Component"]:
export_legacy_format,
export_dir,
export_hub_model_id,
extra_args,
],
[info_box],
)
......@@ -153,6 +163,7 @@ def create_export_tab(engine: "Engine") -> dict[str, "Component"]:
export_legacy_format=export_legacy_format,
export_dir=export_dir,
export_hub_model_id=export_hub_model_id,
extra_args=extra_args,
export_btn=export_btn,
info_box=info_box,
)
......@@ -36,6 +36,7 @@ def create_infer_tab(engine: "Engine") -> dict[str, "Component"]:
with gr.Row():
infer_backend = gr.Dropdown(choices=["huggingface", "vllm", "sglang"], value="huggingface")
infer_dtype = gr.Dropdown(choices=["auto", "float16", "bfloat16", "float32"], value="auto")
extra_args = gr.Textbox(value='{"vllm_enforce_eager": true}')
with gr.Row():
load_btn = gr.Button()
......@@ -43,11 +44,12 @@ def create_infer_tab(engine: "Engine") -> dict[str, "Component"]:
info_box = gr.Textbox(show_label=False, interactive=False)
input_elems.update({infer_backend, infer_dtype})
input_elems.update({infer_backend, infer_dtype, extra_args})
elem_dict.update(
dict(
infer_backend=infer_backend,
infer_dtype=infer_dtype,
extra_args=extra_args,
load_btn=load_btn,
unload_btn=unload_btn,
info_box=info_box,
......
......@@ -18,7 +18,7 @@ from ...data import TEMPLATES
from ...extras.constants import METHODS, SUPPORTED_MODELS
from ...extras.packages import is_gradio_available
from ..common import save_config
from ..control import can_quantize, can_quantize_to, get_model_info, list_checkpoints
from ..control import can_quantize, can_quantize_to, check_template, get_model_info, list_checkpoints
if is_gradio_available():
......@@ -49,7 +49,7 @@ def create_top() -> dict[str, "Component"]:
model_name.change(get_model_info, [model_name], [model_path, template], queue=False).then(
list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False
)
).then(check_template, [lang, template])
model_name.input(save_config, inputs=[lang, model_name], queue=False)
model_path.input(save_config, inputs=[lang, model_name, model_path], queue=False)
finetuning_type.change(can_quantize, [finetuning_type], [quantization_bit], queue=False).then(
......
......@@ -106,11 +106,11 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
use_llama_pro = gr.Checkbox()
with gr.Column():
enable_thinking = gr.Checkbox(value=True)
report_to = gr.Dropdown(
choices=["none", "all", "wandb", "mlflow", "neptune", "tensorboard"],
value=["none"],
choices=["none", "wandb", "mlflow", "neptune", "tensorboard", "all"],
value="none",
allow_custom_value=True,
multiselect=True,
)
input_elems.update(
......@@ -126,6 +126,7 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
mask_history,
resize_vocab,
use_llama_pro,
enable_thinking,
report_to,
}
)
......@@ -143,6 +144,7 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
mask_history=mask_history,
resize_vocab=resize_vocab,
use_llama_pro=use_llama_pro,
enable_thinking=enable_thinking,
report_to=report_to,
)
)
......@@ -231,6 +233,42 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
)
)
with gr.Accordion(open=False) as mm_tab:
with gr.Row():
freeze_vision_tower = gr.Checkbox(value=True)
freeze_multi_modal_projector = gr.Checkbox(value=True)
freeze_language_model = gr.Checkbox(value=False)
with gr.Row():
image_max_pixels = gr.Textbox(value="768*768")
image_min_pixels = gr.Textbox(value="32*32")
video_max_pixels = gr.Textbox(value="256*256")
video_min_pixels = gr.Textbox(value="16*16")
input_elems.update(
{
freeze_vision_tower,
freeze_multi_modal_projector,
freeze_language_model,
image_max_pixels,
image_min_pixels,
video_max_pixels,
video_min_pixels,
}
)
elem_dict.update(
dict(
mm_tab=mm_tab,
freeze_vision_tower=freeze_vision_tower,
freeze_multi_modal_projector=freeze_multi_modal_projector,
freeze_language_model=freeze_language_model,
image_max_pixels=image_max_pixels,
image_min_pixels=image_min_pixels,
video_max_pixels=video_max_pixels,
video_min_pixels=video_min_pixels,
)
)
with gr.Accordion(open=False) as galore_tab:
with gr.Row():
use_galore = gr.Checkbox()
......
......@@ -84,6 +84,17 @@ def get_model_info(model_name: str) -> tuple[str, str]:
return get_model_path(model_name), get_template(model_name)
def check_template(lang: str, template: str) -> None:
r"""Check if an instruct model is used.
Please use queue=True to show the warning message.
Inputs: top.lang, top.template
"""
if template == "default":
gr.Warning(ALERTS["warn_no_instruct"][lang])
def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> tuple[str, "gr.Slider", dict[str, Any]]:
r"""Get training infomation for monitor.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment