Commit 7ea81099 authored by chenych's avatar chenych
Browse files

update llama4

parent 84987715
# Copyright 2024 HuggingFace Inc., Daniel Han-Chen & the Unsloth team and the LlamaFactory team. # Copyright 2025 HuggingFace Inc., Daniel Han-Chen & the Unsloth team and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's Transformers and PEFT library, # This code is inspired by the HuggingFace's Transformers and PEFT library,
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
import inspect import inspect
from functools import WRAPPER_ASSIGNMENTS, partial, wraps from functools import WRAPPER_ASSIGNMENTS, partial, wraps
from types import MethodType from types import MethodType
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import torch import torch
...@@ -40,9 +40,7 @@ logger = logging.get_logger(__name__) ...@@ -40,9 +40,7 @@ logger = logging.get_logger(__name__)
def get_unsloth_gradient_checkpointing_func() -> Callable: def get_unsloth_gradient_checkpointing_func() -> Callable:
class UnslothGradientCheckpointing(torch.autograd.Function): class UnslothGradientCheckpointing(torch.autograd.Function):
r""" r"""Saves VRAM by smartly offloading to RAM."""
Saves VRAM by smartly offloading to RAM.
"""
@staticmethod @staticmethod
@torch.cuda.amp.custom_fwd @torch.cuda.amp.custom_fwd
...@@ -77,13 +75,14 @@ def get_unsloth_gradient_checkpointing_func() -> Callable: ...@@ -77,13 +75,14 @@ def get_unsloth_gradient_checkpointing_func() -> Callable:
def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable) -> Callable: def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable) -> Callable:
r""" r"""Only applies gradient checkpointing to trainable layers."""
Only applies gradient checkpointing to trainable layers.
"""
@wraps(gradient_checkpointing_func, assigned=WRAPPER_ASSIGNMENTS + ("__self__",)) @wraps(gradient_checkpointing_func, assigned=WRAPPER_ASSIGNMENTS + ("__self__",))
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs): def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
module: "torch.nn.Module" = func.__self__ if isinstance(func, partial):
module: torch.nn.Module = func.func.__self__
else:
module: torch.nn.Module = func.__self__
has_grad = False has_grad = False
if any(param.requires_grad for param in module.parameters()): if any(param.requires_grad for param in module.parameters()):
...@@ -103,11 +102,10 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable ...@@ -103,11 +102,10 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
def _gradient_checkpointing_enable( def _gradient_checkpointing_enable(
self: "PreTrainedModel", self: "PreTrainedModel",
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None, gradient_checkpointing_kwargs: Optional[dict[str, Any]] = None,
use_unsloth_gc: bool = False, use_unsloth_gc: bool = False,
) -> None: ) -> None:
r""" r"""Activates gradient checkpointing for the current model.
Activates gradient checkpointing for the current model.
Modification of the original method to enable gradient checkpointing for block-wise optimizer. Modification of the original method to enable gradient checkpointing for block-wise optimizer.
""" """
...@@ -134,17 +132,18 @@ def _gradient_checkpointing_enable( ...@@ -134,17 +132,18 @@ def _gradient_checkpointing_enable(
def _fp32_forward_post_hook( def _fp32_forward_post_hook(
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor" module: "torch.nn.Module", args: tuple["torch.Tensor"], output: "torch.Tensor"
) -> "torch.Tensor": ) -> "torch.Tensor":
return output.to(torch.float32) return output.to(torch.float32)
def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArguments") -> None: def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
r""" r"""Prepare the model before training.
Includes:
(1) cast the layernorm in fp32 Include:
(2) make output embedding layer require grads (1) cast the layernorm in fp32
(3) add the upcasting of the lm_head in fp32 (2) make output embedding layer require grads
(3) add the upcasting of the lm_head in fp32.
""" """
if model_args.upcast_layernorm: if model_args.upcast_layernorm:
logger.info_rank0("Upcasting layernorm weights in float32.") logger.info_rank0("Upcasting layernorm weights in float32.")
......
...@@ -38,9 +38,7 @@ def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int ...@@ -38,9 +38,7 @@ def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int
def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None: def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
r""" r"""Resize token embeddings."""
Resize token embeddings.
"""
if is_deepspeed_zero3_enabled(): if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore import deepspeed # type: ignore
......
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...extras import logging
logger = logging.get_logger(__name__)
if TYPE_CHECKING:
from transformers import PretrainedConfig
from ...hparams import ModelArguments
def configure_kv_cache(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable:
setattr(config, "use_cache", model_args.use_cache)
if hasattr(config, "text_config"):
setattr(config.text_config, "use_cache", model_args.use_cache)
if model_args.use_cache:
logger.info_rank0("KV cache is enabled for faster generation.")
else:
logger.info_rank0("KV cache is disabled.")
else:
setattr(config, "use_cache", False)
if hasattr(config, "text_config"):
setattr(config.text_config, "use_cache", False)
logger.info_rank0("KV cache is disabled during training.")
...@@ -27,39 +27,6 @@ if TYPE_CHECKING: ...@@ -27,39 +27,6 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def apply_liger_kernel_to_qwen2_5_vl(
rope: bool = True,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
) -> None:
from liger_kernel.transformers import LigerCrossEntropyLoss, LigerRMSNorm, LigerSwiGLUMLP
from liger_kernel.transformers.model.qwen2_vl import lce_forward as qwen2_vl_lce_forward
from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
def get_dtype(self: "modeling_qwen2_5_vl.Qwen2_5_VisionTransformerPretrainedModel"):
return self.dtype
modeling_qwen2_5_vl.Qwen2_5_VisionTransformerPretrainedModel.get_dtype = get_dtype
if rope:
modeling_qwen2_5_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb
if rms_norm:
modeling_qwen2_5_vl.Qwen2RMSNorm = LigerRMSNorm
if cross_entropy:
modeling_qwen2_5_vl.CrossEntropyLoss = LigerCrossEntropyLoss
if fused_linear_cross_entropy:
modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_vl_lce_forward
if swiglu:
modeling_qwen2_5_vl.Qwen2MLP = LigerSwiGLUMLP
def apply_liger_kernel( def apply_liger_kernel(
config: "PretrainedConfig", config: "PretrainedConfig",
model_args: "ModelArguments", model_args: "ModelArguments",
...@@ -74,6 +41,12 @@ def apply_liger_kernel( ...@@ -74,6 +41,12 @@ def apply_liger_kernel(
from liger_kernel.transformers import apply_liger_kernel_to_gemma as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_gemma as apply_liger_kernel
elif model_type == "gemma2": elif model_type == "gemma2":
from liger_kernel.transformers import apply_liger_kernel_to_gemma2 as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_gemma2 as apply_liger_kernel
elif model_type == "gemma3":
from liger_kernel.transformers import apply_liger_kernel_to_gemma3 as apply_liger_kernel
elif model_type == "gemma3_text":
from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text as apply_liger_kernel
if model_type == "paligemma":
from liger_kernel.transformers import apply_liger_kernel_to_paligemma as apply_liger_kernel
elif model_type == "llama": elif model_type == "llama":
from liger_kernel.transformers import apply_liger_kernel_to_llama as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_llama as apply_liger_kernel
elif model_type == "mistral": elif model_type == "mistral":
...@@ -89,7 +62,7 @@ def apply_liger_kernel( ...@@ -89,7 +62,7 @@ def apply_liger_kernel(
elif model_type == "qwen2_vl": elif model_type == "qwen2_vl":
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel
elif model_type == "qwen2_5_vl": elif model_type == "qwen2_5_vl":
apply_liger_kernel = apply_liger_kernel_to_qwen2_5_vl from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl as apply_liger_kernel
else: else:
logger.warning_rank0("Current model does not support liger kernel.") logger.warning_rank0("Current model does not support liger kernel.")
return return
......
# Copyright 2024 EleutherAI, HuggingFace Inc., Yukang Chen, and the LlamaFactory team. # Copyright 2025 EleutherAI, HuggingFace Inc., Yukang Chen, and the LlamaFactory team.
# #
# This code is based on the EleutherAI's GPT-NeoX and the HuggingFace's Transformers libraries. # This code is based on the EleutherAI's GPT-NeoX and the HuggingFace's Transformers libraries.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
# limitations under the License. # limitations under the License.
import math import math
from typing import TYPE_CHECKING, Optional, Tuple from typing import TYPE_CHECKING, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -54,14 +54,14 @@ def llama_attention_forward( ...@@ -54,14 +54,14 @@ def llama_attention_forward(
past_key_value: Optional["Cache"] = None, past_key_value: Optional["Cache"] = None,
output_attentions: bool = False, output_attentions: bool = False,
cache_position: Optional["torch.LongTensor"] = None, cache_position: Optional["torch.LongTensor"] = None,
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None, position_embeddings: Optional[tuple["torch.Tensor", "torch.Tensor"]] = None,
**kwargs, **kwargs,
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]: ) -> tuple["torch.Tensor", Optional["torch.Tensor"], Optional[tuple["torch.Tensor"]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
query_states: "torch.Tensor" = self.q_proj(hidden_states) query_states: torch.Tensor = self.q_proj(hidden_states)
key_states: "torch.Tensor" = self.k_proj(hidden_states) key_states: torch.Tensor = self.k_proj(hidden_states)
value_states: "torch.Tensor" = self.v_proj(hidden_states) value_states: torch.Tensor = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
...@@ -139,17 +139,17 @@ def llama_flash_attention_2_forward( ...@@ -139,17 +139,17 @@ def llama_flash_attention_2_forward(
past_key_value: Optional["Cache"] = None, past_key_value: Optional["Cache"] = None,
output_attentions: bool = False, output_attentions: bool = False,
cache_position: Optional["torch.LongTensor"] = None, cache_position: Optional["torch.LongTensor"] = None,
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None, position_embeddings: Optional[tuple["torch.Tensor", "torch.Tensor"]] = None,
**kwargs, **kwargs,
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]: ) -> tuple["torch.Tensor", Optional["torch.Tensor"], Optional[tuple["torch.Tensor"]]]:
# LlamaFlashAttention2 attention does not support output_attentions # LlamaFlashAttention2 attention does not support output_attentions
output_attentions = False output_attentions = False
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
query_states: "torch.Tensor" = self.q_proj(hidden_states) query_states: torch.Tensor = self.q_proj(hidden_states)
key_states: "torch.Tensor" = self.k_proj(hidden_states) key_states: torch.Tensor = self.k_proj(hidden_states)
value_states: "torch.Tensor" = self.v_proj(hidden_states) value_states: torch.Tensor = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
...@@ -209,7 +209,7 @@ def llama_flash_attention_2_forward( ...@@ -209,7 +209,7 @@ def llama_flash_attention_2_forward(
if is_transformers_version_greater_than("4.43.0"): if is_transformers_version_greater_than("4.43.0"):
from transformers.modeling_flash_attention_utils import _flash_attention_forward from transformers.modeling_flash_attention_utils import _flash_attention_forward
attn_output: "torch.Tensor" = _flash_attention_forward( attn_output: torch.Tensor = _flash_attention_forward(
query_states, query_states,
key_states, key_states,
value_states, value_states,
...@@ -221,7 +221,7 @@ def llama_flash_attention_2_forward( ...@@ -221,7 +221,7 @@ def llama_flash_attention_2_forward(
is_causal=self.is_causal, is_causal=self.is_causal,
) )
else: else:
attn_output: "torch.Tensor" = self._flash_attention_forward( attn_output: torch.Tensor = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
) )
...@@ -254,9 +254,9 @@ def llama_sdpa_attention_forward( ...@@ -254,9 +254,9 @@ def llama_sdpa_attention_forward(
past_key_value: Optional["Cache"] = None, past_key_value: Optional["Cache"] = None,
output_attentions: bool = False, output_attentions: bool = False,
cache_position: Optional["torch.LongTensor"] = None, cache_position: Optional["torch.LongTensor"] = None,
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None, position_embeddings: Optional[tuple["torch.Tensor", "torch.Tensor"]] = None,
**kwargs, **kwargs,
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]: ) -> tuple["torch.Tensor", Optional["torch.Tensor"], Optional[tuple["torch.Tensor"]]]:
if output_attentions: if output_attentions:
transformers_logger.warning_once( transformers_logger.warning_once(
"SDPA does not support `output_attentions=True`. Falling back to the vanilla attention" "SDPA does not support `output_attentions=True`. Falling back to the vanilla attention"
...@@ -274,9 +274,9 @@ def llama_sdpa_attention_forward( ...@@ -274,9 +274,9 @@ def llama_sdpa_attention_forward(
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
query_states: "torch.Tensor" = self.q_proj(hidden_states) query_states: torch.Tensor = self.q_proj(hidden_states)
key_states: "torch.Tensor" = self.k_proj(hidden_states) key_states: torch.Tensor = self.k_proj(hidden_states)
value_states: "torch.Tensor" = self.v_proj(hidden_states) value_states: torch.Tensor = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING
from ...extras import logging from ...extras import logging
from .visual import COMPOSITE_MODELS from .visual import COMPOSITE_MODELS
...@@ -25,10 +25,8 @@ if TYPE_CHECKING: ...@@ -25,10 +25,8 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]: def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> list[str]:
r""" r"""Find all available modules to apply LoRA, GaLore or APOLLO."""
Finds all available modules to apply LoRA, GaLore or APOLLO.
"""
model_type = getattr(model.config, "model_type", None) model_type = getattr(model.config, "model_type", None)
forbidden_modules = {"lm_head"} forbidden_modules = {"lm_head"}
if model_type == "chatglm": if model_type == "chatglm":
...@@ -54,10 +52,8 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) ...@@ -54,10 +52,8 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
return list(module_names) return list(module_names)
def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], num_layer_trainable: int) -> List[str]: def find_expanded_modules(model: "PreTrainedModel", target_modules: list[str], num_layer_trainable: int) -> list[str]:
r""" r"""Find the modules in the expanded blocks to apply lora."""
Finds the modules in the expanded blocks to apply lora.
"""
num_layers = getattr(model.config, "num_hidden_layers", None) num_layers = getattr(model.config, "num_hidden_layers", None)
if not num_layers: if not num_layers:
raise ValueError("Model was not supported.") raise ValueError("Model was not supported.")
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, Sequence from typing import TYPE_CHECKING
import torch import torch
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
...@@ -26,7 +26,7 @@ if TYPE_CHECKING: ...@@ -26,7 +26,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments from ...hparams import ModelArguments
def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch.nn.Module"]) -> None: def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: list["torch.nn.Module"]) -> None:
check_version("deepspeed>=0.13.0") check_version("deepspeed>=0.13.0")
from deepspeed.utils import set_z3_leaf_modules # type: ignore from deepspeed.utils import set_z3_leaf_modules # type: ignore
...@@ -34,9 +34,7 @@ def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch ...@@ -34,9 +34,7 @@ def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch
def add_z3_leaf_module(model: "PreTrainedModel") -> None: def add_z3_leaf_module(model: "PreTrainedModel") -> None:
r""" r"""Set module as a leaf module to skip partitioning in deepspeed zero3."""
Sets module as a leaf module to skip partitioning in deepspeed zero3.
"""
if not is_deepspeed_zero3_enabled(): if not is_deepspeed_zero3_enabled():
return return
......
# Copyright 2024 Musab Gultekin and the LlamaFactory team. # Copyright 2025 Musab Gultekin and the LlamaFactory team.
# #
# This code is based on the Musab Gultekin's functionary library. # This code is based on the Musab Gultekin's functionary library.
# https://github.com/MeetKai/functionary/blob/main/functionary/train/packing/monkey_patch_packing.py # https://github.com/MeetKai/functionary/blob/main/functionary/train/packing/monkey_patch_packing.py
...@@ -37,7 +37,7 @@ ...@@ -37,7 +37,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE. # SOFTWARE.
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -59,8 +59,7 @@ logger = logging.get_logger(__name__) ...@@ -59,8 +59,7 @@ logger = logging.get_logger(__name__)
def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor": def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
r""" r"""Get the sequnce lengths in the current batch.
Gets the sequnce lengths in the current batch.
e.g. e.g.
```python ```python
...@@ -76,7 +75,7 @@ def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor": ...@@ -76,7 +75,7 @@ def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
bsz = attention_mask.size(0) bsz = attention_mask.size(0)
dtype, device = attention_mask.dtype, attention_mask.device dtype, device = attention_mask.dtype, attention_mask.device
max_num = torch.max(attention_mask).item() max_num = torch.max(attention_mask).item()
counts: "torch.Tensor" = torch.zeros((bsz, max_num), dtype=dtype, device=device) counts: torch.Tensor = torch.zeros((bsz, max_num), dtype=dtype, device=device)
for i in range(max_num): for i in range(max_num):
counts[:, i] = torch.sum(attention_mask == (i + 1), dim=-1) counts[:, i] = torch.sum(attention_mask == (i + 1), dim=-1)
...@@ -85,9 +84,8 @@ def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor": ...@@ -85,9 +84,8 @@ def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
return seqlens return seqlens
def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "torch.Tensor", int]: def get_unpad_data(attention_mask: "torch.Tensor") -> tuple["torch.Tensor", "torch.Tensor", int]:
r""" r"""Prepare the indices and seqlens for flash attn varlen function.
Prepares the indices and seqlens for flash attn varlen function.
Returns: Returns:
indices: indices of non-masked tokens from the flattened sequence. indices: indices of non-masked tokens from the flattened sequence.
...@@ -106,6 +104,7 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor ...@@ -106,6 +104,7 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
[0, 2, 5, 6, 8, 11] [0, 2, 5, 6, 8, 11]
3 3
``` ```
""" """
seqlens_in_batch = get_seqlens_in_batch(attention_mask) seqlens_in_batch = get_seqlens_in_batch(attention_mask)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
......
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's Transformers and Optimum library. # This code is inspired by the HuggingFace's Transformers and Optimum library.
# https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/utils/quantization_config.py # https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/utils/quantization_config.py
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
import os import os
import random import random
from enum import Enum, unique from enum import Enum, unique
from typing import TYPE_CHECKING, Any, Dict, List from typing import TYPE_CHECKING, Any
import torch import torch
from datasets import load_dataset from datasets import load_dataset
...@@ -43,9 +43,7 @@ logger = logging.get_logger(__name__) ...@@ -43,9 +43,7 @@ logger = logging.get_logger(__name__)
@unique @unique
class QuantizationMethod(str, Enum): class QuantizationMethod(str, Enum):
r""" r"""Borrowed from `transformers.utils.quantization_config.QuantizationMethod`."""
Borrowed from `transformers.utils.quantization_config.QuantizationMethod`.
"""
BITS_AND_BYTES = "bitsandbytes" BITS_AND_BYTES = "bitsandbytes"
GPTQ = "gptq" GPTQ = "gptq"
...@@ -56,10 +54,8 @@ class QuantizationMethod(str, Enum): ...@@ -56,10 +54,8 @@ class QuantizationMethod(str, Enum):
HQQ = "hqq" HQQ = "hqq"
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[Dict[str, Any]]: def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> list[dict[str, Any]]:
r""" r"""Prepare the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization."""
Prepares the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization.
"""
if os.path.isfile(model_args.export_quantization_dataset): if os.path.isfile(model_args.export_quantization_dataset):
data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None) data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None)
data_files = model_args.export_quantization_dataset data_files = model_args.export_quantization_dataset
...@@ -84,7 +80,7 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod ...@@ -84,7 +80,7 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod
raise ValueError("Cannot find satisfying example, considering decrease `export_quantization_maxlen`.") raise ValueError("Cannot find satisfying example, considering decrease `export_quantization_maxlen`.")
sample_idx = random.randint(0, len(dataset) - 1) sample_idx = random.randint(0, len(dataset) - 1)
sample: Dict[str, "torch.Tensor"] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt") sample: dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
n_try += 1 n_try += 1
if sample["input_ids"].size(1) > maxlen: if sample["input_ids"].size(1) > maxlen:
break # TODO: fix large maxlen break # TODO: fix large maxlen
...@@ -101,11 +97,9 @@ def configure_quantization( ...@@ -101,11 +97,9 @@ def configure_quantization(
config: "PretrainedConfig", config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments", model_args: "ModelArguments",
init_kwargs: Dict[str, Any], init_kwargs: dict[str, Any],
) -> None: ) -> None:
r""" r"""Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer)."""
Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer)
"""
if getattr(config, "quantization_config", None): # ptq if getattr(config, "quantization_config", None): # ptq
if model_args.quantization_bit is not None: if model_args.quantization_bit is not None:
logger.warning_rank0("`quantization_bit` will not affect on the PTQ-quantized models.") logger.warning_rank0("`quantization_bit` will not affect on the PTQ-quantized models.")
...@@ -113,7 +107,7 @@ def configure_quantization( ...@@ -113,7 +107,7 @@ def configure_quantization(
if is_deepspeed_zero3_enabled() or is_fsdp_enabled(): if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.") raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None) quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "") quant_method = quantization_config.get("quant_method", "")
if quant_method == QuantizationMethod.GPTQ: if quant_method == QuantizationMethod.GPTQ:
......
# Copyright 2024 LMSYS and the LlamaFactory team. # Copyright 2025 LMSYS and the LlamaFactory team.
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
# #
# This code is inspired by the LMSYS's FastChat library. # This code is inspired by the LMSYS's FastChat library.
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, Optional from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging from ...extras import logging
from ...extras.misc import get_current_device from ...extras.misc import get_current_device
...@@ -29,7 +29,7 @@ logger = logging.get_logger(__name__) ...@@ -29,7 +29,7 @@ logger = logging.get_logger(__name__)
def _get_unsloth_kwargs( def _get_unsloth_kwargs(
config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments" config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments"
) -> Dict[str, Any]: ) -> dict[str, Any]:
return { return {
"model_name": model_name_or_path, "model_name": model_name_or_path,
"max_seq_length": model_args.model_max_length or 4096, "max_seq_length": model_args.model_max_length or 4096,
...@@ -47,10 +47,8 @@ def _get_unsloth_kwargs( ...@@ -47,10 +47,8 @@ def _get_unsloth_kwargs(
def load_unsloth_pretrained_model( def load_unsloth_pretrained_model(
config: "PretrainedConfig", model_args: "ModelArguments" config: "PretrainedConfig", model_args: "ModelArguments"
) -> Optional["PreTrainedModel"]: ) -> Optional["PreTrainedModel"]:
r""" r"""Optionally load pretrained model with unsloth. Used in training."""
Optionally loads pretrained model with unsloth. Used in training. from unsloth import FastLanguageModel # type: ignore
"""
from unsloth import FastLanguageModel
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.model_name_or_path, model_args) unsloth_kwargs = _get_unsloth_kwargs(config, model_args.model_name_or_path, model_args)
try: try:
...@@ -64,12 +62,10 @@ def load_unsloth_pretrained_model( ...@@ -64,12 +62,10 @@ def load_unsloth_pretrained_model(
def get_unsloth_peft_model( def get_unsloth_peft_model(
model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: Dict[str, Any] model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: dict[str, Any]
) -> "PreTrainedModel": ) -> "PreTrainedModel":
r""" r"""Get the peft model for the pretrained model with unsloth. Used in training."""
Gets the peft model for the pretrained model with unsloth. Used in training. from unsloth import FastLanguageModel # type: ignore
"""
from unsloth import FastLanguageModel
unsloth_peft_kwargs = { unsloth_peft_kwargs = {
"model": model, "model": model,
...@@ -82,10 +78,8 @@ def get_unsloth_peft_model( ...@@ -82,10 +78,8 @@ def get_unsloth_peft_model(
def load_unsloth_peft_model( def load_unsloth_peft_model(
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
) -> "PreTrainedModel": ) -> "PreTrainedModel":
r""" r"""Load peft model with unsloth. Used in both training and inference."""
Loads peft model with unsloth. Used in both training and inference. from unsloth import FastLanguageModel # type: ignore
"""
from unsloth import FastLanguageModel
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args) unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args)
try: try:
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING
import torch import torch
from transformers.utils import cached_file from transformers.utils import cached_file
...@@ -30,9 +30,8 @@ if TYPE_CHECKING: ...@@ -30,9 +30,8 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]: def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> dict[str, torch.Tensor]:
r""" r"""Load value head parameters from Hugging Face Hub or local disk.
Loads value head parameters from Hugging Face Hub or local disk.
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`. Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
""" """
......
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's Transformers library. # This code is inspired by the HuggingFace's Transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/modeling_llava.py # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/modeling_llava.py
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple from typing import TYPE_CHECKING, Optional
import torch import torch
import transformers import transformers
...@@ -27,7 +27,7 @@ from ...extras import logging ...@@ -27,7 +27,7 @@ from ...extras import logging
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel, ProcessorMixin from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel
from ...hparams import FinetuningArguments, ModelArguments from ...hparams import FinetuningArguments, ModelArguments
...@@ -40,9 +40,9 @@ transformers_logger = transformers.utils.logging.get_logger(__name__) ...@@ -40,9 +40,9 @@ transformers_logger = transformers.utils.logging.get_logger(__name__)
class CompositeModel: class CompositeModel:
model_type: str model_type: str
projector_key: str projector_key: str
vision_model_keys: List[str] vision_model_keys: list[str]
language_model_keys: List[str] language_model_keys: list[str]
lora_conflict_keys: List[str] lora_conflict_keys: list[str]
def get_projector(self, module: "torch.nn.Module") -> "torch.nn.Module": def get_projector(self, module: "torch.nn.Module") -> "torch.nn.Module":
for key in self.projector_key.split("."): for key in self.projector_key.split("."):
...@@ -51,16 +51,26 @@ class CompositeModel: ...@@ -51,16 +51,26 @@ class CompositeModel:
return module return module
COMPOSITE_MODELS: Dict[str, "CompositeModel"] = {} COMPOSITE_MODELS: dict[str, "CompositeModel"] = {}
def _register_composite_model( def _register_composite_model(
model_type: str, model_type: str,
projector_key: Optional[str] = None, projector_key: Optional[str] = None,
vision_model_keys: Optional[List[str]] = None, vision_model_keys: Optional[list[str]] = None,
language_model_keys: Optional[List[str]] = None, language_model_keys: Optional[list[str]] = None,
lora_conflict_keys: Optional[List[str]] = None, lora_conflict_keys: Optional[list[str]] = None,
): ):
r"""Register a new composite model.
Args:
model_type: model type
projector_key: multi_modal_projector
vision_model_keys: vision_tower
language_model_keys: language_model
lora_conflict_keys: None
"""
COMPOSITE_MODELS[model_type] = CompositeModel( COMPOSITE_MODELS[model_type] = CompositeModel(
model_type=model_type, model_type=model_type,
projector_key=projector_key or "multi_modal_projector", projector_key=projector_key or "multi_modal_projector",
...@@ -116,12 +126,10 @@ class LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL): ...@@ -116,12 +126,10 @@ class LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL):
def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArguments") -> None: def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
r""" r"""Cast projector output to half precision for fine-tuning quantized VLMs."""
Casts projector output to half precision for fine-tuning quantized VLMs.
"""
def _mm_projector_forward_post_hook( def _mm_projector_forward_post_hook(
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor" module: "torch.nn.Module", args: tuple["torch.Tensor"], output: "torch.Tensor"
) -> "torch.Tensor": ) -> "torch.Tensor":
return output.to(model_args.compute_dtype) return output.to(model_args.compute_dtype)
...@@ -137,9 +145,7 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen ...@@ -137,9 +145,7 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
def configure_visual_model(config: "PretrainedConfig") -> None: def configure_visual_model(config: "PretrainedConfig") -> None:
r""" r"""Patch VLMs before loading them."""
Patches VLMs before loading them.
"""
if getattr(config, "text_config", None) and not getattr(config, "hidden_size", None): if getattr(config, "text_config", None) and not getattr(config, "hidden_size", None):
# required for ds zero3 and valuehead models # required for ds zero3 and valuehead models
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None)) setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
...@@ -149,10 +155,8 @@ def configure_visual_model(config: "PretrainedConfig") -> None: ...@@ -149,10 +155,8 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "FinetuningArguments") -> Set[str]: def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "FinetuningArguments") -> set[str]:
r""" r"""Freeze vision tower and language model for VLM full/freeze tuning."""
Freezes vision tower and language model for VLM full/freeze tuning.
"""
model_type = getattr(config, "model_type", None) model_type = getattr(config, "model_type", None)
forbidden_modules = set() forbidden_modules = set()
if model_type in COMPOSITE_MODELS: if model_type in COMPOSITE_MODELS:
...@@ -174,47 +178,10 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni ...@@ -174,47 +178,10 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
return forbidden_modules return forbidden_modules
def get_image_seqlen(config: "PretrainedConfig") -> int:
r"""
Computes the number of special tokens per image.
"""
model_type = getattr(config, "model_type", None)
if model_type == "llava":
image_seqlen = (config.vision_config.image_size // config.vision_config.patch_size) ** 2
if getattr(config, "vision_feature_select_strategy", "default") == "full": # add [CLS] token
image_seqlen += 1
elif model_type == "paligemma":
image_seqlen = config.vision_config.num_image_tokens
else:
image_seqlen = -1
return image_seqlen
def get_patch_size(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
r"""
Computes the patch size of the vit.
"""
patch_size = getattr(config.vision_config, "patch_size", getattr(processor, "patch_size", -1))
return patch_size
def get_vision_feature_select_strategy(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
r"""
Get the vision_feature_select_strategy.
"""
vision_feature_select_strategy = getattr(
config, "vision_feature_select_strategy", getattr(processor, "vision_feature_select_strategy", "default")
)
return vision_feature_select_strategy
def patch_target_modules( def patch_target_modules(
model: "PreTrainedModel", finetuning_args: "FinetuningArguments", target_modules: Sequence[str] model: "PreTrainedModel", finetuning_args: "FinetuningArguments", target_modules: list[str]
) -> List[str]: ) -> list[str]:
r""" r"""Freeze vision tower for VLM LoRA tuning."""
Freezes vision tower for VLM LoRA tuning.
"""
model_type = getattr(model.config, "model_type", None) model_type = getattr(model.config, "model_type", None)
if model_type in COMPOSITE_MODELS: if model_type in COMPOSITE_MODELS:
forbidden_modules = get_forbidden_modules(model.config, finetuning_args) forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
...@@ -231,6 +198,17 @@ def patch_target_modules( ...@@ -231,6 +198,17 @@ def patch_target_modules(
return target_modules return target_modules
_register_composite_model(
model_type="gemma3",
)
_register_composite_model(
model_type="llama4",
vision_model_keys=["vision_model"],
)
_register_composite_model( _register_composite_model(
model_type="llava", model_type="llava",
) )
...@@ -285,6 +263,15 @@ _register_composite_model( ...@@ -285,6 +263,15 @@ _register_composite_model(
) )
_register_composite_model(
model_type="qwen2_5_omni_thinker",
projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks", "audio_tower"],
language_model_keys=["model", "lm_head"],
lora_conflict_keys=["patch_embed"],
)
_register_composite_model( _register_composite_model(
model_type="qwen2_vl", model_type="qwen2_vl",
projector_key="visual.merger", projector_key="visual.merger",
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from types import MethodType from types import MethodType
from typing import TYPE_CHECKING, Any, Dict from typing import TYPE_CHECKING, Any
import torch import torch
from peft import PeftModel from peft import PeftModel
...@@ -27,19 +27,14 @@ from ..extras.packages import is_transformers_version_greater_than ...@@ -27,19 +27,14 @@ from ..extras.packages import is_transformers_version_greater_than
from .model_utils.attention import configure_attn_implementation, print_attn_implementation from .model_utils.attention import configure_attn_implementation, print_attn_implementation
from .model_utils.checkpointing import prepare_model_for_training from .model_utils.checkpointing import prepare_model_for_training
from .model_utils.embedding import resize_embedding_layer from .model_utils.embedding import resize_embedding_layer
from .model_utils.kv_cache import configure_kv_cache
from .model_utils.longlora import configure_longlora from .model_utils.longlora import configure_longlora
from .model_utils.moe import add_z3_leaf_module, configure_moe from .model_utils.moe import add_z3_leaf_module, configure_moe
from .model_utils.packing import configure_packing from .model_utils.packing import configure_packing
from .model_utils.quantization import configure_quantization from .model_utils.quantization import configure_quantization
from .model_utils.rope import configure_rope from .model_utils.rope import configure_rope
from .model_utils.valuehead import prepare_valuehead_model from .model_utils.valuehead import prepare_valuehead_model
from .model_utils.visual import ( from .model_utils.visual import autocast_projector_dtype, configure_visual_model
autocast_projector_dtype,
configure_visual_model,
get_image_seqlen,
get_patch_size,
get_vision_feature_select_strategy,
)
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -56,8 +51,8 @@ def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArgument ...@@ -56,8 +51,8 @@ def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArgument
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__): if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer) tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
if model_args.model_max_length is not None and tokenizer.model_max_length != model_args.model_max_length: if model_args.model_max_length is not None and tokenizer.model_max_length < model_args.model_max_length:
tokenizer.model_max_length = model_args.model_max_length tokenizer.model_max_length = model_args.model_max_length # enlarge the tokenizer max length
if model_args.new_special_tokens is not None: if model_args.new_special_tokens is not None:
num_added_tokens = tokenizer.add_special_tokens( num_added_tokens = tokenizer.add_special_tokens(
...@@ -72,28 +67,25 @@ def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArgument ...@@ -72,28 +67,25 @@ def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArgument
def patch_processor( def patch_processor(
processor: "ProcessorMixin", processor: "ProcessorMixin",
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments", model_args: "ModelArguments",
) -> None: ) -> None:
setattr(processor, "tokenizer", tokenizer) setattr(processor, "tokenizer", tokenizer)
if getattr(config, "vision_config", None) is not None: # visual models setattr(processor, "image_max_pixels", model_args.image_max_pixels)
setattr(processor, "image_seqlen", get_image_seqlen(config)) setattr(processor, "image_min_pixels", model_args.image_min_pixels)
setattr(processor, "patch_size", get_patch_size(config, processor)) setattr(processor, "image_do_pan_and_scan", model_args.image_do_pan_and_scan)
setattr(processor, "image_max_pixels", model_args.image_max_pixels) setattr(processor, "video_max_pixels", model_args.video_max_pixels)
setattr(processor, "image_min_pixels", model_args.image_min_pixels) setattr(processor, "video_min_pixels", model_args.video_min_pixels)
setattr(processor, "video_max_pixels", model_args.video_max_pixels) setattr(processor, "video_fps", model_args.video_fps)
setattr(processor, "video_min_pixels", model_args.video_min_pixels) setattr(processor, "video_maxlen", model_args.video_maxlen)
setattr(processor, "video_fps", model_args.video_fps) setattr(processor, "audio_sampling_rate", model_args.audio_sampling_rate)
setattr(processor, "video_maxlen", model_args.video_maxlen)
setattr(processor, "vision_feature_select_strategy", get_vision_feature_select_strategy(config, processor))
def patch_config( def patch_config(
config: "PretrainedConfig", config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments", model_args: "ModelArguments",
init_kwargs: Dict[str, Any], init_kwargs: dict[str, Any],
is_trainable: bool, is_trainable: bool,
) -> None: ) -> None:
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32 if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
...@@ -112,19 +104,13 @@ def patch_config( ...@@ -112,19 +104,13 @@ def patch_config(
configure_moe(config, model_args, is_trainable) configure_moe(config, model_args, is_trainable)
configure_visual_model(config) configure_visual_model(config)
configure_packing(model_args, is_trainable) configure_packing(model_args, is_trainable)
configure_kv_cache(config, model_args, is_trainable)
if model_args.use_cache and not is_trainable:
setattr(config, "use_cache", True)
logger.info_rank0("Using KV cache for faster generation.")
if getattr(config, "model_type", None) == "qwen": if getattr(config, "model_type", None) == "qwen":
setattr(config, "use_flash_attn", model_args.flash_attn == "fa2") setattr(config, "use_flash_attn", model_args.flash_attn == "fa2")
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]: for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
setattr(config, dtype_name, model_args.compute_dtype == dtype) setattr(config, dtype_name, model_args.compute_dtype == dtype)
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2":
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn
if getattr(config, "model_type", None) == "minicpmo": if getattr(config, "model_type", None) == "minicpmo":
setattr(config, "init_audio", True) setattr(config, "init_audio", True)
setattr(config, "init_tts", False) setattr(config, "init_tts", False)
...@@ -138,15 +124,13 @@ def patch_config( ...@@ -138,15 +124,13 @@ def patch_config(
# deepspeed zero3 is not compatible with low_cpu_mem_usage # deepspeed zero3 is not compatible with low_cpu_mem_usage
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled()) init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
# cast data type of the model if: # do not cast data type of the model deepspeed zero3 without qlora
# 1. not deepspeed zero3 and not fsdp (keep zero3 or fsdp in float32) if not (is_deepspeed_zero3_enabled() and model_args.quantization_bit is None):
# 2. quantization_bit is not None (qlora)
if (not is_deepspeed_zero3_enabled() and not is_fsdp_enabled()) or model_args.quantization_bit is not None:
init_kwargs["torch_dtype"] = model_args.compute_dtype init_kwargs["torch_dtype"] = model_args.compute_dtype
if init_kwargs["low_cpu_mem_usage"]: # device map requires low_cpu_mem_usage=True if init_kwargs["low_cpu_mem_usage"] and not is_fsdp_enabled(): # fsdp does not need device map
if "device_map" not in init_kwargs and model_args.device_map: if "device_map" not in init_kwargs and model_args.device_map:
init_kwargs["device_map"] = model_args.device_map init_kwargs["device_map"] = model_args.device_map # device map requires low_cpu_mem_usage=True
if init_kwargs.get("device_map", None) == "auto": if init_kwargs.get("device_map", None) == "auto":
init_kwargs["offload_folder"] = model_args.offload_folder init_kwargs["offload_folder"] = model_args.offload_folder
......
...@@ -19,7 +19,7 @@ import sys ...@@ -19,7 +19,7 @@ import sys
import time import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta from datetime import timedelta
from typing import TYPE_CHECKING, Any, Dict, Optional from typing import TYPE_CHECKING, Any, Optional
import torch import torch
import transformers import transformers
...@@ -56,7 +56,8 @@ logger = logging.get_logger(__name__) ...@@ -56,7 +56,8 @@ logger = logging.get_logger(__name__)
def fix_valuehead_checkpoint( def fix_valuehead_checkpoint(
model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
) -> None: ) -> None:
r""" r"""Fix the valuehead checkpoint files.
The model is already unwrapped. The model is already unwrapped.
There are three cases: There are three cases:
...@@ -72,10 +73,10 @@ def fix_valuehead_checkpoint( ...@@ -72,10 +73,10 @@ def fix_valuehead_checkpoint(
if safe_serialization: if safe_serialization:
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME) path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f: with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()} state_dict: dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
else: else:
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME) path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu") state_dict: dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
os.remove(path_to_checkpoint) os.remove(path_to_checkpoint)
decoder_state_dict, v_head_state_dict = {}, {} decoder_state_dict, v_head_state_dict = {}, {}
...@@ -98,9 +99,7 @@ def fix_valuehead_checkpoint( ...@@ -98,9 +99,7 @@ def fix_valuehead_checkpoint(
class FixValueHeadModelCallback(TrainerCallback): class FixValueHeadModelCallback(TrainerCallback):
r""" r"""A callback for fixing the checkpoint for valuehead models."""
A callback for fixing the checkpoint for valuehead models.
"""
@override @override
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
...@@ -112,9 +111,7 @@ class FixValueHeadModelCallback(TrainerCallback): ...@@ -112,9 +111,7 @@ class FixValueHeadModelCallback(TrainerCallback):
class SaveProcessorCallback(TrainerCallback): class SaveProcessorCallback(TrainerCallback):
r""" r"""A callback for saving the processor."""
A callback for saving the processor.
"""
def __init__(self, processor: "ProcessorMixin") -> None: def __init__(self, processor: "ProcessorMixin") -> None:
self.processor = processor self.processor = processor
...@@ -132,9 +129,7 @@ class SaveProcessorCallback(TrainerCallback): ...@@ -132,9 +129,7 @@ class SaveProcessorCallback(TrainerCallback):
class PissaConvertCallback(TrainerCallback): class PissaConvertCallback(TrainerCallback):
r""" r"""A callback for converting the PiSSA adapter to a normal one."""
A callback for converting the PiSSA adapter to a normal one.
"""
@override @override
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
...@@ -166,20 +161,17 @@ class PissaConvertCallback(TrainerCallback): ...@@ -166,20 +161,17 @@ class PissaConvertCallback(TrainerCallback):
model.save_pretrained(pissa_backup_dir, safe_serialization=args.save_safetensors) model.save_pretrained(pissa_backup_dir, safe_serialization=args.save_safetensors)
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights) setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
model.save_pretrained( model.save_pretrained(
pissa_convert_dir, safe_serialization=args.save_safetensors, convert_pissa_to_lora=pissa_init_dir pissa_convert_dir,
) # TODO: use `path_initial_model_for_weight_conversion` (peft>=0.12.0) safe_serialization=args.save_safetensors,
path_initial_model_for_weight_conversion=pissa_init_dir,
)
model.load_adapter(pissa_backup_dir, "default", is_trainable=True) model.load_adapter(pissa_backup_dir, "default", is_trainable=True)
model.set_adapter("default") model.set_adapter("default")
if "pissa_init" in model.peft_config.keys(): # backward compatibility (peft<0.12.0)
model.delete_adapter("pissa_init")
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights) setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
class LogCallback(TrainerCallback): class LogCallback(TrainerCallback):
r""" r"""A callback for logging training and evaluation status."""
A callback for logging training and evaluation status.
"""
def __init__(self) -> None: def __init__(self) -> None:
# Progress # Progress
...@@ -188,7 +180,7 @@ class LogCallback(TrainerCallback): ...@@ -188,7 +180,7 @@ class LogCallback(TrainerCallback):
self.max_steps = 0 self.max_steps = 0
self.elapsed_time = "" self.elapsed_time = ""
self.remaining_time = "" self.remaining_time = ""
self.thread_pool: Optional["ThreadPoolExecutor"] = None self.thread_pool: Optional[ThreadPoolExecutor] = None
# Status # Status
self.aborted = False self.aborted = False
self.do_train = False self.do_train = False
...@@ -219,7 +211,7 @@ class LogCallback(TrainerCallback): ...@@ -219,7 +211,7 @@ class LogCallback(TrainerCallback):
self.elapsed_time = str(timedelta(seconds=int(elapsed_time))) self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
self.remaining_time = str(timedelta(seconds=int(remaining_time))) self.remaining_time = str(timedelta(seconds=int(remaining_time)))
def _write_log(self, output_dir: str, logs: Dict[str, Any]) -> None: def _write_log(self, output_dir: str, logs: dict[str, Any]) -> None:
with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f: with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f:
f.write(json.dumps(logs) + "\n") f.write(json.dumps(logs) + "\n")
...@@ -348,9 +340,7 @@ class LogCallback(TrainerCallback): ...@@ -348,9 +340,7 @@ class LogCallback(TrainerCallback):
class ReporterCallback(TrainerCallback): class ReporterCallback(TrainerCallback):
r""" r"""A callback for reporting training status to external logger."""
A callback for reporting training status to external logger.
"""
def __init__( def __init__(
self, self,
......
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's TRL library. # This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/dpo_trainer.py # https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/dpo_trainer.py
...@@ -19,7 +19,7 @@ import warnings ...@@ -19,7 +19,7 @@ import warnings
from collections import defaultdict from collections import defaultdict
from contextlib import nullcontext from contextlib import nullcontext
from types import MethodType from types import MethodType
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union from typing import TYPE_CHECKING, Literal, Optional, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -128,16 +128,12 @@ class CustomDPOTrainer(DPOTrainer): ...@@ -128,16 +128,12 @@ class CustomDPOTrainer(DPOTrainer):
return super()._get_train_sampler() return super()._get_train_sampler()
@override @override
def get_batch_samples(self, epoch_iterator, num_batches): def get_batch_samples(self, *args, **kwargs):
r""" r"""Replace the method of DPO Trainer with the one of the standard Trainer."""
Replaces the method of KTO Trainer with the one of the standard Trainer. return Trainer.get_batch_samples(self, *args, **kwargs)
"""
return Trainer.get_batch_samples(self, epoch_iterator, num_batches)
def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor": def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
r""" r"""Compute ORPO's odds ratio (OR) loss for batched log probabilities of the policy model."""
Computes ORPO's odds ratio (OR) loss for batched log probabilities of the policy model.
"""
log_odds = (chosen_logps - rejected_logps) - ( log_odds = (chosen_logps - rejected_logps) - (
torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps)) torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps))
) )
...@@ -147,9 +143,7 @@ class CustomDPOTrainer(DPOTrainer): ...@@ -147,9 +143,7 @@ class CustomDPOTrainer(DPOTrainer):
return orpo_loss return orpo_loss
def simpo_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor": def simpo_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
r""" r"""Compute SimPO loss for batched log probabilities of the policy model."""
Computes SimPO loss for batched log probabilities of the policy model.
"""
pi_logratios = chosen_logps - rejected_logps pi_logratios = chosen_logps - rejected_logps
gamma_logratios = self.simpo_gamma / self.beta gamma_logratios = self.simpo_gamma / self.beta
logits = pi_logratios - gamma_logratios logits = pi_logratios - gamma_logratios
...@@ -162,10 +156,8 @@ class CustomDPOTrainer(DPOTrainer): ...@@ -162,10 +156,8 @@ class CustomDPOTrainer(DPOTrainer):
policy_rejected_logps: "torch.Tensor", policy_rejected_logps: "torch.Tensor",
reference_chosen_logps: Optional["torch.Tensor"], reference_chosen_logps: Optional["torch.Tensor"],
reference_rejected_logps: Optional["torch.Tensor"], reference_rejected_logps: Optional["torch.Tensor"],
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]: ) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r""" r"""Compute loss for preference learning."""
Computes loss for preference learning.
"""
if not self.finetuning_args.use_ref_model: if not self.finetuning_args.use_ref_model:
if self.loss_type == "orpo": if self.loss_type == "orpo":
losses = self.odds_ratio_loss(policy_chosen_logps, policy_rejected_logps) losses = self.odds_ratio_loss(policy_chosen_logps, policy_rejected_logps)
...@@ -185,17 +177,16 @@ class CustomDPOTrainer(DPOTrainer): ...@@ -185,17 +177,16 @@ class CustomDPOTrainer(DPOTrainer):
@override @override
def concatenated_forward( def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: ) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r""" r"""Compute the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
Computes the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO.
Otherwise the average log probabilities. Otherwise the average log probabilities.
""" """
if self.finetuning_args.use_ref_model: if self.finetuning_args.use_ref_model:
batch = nested_detach(batch, clone=True) # avoid error batch = nested_detach(batch, clone=True) # avoid error
all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32) all_logits: torch.Tensor = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"]) all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"])
if self.loss_type in ["ipo", "orpo", "simpo"]: if self.loss_type in ["ipo", "orpo", "simpo"]:
all_logps = all_logps / valid_length all_logps = all_logps / valid_length
...@@ -212,11 +203,9 @@ class CustomDPOTrainer(DPOTrainer): ...@@ -212,11 +203,9 @@ class CustomDPOTrainer(DPOTrainer):
@override @override
def compute_reference_log_probs( def compute_reference_log_probs(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"]
) -> Tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]: ) -> tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]:
r""" r"""Compute log probabilities of the reference model."""
Computes log probabilities of the reference model.
"""
if not self.finetuning_args.use_ref_model: if not self.finetuning_args.use_ref_model:
return None, None return None, None
...@@ -236,12 +225,10 @@ class CustomDPOTrainer(DPOTrainer): ...@@ -236,12 +225,10 @@ class CustomDPOTrainer(DPOTrainer):
def get_batch_loss_metrics( def get_batch_loss_metrics(
self, self,
model: "PreTrainedModel", model: "PreTrainedModel",
batch: Dict[str, "torch.Tensor"], batch: dict[str, "torch.Tensor"],
train_eval: Literal["train", "eval"] = "train", train_eval: Literal["train", "eval"] = "train",
) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]: ) -> tuple["torch.Tensor", dict[str, "torch.Tensor"]]:
r""" r"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
"""
metrics = {} metrics = {}
( (
policy_chosen_logps, policy_chosen_logps,
...@@ -279,18 +266,14 @@ class CustomDPOTrainer(DPOTrainer): ...@@ -279,18 +266,14 @@ class CustomDPOTrainer(DPOTrainer):
@override @override
def compute_loss( def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs self, model: "PreTrainedModel", inputs: dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]: ) -> Union["torch.Tensor", tuple["torch.Tensor", list["torch.Tensor"]]]:
r""" r"""Subclass and override to accept extra kwargs."""
Subclass and override to accept extra kwargs.
"""
return super().compute_loss(model, inputs, return_outputs) return super().compute_loss(model, inputs, return_outputs)
@override @override
def log(self, logs: Dict[str, float], *args, **kwargs) -> None: def log(self, logs: dict[str, float], *args, **kwargs) -> None:
r""" r"""Log `logs` on the various objects watching training, including stored metrics."""
Log `logs` on the various objects watching training, including stored metrics.
"""
# logs either has "loss" or "eval_loss" # logs either has "loss" or "eval_loss"
train_eval = "train" if "loss" in logs else "eval" train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs # Add averaged stored metrics to logs
......
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's TRL library. # This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/dpo.py # https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/dpo.py
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, Optional
from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
...@@ -38,7 +38,7 @@ def run_dpo( ...@@ -38,7 +38,7 @@ def run_dpo(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None, callbacks: Optional[list["TrainerCallback"]] = None,
): ):
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]
......
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's TRL library. # This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/kto_trainer.py # https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/kto_trainer.py
...@@ -19,7 +19,7 @@ import warnings ...@@ -19,7 +19,7 @@ import warnings
from collections import defaultdict from collections import defaultdict
from contextlib import nullcontext from contextlib import nullcontext
from types import MethodType from types import MethodType
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union from typing import TYPE_CHECKING, Literal, Optional, Union
import torch import torch
from transformers import Trainer from transformers import Trainer
...@@ -120,28 +120,22 @@ class CustomKTOTrainer(KTOTrainer): ...@@ -120,28 +120,22 @@ class CustomKTOTrainer(KTOTrainer):
@override @override
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]: def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
r""" r"""Replace the sequential sampler of KTO Trainer created by trl with the random sampler."""
Replaces the sequential sampler of KTO Trainer created by trl with the random sampler.
"""
if self.finetuning_args.disable_shuffling: if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset) return torch.utils.data.SequentialSampler(self.train_dataset)
return Trainer._get_train_sampler(self) return Trainer._get_train_sampler(self)
@override @override
def get_batch_samples(self, epoch_iterator, num_batches): def get_batch_samples(self, *args, **kwargs):
r""" r"""Replace the method of KTO Trainer with the one of the standard Trainer."""
Replaces the method of KTO Trainer with the one of the standard Trainer. return Trainer.get_batch_samples(self, *args, **kwargs)
"""
return Trainer.get_batch_samples(self, epoch_iterator, num_batches)
@override @override
def forward( def forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = "" self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]: ) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r""" r"""Run forward pass and computes the log probabilities."""
Runs forward pass and computes the log probabilities.
"""
batch = nested_detach(batch, clone=True) # avoid error batch = nested_detach(batch, clone=True) # avoid error
model_inputs = { model_inputs = {
"input_ids": batch[f"{prefix}input_ids"], "input_ids": batch[f"{prefix}input_ids"],
...@@ -171,8 +165,8 @@ class CustomKTOTrainer(KTOTrainer): ...@@ -171,8 +165,8 @@ class CustomKTOTrainer(KTOTrainer):
@override @override
def concatenated_forward( def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: ) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
target_logits, target_logps, target_logps_avg = self.forward(model, batch) target_logits, target_logps, target_logps_avg = self.forward(model, batch)
with torch.no_grad(): with torch.no_grad():
_, kl_logps, _ = self.forward(model, batch, prefix="kl_") _, kl_logps, _ = self.forward(model, batch, prefix="kl_")
...@@ -189,11 +183,9 @@ class CustomKTOTrainer(KTOTrainer): ...@@ -189,11 +183,9 @@ class CustomKTOTrainer(KTOTrainer):
@override @override
def compute_reference_log_probs( def compute_reference_log_probs(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] self, model: "PreTrainedModel", batch: dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]: ) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r""" r"""Compute log probabilities of the reference model."""
Computes log probabilities of the reference model.
"""
if self.ref_model is None: if self.ref_model is None:
ref_model = model ref_model = model
ref_context = self.accelerator.unwrap_model(model).disable_adapter() ref_context = self.accelerator.unwrap_model(model).disable_adapter()
...@@ -212,11 +204,9 @@ class CustomKTOTrainer(KTOTrainer): ...@@ -212,11 +204,9 @@ class CustomKTOTrainer(KTOTrainer):
def get_batch_loss_metrics( def get_batch_loss_metrics(
self, self,
model: "PreTrainedModel", model: "PreTrainedModel",
batch: Dict[str, "torch.Tensor"], batch: dict[str, "torch.Tensor"],
) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]: ) -> tuple["torch.Tensor", dict[str, "torch.Tensor"]]:
r""" r"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
"""
metrics = {} metrics = {}
( (
policy_chosen_logps, policy_chosen_logps,
...@@ -262,18 +252,14 @@ class CustomKTOTrainer(KTOTrainer): ...@@ -262,18 +252,14 @@ class CustomKTOTrainer(KTOTrainer):
@override @override
def compute_loss( def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs self, model: "PreTrainedModel", inputs: dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]: ) -> Union["torch.Tensor", tuple["torch.Tensor", list["torch.Tensor"]]]:
r""" r"""Subclass and override to accept extra kwargs."""
Subclass and override to accept extra kwargs.
"""
return super().compute_loss(model, inputs, return_outputs) return super().compute_loss(model, inputs, return_outputs)
@override @override
def log(self, logs: Dict[str, float], *args, **kwargs) -> None: def log(self, logs: dict[str, float], *args, **kwargs) -> None:
r""" r"""Log `logs` on the various objects watching training, including stored metrics."""
Log `logs` on the various objects watching training, including stored metrics.
"""
# logs either has "loss" or "eval_loss" # logs either has "loss" or "eval_loss"
train_eval = "train" if "loss" in logs else "eval" train_eval = "train" if "loss" in logs else "eval"
prefix = "eval_" if train_eval == "eval" else "" prefix = "eval_" if train_eval == "eval" else ""
...@@ -291,7 +277,7 @@ class CustomKTOTrainer(KTOTrainer): ...@@ -291,7 +277,7 @@ class CustomKTOTrainer(KTOTrainer):
metric_list = torch.tensor(metric_list, dtype=torch.float).to(self.accelerator.device) metric_list = torch.tensor(metric_list, dtype=torch.float).to(self.accelerator.device)
metric_list = self.accelerator.reduce(metric_list, "sum").tolist() metric_list = self.accelerator.reduce(metric_list, "sum").tolist()
metric_dict: Dict[str, float] = dict(zip(key_list, metric_list)) metric_dict: dict[str, float] = dict(zip(key_list, metric_list))
for split in ["chosen", "rejected"]: # accumulate average metrics from sums and lengths for split in ["chosen", "rejected"]: # accumulate average metrics from sums and lengths
if f"count/{split}" in metric_dict: if f"count/{split}" in metric_dict:
for key in ("rewards", "logps", "logits"): for key in ("rewards", "logps", "logits"):
......
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's TRL library. # This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/kto.py # https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/kto.py
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, Optional
from ...data import KTODataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer from ...data import KTODataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
...@@ -37,7 +37,7 @@ def run_kto( ...@@ -37,7 +37,7 @@ def run_kto(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None, callbacks: Optional[list["TrainerCallback"]] = None,
): ):
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import json import json
from contextlib import nullcontext from contextlib import nullcontext
from typing import TYPE_CHECKING, Dict, List, Literal, Optional from typing import TYPE_CHECKING, Literal, Optional
import torch import torch
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
...@@ -31,10 +31,8 @@ if TYPE_CHECKING: ...@@ -31,10 +31,8 @@ if TYPE_CHECKING:
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
def get_rewards_from_server(server_url: str, messages: List[str]) -> List["torch.Tensor"]: def get_rewards_from_server(server_url: str, messages: list[str]) -> list["torch.Tensor"]:
r""" r"""Get reward scores from the API server."""
Gets reward scores from the API server.
"""
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}
payload = {"model": "model", "messages": messages} payload = {"model": "model", "messages": messages}
response = requests.post(server_url, json=payload, headers=headers) response = requests.post(server_url, json=payload, headers=headers)
...@@ -43,9 +41,7 @@ def get_rewards_from_server(server_url: str, messages: List[str]) -> List["torch ...@@ -43,9 +41,7 @@ def get_rewards_from_server(server_url: str, messages: List[str]) -> List["torch
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None: def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
r""" r"""Replace the default/reward modules in the model. The model is already unwrapped."""
Replaces the default/reward modules in the model. The model is already unwrapped.
"""
v_head_layer = model.v_head.summary v_head_layer = model.v_head.summary
if is_deepspeed_zero3_enabled(): if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore import deepspeed # type: ignore
...@@ -66,10 +62,8 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d ...@@ -66,10 +62,8 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d
v_head_layer.bias.data = model.get_buffer(f"{target}_head_bias").detach().clone().to(device) v_head_layer.bias.data = model.get_buffer(f"{target}_head_bias").detach().clone().to(device)
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, "torch.Tensor"]: def dump_layernorm(model: "PreTrainedModel") -> dict[str, "torch.Tensor"]:
r""" r"""Dump the layernorm parameters in the model. The model is already unwrapped (and gathered)."""
Dumps the layernorm parameters in the model. The model is already unwrapped (and gathered).
"""
layer_norm_params = {} layer_norm_params = {}
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if param.data.dtype == torch.float32: if param.data.dtype == torch.float32:
...@@ -79,10 +73,8 @@ def dump_layernorm(model: "PreTrainedModel") -> Dict[str, "torch.Tensor"]: ...@@ -79,10 +73,8 @@ def dump_layernorm(model: "PreTrainedModel") -> Dict[str, "torch.Tensor"]:
return layer_norm_params return layer_norm_params
def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, "torch.Tensor"]] = None) -> None: def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[dict[str, "torch.Tensor"]] = None) -> None:
r""" r"""Restore the layernorm parameters in the model. The model is already unwrapped (and gathered)."""
Restores the layernorm parameters in the model. The model is already unwrapped (and gathered).
"""
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if name in layernorm_params: if name in layernorm_params:
param.data = layernorm_params[name] param.data = layernorm_params[name]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment