Commit 7ea81099 authored by chenych's avatar chenych
Browse files

update llama4

parent 84987715
# Copyright 2024 HuggingFace Inc., Daniel Han-Chen & the Unsloth team and the LlamaFactory team.
# Copyright 2025 HuggingFace Inc., Daniel Han-Chen & the Unsloth team and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's Transformers and PEFT library,
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py
......@@ -21,7 +21,7 @@
import inspect
from functools import WRAPPER_ASSIGNMENTS, partial, wraps
from types import MethodType
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import torch
......@@ -40,9 +40,7 @@ logger = logging.get_logger(__name__)
def get_unsloth_gradient_checkpointing_func() -> Callable:
class UnslothGradientCheckpointing(torch.autograd.Function):
r"""
Saves VRAM by smartly offloading to RAM.
"""
r"""Saves VRAM by smartly offloading to RAM."""
@staticmethod
@torch.cuda.amp.custom_fwd
......@@ -77,13 +75,14 @@ def get_unsloth_gradient_checkpointing_func() -> Callable:
def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable) -> Callable:
r"""
Only applies gradient checkpointing to trainable layers.
"""
r"""Only applies gradient checkpointing to trainable layers."""
@wraps(gradient_checkpointing_func, assigned=WRAPPER_ASSIGNMENTS + ("__self__",))
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
module: "torch.nn.Module" = func.__self__
if isinstance(func, partial):
module: torch.nn.Module = func.func.__self__
else:
module: torch.nn.Module = func.__self__
has_grad = False
if any(param.requires_grad for param in module.parameters()):
......@@ -103,11 +102,10 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
def _gradient_checkpointing_enable(
self: "PreTrainedModel",
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None,
gradient_checkpointing_kwargs: Optional[dict[str, Any]] = None,
use_unsloth_gc: bool = False,
) -> None:
r"""
Activates gradient checkpointing for the current model.
r"""Activates gradient checkpointing for the current model.
Modification of the original method to enable gradient checkpointing for block-wise optimizer.
"""
......@@ -134,17 +132,18 @@ def _gradient_checkpointing_enable(
def _fp32_forward_post_hook(
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
module: "torch.nn.Module", args: tuple["torch.Tensor"], output: "torch.Tensor"
) -> "torch.Tensor":
return output.to(torch.float32)
def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
r"""
Includes:
(1) cast the layernorm in fp32
(2) make output embedding layer require grads
(3) add the upcasting of the lm_head in fp32
r"""Prepare the model before training.
Include:
(1) cast the layernorm in fp32
(2) make output embedding layer require grads
(3) add the upcasting of the lm_head in fp32.
"""
if model_args.upcast_layernorm:
logger.info_rank0("Upcasting layernorm weights in float32.")
......
......@@ -38,9 +38,7 @@ def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int
def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
r"""
Resize token embeddings.
"""
r"""Resize token embeddings."""
if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore
......
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...extras import logging
logger = logging.get_logger(__name__)
if TYPE_CHECKING:
from transformers import PretrainedConfig
from ...hparams import ModelArguments
def configure_kv_cache(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable:
setattr(config, "use_cache", model_args.use_cache)
if hasattr(config, "text_config"):
setattr(config.text_config, "use_cache", model_args.use_cache)
if model_args.use_cache:
logger.info_rank0("KV cache is enabled for faster generation.")
else:
logger.info_rank0("KV cache is disabled.")
else:
setattr(config, "use_cache", False)
if hasattr(config, "text_config"):
setattr(config.text_config, "use_cache", False)
logger.info_rank0("KV cache is disabled during training.")
......@@ -27,39 +27,6 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
def apply_liger_kernel_to_qwen2_5_vl(
rope: bool = True,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
) -> None:
from liger_kernel.transformers import LigerCrossEntropyLoss, LigerRMSNorm, LigerSwiGLUMLP
from liger_kernel.transformers.model.qwen2_vl import lce_forward as qwen2_vl_lce_forward
from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
def get_dtype(self: "modeling_qwen2_5_vl.Qwen2_5_VisionTransformerPretrainedModel"):
return self.dtype
modeling_qwen2_5_vl.Qwen2_5_VisionTransformerPretrainedModel.get_dtype = get_dtype
if rope:
modeling_qwen2_5_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb
if rms_norm:
modeling_qwen2_5_vl.Qwen2RMSNorm = LigerRMSNorm
if cross_entropy:
modeling_qwen2_5_vl.CrossEntropyLoss = LigerCrossEntropyLoss
if fused_linear_cross_entropy:
modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_vl_lce_forward
if swiglu:
modeling_qwen2_5_vl.Qwen2MLP = LigerSwiGLUMLP
def apply_liger_kernel(
config: "PretrainedConfig",
model_args: "ModelArguments",
......@@ -74,6 +41,12 @@ def apply_liger_kernel(
from liger_kernel.transformers import apply_liger_kernel_to_gemma as apply_liger_kernel
elif model_type == "gemma2":
from liger_kernel.transformers import apply_liger_kernel_to_gemma2 as apply_liger_kernel
elif model_type == "gemma3":
from liger_kernel.transformers import apply_liger_kernel_to_gemma3 as apply_liger_kernel
elif model_type == "gemma3_text":
from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text as apply_liger_kernel
if model_type == "paligemma":
from liger_kernel.transformers import apply_liger_kernel_to_paligemma as apply_liger_kernel
elif model_type == "llama":
from liger_kernel.transformers import apply_liger_kernel_to_llama as apply_liger_kernel
elif model_type == "mistral":
......@@ -89,7 +62,7 @@ def apply_liger_kernel(
elif model_type == "qwen2_vl":
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel
elif model_type == "qwen2_5_vl":
apply_liger_kernel = apply_liger_kernel_to_qwen2_5_vl
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl as apply_liger_kernel
else:
logger.warning_rank0("Current model does not support liger kernel.")
return
......
# Copyright 2024 EleutherAI, HuggingFace Inc., Yukang Chen, and the LlamaFactory team.
# Copyright 2025 EleutherAI, HuggingFace Inc., Yukang Chen, and the LlamaFactory team.
#
# This code is based on the EleutherAI's GPT-NeoX and the HuggingFace's Transformers libraries.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
......@@ -18,7 +18,7 @@
# limitations under the License.
import math
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, Optional
import torch
import torch.nn as nn
......@@ -54,14 +54,14 @@ def llama_attention_forward(
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
cache_position: Optional["torch.LongTensor"] = None,
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
position_embeddings: Optional[tuple["torch.Tensor", "torch.Tensor"]] = None,
**kwargs,
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
) -> tuple["torch.Tensor", Optional["torch.Tensor"], Optional[tuple["torch.Tensor"]]]:
bsz, q_len, _ = hidden_states.size()
query_states: "torch.Tensor" = self.q_proj(hidden_states)
key_states: "torch.Tensor" = self.k_proj(hidden_states)
value_states: "torch.Tensor" = self.v_proj(hidden_states)
query_states: torch.Tensor = self.q_proj(hidden_states)
key_states: torch.Tensor = self.k_proj(hidden_states)
value_states: torch.Tensor = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
......@@ -139,17 +139,17 @@ def llama_flash_attention_2_forward(
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
cache_position: Optional["torch.LongTensor"] = None,
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
position_embeddings: Optional[tuple["torch.Tensor", "torch.Tensor"]] = None,
**kwargs,
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
) -> tuple["torch.Tensor", Optional["torch.Tensor"], Optional[tuple["torch.Tensor"]]]:
# LlamaFlashAttention2 attention does not support output_attentions
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states: "torch.Tensor" = self.q_proj(hidden_states)
key_states: "torch.Tensor" = self.k_proj(hidden_states)
value_states: "torch.Tensor" = self.v_proj(hidden_states)
query_states: torch.Tensor = self.q_proj(hidden_states)
key_states: torch.Tensor = self.k_proj(hidden_states)
value_states: torch.Tensor = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
......@@ -209,7 +209,7 @@ def llama_flash_attention_2_forward(
if is_transformers_version_greater_than("4.43.0"):
from transformers.modeling_flash_attention_utils import _flash_attention_forward
attn_output: "torch.Tensor" = _flash_attention_forward(
attn_output: torch.Tensor = _flash_attention_forward(
query_states,
key_states,
value_states,
......@@ -221,7 +221,7 @@ def llama_flash_attention_2_forward(
is_causal=self.is_causal,
)
else:
attn_output: "torch.Tensor" = self._flash_attention_forward(
attn_output: torch.Tensor = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
)
......@@ -254,9 +254,9 @@ def llama_sdpa_attention_forward(
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
cache_position: Optional["torch.LongTensor"] = None,
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
position_embeddings: Optional[tuple["torch.Tensor", "torch.Tensor"]] = None,
**kwargs,
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
) -> tuple["torch.Tensor", Optional["torch.Tensor"], Optional[tuple["torch.Tensor"]]]:
if output_attentions:
transformers_logger.warning_once(
"SDPA does not support `output_attentions=True`. Falling back to the vanilla attention"
......@@ -274,9 +274,9 @@ def llama_sdpa_attention_forward(
bsz, q_len, _ = hidden_states.size()
query_states: "torch.Tensor" = self.q_proj(hidden_states)
key_states: "torch.Tensor" = self.k_proj(hidden_states)
value_states: "torch.Tensor" = self.v_proj(hidden_states)
query_states: torch.Tensor = self.q_proj(hidden_states)
key_states: torch.Tensor = self.k_proj(hidden_states)
value_states: torch.Tensor = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING
from ...extras import logging
from .visual import COMPOSITE_MODELS
......@@ -25,10 +25,8 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]:
r"""
Finds all available modules to apply LoRA, GaLore or APOLLO.
"""
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> list[str]:
r"""Find all available modules to apply LoRA, GaLore or APOLLO."""
model_type = getattr(model.config, "model_type", None)
forbidden_modules = {"lm_head"}
if model_type == "chatglm":
......@@ -54,10 +52,8 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
return list(module_names)
def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], num_layer_trainable: int) -> List[str]:
r"""
Finds the modules in the expanded blocks to apply lora.
"""
def find_expanded_modules(model: "PreTrainedModel", target_modules: list[str], num_layer_trainable: int) -> list[str]:
r"""Find the modules in the expanded blocks to apply lora."""
num_layers = getattr(model.config, "num_hidden_layers", None)
if not num_layers:
raise ValueError("Model was not supported.")
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Sequence
from typing import TYPE_CHECKING
import torch
from transformers.integrations import is_deepspeed_zero3_enabled
......@@ -26,7 +26,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments
def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch.nn.Module"]) -> None:
def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: list["torch.nn.Module"]) -> None:
check_version("deepspeed>=0.13.0")
from deepspeed.utils import set_z3_leaf_modules # type: ignore
......@@ -34,9 +34,7 @@ def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch
def add_z3_leaf_module(model: "PreTrainedModel") -> None:
r"""
Sets module as a leaf module to skip partitioning in deepspeed zero3.
"""
r"""Set module as a leaf module to skip partitioning in deepspeed zero3."""
if not is_deepspeed_zero3_enabled():
return
......
# Copyright 2024 Musab Gultekin and the LlamaFactory team.
# Copyright 2025 Musab Gultekin and the LlamaFactory team.
#
# This code is based on the Musab Gultekin's functionary library.
# https://github.com/MeetKai/functionary/blob/main/functionary/train/packing/monkey_patch_packing.py
......@@ -37,7 +37,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from typing import TYPE_CHECKING, Tuple
from typing import TYPE_CHECKING
import torch
import torch.nn.functional as F
......@@ -59,8 +59,7 @@ logger = logging.get_logger(__name__)
def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
r"""
Gets the sequnce lengths in the current batch.
r"""Get the sequnce lengths in the current batch.
e.g.
```python
......@@ -76,7 +75,7 @@ def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
bsz = attention_mask.size(0)
dtype, device = attention_mask.dtype, attention_mask.device
max_num = torch.max(attention_mask).item()
counts: "torch.Tensor" = torch.zeros((bsz, max_num), dtype=dtype, device=device)
counts: torch.Tensor = torch.zeros((bsz, max_num), dtype=dtype, device=device)
for i in range(max_num):
counts[:, i] = torch.sum(attention_mask == (i + 1), dim=-1)
......@@ -85,9 +84,8 @@ def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
return seqlens
def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "torch.Tensor", int]:
r"""
Prepares the indices and seqlens for flash attn varlen function.
def get_unpad_data(attention_mask: "torch.Tensor") -> tuple["torch.Tensor", "torch.Tensor", int]:
r"""Prepare the indices and seqlens for flash attn varlen function.
Returns:
indices: indices of non-masked tokens from the flattened sequence.
......@@ -106,6 +104,7 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
[0, 2, 5, 6, 8, 11]
3
```
"""
seqlens_in_batch = get_seqlens_in_batch(attention_mask)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
......
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's Transformers and Optimum library.
# https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/utils/quantization_config.py
......@@ -19,7 +19,7 @@
import os
import random
from enum import Enum, unique
from typing import TYPE_CHECKING, Any, Dict, List
from typing import TYPE_CHECKING, Any
import torch
from datasets import load_dataset
......@@ -43,9 +43,7 @@ logger = logging.get_logger(__name__)
@unique
class QuantizationMethod(str, Enum):
r"""
Borrowed from `transformers.utils.quantization_config.QuantizationMethod`.
"""
r"""Borrowed from `transformers.utils.quantization_config.QuantizationMethod`."""
BITS_AND_BYTES = "bitsandbytes"
GPTQ = "gptq"
......@@ -56,10 +54,8 @@ class QuantizationMethod(str, Enum):
HQQ = "hqq"
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[Dict[str, Any]]:
r"""
Prepares the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization.
"""
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> list[dict[str, Any]]:
r"""Prepare the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization."""
if os.path.isfile(model_args.export_quantization_dataset):
data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None)
data_files = model_args.export_quantization_dataset
......@@ -84,7 +80,7 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod
raise ValueError("Cannot find satisfying example, considering decrease `export_quantization_maxlen`.")
sample_idx = random.randint(0, len(dataset) - 1)
sample: Dict[str, "torch.Tensor"] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
sample: dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
n_try += 1
if sample["input_ids"].size(1) > maxlen:
break # TODO: fix large maxlen
......@@ -101,11 +97,9 @@ def configure_quantization(
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
init_kwargs: Dict[str, Any],
init_kwargs: dict[str, Any],
) -> None:
r"""
Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer)
"""
r"""Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer)."""
if getattr(config, "quantization_config", None): # ptq
if model_args.quantization_bit is not None:
logger.warning_rank0("`quantization_bit` will not affect on the PTQ-quantized models.")
......@@ -113,7 +107,7 @@ def configure_quantization(
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "")
if quant_method == QuantizationMethod.GPTQ:
......
# Copyright 2024 LMSYS and the LlamaFactory team.
# Copyright 2025 LMSYS and the LlamaFactory team.
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# This code is inspired by the LMSYS's FastChat library.
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging
from ...extras.misc import get_current_device
......@@ -29,7 +29,7 @@ logger = logging.get_logger(__name__)
def _get_unsloth_kwargs(
config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments"
) -> Dict[str, Any]:
) -> dict[str, Any]:
return {
"model_name": model_name_or_path,
"max_seq_length": model_args.model_max_length or 4096,
......@@ -47,10 +47,8 @@ def _get_unsloth_kwargs(
def load_unsloth_pretrained_model(
config: "PretrainedConfig", model_args: "ModelArguments"
) -> Optional["PreTrainedModel"]:
r"""
Optionally loads pretrained model with unsloth. Used in training.
"""
from unsloth import FastLanguageModel
r"""Optionally load pretrained model with unsloth. Used in training."""
from unsloth import FastLanguageModel # type: ignore
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.model_name_or_path, model_args)
try:
......@@ -64,12 +62,10 @@ def load_unsloth_pretrained_model(
def get_unsloth_peft_model(
model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: Dict[str, Any]
model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: dict[str, Any]
) -> "PreTrainedModel":
r"""
Gets the peft model for the pretrained model with unsloth. Used in training.
"""
from unsloth import FastLanguageModel
r"""Get the peft model for the pretrained model with unsloth. Used in training."""
from unsloth import FastLanguageModel # type: ignore
unsloth_peft_kwargs = {
"model": model,
......@@ -82,10 +78,8 @@ def get_unsloth_peft_model(
def load_unsloth_peft_model(
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
) -> "PreTrainedModel":
r"""
Loads peft model with unsloth. Used in both training and inference.
"""
from unsloth import FastLanguageModel
r"""Load peft model with unsloth. Used in both training and inference."""
from unsloth import FastLanguageModel # type: ignore
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args)
try:
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING
import torch
from transformers.utils import cached_file
......@@ -30,9 +30,8 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
r"""
Loads value head parameters from Hugging Face Hub or local disk.
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> dict[str, torch.Tensor]:
r"""Load value head parameters from Hugging Face Hub or local disk.
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
"""
......
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's Transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/modeling_llava.py
......@@ -16,7 +16,7 @@
# limitations under the License.
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple
from typing import TYPE_CHECKING, Optional
import torch
import transformers
......@@ -27,7 +27,7 @@ from ...extras import logging
if TYPE_CHECKING:
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel, ProcessorMixin
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel
from ...hparams import FinetuningArguments, ModelArguments
......@@ -40,9 +40,9 @@ transformers_logger = transformers.utils.logging.get_logger(__name__)
class CompositeModel:
model_type: str
projector_key: str
vision_model_keys: List[str]
language_model_keys: List[str]
lora_conflict_keys: List[str]
vision_model_keys: list[str]
language_model_keys: list[str]
lora_conflict_keys: list[str]
def get_projector(self, module: "torch.nn.Module") -> "torch.nn.Module":
for key in self.projector_key.split("."):
......@@ -51,16 +51,26 @@ class CompositeModel:
return module
COMPOSITE_MODELS: Dict[str, "CompositeModel"] = {}
COMPOSITE_MODELS: dict[str, "CompositeModel"] = {}
def _register_composite_model(
model_type: str,
projector_key: Optional[str] = None,
vision_model_keys: Optional[List[str]] = None,
language_model_keys: Optional[List[str]] = None,
lora_conflict_keys: Optional[List[str]] = None,
vision_model_keys: Optional[list[str]] = None,
language_model_keys: Optional[list[str]] = None,
lora_conflict_keys: Optional[list[str]] = None,
):
r"""Register a new composite model.
Args:
model_type: model type
projector_key: multi_modal_projector
vision_model_keys: vision_tower
language_model_keys: language_model
lora_conflict_keys: None
"""
COMPOSITE_MODELS[model_type] = CompositeModel(
model_type=model_type,
projector_key=projector_key or "multi_modal_projector",
......@@ -116,12 +126,10 @@ class LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL):
def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
r"""
Casts projector output to half precision for fine-tuning quantized VLMs.
"""
r"""Cast projector output to half precision for fine-tuning quantized VLMs."""
def _mm_projector_forward_post_hook(
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
module: "torch.nn.Module", args: tuple["torch.Tensor"], output: "torch.Tensor"
) -> "torch.Tensor":
return output.to(model_args.compute_dtype)
......@@ -137,9 +145,7 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
def configure_visual_model(config: "PretrainedConfig") -> None:
r"""
Patches VLMs before loading them.
"""
r"""Patch VLMs before loading them."""
if getattr(config, "text_config", None) and not getattr(config, "hidden_size", None):
# required for ds zero3 and valuehead models
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
......@@ -149,10 +155,8 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "FinetuningArguments") -> Set[str]:
r"""
Freezes vision tower and language model for VLM full/freeze tuning.
"""
def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "FinetuningArguments") -> set[str]:
r"""Freeze vision tower and language model for VLM full/freeze tuning."""
model_type = getattr(config, "model_type", None)
forbidden_modules = set()
if model_type in COMPOSITE_MODELS:
......@@ -174,47 +178,10 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
return forbidden_modules
def get_image_seqlen(config: "PretrainedConfig") -> int:
r"""
Computes the number of special tokens per image.
"""
model_type = getattr(config, "model_type", None)
if model_type == "llava":
image_seqlen = (config.vision_config.image_size // config.vision_config.patch_size) ** 2
if getattr(config, "vision_feature_select_strategy", "default") == "full": # add [CLS] token
image_seqlen += 1
elif model_type == "paligemma":
image_seqlen = config.vision_config.num_image_tokens
else:
image_seqlen = -1
return image_seqlen
def get_patch_size(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
r"""
Computes the patch size of the vit.
"""
patch_size = getattr(config.vision_config, "patch_size", getattr(processor, "patch_size", -1))
return patch_size
def get_vision_feature_select_strategy(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
r"""
Get the vision_feature_select_strategy.
"""
vision_feature_select_strategy = getattr(
config, "vision_feature_select_strategy", getattr(processor, "vision_feature_select_strategy", "default")
)
return vision_feature_select_strategy
def patch_target_modules(
model: "PreTrainedModel", finetuning_args: "FinetuningArguments", target_modules: Sequence[str]
) -> List[str]:
r"""
Freezes vision tower for VLM LoRA tuning.
"""
model: "PreTrainedModel", finetuning_args: "FinetuningArguments", target_modules: list[str]
) -> list[str]:
r"""Freeze vision tower for VLM LoRA tuning."""
model_type = getattr(model.config, "model_type", None)
if model_type in COMPOSITE_MODELS:
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
......@@ -231,6 +198,17 @@ def patch_target_modules(
return target_modules
_register_composite_model(
model_type="gemma3",
)
_register_composite_model(
model_type="llama4",
vision_model_keys=["vision_model"],
)
_register_composite_model(
model_type="llava",
)
......@@ -285,6 +263,15 @@ _register_composite_model(
)
_register_composite_model(
model_type="qwen2_5_omni_thinker",
projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks", "audio_tower"],
language_model_keys=["model", "lm_head"],
lora_conflict_keys=["patch_embed"],
)
_register_composite_model(
model_type="qwen2_vl",
projector_key="visual.merger",
......
......@@ -13,7 +13,7 @@
# limitations under the License.
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict
from typing import TYPE_CHECKING, Any
import torch
from peft import PeftModel
......@@ -27,19 +27,14 @@ from ..extras.packages import is_transformers_version_greater_than
from .model_utils.attention import configure_attn_implementation, print_attn_implementation
from .model_utils.checkpointing import prepare_model_for_training
from .model_utils.embedding import resize_embedding_layer
from .model_utils.kv_cache import configure_kv_cache
from .model_utils.longlora import configure_longlora
from .model_utils.moe import add_z3_leaf_module, configure_moe
from .model_utils.packing import configure_packing
from .model_utils.quantization import configure_quantization
from .model_utils.rope import configure_rope
from .model_utils.valuehead import prepare_valuehead_model
from .model_utils.visual import (
autocast_projector_dtype,
configure_visual_model,
get_image_seqlen,
get_patch_size,
get_vision_feature_select_strategy,
)
from .model_utils.visual import autocast_projector_dtype, configure_visual_model
if TYPE_CHECKING:
......@@ -56,8 +51,8 @@ def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArgument
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
if model_args.model_max_length is not None and tokenizer.model_max_length != model_args.model_max_length:
tokenizer.model_max_length = model_args.model_max_length
if model_args.model_max_length is not None and tokenizer.model_max_length < model_args.model_max_length:
tokenizer.model_max_length = model_args.model_max_length # enlarge the tokenizer max length
if model_args.new_special_tokens is not None:
num_added_tokens = tokenizer.add_special_tokens(
......@@ -72,28 +67,25 @@ def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArgument
def patch_processor(
processor: "ProcessorMixin",
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
) -> None:
setattr(processor, "tokenizer", tokenizer)
if getattr(config, "vision_config", None) is not None: # visual models
setattr(processor, "image_seqlen", get_image_seqlen(config))
setattr(processor, "patch_size", get_patch_size(config, processor))
setattr(processor, "image_max_pixels", model_args.image_max_pixels)
setattr(processor, "image_min_pixels", model_args.image_min_pixels)
setattr(processor, "video_max_pixels", model_args.video_max_pixels)
setattr(processor, "video_min_pixels", model_args.video_min_pixels)
setattr(processor, "video_fps", model_args.video_fps)
setattr(processor, "video_maxlen", model_args.video_maxlen)
setattr(processor, "vision_feature_select_strategy", get_vision_feature_select_strategy(config, processor))
setattr(processor, "image_max_pixels", model_args.image_max_pixels)
setattr(processor, "image_min_pixels", model_args.image_min_pixels)
setattr(processor, "image_do_pan_and_scan", model_args.image_do_pan_and_scan)
setattr(processor, "video_max_pixels", model_args.video_max_pixels)
setattr(processor, "video_min_pixels", model_args.video_min_pixels)
setattr(processor, "video_fps", model_args.video_fps)
setattr(processor, "video_maxlen", model_args.video_maxlen)
setattr(processor, "audio_sampling_rate", model_args.audio_sampling_rate)
def patch_config(
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
init_kwargs: Dict[str, Any],
init_kwargs: dict[str, Any],
is_trainable: bool,
) -> None:
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
......@@ -112,19 +104,13 @@ def patch_config(
configure_moe(config, model_args, is_trainable)
configure_visual_model(config)
configure_packing(model_args, is_trainable)
if model_args.use_cache and not is_trainable:
setattr(config, "use_cache", True)
logger.info_rank0("Using KV cache for faster generation.")
configure_kv_cache(config, model_args, is_trainable)
if getattr(config, "model_type", None) == "qwen":
setattr(config, "use_flash_attn", model_args.flash_attn == "fa2")
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
setattr(config, dtype_name, model_args.compute_dtype == dtype)
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2":
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn
if getattr(config, "model_type", None) == "minicpmo":
setattr(config, "init_audio", True)
setattr(config, "init_tts", False)
......@@ -138,15 +124,13 @@ def patch_config(
# deepspeed zero3 is not compatible with low_cpu_mem_usage
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
# cast data type of the model if:
# 1. not deepspeed zero3 and not fsdp (keep zero3 or fsdp in float32)
# 2. quantization_bit is not None (qlora)
if (not is_deepspeed_zero3_enabled() and not is_fsdp_enabled()) or model_args.quantization_bit is not None:
# do not cast data type of the model deepspeed zero3 without qlora
if not (is_deepspeed_zero3_enabled() and model_args.quantization_bit is None):
init_kwargs["torch_dtype"] = model_args.compute_dtype
if init_kwargs["low_cpu_mem_usage"]: # device map requires low_cpu_mem_usage=True
if init_kwargs["low_cpu_mem_usage"] and not is_fsdp_enabled(): # fsdp does not need device map
if "device_map" not in init_kwargs and model_args.device_map:
init_kwargs["device_map"] = model_args.device_map
init_kwargs["device_map"] = model_args.device_map # device map requires low_cpu_mem_usage=True
if init_kwargs.get("device_map", None) == "auto":
init_kwargs["offload_folder"] = model_args.offload_folder
......
......@@ -19,7 +19,7 @@ import sys
import time
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Optional
import torch
import transformers
......@@ -56,7 +56,8 @@ logger = logging.get_logger(__name__)
def fix_valuehead_checkpoint(
model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
) -> None:
r"""
r"""Fix the valuehead checkpoint files.
The model is already unwrapped.
There are three cases:
......@@ -72,10 +73,10 @@ def fix_valuehead_checkpoint(
if safe_serialization:
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
state_dict: dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
else:
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
state_dict: dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
os.remove(path_to_checkpoint)
decoder_state_dict, v_head_state_dict = {}, {}
......@@ -98,9 +99,7 @@ def fix_valuehead_checkpoint(
class FixValueHeadModelCallback(TrainerCallback):
r"""
A callback for fixing the checkpoint for valuehead models.
"""
r"""A callback for fixing the checkpoint for valuehead models."""
@override
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
......@@ -112,9 +111,7 @@ class FixValueHeadModelCallback(TrainerCallback):
class SaveProcessorCallback(TrainerCallback):
r"""
A callback for saving the processor.
"""
r"""A callback for saving the processor."""
def __init__(self, processor: "ProcessorMixin") -> None:
self.processor = processor
......@@ -132,9 +129,7 @@ class SaveProcessorCallback(TrainerCallback):
class PissaConvertCallback(TrainerCallback):
r"""
A callback for converting the PiSSA adapter to a normal one.
"""
r"""A callback for converting the PiSSA adapter to a normal one."""
@override
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
......@@ -166,20 +161,17 @@ class PissaConvertCallback(TrainerCallback):
model.save_pretrained(pissa_backup_dir, safe_serialization=args.save_safetensors)
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
model.save_pretrained(
pissa_convert_dir, safe_serialization=args.save_safetensors, convert_pissa_to_lora=pissa_init_dir
) # TODO: use `path_initial_model_for_weight_conversion` (peft>=0.12.0)
pissa_convert_dir,
safe_serialization=args.save_safetensors,
path_initial_model_for_weight_conversion=pissa_init_dir,
)
model.load_adapter(pissa_backup_dir, "default", is_trainable=True)
model.set_adapter("default")
if "pissa_init" in model.peft_config.keys(): # backward compatibility (peft<0.12.0)
model.delete_adapter("pissa_init")
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
class LogCallback(TrainerCallback):
r"""
A callback for logging training and evaluation status.
"""
r"""A callback for logging training and evaluation status."""
def __init__(self) -> None:
# Progress
......@@ -188,7 +180,7 @@ class LogCallback(TrainerCallback):
self.max_steps = 0
self.elapsed_time = ""
self.remaining_time = ""
self.thread_pool: Optional["ThreadPoolExecutor"] = None
self.thread_pool: Optional[ThreadPoolExecutor] = None
# Status
self.aborted = False
self.do_train = False
......@@ -219,7 +211,7 @@ class LogCallback(TrainerCallback):
self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
self.remaining_time = str(timedelta(seconds=int(remaining_time)))
def _write_log(self, output_dir: str, logs: Dict[str, Any]) -> None:
def _write_log(self, output_dir: str, logs: dict[str, Any]) -> None:
with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f:
f.write(json.dumps(logs) + "\n")
......@@ -348,9 +340,7 @@ class LogCallback(TrainerCallback):
class ReporterCallback(TrainerCallback):
r"""
A callback for reporting training status to external logger.
"""
r"""A callback for reporting training status to external logger."""
def __init__(
self,
......
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/dpo_trainer.py
......@@ -19,7 +19,7 @@ import warnings
from collections import defaultdict
from contextlib import nullcontext
from types import MethodType
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
from typing import TYPE_CHECKING, Literal, Optional, Union
import torch
import torch.nn.functional as F
......@@ -128,16 +128,12 @@ class CustomDPOTrainer(DPOTrainer):
return super()._get_train_sampler()
@override
def get_batch_samples(self, epoch_iterator, num_batches):
r"""
Replaces the method of KTO Trainer with the one of the standard Trainer.
"""
return Trainer.get_batch_samples(self, epoch_iterator, num_batches)
def get_batch_samples(self, *args, **kwargs):
r"""Replace the method of DPO Trainer with the one of the standard Trainer."""
return Trainer.get_batch_samples(self, *args, **kwargs)
def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
r"""
Computes ORPO's odds ratio (OR) loss for batched log probabilities of the policy model.
"""
r"""Compute ORPO's odds ratio (OR) loss for batched log probabilities of the policy model."""
log_odds = (chosen_logps - rejected_logps) - (
torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps))
)
......@@ -147,9 +143,7 @@ class CustomDPOTrainer(DPOTrainer):
return orpo_loss
def simpo_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
r"""
Computes SimPO loss for batched log probabilities of the policy model.
"""
r"""Compute SimPO loss for batched log probabilities of the policy model."""
pi_logratios = chosen_logps - rejected_logps
gamma_logratios = self.simpo_gamma / self.beta
logits = pi_logratios - gamma_logratios
......@@ -162,10 +156,8 @@ class CustomDPOTrainer(DPOTrainer):
policy_rejected_logps: "torch.Tensor",
reference_chosen_logps: Optional["torch.Tensor"],
reference_rejected_logps: Optional["torch.Tensor"],
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""
Computes loss for preference learning.
"""
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""Compute loss for preference learning."""
if not self.finetuning_args.use_ref_model:
if self.loss_type == "orpo":
losses = self.odds_ratio_loss(policy_chosen_logps, policy_rejected_logps)
......@@ -185,17 +177,16 @@ class CustomDPOTrainer(DPOTrainer):
@override
def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""
Computes the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"]
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
Otherwise the average log probabilities.
"""
if self.finetuning_args.use_ref_model:
batch = nested_detach(batch, clone=True) # avoid error
all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
all_logits: torch.Tensor = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"])
if self.loss_type in ["ipo", "orpo", "simpo"]:
all_logps = all_logps / valid_length
......@@ -212,11 +203,9 @@ class CustomDPOTrainer(DPOTrainer):
@override
def compute_reference_log_probs(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]:
r"""
Computes log probabilities of the reference model.
"""
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"]
) -> tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]:
r"""Compute log probabilities of the reference model."""
if not self.finetuning_args.use_ref_model:
return None, None
......@@ -236,12 +225,10 @@ class CustomDPOTrainer(DPOTrainer):
def get_batch_loss_metrics(
self,
model: "PreTrainedModel",
batch: Dict[str, "torch.Tensor"],
batch: dict[str, "torch.Tensor"],
train_eval: Literal["train", "eval"] = "train",
) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]:
r"""
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
"""
) -> tuple["torch.Tensor", dict[str, "torch.Tensor"]]:
r"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
metrics = {}
(
policy_chosen_logps,
......@@ -279,18 +266,14 @@ class CustomDPOTrainer(DPOTrainer):
@override
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r"""
Subclass and override to accept extra kwargs.
"""
self, model: "PreTrainedModel", inputs: dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", tuple["torch.Tensor", list["torch.Tensor"]]]:
r"""Subclass and override to accept extra kwargs."""
return super().compute_loss(model, inputs, return_outputs)
@override
def log(self, logs: Dict[str, float], *args, **kwargs) -> None:
r"""
Log `logs` on the various objects watching training, including stored metrics.
"""
def log(self, logs: dict[str, float], *args, **kwargs) -> None:
r"""Log `logs` on the various objects watching training, including stored metrics."""
# logs either has "loss" or "eval_loss"
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
......
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/dpo.py
......@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Optional
from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX
......@@ -38,7 +38,7 @@ def run_dpo(
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None,
callbacks: Optional[list["TrainerCallback"]] = None,
):
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
......
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/kto_trainer.py
......@@ -19,7 +19,7 @@ import warnings
from collections import defaultdict
from contextlib import nullcontext
from types import MethodType
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
from typing import TYPE_CHECKING, Literal, Optional, Union
import torch
from transformers import Trainer
......@@ -120,28 +120,22 @@ class CustomKTOTrainer(KTOTrainer):
@override
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
r"""
Replaces the sequential sampler of KTO Trainer created by trl with the random sampler.
"""
r"""Replace the sequential sampler of KTO Trainer created by trl with the random sampler."""
if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset)
return Trainer._get_train_sampler(self)
@override
def get_batch_samples(self, epoch_iterator, num_batches):
r"""
Replaces the method of KTO Trainer with the one of the standard Trainer.
"""
return Trainer.get_batch_samples(self, epoch_iterator, num_batches)
def get_batch_samples(self, *args, **kwargs):
r"""Replace the method of KTO Trainer with the one of the standard Trainer."""
return Trainer.get_batch_samples(self, *args, **kwargs)
@override
def forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""
Runs forward pass and computes the log probabilities.
"""
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""Run forward pass and computes the log probabilities."""
batch = nested_detach(batch, clone=True) # avoid error
model_inputs = {
"input_ids": batch[f"{prefix}input_ids"],
......@@ -171,8 +165,8 @@ class CustomKTOTrainer(KTOTrainer):
@override
def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"]
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
target_logits, target_logps, target_logps_avg = self.forward(model, batch)
with torch.no_grad():
_, kl_logps, _ = self.forward(model, batch, prefix="kl_")
......@@ -189,11 +183,9 @@ class CustomKTOTrainer(KTOTrainer):
@override
def compute_reference_log_probs(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""
Computes log probabilities of the reference model.
"""
self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"]
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""Compute log probabilities of the reference model."""
if self.ref_model is None:
ref_model = model
ref_context = self.accelerator.unwrap_model(model).disable_adapter()
......@@ -212,11 +204,9 @@ class CustomKTOTrainer(KTOTrainer):
def get_batch_loss_metrics(
self,
model: "PreTrainedModel",
batch: Dict[str, "torch.Tensor"],
) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]:
r"""
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
"""
batch: dict[str, "torch.Tensor"],
) -> tuple["torch.Tensor", dict[str, "torch.Tensor"]]:
r"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
metrics = {}
(
policy_chosen_logps,
......@@ -262,18 +252,14 @@ class CustomKTOTrainer(KTOTrainer):
@override
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r"""
Subclass and override to accept extra kwargs.
"""
self, model: "PreTrainedModel", inputs: dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", tuple["torch.Tensor", list["torch.Tensor"]]]:
r"""Subclass and override to accept extra kwargs."""
return super().compute_loss(model, inputs, return_outputs)
@override
def log(self, logs: Dict[str, float], *args, **kwargs) -> None:
r"""
Log `logs` on the various objects watching training, including stored metrics.
"""
def log(self, logs: dict[str, float], *args, **kwargs) -> None:
r"""Log `logs` on the various objects watching training, including stored metrics."""
# logs either has "loss" or "eval_loss"
train_eval = "train" if "loss" in logs else "eval"
prefix = "eval_" if train_eval == "eval" else ""
......@@ -291,7 +277,7 @@ class CustomKTOTrainer(KTOTrainer):
metric_list = torch.tensor(metric_list, dtype=torch.float).to(self.accelerator.device)
metric_list = self.accelerator.reduce(metric_list, "sum").tolist()
metric_dict: Dict[str, float] = dict(zip(key_list, metric_list))
metric_dict: dict[str, float] = dict(zip(key_list, metric_list))
for split in ["chosen", "rejected"]: # accumulate average metrics from sums and lengths
if f"count/{split}" in metric_dict:
for key in ("rewards", "logps", "logits"):
......
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/kto.py
......@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Optional
from ...data import KTODataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX
......@@ -37,7 +37,7 @@ def run_kto(
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None,
callbacks: Optional[list["TrainerCallback"]] = None,
):
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
......
......@@ -14,7 +14,7 @@
import json
from contextlib import nullcontext
from typing import TYPE_CHECKING, Dict, List, Literal, Optional
from typing import TYPE_CHECKING, Literal, Optional
import torch
from transformers.integrations import is_deepspeed_zero3_enabled
......@@ -31,10 +31,8 @@ if TYPE_CHECKING:
from trl import AutoModelForCausalLMWithValueHead
def get_rewards_from_server(server_url: str, messages: List[str]) -> List["torch.Tensor"]:
r"""
Gets reward scores from the API server.
"""
def get_rewards_from_server(server_url: str, messages: list[str]) -> list["torch.Tensor"]:
r"""Get reward scores from the API server."""
headers = {"Content-Type": "application/json"}
payload = {"model": "model", "messages": messages}
response = requests.post(server_url, json=payload, headers=headers)
......@@ -43,9 +41,7 @@ def get_rewards_from_server(server_url: str, messages: List[str]) -> List["torch
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
r"""
Replaces the default/reward modules in the model. The model is already unwrapped.
"""
r"""Replace the default/reward modules in the model. The model is already unwrapped."""
v_head_layer = model.v_head.summary
if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore
......@@ -66,10 +62,8 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d
v_head_layer.bias.data = model.get_buffer(f"{target}_head_bias").detach().clone().to(device)
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, "torch.Tensor"]:
r"""
Dumps the layernorm parameters in the model. The model is already unwrapped (and gathered).
"""
def dump_layernorm(model: "PreTrainedModel") -> dict[str, "torch.Tensor"]:
r"""Dump the layernorm parameters in the model. The model is already unwrapped (and gathered)."""
layer_norm_params = {}
for name, param in model.named_parameters():
if param.data.dtype == torch.float32:
......@@ -79,10 +73,8 @@ def dump_layernorm(model: "PreTrainedModel") -> Dict[str, "torch.Tensor"]:
return layer_norm_params
def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
r"""
Restores the layernorm parameters in the model. The model is already unwrapped (and gathered).
"""
def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[dict[str, "torch.Tensor"]] = None) -> None:
r"""Restore the layernorm parameters in the model. The model is already unwrapped (and gathered)."""
for name, param in model.named_parameters():
if name in layernorm_params:
param.data = layernorm_params[name]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment