Commit ca625f43 authored by shihm's avatar shihm
Browse files

uodata

parent 7164651d
......@@ -43,10 +43,24 @@ if TYPE_CHECKING:
from ..hparams import ModelArguments
if is_transformers_version_greater_than("4.57.0"):
from transformers.models.qwen3_omni_moe import modeling_qwen3_omni_moe
logger = logging.get_logger(__name__)
def patch_qwen3_omni_moe_thinker_text_sparse_moe_block():
if is_transformers_version_greater_than("4.57.0") and not is_transformers_version_greater_than("4.58.0"):
from .model_utils.moe import Qwen3OmniMoeThinkerTextSparseMoeBlock
logger.warning_rank0(
"You are using transformers with 4.x version, the Qwen3OmniMoeThinkerTextSparseMoeBlock will have some issues about deepspeed zero2 and fsdp2 training, so that we patched this model to avoid it. Transformers v5.0.0rc0 has fixed the issue, you can also try to update the transformers to using qwen3_omni. See more information on https://github.com/hiyouga/LLaMA-Factory/issues/9628."
)
modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextSparseMoeBlock = Qwen3OmniMoeThinkerTextSparseMoeBlock
def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> None:
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
......@@ -105,7 +119,7 @@ def patch_config(
configure_attn_implementation(config, model_args)
configure_rope(config, model_args)
configure_longlora(config, model_args, is_trainable)
configure_quantization(config, tokenizer, model_args, init_kwargs)
configure_quantization(config, tokenizer, model_args, is_trainable, init_kwargs)
configure_moe(config, model_args, is_trainable)
configure_visual_model(config)
configure_packing(model_args, is_trainable)
......@@ -136,19 +150,19 @@ def patch_config(
if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"):
raise RuntimeError("InternLM3 model requires transformers>=4.47.1, please upgrade it.")
if getattr(config, "model_type", None) == "qwen3_omni_moe":
patch_qwen3_omni_moe_thinker_text_sparse_moe_block()
# deepspeed zero3 is not compatible with low_cpu_mem_usage
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
# do not cast data type of the model deepspeed zero3 without qlora
if not (is_deepspeed_zero3_enabled() and model_args.quantization_bit is None):
init_kwargs["torch_dtype"] = model_args.compute_dtype
# fsdp/deepspeed zero3 does not need device map
if not (is_deepspeed_zero3_enabled() or is_fsdp_enabled()) and init_kwargs["low_cpu_mem_usage"]:
if "device_map" not in init_kwargs and model_args.device_map:
init_kwargs["device_map"] = model_args.device_map # device map requires low_cpu_mem_usage=True
if init_kwargs["low_cpu_mem_usage"] and not is_fsdp_enabled(): # fsdp does not need device map
if "device_map" not in init_kwargs and model_args.device_map:
init_kwargs["device_map"] = model_args.device_map # device map requires low_cpu_mem_usage=True
if init_kwargs.get("device_map", None) == "auto":
init_kwargs["offload_folder"] = model_args.offload_folder
if init_kwargs.get("device_map", None) == "auto":
init_kwargs["offload_folder"] = model_args.offload_folder
def patch_model(
......@@ -175,7 +189,12 @@ def patch_model(
prepare_valuehead_model(model)
if model_args.resize_vocab:
resize_embedding_layer(model, tokenizer)
resize_embedding_layer(
model,
tokenizer,
new_special_tokens_config=getattr(model_args, "_special_token_descriptions", None),
init_special_tokens=model_args.init_special_tokens,
)
if is_trainable:
if getattr(model.config, "model_type", None) == "gemma3n":
......@@ -211,9 +230,23 @@ def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
if isinstance(self.pretrained_model, PeftModel):
self.pretrained_model.create_or_update_model_card(output_dir)
def get_rope_index_func(self: "AutoModelForCausalLMWithValueHead"):
if isinstance(self.pretrained_model, PeftModel):
base_model = self.pretrained_model.base_model.model
else:
base_model = self.pretrained_model
if base_model and hasattr(base_model, "get_rope_index"):
return base_model.get_rope_index
elif base_model and hasattr(base_model, "model") and hasattr(base_model.model, "get_rope_index"):
return base_model.model.get_rope_index
else:
return None
ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name]
setattr(model, "_keys_to_ignore_on_save", ignore_modules)
setattr(model, "tie_weights", MethodType(tie_weights, model))
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))
setattr(model, "get_output_embeddings", MethodType(get_output_embeddings, model))
setattr(model, "get_rope_index", get_rope_index_func(model))
setattr(model, "create_or_update_model_card", MethodType(create_or_update_model_card, model))
......@@ -26,16 +26,13 @@ import transformers
from peft import PeftModel
from transformers import PreTrainedModel, ProcessorMixin, TrainerCallback
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
from transformers.utils import (
SAFE_WEIGHTS_NAME,
WEIGHTS_NAME,
is_safetensors_available,
)
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
from typing_extensions import override
from ..extras import logging
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.misc import get_peak_memory, is_env_enabled, use_ray
from ..extras.packages import is_safetensors_available
if is_safetensors_available():
......@@ -73,7 +70,7 @@ def fix_valuehead_checkpoint(
if safe_serialization:
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
state_dict: dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
state_dict: dict[str, torch.Tensor] = {key: f.get_tensor(key).clone() for key in f.keys()}
else:
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
state_dict: dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu", weights_only=True)
......
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/dpo_trainer.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
import torch
from ktransformers.sft.lora import KTrainer # type: ignore
from typing_extensions import override
from ..trainer_utils import get_batch_logps, nested_detach
from .trainer import CustomDPOTrainer
if TYPE_CHECKING:
from transformers import PreTrainedModel
class KDPOTrainer(KTrainer, CustomDPOTrainer):
@override
def concatenated_forward(
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"], is_ref_model: bool = False
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
Otherwise the average log probabilities.
"""
if self.finetuning_args.use_ref_model:
batch = nested_detach(batch, clone=True) # avoid error
labels = batch.pop("labels") # dpo do not need compute loss in forward
all_logits: torch.Tensor = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
all_logits = all_logits.to("cpu")
labels = labels.to(all_logits.device)
all_logps, valid_length = get_batch_logps(
logits=all_logits, labels=labels, ld_alpha=(self.ld_alpha if not is_ref_model else None)
)
if self.loss_type in ["ipo", "orpo", "simpo"]:
all_logps = all_logps / valid_length
batch_size = batch["input_ids"].size(0) // 2
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
chosen_length, _ = valid_length.split(batch_size, dim=0)
if self.loss_type in ["ipo", "orpo", "simpo"]:
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps
else:
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length
......@@ -26,6 +26,7 @@ import torch.nn.functional as F
from transformers import Trainer
from trl import DPOTrainer
from trl.trainer import disable_dropout_in_model
from trl.trainer.utils import prepare_deepspeed
from typing_extensions import override
from ...extras.constants import IGNORE_INDEX
......@@ -78,6 +79,7 @@ class CustomDPOTrainer(DPOTrainer):
self.beta = finetuning_args.pref_beta
self.loss_type = finetuning_args.pref_loss
self.ftx_gamma = finetuning_args.pref_ftx
self.bco_gemma = finetuning_args.pref_bco_weight
self.label_smoothing = finetuning_args.dpo_label_smoothing
self.simpo_gamma = finetuning_args.simpo_gamma
self.ld_alpha = finetuning_args.ld_alpha
......@@ -94,7 +96,7 @@ class CustomDPOTrainer(DPOTrainer):
if not (
getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
): # quantized models are already set on the correct device
self.ref_model = self._prepare_deepspeed(self.ref_model)
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
self.ref_model.eval()
......@@ -108,6 +110,11 @@ class CustomDPOTrainer(DPOTrainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)
if self.bco_gemma >= 1e-6:
from trl.trainer import RunningMoments
self.running = RunningMoments(self.accelerator)
@override
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
......@@ -151,6 +158,25 @@ class CustomDPOTrainer(DPOTrainer):
simpo_loss = -F.logsigmoid(self.beta * logits)
return simpo_loss
def bco_loss(
self,
chosen_logps: "torch.Tensor",
rejected_logps: "torch.Tensor",
reference_chosen_logps: "torch.Tensor",
reference_rejected_logps: "torch.Tensor",
) -> "torch.Tensor":
chosen_logratios = chosen_logps - reference_chosen_logps
rejected_logratios = rejected_logps - reference_rejected_logps
chosen_rewards = self.beta * chosen_logratios
rejected_rewards = self.beta * rejected_logratios
rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach()
self.running.update(rewards) # update baseline
delta = self.running.mean
bco_loss = -F.logsigmoid((self.beta * chosen_logratios) - delta) - F.logsigmoid(
-(self.beta * rejected_logratios - delta)
)
return bco_loss
def compute_preference_loss(
self,
policy_chosen_logps: "torch.Tensor",
......@@ -174,12 +200,18 @@ class CustomDPOTrainer(DPOTrainer):
policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
)
if self.bco_gemma > 1e-6:
bco_losses = self.bco_loss(
policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
)
losses = (losses + bco_losses * self.bco_gemma) / (1.0 + self.bco_gemma) # re-weight W_p and W_q
return losses, chosen_rewards, rejected_rewards
@override
def concatenated_forward(
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"], is_ref_model: bool = False
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
) -> dict[str, "torch.Tensor"]:
r"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
Otherwise the average log probabilities.
......@@ -187,9 +219,10 @@ class CustomDPOTrainer(DPOTrainer):
if self.finetuning_args.use_ref_model:
batch = nested_detach(batch, clone=True) # avoid error
labels = batch.pop("labels") # dpo do not need compute loss in forward
all_logits: torch.Tensor = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
all_logps, valid_length = get_batch_logps(
logits=all_logits, labels=batch["labels"], ld_alpha=(self.ld_alpha if not is_ref_model else None)
logits=all_logits, labels=labels, ld_alpha=(self.ld_alpha if not is_ref_model else None)
)
if self.loss_type in ["ipo", "orpo", "simpo"]:
all_logps = all_logps / valid_length
......@@ -198,11 +231,18 @@ class CustomDPOTrainer(DPOTrainer):
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
chosen_length, _ = valid_length.split(batch_size, dim=0)
if self.loss_type in ["ipo", "orpo", "simpo"]:
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps
chosen_logps_avg = chosen_logps
else:
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length
chosen_logps_avg = chosen_logps / chosen_length
return {
"chosen_logps": chosen_logps,
"rejected_logps": rejected_logps,
"chosen_logits": chosen_logits,
"rejected_logits": rejected_logits,
"chosen_logps_avg": chosen_logps_avg,
}
@override
def compute_reference_log_probs(
......@@ -220,9 +260,9 @@ class CustomDPOTrainer(DPOTrainer):
ref_context = nullcontext()
with torch.no_grad(), ref_context:
reference_chosen_logps, reference_rejected_logps, *_ = self.concatenated_forward(
ref_model, batch, is_ref_model=True
)
ref_output = self.concatenated_forward(ref_model, batch, is_ref_model=True)
reference_chosen_logps = ref_output["chosen_logps"]
reference_rejected_logps = ref_output["rejected_logps"]
return reference_chosen_logps, reference_rejected_logps
......@@ -235,13 +275,13 @@ class CustomDPOTrainer(DPOTrainer):
) -> tuple["torch.Tensor", dict[str, "torch.Tensor"]]:
r"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
metrics = {}
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_chosen_logps_avg,
) = self.concatenated_forward(model, batch)
model_output = self.concatenated_forward(model, batch)
policy_chosen_logps = model_output["chosen_logps"]
policy_rejected_logps = model_output["rejected_logps"]
policy_chosen_logits = model_output["chosen_logits"]
policy_rejected_logits = model_output["rejected_logits"]
policy_chosen_logps_avg = model_output["chosen_logps_avg"]
reference_chosen_logps, reference_rejected_logps = self.compute_reference_log_probs(model, batch)
losses, chosen_rewards, rejected_rewards = self.compute_preference_loss(
......
......@@ -24,7 +24,6 @@ from ...extras.ploting import plot_loss
from ...hparams import ModelArguments
from ...model import load_model, load_tokenizer
from ..trainer_utils import create_modelcard_and_push, create_ref_model
from .trainer import CustomDPOTrainer
if TYPE_CHECKING:
......@@ -63,6 +62,16 @@ def run_dpo(
else:
ref_model = None
if model_args.use_kt:
from ktransformers.util.globals import GLOBAL_CONFIG # type: ignore
from .ktrainer import KDPOTrainer as CustomDPOTrainer
GLOBAL_CONFIG._config["mod"] = "sft"
else:
from .trainer import CustomDPOTrainer
# Initialize our Trainer
trainer = CustomDPOTrainer(
model=model,
......
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Optional
from ..extras import logging
if TYPE_CHECKING:
from ..hparams import ModelArguments
logger = logging.get_logger(__name__)
def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]:
"""Create AORecipeKwargs for FP8 training with HuggingFace Accelerate.
Args:
model_args: Model arguments containing FP8 configuration
Returns:
List containing AORecipeKwargs if FP8 is enabled and supported, empty list otherwise
"""
if not model_args.fp8:
return []
try:
# Check if AORecipeKwargs is available (Accelerate 1.8.0+)
from accelerate.utils import AORecipeKwargs
backend = getattr(model_args, "fp8_backend", "auto")
logger.info_rank0(f"Creating FP8 configuration with backend: {backend}")
# Create Float8LinearConfig if torchao backend is used
config = None
if backend == "torchao" or backend == "auto":
from torchao.float8 import Float8LinearConfig
# Use rowwise scaling for better performance (as recommended by torchao)
# Configure alignment requirements for FP8 kernels
config = Float8LinearConfig.from_recipe_name("rowwise")
# Enable alignment for better kernel performance
if hasattr(config, "enable_amax_init"):
config.enable_amax_init = True
if hasattr(config, "enable_pre_and_post_forward"):
config.enable_pre_and_post_forward = True
# Create module filter function to skip problematic layers
# TorchAO FP8 requires dimensions divisible by 16 for optimal kernels
def module_filter_func(module, layer_name):
# Skip embedding and output layers for numerical stability
skip_layers = ["embed", "lm_head", "output", "classifier"]
if any(skip_name in layer_name.lower() for skip_name in skip_layers):
return False
# Only convert Linear layers
if not (hasattr(module, "weight") and len(module.weight.shape) == 2):
return False
# Check dimension alignment for FP8 kernels
weight = module.weight
in_features, out_features = weight.shape[1], weight.shape[0]
# Skip layers with dimensions not divisible by 16 to avoid kernel errors
if in_features % 16 != 0 or out_features % 16 != 0:
logger.debug(
f"Skipping layer {layer_name} with dimensions {out_features}x{in_features} (not divisible by 16)"
)
return False
return True
# Map FSDP all-gather setting if available (this affects the underlying implementation)
if hasattr(model_args, "fp8_enable_fsdp_float8_all_gather") and model_args.fp8_enable_fsdp_float8_all_gather:
logger.info_rank0("FSDP float8 all-gather optimization requested")
return [AORecipeKwargs(config=config, module_filter_func=module_filter_func)]
except Exception as e:
logger.info_rank0(f"Failed to create FP8 configuration: {e}")
return []
def get_fp8_mixed_precision(model_args: "ModelArguments") -> Optional[str]:
"""Get the mixed precision setting for Accelerate when using FP8.
Args:
model_args: Model arguments containing FP8 configuration
Returns:
"fp8" if FP8 is enabled, None otherwise
"""
return "fp8" if model_args.fp8 else None
def configure_fp8_environment(model_args: "ModelArguments") -> None:
"""Configure FP8 environment for HuggingFace Accelerate.
FP8 training is handled entirely through HuggingFace Accelerate, regardless of whether
DeepSpeed or FSDP is used for distributed training. This function sets up the environment
variables and validates the FP8 configuration.
Args:
model_args: Model arguments containing FP8 configuration
"""
import os
if not model_args.fp8:
return
# Set mixed precision to fp8 for HuggingFace Accelerate
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8"
logger.info_rank0("Set ACCELERATE_MIXED_PRECISION=fp8")
# Configure FP8 backend and options
backend = getattr(model_args, "fp8_backend", "auto")
if backend != "auto":
os.environ["FP8_BACKEND"] = backend
logger.info_rank0(f"Set FP8_BACKEND={backend}")
# Create and validate FP8 recipe kwargs (for logging/debugging)
fp8_kwargs = create_fp8_kwargs(model_args)
logger.info_rank0(f"FP8 AORecipeKwargs created: {len(fp8_kwargs)} items")
# Enable FSDP float8 all-gather optimization if requested
if hasattr(model_args, "fp8_enable_fsdp_float8_all_gather") and model_args.fp8_enable_fsdp_float8_all_gather:
os.environ["FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER"] = "true"
logger.info_rank0("Set FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER=true")
logger.info_rank0("FP8 environment configured - all FP8 training handled by HuggingFace Accelerate")
def verify_fp8_status(accelerator, model_args: "ModelArguments") -> None:
"""Verify that FP8 training is actually working after model preparation.
Args:
accelerator: The HuggingFace Accelerator instance
model_args: Model arguments containing FP8 configuration
"""
if not model_args.fp8:
return
# Check Accelerate's FP8 status
fp8_enabled = getattr(accelerator, "fp8_enabled", False)
fp8_backend_type = getattr(accelerator, "fp8_backend", "UNKNOWN")
backend = getattr(model_args, "fp8_backend", "auto")
if backend == "torchao" or backend == "auto":
logger.info_rank0(
"FP8 training enabled with TorchAO backend. For optimal performance, "
"ensure model layer dimensions are mostly divisible by 16. "
"If you encounter issues, try fp8_backend='te' with Transformer Engine."
)
else:
logger.info_rank0(f"FP8 training enabled with {backend} backend.")
logger.info_rank0(f"Accelerate FP8 status - enabled: {fp8_enabled}, backend: {fp8_backend_type}")
if not fp8_enabled:
logger.info_rank0("WARNING: FP8 was requested but Accelerate shows fp8_enabled=False. FP8 may not be working.")
......@@ -25,6 +25,7 @@ import torch
from transformers import Trainer
from trl import KTOTrainer
from trl.trainer import disable_dropout_in_model
from trl.trainer.utils import prepare_deepspeed
from typing_extensions import override
from ...extras.constants import IGNORE_INDEX
......@@ -77,6 +78,13 @@ class CustomKTOTrainer(KTOTrainer):
self.desirable_weight = finetuning_args.kto_chosen_weight
self.undesirable_weight = finetuning_args.kto_rejected_weight
self.ftx_gamma = finetuning_args.pref_ftx
# trl
# Not all losses require a KL calculation
self.calculate_KL = True
if hasattr(self, "loss_type") and self.loss_type in ["apo_zero_unpaired"]:
self.calculate_KL = False
else:
self.loss_type = "kto"
Trainer.__init__(self, model=model, **kwargs)
self.model_accepts_loss_kwargs = False # overwrite trainer's default behavior
......@@ -90,7 +98,7 @@ class CustomKTOTrainer(KTOTrainer):
if not (
getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
): # quantized models are already set on the correct device
self.ref_model = self._prepare_deepspeed(self.ref_model)
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
self.ref_model.eval()
......
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .workflow import run_dpo, run_pt, run_sft
__all__ = ["run_dpo", "run_pt", "run_sft"]
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO override the original trainer
# Copyright 2025 the ROLL team and the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from collections.abc import Sequence
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Optional
from transformers import DataCollatorForSeq2Seq
from ...data import (
SFTDataCollatorWith4DAttentionMask,
get_dataset,
get_template_and_fix_tokenizer,
)
from ...data.collator import (
PairwiseDataCollatorWithPadding,
)
from ...extras.constants import IGNORE_INDEX, MCA_SUPPORTED_MODELS
from ...extras.logging import get_logger
from ...extras.misc import calculate_tps
from ...extras.packages import is_mcore_adapter_available
from ...extras.ploting import plot_loss
from ...model import load_tokenizer
from ..callbacks import SaveProcessorCallback
if not is_mcore_adapter_available():
raise ImportError("mcore_adapter is not installed. Please install it with `pip install mcore-adapter`.")
from mcore_adapter.models import AutoConfig, AutoModel
from mcore_adapter.trainer import DPOTrainer as McaDPOTrainer
from mcore_adapter.trainer import McaTrainer
from mcore_adapter.trainer.dpo_config import DPOConfig
if TYPE_CHECKING:
from mcore_adapter.training_args import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments
from transformers import TrainerCallback
from ...hparams import DataArguments, FinetuningArguments, ModelArguments
logger = get_logger(__name__)
def _data_collator_wrapper(data_collator: Any):
@functools.wraps(data_collator)
def wrapper(features: Sequence[dict[str, Any]]):
labels_key = [k for k in features[0].keys() if k.endswith("labels")]
input_ids_key = [k for k in features[0].keys() if k.endswith("input_ids")]
for feature in features:
if len(labels_key) == 0: # pt
feature["labels"] = deepcopy(feature["input_ids"])[1:]
for k in labels_key:
feature[k] = feature[k][1:]
for k in input_ids_key:
feature[k] = feature[k][:-1]
for k in ["attention_mask", "position_ids"]:
if k in feature:
feature[k] = feature[k][:-1]
return data_collator(features)
return wrapper
def _check_model_support(model_args: "ModelArguments"):
from transformers import AutoConfig as HfAutoConfig
config = HfAutoConfig.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
if config.model_type not in MCA_SUPPORTED_MODELS:
raise ValueError(f"Model {config.model_type} is not supported by MCA.")
def run_pt(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "McaSeq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[list["TrainerCallback"]] = None,
):
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
template = get_template_and_fix_tokenizer(tokenizer, data_args)
# dataset needs +1 then cut back due to MCA shift logic
data_args.cutoff_len += 1
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="pt", **tokenizer_module)
data_args.cutoff_len -= 1
_check_model_support(model_args)
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer,
pad_to_multiple_of=8,
label_pad_token_id=IGNORE_INDEX,
)
data_collator = _data_collator_wrapper(data_collator)
trainer = McaTrainer(
model=model,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
**dataset_module,
)
if "processor" in tokenizer_module and tokenizer_module["processor"] is not None:
trainer.add_callback(SaveProcessorCallback(tokenizer_module["processor"]))
if training_args.do_train:
train_result = trainer.train(training_args.resume_from_checkpoint)
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
keys = ["loss"]
if isinstance(dataset_module.get("eval_dataset"), dict):
keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()]
else:
keys += ["eval_loss"]
plot_loss(training_args.output_dir, keys=keys)
def run_sft(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "McaSeq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[list["TrainerCallback"]] = None,
):
# align packing flags
# TODO: FIX SequencePacking
data_args.neat_packing = training_args.sequence_packing = data_args.neat_packing or training_args.sequence_packing
data_args.packing = data_args.neat_packing or data_args.packing
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
template = get_template_and_fix_tokenizer(tokenizer, data_args)
# dataset needs +1 then cut back due to MCA shift logic
data_args.cutoff_len += 1
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module)
data_args.cutoff_len -= 1
_check_model_support(model_args)
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
# optional freezing for qwen2_vl, qwen2_5_vl
if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl"]:
params_to_freeze = []
if finetuning_args.freeze_vision_tower:
params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"])
if finetuning_args.freeze_multi_modal_projector:
params_to_freeze.extend(["multi_modal_projector"])
if finetuning_args.freeze_language_model:
params_to_freeze.extend(["embedding", "decoder", "output_layer"])
if params_to_freeze:
for name, p in model.named_parameters():
if any(name.startswith(k) for k in params_to_freeze):
p.requires_grad_(False)
pad_to_max = training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1
data_collator = SFTDataCollatorWith4DAttentionMask(
template=template,
padding="max_length" if pad_to_max else "longest",
max_length=data_args.cutoff_len if pad_to_max else None,
pad_to_multiple_of=64,
label_pad_token_id=IGNORE_INDEX,
**tokenizer_module,
)
data_collator = _data_collator_wrapper(data_collator)
trainer = McaTrainer(
model=model,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
**dataset_module,
)
if "processor" in tokenizer_module and tokenizer_module["processor"] is not None:
trainer.add_callback(SaveProcessorCallback(tokenizer_module["processor"]))
train_result = trainer.train(training_args.resume_from_checkpoint)
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
keys = ["loss"]
if isinstance(dataset_module.get("eval_dataset"), dict):
keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()]
else:
keys += ["eval_loss"]
plot_loss(training_args.output_dir, keys=keys)
def run_dpo(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "McaSeq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[list["TrainerCallback"]] = None,
):
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
template = get_template_and_fix_tokenizer(tokenizer, data_args)
_check_model_support(model_args)
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
if finetuning_args.use_ref_model:
ref_config = AutoConfig.from_pretrained(model_args.model_name_or_path, training_args)
ref_model = AutoModel.from_config(ref_config)
ref_model.load_state_dict(model.state_dict())
else:
ref_model = None
# dataset needs +1 then cut back due to MCA shift logic
data_args.cutoff_len += 1
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="rm", **tokenizer_module)
data_args.cutoff_len -= 1
pad_to_max = training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1
dpo_config = DPOConfig(
beta=finetuning_args.pref_beta,
pref_loss=finetuning_args.pref_loss,
label_smoothing=finetuning_args.dpo_label_smoothing,
)
data_collator = PairwiseDataCollatorWithPadding(
template=template,
pad_to_multiple_of=64,
padding="max_length" if pad_to_max else "longest",
max_length=data_args.cutoff_len if pad_to_max else None,
label_pad_token_id=IGNORE_INDEX,
**tokenizer_module,
)
data_collator = _data_collator_wrapper(data_collator)
trainer = McaDPOTrainer(
model=model,
ref_model=ref_model,
args=training_args,
train_config=dpo_config,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
**dataset_module,
)
if "processor" in tokenizer_module and tokenizer_module["processor"] is not None:
trainer.add_callback(SaveProcessorCallback(tokenizer_module["processor"]))
train_result = trainer.train(training_args.resume_from_checkpoint)
trainer.save_model()
if finetuning_args.include_effective_tokens_per_second:
train_result.metrics["effective_tokens_per_sec"] = calculate_tps(
dataset_module["train_dataset"], train_result.metrics, stage="rm"
)
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
keys = ["loss", "rewards/accuracies"]
if isinstance(dataset_module.get("eval_dataset"), dict):
keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()]
else:
keys += ["eval_loss"]
plot_loss(training_args.output_dir, keys=keys)
......@@ -33,12 +33,12 @@ from transformers.trainer_pt_utils import remove_dummy_checkpoint
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
from trl import PPOConfig, PPOTrainer
from trl.core import PPODecorators, logprobs_from_logits
from trl import __version__ as trl_version
from trl.models.utils import unwrap_model_for_generation
from typing_extensions import override
from ...extras import logging
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor, torch_gc
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
......@@ -83,6 +83,19 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if eval_dataset is not None:
raise NotImplementedError("PPOTrainer does not support eval dataset yet.")
# Check if TRL version is compatible (0.8.6 <= version <= 0.9.6)
try:
from transformers.utils.versions import require_version
require_version(
"trl>=0.8.6,<=0.9.6",
"Incompatible TRL version detected. LLaMA-Factory ppo requires TRL version >=0.8.6,<=0.9.6. "
f"Found version {trl_version}. Please install the correct version with: `pip install trl>=0.8.6,<=0.9.6`\n"
"To fix: run `DISABLE_VERSION_CHECK=1 llamafactory-cli train example_ppo.yaml`\n",
)
except ImportError as e:
raise e
backward_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
ppo_config = PPOConfig(
model_name=model_args.model_name_or_path,
......@@ -390,7 +403,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
batch: dict[str, torch.Tensor] = self.prepare_model_inputs(queries, responses)
unwrapped_model: AutoModelForCausalLMWithValueHead = self.accelerator.unwrap_model(self.model)
if self.finetuning_args.reward_model_type == "lora":
if self.finetuning_args.reward_model_type in ["lora", "oft"]:
replace_model(unwrapped_model, target="reward")
reward_model = self.model
else:
......@@ -399,14 +412,13 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
with unwrap_model_for_generation(reward_model, self.accelerator), self.amp_context: # support bf16
values: torch.Tensor = reward_model(**batch, return_dict=True, use_cache=False)[-1]
if self.finetuning_args.reward_model_type == "lora":
if self.finetuning_args.reward_model_type in ["lora", "oft"]:
replace_model(unwrapped_model, target="default")
rewards = values.gather(dim=-1, index=(batch["attention_mask"].sum(dim=-1, keepdim=True) - 1))
return rewards.float().detach() # use fp32 type
@override
@PPODecorators.empty_device_cache()
def batched_forward_pass(
self,
model: "AutoModelForCausalLMWithValueHead",
......@@ -420,6 +432,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
Subclass and override to inject custom behavior.
"""
from trl.core import logprobs_from_logits
torch_gc()
bs = len(queries)
fbs = self.config.mini_batch_size
all_logprobs = []
......
......@@ -21,21 +21,29 @@ from typing_extensions import override
from ...extras.packages import is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback
from ..fp8_utils import configure_fp8_environment, verify_fp8_status
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if TYPE_CHECKING:
from transformers import ProcessorMixin
from ...hparams import FinetuningArguments
from ...hparams import FinetuningArguments, ModelArguments
class CustomTrainer(Trainer):
r"""Inherit Trainer for custom optimizer."""
def __init__(
self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
self,
finetuning_args: "FinetuningArguments",
processor: Optional["ProcessorMixin"],
model_args: Optional["ModelArguments"] = None,
**kwargs,
) -> None:
# Configure FP8 environment if enabled
if model_args is not None and model_args.fp8:
configure_fp8_environment(model_args)
if is_transformers_version_greater_than("4.46"):
kwargs["processing_class"] = kwargs.pop("tokenizer")
......@@ -56,6 +64,10 @@ class CustomTrainer(Trainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)
# Verify FP8 status after trainer initialization (accelerator should be available)
if model_args is not None and model_args.fp8 and hasattr(self, "accelerator"):
verify_fp8_status(self.accelerator, model_args)
@override
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
......
......@@ -13,7 +13,7 @@
# limitations under the License.
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING
import numpy as np
......@@ -28,7 +28,7 @@ if TYPE_CHECKING:
class ComputeAccuracy:
r"""Compute reward accuracy and support `batch_eval_metrics`."""
def _dump(self) -> Optional[dict[str, float]]:
def _dump(self) -> dict[str, float] | None:
result = None
if hasattr(self, "score_dict"):
result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}
......@@ -39,7 +39,7 @@ class ComputeAccuracy:
def __post_init__(self):
self._dump()
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[dict[str, float]]:
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> dict[str, float] | None:
chosen_scores, rejected_scores = numpify(eval_preds.predictions[0]), numpify(eval_preds.predictions[1])
if not chosen_scores.shape:
self.score_dict["accuracy"].append(chosen_scores > rejected_scores)
......
......@@ -21,11 +21,11 @@ from typing import TYPE_CHECKING, Optional
import numpy as np
import torch
from transformers.utils import is_jieba_available, is_nltk_available
from transformers.utils import is_nltk_available
from ...extras.constants import IGNORE_INDEX
from ...extras.misc import numpify
from ...extras.packages import is_rouge_available
from ...extras.packages import is_jieba_available, is_rouge_available
if TYPE_CHECKING:
......
......@@ -29,6 +29,7 @@ from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback
from ..fp8_utils import configure_fp8_environment, verify_fp8_status
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
......@@ -37,7 +38,7 @@ if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.trainer import PredictionOutput
from ...hparams import FinetuningArguments
from ...hparams import FinetuningArguments, ModelArguments
logger = logging.get_logger(__name__)
......@@ -50,9 +51,13 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
self,
finetuning_args: "FinetuningArguments",
processor: Optional["ProcessorMixin"],
model_args: Optional["ModelArguments"] = None,
gen_kwargs: Optional[dict[str, Any]] = None,
**kwargs,
) -> None:
# Configure FP8 environment if enabled
if model_args is not None and model_args.fp8:
configure_fp8_environment(model_args)
if is_transformers_version_greater_than("4.46"):
kwargs["processing_class"] = kwargs.pop("tokenizer")
else:
......@@ -78,6 +83,15 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)
if finetuning_args.use_dft_loss:
from ..trainer_utils import dft_loss_func
self.compute_loss_func = dft_loss_func
# Verify FP8 status after trainer initialization (accelerator should be available)
if model_args is not None and model_args.fp8 and hasattr(self, "accelerator"):
verify_fp8_status(self.accelerator, model_args)
@override
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
......
......@@ -21,6 +21,7 @@ from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_templat
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from ...extras.misc import calculate_tps
from ...extras.packages import is_transformers_version_greater_than
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
from ..trainer_utils import create_modelcard_and_push
......@@ -67,6 +68,12 @@ def run_sft(
# Metric utils
metric_module = {}
if model_args.use_kt:
if training_args.predict_with_generate:
raise NotImplementedError("`predict_with_generate` is not supported in KTransformers SFT yet.")
elif finetuning_args.compute_accuracy:
raise NotImplementedError("`compute_accuracy` is not supported in KTransformers SFT yet.")
if training_args.predict_with_generate:
metric_module["compute_metrics"] = ComputeSimilarity(tokenizer=tokenizer)
elif finetuning_args.compute_accuracy:
......@@ -75,21 +82,52 @@ def run_sft(
# Keyword arguments for `model.generate`
gen_kwargs = generating_args.to_dict(obey_generation_config=True)
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
# Compatible with Transformers v4 and Transformers v5
if is_transformers_version_greater_than("4.58.0"):
extra_ids = getattr(tokenizer, "additional_special_tokens_ids", None)
if not isinstance(extra_ids, list):
extra_special_tokens = getattr(tokenizer, "_extra_special_tokens", [])
string_tokens = [str(t) for t in extra_special_tokens]
extra_ids = tokenizer.convert_tokens_to_ids(string_tokens)
all_eos_ids = [tokenizer.eos_token_id] + [i for i in extra_ids if i != -1]
unique_eos_ids = list(dict.fromkeys(all_eos_ids))
gen_kwargs["eos_token_id"] = unique_eos_ids
else:
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
# Initialize our Trainer
trainer = CustomSeq2SeqTrainer(
model=model,
args=training_args,
finetuning_args=finetuning_args,
data_collator=data_collator,
callbacks=callbacks,
gen_kwargs=gen_kwargs,
**dataset_module,
**tokenizer_module,
**metric_module,
)
if model_args.use_kt:
from ktransformers.sft.lora import KTrainer # type: ignore
from ktransformers.util.globals import GLOBAL_CONFIG # type: ignore
GLOBAL_CONFIG._config["mod"] = "sft"
trainer = KTrainer(
model=model,
args=training_args,
tokenizer=tokenizer_module,
data_collator=data_collator,
callbacks=callbacks,
**dataset_module,
**metric_module,
)
trainer.model_accepts_loss_kwargs = False
model.config.use_cache = False
else:
trainer = CustomSeq2SeqTrainer(
model=model,
args=training_args,
finetuning_args=finetuning_args,
data_collator=data_collator,
callbacks=callbacks,
gen_kwargs=gen_kwargs,
**dataset_module,
**tokenizer_module,
**metric_module,
)
# Training
if training_args.do_train:
......
......@@ -20,7 +20,6 @@ from transformers import AutoModelForCausalLM
from trl import AutoModelForCausalLMWithValueHead
from ..data import get_dataset, get_template_and_fix_tokenizer
from ..extras.misc import get_current_device
from ..hparams import get_infer_args, get_train_args
from ..model import load_model, load_tokenizer
......@@ -81,17 +80,14 @@ def load_reference_model(
is_trainable: bool = False,
add_valuehead: bool = False,
) -> Union["PreTrainedModel", "LoraModel"]:
current_device = get_current_device()
if add_valuehead:
model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(
model_path, torch_dtype=torch.float16, device_map=current_device
model_path, torch_dtype=torch.float16, device_map="auto"
)
if not is_trainable:
model.v_head = model.v_head.to(torch.float16)
return model
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map=current_device)
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map="auto")
if use_lora or use_pissa:
model = PeftModel.from_pretrained(
model, lora_path, subfolder="pissa_init" if use_pissa else None, is_trainable=is_trainable
......
......@@ -19,9 +19,9 @@
import json
import os
from collections.abc import Mapping
from collections.abc import Callable, Mapping
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union
import torch
from transformers import Trainer
......@@ -100,12 +100,15 @@ def create_modelcard_and_push(
if model_args.use_unsloth:
kwargs["tags"] = kwargs["tags"] + ["unsloth"]
if model_args.use_kt:
kwargs["tags"] = kwargs["tags"] + ["ktransformers"]
if not training_args.do_train:
pass
elif training_args.push_to_hub:
trainer.push_to_hub(**kwargs)
else:
trainer.create_model_card(license="other", **kwargs) # prevent from connecting to hub
Trainer.create_model_card(trainer, license="other", **kwargs) # prevent from connecting to hub
def create_ref_model(
......@@ -631,6 +634,51 @@ def get_batch_logps(
return logps, valid_length
def dft_loss_func(outputs, labels, num_items_in_batch=None):
logits = outputs.get("logits")
if logits is None:
return outputs.get("loss", torch.tensor(0.0))
logits = logits.float()
vocab_size = logits.size(-1)
labels = torch.nn.functional.pad(labels, (0, 1), value=-100)
shift_labels = labels[..., 1:].contiguous()
logits = logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1)
shift_labels = shift_labels.to(logits.device)
loss = _dft_cross_entropy(logits, shift_labels, num_items_in_batch)
return loss
def _dft_cross_entropy(
source: torch.Tensor,
target: torch.Tensor,
num_items_in_batch: Optional[torch.Tensor] = None,
ignore_index: int = -100,
) -> torch.Tensor:
per_token_loss = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none")
valid_mask = target != ignore_index
if not valid_mask.any():
return torch.tensor(0.0, device=source.device, dtype=source.dtype)
valid_losses = per_token_loss[valid_mask]
with torch.no_grad():
target_probs = torch.exp(-valid_losses)
weighted_losses = valid_losses * target_probs
if num_items_in_batch is not None:
total_loss = weighted_losses.sum()
if torch.is_tensor(num_items_in_batch):
num_items_in_batch = num_items_in_batch.to(total_loss.device)
loss = total_loss / num_items_in_batch
else:
loss = weighted_losses.mean()
return loss
def nested_detach(
tensors: Union["torch.Tensor", list["torch.Tensor"], tuple["torch.Tensor"], dict[str, "torch.Tensor"]],
clone: bool = False,
......
# Copyright 2025 the LlamaFactory team.
# Copyright 2025 the KVCache.AI team, Approaching AI, and the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -24,7 +24,7 @@ from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.misc import infer_optim_dtype
from ..extras.packages import is_ray_available
from ..extras.packages import is_mcore_adapter_available, is_ray_available
from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args
from ..model import load_model, load_tokenizer
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
......@@ -66,7 +66,23 @@ def _training_function(config: dict[str, Any]) -> None:
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
if finetuning_args.stage == "pt":
if finetuning_args.stage in ["pt", "sft", "dpo"] and finetuning_args.use_mca:
if not is_mcore_adapter_available():
raise ImportError("mcore_adapter is not installed. Please install it with `pip install mcore-adapter`.")
if finetuning_args.stage == "pt":
from .mca import run_pt as run_pt_mca
run_pt_mca(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "sft":
from .mca import run_sft as run_sft_mca
run_sft_mca(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "dpo":
from .mca import run_dpo as run_dpo_mca
run_dpo_mca(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "pt":
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "sft":
run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment