Commit 2778a3d0 authored by luopl's avatar luopl
Browse files

updata to v0.9.1_stable

parent e92143e3
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import inspect import inspect
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...extras.logging import get_logger from ...extras import logging
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -24,7 +24,7 @@ if TYPE_CHECKING: ...@@ -24,7 +24,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments from ...hparams import ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def apply_liger_kernel( def apply_liger_kernel(
...@@ -54,14 +54,14 @@ def apply_liger_kernel( ...@@ -54,14 +54,14 @@ def apply_liger_kernel(
elif model_type == "qwen2_vl": elif model_type == "qwen2_vl":
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel
else: else:
logger.warning("Current model does not support liger kernel.") logger.warning_rank0("Current model does not support liger kernel.")
return return
if require_logits and "fused_linear_cross_entropy" in inspect.signature(apply_liger_kernel).parameters: if require_logits and "fused_linear_cross_entropy" in inspect.signature(apply_liger_kernel).parameters:
logger.info("Current training stage does not support chunked cross entropy.") logger.info_rank0("Current training stage does not support chunked cross entropy.")
kwargs = {"fused_linear_cross_entropy": False} kwargs = {"fused_linear_cross_entropy": False}
else: else:
kwargs = {} kwargs = {}
apply_liger_kernel(**kwargs) apply_liger_kernel(**kwargs)
logger.info("Liger kernel has been applied to the model.") logger.info_rank0("Liger kernel has been applied to the model.")
...@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Optional, Tuple ...@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import transformers
from transformers.models.llama.modeling_llama import ( from transformers.models.llama.modeling_llama import (
Cache, Cache,
LlamaAttention, LlamaAttention,
...@@ -30,12 +31,11 @@ from transformers.models.llama.modeling_llama import ( ...@@ -30,12 +31,11 @@ from transformers.models.llama.modeling_llama import (
apply_rotary_pos_emb, apply_rotary_pos_emb,
repeat_kv, repeat_kv,
) )
from transformers.utils import logging
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from ...extras import logging
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
from ...extras.logging import get_logger from ...extras.packages import is_transformers_version_greater_than
from ...extras.packages import is_transformers_version_greater_than_4_43
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -44,7 +44,7 @@ if TYPE_CHECKING: ...@@ -44,7 +44,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments from ...hparams import ModelArguments
transformers_logger = logging.get_logger(__name__) transformers_logger = transformers.utils.logging.get_logger(__name__)
# Modified from: # Modified from:
...@@ -86,7 +86,7 @@ def llama_attention_forward( ...@@ -86,7 +86,7 @@ def llama_attention_forward(
if getattr(self.config, "group_size_ratio", None) and self.training: # shift if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio")) groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) assert q_len % groupsz == 0, f"q_len {q_len} should be divisible by group size {groupsz}."
num_groups = q_len // groupsz num_groups = q_len // groupsz
def shift(state: "torch.Tensor") -> "torch.Tensor": def shift(state: "torch.Tensor") -> "torch.Tensor":
...@@ -195,7 +195,7 @@ def llama_flash_attention_2_forward( ...@@ -195,7 +195,7 @@ def llama_flash_attention_2_forward(
if getattr(self.config, "group_size_ratio", None) and self.training: # shift if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio")) groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) assert q_len % groupsz == 0, f"q_len {q_len} should be divisible by group size {groupsz}."
num_groups = q_len // groupsz num_groups = q_len // groupsz
def shift(state: "torch.Tensor") -> "torch.Tensor": def shift(state: "torch.Tensor") -> "torch.Tensor":
...@@ -209,7 +209,7 @@ def llama_flash_attention_2_forward( ...@@ -209,7 +209,7 @@ def llama_flash_attention_2_forward(
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1) attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
if is_transformers_version_greater_than_4_43(): if is_transformers_version_greater_than("4.43.0"):
from transformers.modeling_flash_attention_utils import _flash_attention_forward from transformers.modeling_flash_attention_utils import _flash_attention_forward
attn_output: "torch.Tensor" = _flash_attention_forward( attn_output: "torch.Tensor" = _flash_attention_forward(
...@@ -301,7 +301,7 @@ def llama_sdpa_attention_forward( ...@@ -301,7 +301,7 @@ def llama_sdpa_attention_forward(
if getattr(self.config, "group_size_ratio", None) and self.training: # shift if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio")) groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) assert q_len % groupsz == 0, f"q_len {q_len} should be divisible by group size {groupsz}."
num_groups = q_len // groupsz num_groups = q_len // groupsz
def shift(state: "torch.Tensor") -> "torch.Tensor": def shift(state: "torch.Tensor") -> "torch.Tensor":
...@@ -353,7 +353,7 @@ def llama_sdpa_attention_forward( ...@@ -353,7 +353,7 @@ def llama_sdpa_attention_forward(
def _apply_llama_patch() -> None: def _apply_llama_patch() -> None:
require_version("transformers>=4.41.2,<=4.45.2", "To fix: pip install transformers>=4.41.2,<=4.45.2") require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
LlamaAttention.forward = llama_attention_forward LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward LlamaSdpaAttention.forward = llama_sdpa_attention_forward
...@@ -363,11 +363,11 @@ def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", ...@@ -363,11 +363,11 @@ def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments",
if not is_trainable or not model_args.shift_attn: if not is_trainable or not model_args.shift_attn:
return return
logger = get_logger(__name__) logger = logging.get_logger(__name__)
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN: if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
setattr(config, "group_size_ratio", 0.25) setattr(config, "group_size_ratio", 0.25)
_apply_llama_patch() _apply_llama_patch()
logger.info("Using shift short attention with group_size_ratio=1/4.") logger.info_rank0("Using shift short attention with group_size_ratio=1/4.")
else: else:
logger.warning("Current model does not support shift short attention.") logger.warning_rank0("Current model does not support shift short attention.")
...@@ -14,14 +14,14 @@ ...@@ -14,14 +14,14 @@
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING, List
from ...extras.logging import get_logger from ...extras import logging
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]: def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]:
...@@ -34,13 +34,15 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) ...@@ -34,13 +34,15 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
forbidden_modules.add("output_layer") forbidden_modules.add("output_layer")
elif model_type == "internlm2": elif model_type == "internlm2":
forbidden_modules.add("output") forbidden_modules.add("output")
elif model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]: elif model_type in ["llava", "llava_next", "llava_next_video", "mllama", "paligemma", "video_llava"]:
forbidden_modules.add("multi_modal_projector") forbidden_modules.add("multi_modal_projector")
elif model_type == "qwen2_vl": elif model_type == "qwen2_vl":
forbidden_modules.add("merger") forbidden_modules.add("merger")
if freeze_vision_tower: if freeze_vision_tower:
if model_type == "qwen2_vl": if model_type == "mllama":
forbidden_modules.add("vision_model")
elif model_type == "qwen2_vl":
forbidden_modules.add("visual") forbidden_modules.add("visual")
else: else:
forbidden_modules.add("vision_tower") forbidden_modules.add("vision_tower")
...@@ -53,7 +55,7 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) ...@@ -53,7 +55,7 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__: if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__:
module_names.add(name.split(".")[-1]) module_names.add(name.split(".")[-1])
logger.info("Found linear modules: {}".format(",".join(module_names))) logger.info_rank0("Found linear modules: {}".format(",".join(module_names)))
return list(module_names) return list(module_names)
...@@ -67,12 +69,12 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n ...@@ -67,12 +69,12 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
if num_layers % num_layer_trainable != 0: if num_layers % num_layer_trainable != 0:
raise ValueError( raise ValueError(
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(num_layers, num_layer_trainable) f"`num_layers` {num_layers} should be divisible by `num_layer_trainable` {num_layer_trainable}."
) )
stride = num_layers // num_layer_trainable stride = num_layers // num_layer_trainable
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride) trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
trainable_layers = [".{:d}.".format(idx) for idx in trainable_layer_ids] trainable_layers = [f".{idx:d}." for idx in trainable_layer_ids]
module_names = [] module_names = []
for name, _ in model.named_modules(): for name, _ in model.named_modules():
if any(target_module in name for target_module in target_modules) and any( if any(target_module in name for target_module in target_modules) and any(
...@@ -80,7 +82,7 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n ...@@ -80,7 +82,7 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
): ):
module_names.append(name) module_names.append(name)
logger.info("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids)))) logger.info_rank0("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids))))
return module_names return module_names
......
...@@ -43,9 +43,9 @@ import torch ...@@ -43,9 +43,9 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from ...extras import logging
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
from ...extras.logging import get_logger from ...extras.packages import is_transformers_version_greater_than
from ...extras.packages import is_transformers_version_greater_than_4_43
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -54,7 +54,7 @@ if TYPE_CHECKING: ...@@ -54,7 +54,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments from ...hparams import ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor": def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
...@@ -114,8 +114,8 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor ...@@ -114,8 +114,8 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
def _patch_for_block_diag_attn(model_type: str) -> None: def _patch_for_block_diag_attn(model_type: str) -> None:
require_version("transformers>=4.41.2,<=4.45.2", "To fix: pip install transformers>=4.41.2,<=4.45.2") require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
if is_transformers_version_greater_than_4_43(): if is_transformers_version_greater_than("4.43.0"):
import transformers.modeling_flash_attention_utils import transformers.modeling_flash_attention_utils
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
...@@ -152,6 +152,6 @@ def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments", ...@@ -152,6 +152,6 @@ def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments",
model_type = getattr(config, "model_type", None) model_type = getattr(config, "model_type", None)
if model_type in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN: if model_type in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN:
_patch_for_block_diag_attn(model_type) _patch_for_block_diag_attn(model_type)
logger.info("Using block diagonal attention for sequence packing without cross-attention.") logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")
else: else:
raise ValueError("Current model does not support block diagonal attention.") raise ValueError("Current model does not support block diagonal attention.")
...@@ -28,8 +28,8 @@ from transformers.integrations import is_deepspeed_zero3_enabled ...@@ -28,8 +28,8 @@ from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled from transformers.modeling_utils import is_fsdp_enabled
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from ...extras import logging
from ...extras.constants import FILEEXT2TYPE from ...extras.constants import FILEEXT2TYPE
from ...extras.logging import get_logger
from ...extras.misc import get_current_device from ...extras.misc import get_current_device
...@@ -39,7 +39,7 @@ if TYPE_CHECKING: ...@@ -39,7 +39,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments from ...hparams import ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
@unique @unique
...@@ -109,7 +109,7 @@ def configure_quantization( ...@@ -109,7 +109,7 @@ def configure_quantization(
""" """
if getattr(config, "quantization_config", None): # ptq if getattr(config, "quantization_config", None): # ptq
if model_args.quantization_bit is not None: if model_args.quantization_bit is not None:
logger.warning("`quantization_bit` will not affect on the PTQ-quantized models.") logger.warning_rank0("`quantization_bit` will not affect on the PTQ-quantized models.")
if is_deepspeed_zero3_enabled() or is_fsdp_enabled(): if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.") raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")
...@@ -130,7 +130,7 @@ def configure_quantization( ...@@ -130,7 +130,7 @@ def configure_quantization(
quantization_config["bits"] = 2 quantization_config["bits"] = 2
quant_bits = quantization_config.get("bits", "?") quant_bits = quantization_config.get("bits", "?")
logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper())) logger.info_rank0(f"Loading {quant_bits}-bit {quant_method.upper()}-quantized model.")
elif model_args.export_quantization_bit is not None: # auto-gptq elif model_args.export_quantization_bit is not None: # auto-gptq
if model_args.export_quantization_bit not in [8, 4, 3, 2]: if model_args.export_quantization_bit not in [8, 4, 3, 2]:
...@@ -149,7 +149,7 @@ def configure_quantization( ...@@ -149,7 +149,7 @@ def configure_quantization(
) )
init_kwargs["device_map"] = "auto" init_kwargs["device_map"] = "auto"
init_kwargs["max_memory"] = get_max_memory() init_kwargs["max_memory"] = get_max_memory()
logger.info("Quantizing model to {} bit with AutoGPTQ.".format(model_args.export_quantization_bit)) logger.info_rank0(f"Quantizing model to {model_args.export_quantization_bit} bit with AutoGPTQ.")
elif model_args.quantization_bit is not None: # on-the-fly elif model_args.quantization_bit is not None: # on-the-fly
if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value: if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
...@@ -179,7 +179,7 @@ def configure_quantization( ...@@ -179,7 +179,7 @@ def configure_quantization(
else: else:
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
logger.info("Quantizing model to {} bit with bitsandbytes.".format(model_args.quantization_bit)) logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with bitsandbytes.")
elif model_args.quantization_method == QuantizationMethod.HQQ.value: elif model_args.quantization_method == QuantizationMethod.HQQ.value:
if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]: if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]:
raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.") raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.")
...@@ -191,7 +191,7 @@ def configure_quantization( ...@@ -191,7 +191,7 @@ def configure_quantization(
init_kwargs["quantization_config"] = HqqConfig( init_kwargs["quantization_config"] = HqqConfig(
nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0 nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0
) # use ATEN kernel (axis=0) for performance ) # use ATEN kernel (axis=0) for performance
logger.info("Quantizing model to {} bit with HQQ.".format(model_args.quantization_bit)) logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with HQQ.")
elif model_args.quantization_method == QuantizationMethod.EETQ.value: elif model_args.quantization_method == QuantizationMethod.EETQ.value:
if model_args.quantization_bit != 8: if model_args.quantization_bit != 8:
raise ValueError("EETQ only accepts 8-bit quantization.") raise ValueError("EETQ only accepts 8-bit quantization.")
...@@ -201,4 +201,4 @@ def configure_quantization( ...@@ -201,4 +201,4 @@ def configure_quantization(
require_version("eetq", "To fix: pip install eetq") require_version("eetq", "To fix: pip install eetq")
init_kwargs["quantization_config"] = EetqConfig() init_kwargs["quantization_config"] = EetqConfig()
logger.info("Quantizing model to {} bit with EETQ.".format(model_args.quantization_bit)) logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with EETQ.")
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
import math import math
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...extras.logging import get_logger from ...extras import logging
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -28,7 +28,7 @@ if TYPE_CHECKING: ...@@ -28,7 +28,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments from ...hparams import ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
...@@ -36,30 +36,28 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_ ...@@ -36,30 +36,28 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
return return
if not hasattr(config, "rope_scaling"): if not hasattr(config, "rope_scaling"):
logger.warning("Current model does not support RoPE scaling.") logger.warning_rank0("Current model does not support RoPE scaling.")
return return
if model_args.model_max_length is not None: if model_args.model_max_length is not None:
if is_trainable and model_args.rope_scaling == "dynamic": if is_trainable and model_args.rope_scaling == "dynamic":
logger.warning( logger.warning_rank0(
"Dynamic NTK scaling may not work well with fine-tuning. " "Dynamic NTK scaling may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653" "See: https://github.com/huggingface/transformers/pull/24653"
) )
current_max_length = getattr(config, "max_position_embeddings", None) current_max_length = getattr(config, "max_position_embeddings", None)
if current_max_length and model_args.model_max_length > current_max_length: if current_max_length and model_args.model_max_length > current_max_length:
logger.info( logger.info_rank0(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.")
"Enlarge max model length from {} to {}.".format(current_max_length, model_args.model_max_length)
)
setattr(config, "max_position_embeddings", model_args.model_max_length) setattr(config, "max_position_embeddings", model_args.model_max_length)
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length)) scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
else: else:
logger.warning("Input length is smaller than max length. Consider increase input length.") logger.warning_rank0("Input length is smaller than max length. Consider increase input length.")
scaling_factor = 1.0 scaling_factor = 1.0
else: else:
scaling_factor = 2.0 scaling_factor = 2.0
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor}) setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
logger.info( logger.info_rank0(
"Using {} scaling strategy and setting scaling factor to {}".format(model_args.rope_scaling, scaling_factor) f"Using {model_args.rope_scaling} scaling strategy and setting scaling factor to {scaling_factor}"
) )
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
from typing import TYPE_CHECKING, Any, Dict, Optional from typing import TYPE_CHECKING, Any, Dict, Optional
from ...extras.logging import get_logger from ...extras import logging
from ...extras.misc import get_current_device from ...extras.misc import get_current_device
...@@ -24,7 +24,7 @@ if TYPE_CHECKING: ...@@ -24,7 +24,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments from ...hparams import ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def _get_unsloth_kwargs( def _get_unsloth_kwargs(
...@@ -56,7 +56,7 @@ def load_unsloth_pretrained_model( ...@@ -56,7 +56,7 @@ def load_unsloth_pretrained_model(
try: try:
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs) model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
except NotImplementedError: except NotImplementedError:
logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None))) logger.warning_rank0("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
model = None model = None
model_args.use_unsloth = False model_args.use_unsloth = False
......
...@@ -17,8 +17,8 @@ from typing import TYPE_CHECKING, Dict ...@@ -17,8 +17,8 @@ from typing import TYPE_CHECKING, Dict
import torch import torch
from transformers.utils import cached_file from transformers.utils import cached_file
from ...extras import logging
from ...extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ...extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ...extras.logging import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -27,7 +27,7 @@ if TYPE_CHECKING: ...@@ -27,7 +27,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments from ...hparams import ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]: def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
...@@ -54,8 +54,8 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> ...@@ -54,8 +54,8 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") ->
except Exception as err: except Exception as err:
err_text = str(err) err_text = str(err)
logger.info("Provided path ({}) does not contain value head weights: {}.".format(path_or_repo_id, err_text)) logger.info_rank0(f"Provided path ({path_or_repo_id}) does not contain value head weights: {err_text}.")
logger.info("Ignore the above message if you are not resuming the training of a value head model.") logger.info_rank0("Ignore the above message if you are not resuming the training of a value head model.")
return None return None
......
...@@ -18,21 +18,21 @@ ...@@ -18,21 +18,21 @@
from typing import TYPE_CHECKING, List, Sequence, Set, Tuple, Union from typing import TYPE_CHECKING, List, Sequence, Set, Tuple, Union
import torch import torch
import transformers
import transformers.models import transformers.models
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.utils import logging
from ...extras.logging import get_logger from ...extras import logging
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel, ProcessorMixin
from ...hparams import FinetuningArguments, ModelArguments from ...hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
transformers_logger = logging.get_logger(__name__) transformers_logger = transformers.utils.logging.get_logger(__name__)
class LlavaMultiModalProjectorForYiVL(torch.nn.Module): class LlavaMultiModalProjectorForYiVL(torch.nn.Module):
...@@ -92,14 +92,14 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen ...@@ -92,14 +92,14 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
if getattr(model, "quantization_method", None): if getattr(model, "quantization_method", None):
model_type = getattr(model.config, "model_type", None) model_type = getattr(model.config, "model_type", None)
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]: if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "pixtral", "video_llava"]:
mm_projector: "torch.nn.Module" = getattr(model, "multi_modal_projector") mm_projector: "torch.nn.Module" = getattr(model, "multi_modal_projector")
elif model_type == "qwen2_vl": elif model_type == "qwen2_vl":
mm_projector: "torch.nn.Module" = getattr(getattr(model, "visual"), "merger") mm_projector: "torch.nn.Module" = getattr(getattr(model, "visual"), "merger")
else: else:
return return
logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype)) logger.info_rank0(f"Casting multimodal projector outputs in {model_args.compute_dtype}.")
mm_projector.register_forward_hook(_mm_projector_forward_post_hook) mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
...@@ -113,12 +113,13 @@ def configure_visual_model(config: "PretrainedConfig") -> None: ...@@ -113,12 +113,13 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
"llava_next", "llava_next",
"llava_next_video", "llava_next_video",
"paligemma", "paligemma",
"pixtral",
"video_llava", "video_llava",
]: # required for ds zero3 and valuehead models ]: # required for ds zero3 and valuehead models
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None)) setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
if getattr(config, "is_yi_vl_derived_model", None): if getattr(config, "is_yi_vl_derived_model", None):
logger.info("Detected Yi-VL model, applying projector patch.") logger.info_rank0("Detected Yi-VL model, applying projector patch.")
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
...@@ -128,7 +129,7 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni ...@@ -128,7 +129,7 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
""" """
model_type = getattr(config, "model_type", None) model_type = getattr(config, "model_type", None)
forbidden_modules = set() forbidden_modules = set()
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]: if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "pixtral", "video_llava"]:
if finetuning_args.freeze_vision_tower: if finetuning_args.freeze_vision_tower:
forbidden_modules.add("vision_tower") forbidden_modules.add("vision_tower")
...@@ -162,19 +163,21 @@ def get_image_seqlen(config: "PretrainedConfig") -> int: ...@@ -162,19 +163,21 @@ def get_image_seqlen(config: "PretrainedConfig") -> int:
return image_seqlen return image_seqlen
def get_patch_size(config: "PretrainedConfig") -> int: def get_patch_size(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
r""" r"""
Computes the patch size of the vit. Computes the patch size of the vit.
""" """
patch_size = getattr(config.vision_config, "patch_size", -1) patch_size = getattr(config.vision_config, "patch_size", getattr(processor, "patch_size", -1))
return patch_size return patch_size
def get_vision_feature_select_strategy(config: "PretrainedConfig") -> int: def get_vision_feature_select_strategy(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
r""" r"""
Get the vision_feature_select_strategy. Get the vision_feature_select_strategy.
""" """
vision_feature_select_strategy = getattr(config, "vision_feature_select_strategy", "default") vision_feature_select_strategy = getattr(
config, "vision_feature_select_strategy", getattr(processor, "vision_feature_select_strategy", "default")
)
return vision_feature_select_strategy return vision_feature_select_strategy
...@@ -186,8 +189,10 @@ def patch_target_modules( ...@@ -186,8 +189,10 @@ def patch_target_modules(
""" """
model_type = getattr(config, "model_type", None) model_type = getattr(config, "model_type", None)
if finetuning_args.freeze_vision_tower: if finetuning_args.freeze_vision_tower:
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]: if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "pixtral", "video_llava"]:
return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules)) return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
elif model_type == "mllama":
return "^(?!.*vision_model).*(?:{}).*".format("|".join(target_modules))
elif model_type == "qwen2_vl": elif model_type == "qwen2_vl":
return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules)) return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules))
else: else:
...@@ -195,5 +200,7 @@ def patch_target_modules( ...@@ -195,5 +200,7 @@ def patch_target_modules(
else: else:
if model_type == "qwen2_vl": if model_type == "qwen2_vl":
return "^(?!.*patch_embed).*(?:{}).*".format("|".join(target_modules)) return "^(?!.*patch_embed).*(?:{}).*".format("|".join(target_modules))
elif model_type == "pixtral":
return "^(?!.*patch_conv).*(?:{}).*".format("|".join(target_modules))
else: else:
return target_modules return target_modules
...@@ -22,7 +22,7 @@ from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_ ...@@ -22,7 +22,7 @@ from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled from transformers.modeling_utils import is_fsdp_enabled
from ..extras.logging import get_logger from ..extras import logging
from ..extras.misc import infer_optim_dtype from ..extras.misc import infer_optim_dtype
from .model_utils.attention import configure_attn_implementation, print_attn_implementation from .model_utils.attention import configure_attn_implementation, print_attn_implementation
from .model_utils.checkpointing import prepare_model_for_training from .model_utils.checkpointing import prepare_model_for_training
...@@ -49,7 +49,7 @@ if TYPE_CHECKING: ...@@ -49,7 +49,7 @@ if TYPE_CHECKING:
from ..hparams import ModelArguments from ..hparams import ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None: def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
...@@ -66,11 +66,11 @@ def patch_processor( ...@@ -66,11 +66,11 @@ def patch_processor(
setattr(processor, "tokenizer", tokenizer) setattr(processor, "tokenizer", tokenizer)
setattr(processor, "image_seqlen", get_image_seqlen(config)) setattr(processor, "image_seqlen", get_image_seqlen(config))
setattr(processor, "image_resolution", model_args.image_resolution) setattr(processor, "image_resolution", model_args.image_resolution)
setattr(processor, "patch_size", get_patch_size(config)) setattr(processor, "patch_size", get_patch_size(config, processor))
setattr(processor, "video_resolution", model_args.video_resolution) setattr(processor, "video_resolution", model_args.video_resolution)
setattr(processor, "video_fps", model_args.video_fps) setattr(processor, "video_fps", model_args.video_fps)
setattr(processor, "video_maxlen", model_args.video_maxlen) setattr(processor, "video_maxlen", model_args.video_maxlen)
setattr(processor, "vision_feature_select_strategy", get_vision_feature_select_strategy(config)) setattr(processor, "vision_feature_select_strategy", get_vision_feature_select_strategy(config, processor))
def patch_config( def patch_config(
...@@ -100,7 +100,7 @@ def patch_config( ...@@ -100,7 +100,7 @@ def patch_config(
if model_args.use_cache and not is_trainable: if model_args.use_cache and not is_trainable:
setattr(config, "use_cache", True) setattr(config, "use_cache", True)
logger.info("Using KV cache for faster generation.") logger.info_rank0("Using KV cache for faster generation.")
if getattr(config, "model_type", None) == "qwen": if getattr(config, "model_type", None) == "qwen":
setattr(config, "use_flash_attn", model_args.flash_attn == "fa2") setattr(config, "use_flash_attn", model_args.flash_attn == "fa2")
...@@ -165,7 +165,7 @@ def patch_model( ...@@ -165,7 +165,7 @@ def patch_model(
try: try:
model.add_model_tags(["llama-factory"]) model.add_model_tags(["llama-factory"])
except Exception: except Exception:
logger.warning("Cannot properly tag the model.") logger.warning_rank0("Cannot properly tag the model.")
def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None: def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import json import json
import logging
import os import os
import signal import signal
import sys import sys
...@@ -34,8 +33,8 @@ from transformers.utils import ( ...@@ -34,8 +33,8 @@ from transformers.utils import (
) )
from typing_extensions import override from typing_extensions import override
from ..extras import logging
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.logging import LoggerHandler, get_logger
from ..extras.misc import get_peak_memory from ..extras.misc import get_peak_memory
...@@ -48,7 +47,7 @@ if TYPE_CHECKING: ...@@ -48,7 +47,7 @@ if TYPE_CHECKING:
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def fix_valuehead_checkpoint( def fix_valuehead_checkpoint(
...@@ -92,7 +91,7 @@ def fix_valuehead_checkpoint( ...@@ -92,7 +91,7 @@ def fix_valuehead_checkpoint(
else: else:
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME)) torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
logger.info("Value head model saved at: {}".format(output_dir)) logger.info_rank0(f"Value head model saved at: {output_dir}")
class FixValueHeadModelCallback(TrainerCallback): class FixValueHeadModelCallback(TrainerCallback):
...@@ -106,7 +105,7 @@ class FixValueHeadModelCallback(TrainerCallback): ...@@ -106,7 +105,7 @@ class FixValueHeadModelCallback(TrainerCallback):
Event called after a checkpoint save. Event called after a checkpoint save.
""" """
if args.should_save: if args.should_save:
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)) output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
fix_valuehead_checkpoint( fix_valuehead_checkpoint(
model=kwargs.pop("model"), output_dir=output_dir, safe_serialization=args.save_safetensors model=kwargs.pop("model"), output_dir=output_dir, safe_serialization=args.save_safetensors
) )
...@@ -123,13 +122,13 @@ class SaveProcessorCallback(TrainerCallback): ...@@ -123,13 +122,13 @@ class SaveProcessorCallback(TrainerCallback):
@override @override
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
if args.should_save: if args.should_save:
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)) output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
getattr(self.processor, "image_processor").save_pretrained(output_dir) self.processor.save_pretrained(output_dir)
@override @override
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
if args.should_save: if args.should_save:
getattr(self.processor, "image_processor").save_pretrained(args.output_dir) self.processor.save_pretrained(args.output_dir)
class PissaConvertCallback(TrainerCallback): class PissaConvertCallback(TrainerCallback):
...@@ -145,7 +144,7 @@ class PissaConvertCallback(TrainerCallback): ...@@ -145,7 +144,7 @@ class PissaConvertCallback(TrainerCallback):
if args.should_save: if args.should_save:
model = kwargs.pop("model") model = kwargs.pop("model")
pissa_init_dir = os.path.join(args.output_dir, "pissa_init") pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
logger.info("Initial PiSSA adapter will be saved at: {}.".format(pissa_init_dir)) logger.info_rank0(f"Initial PiSSA adapter will be saved at: {pissa_init_dir}.")
if isinstance(model, PeftModel): if isinstance(model, PeftModel):
init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights") init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
setattr(model.peft_config["default"], "init_lora_weights", True) setattr(model.peft_config["default"], "init_lora_weights", True)
...@@ -159,7 +158,7 @@ class PissaConvertCallback(TrainerCallback): ...@@ -159,7 +158,7 @@ class PissaConvertCallback(TrainerCallback):
pissa_init_dir = os.path.join(args.output_dir, "pissa_init") pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
pissa_backup_dir = os.path.join(args.output_dir, "pissa_backup") pissa_backup_dir = os.path.join(args.output_dir, "pissa_backup")
pissa_convert_dir = os.path.join(args.output_dir, "pissa_converted") pissa_convert_dir = os.path.join(args.output_dir, "pissa_converted")
logger.info("Converted PiSSA adapter will be saved at: {}.".format(pissa_convert_dir)) logger.info_rank0(f"Converted PiSSA adapter will be saved at: {pissa_convert_dir}.")
# 1. save a pissa backup with init_lora_weights: True # 1. save a pissa backup with init_lora_weights: True
# 2. save a converted lora with init_lora_weights: pissa # 2. save a converted lora with init_lora_weights: pissa
# 3. load the pissa backup with init_lora_weights: True # 3. load the pissa backup with init_lora_weights: True
...@@ -200,8 +199,8 @@ class LogCallback(TrainerCallback): ...@@ -200,8 +199,8 @@ class LogCallback(TrainerCallback):
self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"] self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
if self.webui_mode: if self.webui_mode:
signal.signal(signal.SIGABRT, self._set_abort) signal.signal(signal.SIGABRT, self._set_abort)
self.logger_handler = LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR")) self.logger_handler = logging.LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR"))
logging.root.addHandler(self.logger_handler) logging.add_handler(self.logger_handler)
transformers.logging.add_handler(self.logger_handler) transformers.logging.add_handler(self.logger_handler)
def _set_abort(self, signum, frame) -> None: def _set_abort(self, signum, frame) -> None:
...@@ -243,7 +242,7 @@ class LogCallback(TrainerCallback): ...@@ -243,7 +242,7 @@ class LogCallback(TrainerCallback):
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG)) and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
and args.overwrite_output_dir and args.overwrite_output_dir
): ):
logger.warning("Previous trainer log in this folder will be deleted.") logger.warning_once("Previous trainer log in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, TRAINER_LOG)) os.remove(os.path.join(args.output_dir, TRAINER_LOG))
@override @override
...@@ -288,13 +287,13 @@ class LogCallback(TrainerCallback): ...@@ -288,13 +287,13 @@ class LogCallback(TrainerCallback):
logs = dict( logs = dict(
current_steps=self.cur_steps, current_steps=self.cur_steps,
total_steps=self.max_steps, total_steps=self.max_steps,
loss=state.log_history[-1].get("loss", None), loss=state.log_history[-1].get("loss"),
eval_loss=state.log_history[-1].get("eval_loss", None), eval_loss=state.log_history[-1].get("eval_loss"),
predict_loss=state.log_history[-1].get("predict_loss", None), predict_loss=state.log_history[-1].get("predict_loss"),
reward=state.log_history[-1].get("reward", None), reward=state.log_history[-1].get("reward"),
accuracy=state.log_history[-1].get("rewards/accuracies", None), accuracy=state.log_history[-1].get("rewards/accuracies"),
learning_rate=state.log_history[-1].get("learning_rate", None), lr=state.log_history[-1].get("learning_rate"),
epoch=state.log_history[-1].get("epoch", None), epoch=state.log_history[-1].get("epoch"),
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
elapsed_time=self.elapsed_time, elapsed_time=self.elapsed_time,
remaining_time=self.remaining_time, remaining_time=self.remaining_time,
...@@ -305,16 +304,17 @@ class LogCallback(TrainerCallback): ...@@ -305,16 +304,17 @@ class LogCallback(TrainerCallback):
if os.environ.get("RECORD_VRAM", "0").lower() in ["true", "1"]: if os.environ.get("RECORD_VRAM", "0").lower() in ["true", "1"]:
vram_allocated, vram_reserved = get_peak_memory() vram_allocated, vram_reserved = get_peak_memory()
logs["vram_allocated"] = round(vram_allocated / 1024 / 1024 / 1024, 2) logs["vram_allocated"] = round(vram_allocated / (1024**3), 2)
logs["vram_reserved"] = round(vram_reserved / 1024 / 1024 / 1024, 2) logs["vram_reserved"] = round(vram_reserved / (1024**3), 2)
logs = {k: v for k, v in logs.items() if v is not None} logs = {k: v for k, v in logs.items() if v is not None}
if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]): if self.webui_mode and all(key in logs for key in ("loss", "lr", "epoch")):
logger.info( log_str = f"'loss': {logs['loss']:.4f}, 'learning_rate': {logs['lr']:2.4e}, 'epoch': {logs['epoch']:.2f}"
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}, 'throughput': {}}}".format( for extra_key in ("reward", "accuracy", "throughput"):
logs["loss"], logs["learning_rate"], logs["epoch"], logs.get("throughput", "N/A") if logs.get(extra_key):
) log_str += f", '{extra_key}': {logs[extra_key]:.2f}"
)
logger.info_rank0("{" + log_str + "}")
if self.thread_pool is not None: if self.thread_pool is not None:
self.thread_pool.submit(self._write_log, args.output_dir, logs) self.thread_pool.submit(self._write_log, args.output_dir, logs)
......
...@@ -29,6 +29,7 @@ from trl.trainer import disable_dropout_in_model ...@@ -29,6 +29,7 @@ from trl.trainer import disable_dropout_in_model
from typing_extensions import override from typing_extensions import override
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46
from ..callbacks import PissaConvertCallback, SaveProcessorCallback from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
...@@ -100,7 +101,7 @@ class CustomDPOTrainer(DPOTrainer): ...@@ -100,7 +101,7 @@ class CustomDPOTrainer(DPOTrainer):
self.callback_handler.add_callback(PissaConvertCallback) self.callback_handler.add_callback(PissaConvertCallback)
if finetuning_args.use_badam: if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
...@@ -118,6 +119,13 @@ class CustomDPOTrainer(DPOTrainer): ...@@ -118,6 +119,13 @@ class CustomDPOTrainer(DPOTrainer):
create_custom_scheduler(self.args, num_training_steps, optimizer) create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
@override
def get_batch_samples(self, epoch_iterator, num_batches):
r"""
Replaces the method of KTO Trainer with the one of the standard Trainer.
"""
return Trainer.get_batch_samples(self, epoch_iterator, num_batches)
def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor": def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
r""" r"""
Computes ORPO's odds ratio (OR) loss for batched log probabilities of the policy model. Computes ORPO's odds ratio (OR) loss for batched log probabilities of the policy model.
...@@ -156,7 +164,7 @@ class CustomDPOTrainer(DPOTrainer): ...@@ -156,7 +164,7 @@ class CustomDPOTrainer(DPOTrainer):
elif self.loss_type == "simpo": elif self.loss_type == "simpo":
losses = self.simpo_loss(policy_chosen_logps, policy_rejected_logps) losses = self.simpo_loss(policy_chosen_logps, policy_rejected_logps)
else: else:
raise NotImplementedError("Unknown loss type: {}.".format(self.loss_type)) raise NotImplementedError(f"Unknown loss type: {self.loss_type}.")
chosen_rewards = self.beta * policy_chosen_logps.to(self.accelerator.device).detach() chosen_rewards = self.beta * policy_chosen_logps.to(self.accelerator.device).detach()
rejected_rewards = self.beta * policy_rejected_logps.to(self.accelerator.device).detach() rejected_rewards = self.beta * policy_rejected_logps.to(self.accelerator.device).detach()
...@@ -242,19 +250,59 @@ class CustomDPOTrainer(DPOTrainer): ...@@ -242,19 +250,59 @@ class CustomDPOTrainer(DPOTrainer):
if self.ftx_gamma > 1e-6: if self.ftx_gamma > 1e-6:
losses += self.ftx_gamma * sft_loss losses += self.ftx_gamma * sft_loss
reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else "" prefix = "eval_" if train_eval == "eval" else ""
metrics["{}rewards/chosen".format(prefix)] = chosen_rewards.mean().cpu() metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().item()
metrics["{}rewards/rejected".format(prefix)] = rejected_rewards.mean().cpu() metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().item()
metrics["{}rewards/accuracies".format(prefix)] = reward_accuracies.mean().cpu() metrics[f"{prefix}rewards/accuracies"] = (chosen_rewards > rejected_rewards).float().mean().item()
metrics["{}rewards/margins".format(prefix)] = (chosen_rewards - rejected_rewards).mean().cpu() metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().item()
metrics["{}logps/rejected".format(prefix)] = policy_rejected_logps.detach().mean().cpu() metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.mean().item()
metrics["{}logps/chosen".format(prefix)] = policy_chosen_logps.detach().mean().cpu() metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.mean().item()
metrics["{}logits/rejected".format(prefix)] = policy_rejected_logits.detach().mean().cpu() metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.mean().item()
metrics["{}logits/chosen".format(prefix)] = policy_chosen_logits.detach().mean().cpu() metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.mean().item()
if self.loss_type == "orpo": if self.loss_type == "orpo":
metrics["{}sft_loss".format(prefix)] = sft_loss.detach().mean().cpu() metrics[f"{prefix}sft_loss"] = sft_loss.mean().item()
metrics["{}odds_ratio_loss".format(prefix)] = ((losses - sft_loss) / self.beta).detach().mean().cpu() metrics[f"{prefix}odds_ratio_loss"] = ((losses - sft_loss) / self.beta).mean().item()
return losses.mean(), metrics return losses.mean(), metrics
@override
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
r"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
"""
loss = super().compute_loss(model, inputs, return_outputs)
if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False):
if return_outputs:
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
else:
return loss / self.args.gradient_accumulation_steps
return loss
@override
def log(self, logs: Dict[str, float]) -> None:
r"""
Log `logs` on the various objects watching training, including stored metrics.
"""
# logs either has "loss" or "eval_loss"
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
key_list, metric_list = [], []
for key, metrics in self._stored_metrics[train_eval].items():
key_list.append(key)
metric_list.append(torch.tensor(metrics, dtype=torch.float).to(self.accelerator.device).mean().item())
del self._stored_metrics[train_eval]
if len(metric_list) < 10: # pad to for all reduce
for i in range(10 - len(metric_list)):
key_list.append(f"dummy_{i}")
metric_list.append(0.0)
metric_list = torch.tensor(metric_list, dtype=torch.float).to(self.accelerator.device)
metric_list = self.accelerator.reduce(metric_list, "mean").tolist()
for key, metric in zip(key_list, metric_list): # add remaining items
if not key.startswith("dummy_"):
logs[key] = metric
return Trainer.log(self, logs)
...@@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, List, Optional ...@@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, List, Optional
from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.misc import cal_effective_tokens
from ...extras.ploting import plot_loss from ...extras.ploting import plot_loss
from ...hparams import ModelArguments from ...hparams import ModelArguments
from ...model import load_model, load_tokenizer from ...model import load_model, load_tokenizer
...@@ -64,6 +65,12 @@ def run_dpo( ...@@ -64,6 +65,12 @@ def run_dpo(
# Update arguments # Update arguments
training_args.remove_unused_columns = False # important for multimodal and pairwise dataset training_args.remove_unused_columns = False # important for multimodal and pairwise dataset
effective_token_num = 0.0
if finetuning_args.include_effective_tokens_per_second:
for data in dataset_module["train_dataset"]:
effective_token_num += len(data["chosen_input_ids"])
effective_token_num += len(data["rejected_input_ids"])
# Initialize our Trainer # Initialize our Trainer
trainer = CustomDPOTrainer( trainer = CustomDPOTrainer(
model=model, model=model,
...@@ -79,6 +86,12 @@ def run_dpo( ...@@ -79,6 +86,12 @@ def run_dpo(
# Training # Training
if training_args.do_train: if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
if finetuning_args.include_effective_tokens_per_second:
train_result.metrics["effective_tokens_per_sec"] = cal_effective_tokens(
effective_token_num, train_result.metrics["epoch"], train_result.metrics["train_runtime"]
)
trainer.save_model() trainer.save_model()
trainer.log_metrics("train", train_result.metrics) trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics)
......
...@@ -28,6 +28,7 @@ from trl.trainer import disable_dropout_in_model ...@@ -28,6 +28,7 @@ from trl.trainer import disable_dropout_in_model
from typing_extensions import override from typing_extensions import override
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46
from ..callbacks import SaveProcessorCallback from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
...@@ -95,7 +96,7 @@ class CustomKTOTrainer(KTOTrainer): ...@@ -95,7 +96,7 @@ class CustomKTOTrainer(KTOTrainer):
self.add_callback(SaveProcessorCallback(processor)) self.add_callback(SaveProcessorCallback(processor))
if finetuning_args.use_badam: if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
...@@ -120,20 +121,27 @@ class CustomKTOTrainer(KTOTrainer): ...@@ -120,20 +121,27 @@ class CustomKTOTrainer(KTOTrainer):
""" """
return Trainer._get_train_sampler(self) return Trainer._get_train_sampler(self)
@override
def get_batch_samples(self, epoch_iterator, num_batches):
r"""
Replaces the method of KTO Trainer with the one of the standard Trainer.
"""
return Trainer.get_batch_samples(self, epoch_iterator, num_batches)
@override @override
def forward( def forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = "" self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
) -> Tuple["torch.Tensor", "torch.Tensor"]: ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r""" r"""
Runs forward pass and computes the log probabilities. Runs forward pass and computes the log probabilities.
""" """
batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error
model_inputs = { model_inputs = {
"input_ids": batch["{}input_ids".format(prefix)], "input_ids": batch[f"{prefix}input_ids"],
"attention_mask": batch["{}attention_mask".format(prefix)], "attention_mask": batch[f"{prefix}attention_mask"],
} }
if "{}token_type_ids".format(prefix) in batch: if f"{prefix}token_type_ids" in batch:
model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)] model_inputs["token_type_ids"] = batch[f"{prefix}token_type_ids"]
if "pixel_values" in batch: if "pixel_values" in batch:
model_inputs["pixel_values"] = batch["pixel_values"] model_inputs["pixel_values"] = batch["pixel_values"]
...@@ -142,24 +150,26 @@ class CustomKTOTrainer(KTOTrainer): ...@@ -142,24 +150,26 @@ class CustomKTOTrainer(KTOTrainer):
model_inputs["image_grid_thw"] = batch["image_grid_thw"] model_inputs["image_grid_thw"] = batch["image_grid_thw"]
logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32) logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
logps, valid_length = get_batch_logps(logits=logits, labels=batch["{}labels".format(prefix)]) logps, valid_length = get_batch_logps(logits=logits, labels=batch[f"{prefix}labels"])
return logps, logps / valid_length return logits, logps, logps / valid_length
@override @override
def concatenated_forward( def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
target_logps, target_logps_avg = self.forward(model, batch) target_logits, target_logps, target_logps_avg = self.forward(model, batch)
with torch.no_grad(): with torch.no_grad():
kl_logps, _ = self.forward(model, batch, prefix="kl_") _, kl_logps, _ = self.forward(model, batch, prefix="kl_")
if len(target_logps) != len(batch["kto_tags"]): if len(target_logps) != len(batch["kto_tags"]):
raise ValueError("Mismatched shape of inputs and labels.") raise ValueError("Mismatched shape of inputs and labels.")
chosen_logits = target_logits[batch["kto_tags"]]
chosen_logps = target_logps[batch["kto_tags"]] chosen_logps = target_logps[batch["kto_tags"]]
rejected_logits = target_logits[~batch["kto_tags"]]
rejected_logps = target_logps[~batch["kto_tags"]] rejected_logps = target_logps[~batch["kto_tags"]]
chosen_logps_avg = target_logps_avg[batch["kto_tags"]] chosen_logps_avg = target_logps_avg[batch["kto_tags"]]
return chosen_logps, rejected_logps, kl_logps, chosen_logps_avg return chosen_logps, rejected_logps, chosen_logits, rejected_logits, kl_logps, chosen_logps_avg
@override @override
def compute_reference_log_probs( def compute_reference_log_probs(
...@@ -176,7 +186,7 @@ class CustomKTOTrainer(KTOTrainer): ...@@ -176,7 +186,7 @@ class CustomKTOTrainer(KTOTrainer):
ref_context = nullcontext() ref_context = nullcontext()
with torch.no_grad(), ref_context: with torch.no_grad(), ref_context:
reference_chosen_logps, reference_rejected_logps, reference_kl_logps, _ = self.concatenated_forward( reference_chosen_logps, reference_rejected_logps, _, _, reference_kl_logps, _ = self.concatenated_forward(
ref_model, batch ref_model, batch
) )
...@@ -192,9 +202,14 @@ class CustomKTOTrainer(KTOTrainer): ...@@ -192,9 +202,14 @@ class CustomKTOTrainer(KTOTrainer):
Computes the DPO loss and other metrics for the given batch of inputs for train or test. Computes the DPO loss and other metrics for the given batch of inputs for train or test.
""" """
metrics = {} metrics = {}
policy_chosen_logps, policy_rejected_logps, policy_kl_logps, policy_chosen_logps_avg = ( (
self.concatenated_forward(model, batch) policy_chosen_logps,
) policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_kl_logps,
policy_chosen_logps_avg,
) = self.concatenated_forward(model, batch)
reference_chosen_logps, reference_rejected_logps, reference_kl_logps = self.compute_reference_log_probs( reference_chosen_logps, reference_rejected_logps, reference_kl_logps = self.compute_reference_log_probs(
model, batch model, batch
) )
...@@ -212,22 +227,73 @@ class CustomKTOTrainer(KTOTrainer): ...@@ -212,22 +227,73 @@ class CustomKTOTrainer(KTOTrainer):
sft_loss = -policy_chosen_logps_avg sft_loss = -policy_chosen_logps_avg
losses += self.ftx_gamma * sft_loss.nanmean() / len(policy_chosen_logps) * len(batch["labels"]) losses += self.ftx_gamma * sft_loss.nanmean() / len(policy_chosen_logps) * len(batch["labels"])
num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device) num_chosen = len(chosen_rewards)
num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device) num_rejected = len(rejected_rewards)
if num_chosen > 0:
metrics["rewards/chosen_sum"] = chosen_rewards.nansum().item()
metrics["logps/chosen_sum"] = policy_chosen_logps.nansum().item()
metrics["logits/chosen_sum"] = policy_chosen_logits.nansum().item()
metrics["count/chosen"] = float(num_chosen)
all_num_chosen = self.accelerator.gather(num_chosen).sum().item() if num_rejected > 0:
all_num_rejected = self.accelerator.gather(num_rejected).sum().item() metrics["rewards/rejected_sum"] = rejected_rewards.nansum().item()
metrics["logps/rejected_sum"] = policy_rejected_logps.nansum().item()
metrics["logits/rejected_sum"] = policy_rejected_logits.nansum().item()
metrics["count/rejected"] = float(num_rejected)
if all_num_chosen > 0: metrics["kl"] = kl.item()
metrics["rewards/chosen_sum"] = self.accelerator.gather(chosen_rewards.nansum()).nansum().item() return losses, metrics
metrics["logps/chosen_sum"] = self.accelerator.gather(policy_chosen_logps.nansum()).nansum().item()
metrics["count/chosen"] = all_num_chosen
if all_num_rejected > 0: @override
metrics["rewards/rejected_sum"] = self.accelerator.gather(rejected_rewards.nansum()).nansum().item() def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
metrics["logps/rejected_sum"] = self.accelerator.gather(policy_rejected_logps.nansum()).nansum().item() r"""
metrics["count/rejected"] = all_num_rejected Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
"""
loss = super().compute_loss(model, inputs, return_outputs)
if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False):
if return_outputs:
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
else:
return loss / self.args.gradient_accumulation_steps
metrics["kl"] = kl.item() return loss
return losses, metrics @override
def log(self, logs: Dict[str, float]) -> None:
r"""
Log `logs` on the various objects watching training, including stored metrics.
"""
# logs either has "loss" or "eval_loss"
train_eval = "train" if "loss" in logs else "eval"
prefix = "eval_" if train_eval == "eval" else ""
# Add averaged stored metrics to logs
key_list, metric_list = [], []
for key, metrics in self._stored_metrics[train_eval].items():
key_list.append(key)
metric_list.append(torch.tensor(metrics, dtype=torch.float).to(self.accelerator.device).sum().item())
del self._stored_metrics[train_eval]
if len(metric_list) < 9: # pad to for all reduce
for i in range(9 - len(metric_list)):
key_list.append(f"dummy_{i}")
metric_list.append(0.0)
metric_list = torch.tensor(metric_list, dtype=torch.float).to(self.accelerator.device)
metric_list = self.accelerator.reduce(metric_list, "sum").tolist()
metric_dict: Dict[str, float] = dict(zip(key_list, metric_list))
for split in ["chosen", "rejected"]: # accumulate average metrics from sums and lengths
if f"count/{split}" in metric_dict:
for key in ("rewards", "logps", "logits"):
logs[f"{prefix}{key}/{split}"] = metric_dict[f"{key}/{split}_sum"] / metric_dict[f"count/{split}"]
del metric_dict[f"{key}/{split}_sum"]
del metric_dict[f"count/{split}"]
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs: # calculate reward margin
logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
for key, metric in metric_dict.items(): # add remaining items
if not key.startswith("dummy_"):
logs[key] = metric
return Trainer.log(self, logs)
...@@ -81,7 +81,7 @@ def run_kto( ...@@ -81,7 +81,7 @@ def run_kto(
trainer.save_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics)
trainer.save_state() trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss: if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "train/rewards/chosen"]) plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "rewards/chosen"])
# Evaluation # Evaluation
if training_args.do_eval: if training_args.do_eval:
......
...@@ -62,8 +62,8 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d ...@@ -62,8 +62,8 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d
setattr(model, "default_head_bias", v_head_layer.bias.data.detach().clone()) setattr(model, "default_head_bias", v_head_layer.bias.data.detach().clone())
device = v_head_layer.weight.device device = v_head_layer.weight.device
v_head_layer.weight.data = model.get_buffer("{}_head_weight".format(target)).detach().clone().to(device) v_head_layer.weight.data = model.get_buffer(f"{target}_head_weight").detach().clone().to(device)
v_head_layer.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone().to(device) v_head_layer.bias.data = model.get_buffer(f"{target}_head_bias").detach().clone().to(device)
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, "torch.Tensor"]: def dump_layernorm(model: "PreTrainedModel") -> Dict[str, "torch.Tensor"]:
......
...@@ -37,7 +37,7 @@ from trl.core import PPODecorators, logprobs_from_logits ...@@ -37,7 +37,7 @@ from trl.core import PPODecorators, logprobs_from_logits
from trl.models.utils import unwrap_model_for_generation from trl.models.utils import unwrap_model_for_generation
from typing_extensions import override from typing_extensions import override
from ...extras.logging import get_logger from ...extras import logging
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
...@@ -58,7 +58,7 @@ if TYPE_CHECKING: ...@@ -58,7 +58,7 @@ if TYPE_CHECKING:
from ...hparams import FinetuningArguments, GeneratingArguments, ModelArguments from ...hparams import FinetuningArguments, GeneratingArguments, ModelArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
class CustomPPOTrainer(PPOTrainer, Trainer): class CustomPPOTrainer(PPOTrainer, Trainer):
...@@ -112,7 +112,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): ...@@ -112,7 +112,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
] ]
ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin
if ppo_config.log_with is not None: if ppo_config.log_with is not None:
logger.warning("PPOTrainer cannot use external logger when DeepSpeed is enabled.") logger.warning_rank0("PPOTrainer cannot use external logger when DeepSpeed is enabled.")
ppo_config.log_with = None ppo_config.log_with = None
# Create optimizer and scheduler # Create optimizer and scheduler
...@@ -160,7 +160,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): ...@@ -160,7 +160,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
callbacks, self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler callbacks, self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler
) )
if self.args.max_steps > 0: if self.args.max_steps > 0:
logger.info("max_steps is given, it will override any value given in num_train_epochs") logger.info_rank0("max_steps is given, it will override any value given in num_train_epochs")
self.amp_context = torch.autocast(self.current_device.type) self.amp_context = torch.autocast(self.current_device.type)
warnings.simplefilter("ignore") # remove gc warnings on ref model warnings.simplefilter("ignore") # remove gc warnings on ref model
...@@ -181,7 +181,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): ...@@ -181,7 +181,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.add_callback(SaveProcessorCallback(processor)) self.add_callback(SaveProcessorCallback(processor))
if finetuning_args.use_badam: if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
...@@ -216,20 +216,19 @@ class CustomPPOTrainer(PPOTrainer, Trainer): ...@@ -216,20 +216,19 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.state.is_local_process_zero = self.is_local_process_zero() self.state.is_local_process_zero = self.is_local_process_zero()
self.state.is_world_process_zero = self.is_world_process_zero() self.state.is_world_process_zero = self.is_world_process_zero()
if self.is_world_process_zero(): logger.info_rank0("***** Running training *****")
logger.info("***** Running training *****") logger.info_rank0(f" Num examples = {num_examples:,}")
logger.info(" Num examples = {:,}".format(num_examples)) logger.info_rank0(f" Num Epochs = {num_train_epochs:,}")
logger.info(" Num Epochs = {:,}".format(num_train_epochs)) logger.info_rank0(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
logger.info(" Instantaneous batch size per device = {:,}".format(self.args.per_device_train_batch_size)) logger.info_rank0(
logger.info( " Total train batch size (w. parallel, buffer, distributed & accumulation) = {:,}".format(
" Total train batch size (w. parallel, buffer, distributed & accumulation) = {:,}".format( total_train_batch_size
total_train_batch_size
)
) )
logger.info(" Gradient Accumulation steps = {:,}".format(self.args.gradient_accumulation_steps)) )
logger.info(" Num optimization epochs per batch = {:,}".format(self.finetuning_args.ppo_epochs)) logger.info_rank0(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps:,}")
logger.info(" Total training steps = {:,}".format(max_steps)) logger.info_rank0(f" Num optimization epochs per batch = {self.finetuning_args.ppo_epochs:,}")
logger.info(" Number of trainable parameters = {:,}".format(count_parameters(self.model)[0])) logger.info_rank0(f" Total training steps = {max_steps:,}")
logger.info_rank0(f" Number of trainable parameters = {count_parameters(self.model)[0]:,}")
dataiter = iter(self.dataloader) dataiter = iter(self.dataloader)
loss_meter = AverageMeter() loss_meter = AverageMeter()
...@@ -269,7 +268,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): ...@@ -269,7 +268,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True) batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True)
self.log_stats(stats, batch, rewards) self.log_stats(stats, batch, rewards)
except Exception: except Exception:
logger.warning("Failed to save stats due to unknown errors.") logger.warning_rank0("Failed to save stats due to unknown errors.")
self.state.global_step += 1 self.state.global_step += 1
self.callback_handler.on_step_end(self.args, self.state, self.control) self.callback_handler.on_step_end(self.args, self.state, self.control)
...@@ -290,7 +289,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): ...@@ -290,7 +289,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if (step + 1) % self.args.save_steps == 0: # save checkpoint if (step + 1) % self.args.save_steps == 0: # save checkpoint
self.save_model( self.save_model(
os.path.join(self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step)) os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
) )
self.callback_handler.on_save(self.args, self.state, self.control) self.callback_handler.on_save(self.args, self.state, self.control)
...@@ -498,7 +497,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): ...@@ -498,7 +497,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if self.args.should_save: if self.args.should_save:
self._save(output_dir, state_dict=state_dict) self._save(output_dir, state_dict=state_dict)
except ValueError: except ValueError:
logger.warning( logger.warning_rank0(
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead," " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead,"
" use zero_to_fp32.py to recover weights" " use zero_to_fp32.py to recover weights"
) )
......
...@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Optional ...@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Optional
from transformers import Trainer from transformers import Trainer
from typing_extensions import override from typing_extensions import override
from ...extras.logging import get_logger from ...extras.packages import is_transformers_version_equal_to_4_46
from ..callbacks import PissaConvertCallback, SaveProcessorCallback from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
...@@ -30,9 +30,6 @@ if TYPE_CHECKING: ...@@ -30,9 +30,6 @@ if TYPE_CHECKING:
from ...hparams import FinetuningArguments from ...hparams import FinetuningArguments
logger = get_logger(__name__)
class CustomTrainer(Trainer): class CustomTrainer(Trainer):
r""" r"""
Inherits Trainer for custom optimizer. Inherits Trainer for custom optimizer.
...@@ -51,7 +48,7 @@ class CustomTrainer(Trainer): ...@@ -51,7 +48,7 @@ class CustomTrainer(Trainer):
self.add_callback(PissaConvertCallback) self.add_callback(PissaConvertCallback)
if finetuning_args.use_badam: if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
...@@ -68,3 +65,19 @@ class CustomTrainer(Trainer): ...@@ -68,3 +65,19 @@ class CustomTrainer(Trainer):
) -> "torch.optim.lr_scheduler.LRScheduler": ) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(self.args, num_training_steps, optimizer) create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
@override
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
r"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
"""
loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False):
# other model should not scale the loss
if return_outputs:
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
else:
return loss / self.args.gradient_accumulation_steps
return loss
...@@ -24,7 +24,8 @@ import torch ...@@ -24,7 +24,8 @@ import torch
from transformers import Trainer from transformers import Trainer
from typing_extensions import override from typing_extensions import override
from ...extras.logging import get_logger from ...extras import logging
from ...extras.packages import is_transformers_version_equal_to_4_46
from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
...@@ -36,7 +37,7 @@ if TYPE_CHECKING: ...@@ -36,7 +37,7 @@ if TYPE_CHECKING:
from ...hparams import FinetuningArguments from ...hparams import FinetuningArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
class PairwiseTrainer(Trainer): class PairwiseTrainer(Trainer):
...@@ -59,7 +60,7 @@ class PairwiseTrainer(Trainer): ...@@ -59,7 +60,7 @@ class PairwiseTrainer(Trainer):
self.add_callback(PissaConvertCallback) self.add_callback(PissaConvertCallback)
if finetuning_args.use_badam: if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
...@@ -79,7 +80,7 @@ class PairwiseTrainer(Trainer): ...@@ -79,7 +80,7 @@ class PairwiseTrainer(Trainer):
@override @override
def compute_loss( def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]: ) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r""" r"""
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected. Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
...@@ -98,6 +99,10 @@ class PairwiseTrainer(Trainer): ...@@ -98,6 +99,10 @@ class PairwiseTrainer(Trainer):
chosen_scores, rejected_scores = chosen_scores.squeeze(), rejected_scores.squeeze() chosen_scores, rejected_scores = chosen_scores.squeeze(), rejected_scores.squeeze()
loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean() loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean()
if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False):
loss /= self.args.gradient_accumulation_steps # fixes the loss value for transformers 4.46.0
if return_outputs: if return_outputs:
return loss, (loss, chosen_scores, rejected_scores) return loss, (loss, chosen_scores, rejected_scores)
else: else:
...@@ -113,7 +118,7 @@ class PairwiseTrainer(Trainer): ...@@ -113,7 +118,7 @@ class PairwiseTrainer(Trainer):
return return
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
logger.info(f"Saving prediction results to {output_prediction_file}") logger.info_rank0(f"Saving prediction results to {output_prediction_file}")
chosen_scores, rejected_scores = predict_results.predictions chosen_scores, rejected_scores = predict_results.predictions
with open(output_prediction_file, "w", encoding="utf-8") as writer: with open(output_prediction_file, "w", encoding="utf-8") as writer:
......
...@@ -25,8 +25,9 @@ import torch ...@@ -25,8 +25,9 @@ import torch
from transformers import Seq2SeqTrainer from transformers import Seq2SeqTrainer
from typing_extensions import override from typing_extensions import override
from ...extras import logging
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger from ...extras.packages import is_transformers_version_equal_to_4_46
from ..callbacks import PissaConvertCallback, SaveProcessorCallback from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
...@@ -39,7 +40,7 @@ if TYPE_CHECKING: ...@@ -39,7 +40,7 @@ if TYPE_CHECKING:
from ...hparams import FinetuningArguments from ...hparams import FinetuningArguments
logger = get_logger(__name__) logger = logging.get_logger(__name__)
class CustomSeq2SeqTrainer(Seq2SeqTrainer): class CustomSeq2SeqTrainer(Seq2SeqTrainer):
...@@ -60,7 +61,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): ...@@ -60,7 +61,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
self.add_callback(PissaConvertCallback) self.add_callback(PissaConvertCallback)
if finetuning_args.use_badam: if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
...@@ -78,6 +79,22 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): ...@@ -78,6 +79,22 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
create_custom_scheduler(self.args, num_training_steps, optimizer) create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
@override
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
r"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
"""
loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False):
# other model should not scale the loss
if return_outputs:
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
else:
return loss / self.args.gradient_accumulation_steps
return loss
@override @override
def prediction_step( def prediction_step(
self, self,
...@@ -129,7 +146,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): ...@@ -129,7 +146,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
return return
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
logger.info(f"Saving prediction results to {output_prediction_file}") logger.info_rank0(f"Saving prediction results to {output_prediction_file}")
labels = np.where( labels = np.where(
predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment