Commit 2778a3d0 authored by luopl's avatar luopl
Browse files

updata to v0.9.1_stable

parent e92143e3
......@@ -15,7 +15,7 @@
import inspect
from typing import TYPE_CHECKING
from ...extras.logging import get_logger
from ...extras import logging
if TYPE_CHECKING:
......@@ -24,7 +24,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def apply_liger_kernel(
......@@ -54,14 +54,14 @@ def apply_liger_kernel(
elif model_type == "qwen2_vl":
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel
else:
logger.warning("Current model does not support liger kernel.")
logger.warning_rank0("Current model does not support liger kernel.")
return
if require_logits and "fused_linear_cross_entropy" in inspect.signature(apply_liger_kernel).parameters:
logger.info("Current training stage does not support chunked cross entropy.")
logger.info_rank0("Current training stage does not support chunked cross entropy.")
kwargs = {"fused_linear_cross_entropy": False}
else:
kwargs = {}
apply_liger_kernel(**kwargs)
logger.info("Liger kernel has been applied to the model.")
logger.info_rank0("Liger kernel has been applied to the model.")
......@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Optional, Tuple
import torch
import torch.nn as nn
import transformers
from transformers.models.llama.modeling_llama import (
Cache,
LlamaAttention,
......@@ -30,12 +31,11 @@ from transformers.models.llama.modeling_llama import (
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.utils import logging
from transformers.utils.versions import require_version
from ...extras import logging
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
from ...extras.logging import get_logger
from ...extras.packages import is_transformers_version_greater_than_4_43
from ...extras.packages import is_transformers_version_greater_than
if TYPE_CHECKING:
......@@ -44,7 +44,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments
transformers_logger = logging.get_logger(__name__)
transformers_logger = transformers.utils.logging.get_logger(__name__)
# Modified from:
......@@ -86,7 +86,7 @@ def llama_attention_forward(
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
assert q_len % groupsz == 0, f"q_len {q_len} should be divisible by group size {groupsz}."
num_groups = q_len // groupsz
def shift(state: "torch.Tensor") -> "torch.Tensor":
......@@ -195,7 +195,7 @@ def llama_flash_attention_2_forward(
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
assert q_len % groupsz == 0, f"q_len {q_len} should be divisible by group size {groupsz}."
num_groups = q_len // groupsz
def shift(state: "torch.Tensor") -> "torch.Tensor":
......@@ -209,7 +209,7 @@ def llama_flash_attention_2_forward(
if attention_mask is not None:
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
if is_transformers_version_greater_than_4_43():
if is_transformers_version_greater_than("4.43.0"):
from transformers.modeling_flash_attention_utils import _flash_attention_forward
attn_output: "torch.Tensor" = _flash_attention_forward(
......@@ -301,7 +301,7 @@ def llama_sdpa_attention_forward(
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
assert q_len % groupsz == 0, f"q_len {q_len} should be divisible by group size {groupsz}."
num_groups = q_len // groupsz
def shift(state: "torch.Tensor") -> "torch.Tensor":
......@@ -353,7 +353,7 @@ def llama_sdpa_attention_forward(
def _apply_llama_patch() -> None:
require_version("transformers>=4.41.2,<=4.45.2", "To fix: pip install transformers>=4.41.2,<=4.45.2")
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
......@@ -363,11 +363,11 @@ def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments",
if not is_trainable or not model_args.shift_attn:
return
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
setattr(config, "group_size_ratio", 0.25)
_apply_llama_patch()
logger.info("Using shift short attention with group_size_ratio=1/4.")
logger.info_rank0("Using shift short attention with group_size_ratio=1/4.")
else:
logger.warning("Current model does not support shift short attention.")
logger.warning_rank0("Current model does not support shift short attention.")
......@@ -14,14 +14,14 @@
from typing import TYPE_CHECKING, List
from ...extras.logging import get_logger
from ...extras import logging
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]:
......@@ -34,13 +34,15 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
forbidden_modules.add("output_layer")
elif model_type == "internlm2":
forbidden_modules.add("output")
elif model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]:
elif model_type in ["llava", "llava_next", "llava_next_video", "mllama", "paligemma", "video_llava"]:
forbidden_modules.add("multi_modal_projector")
elif model_type == "qwen2_vl":
forbidden_modules.add("merger")
if freeze_vision_tower:
if model_type == "qwen2_vl":
if model_type == "mllama":
forbidden_modules.add("vision_model")
elif model_type == "qwen2_vl":
forbidden_modules.add("visual")
else:
forbidden_modules.add("vision_tower")
......@@ -53,7 +55,7 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__:
module_names.add(name.split(".")[-1])
logger.info("Found linear modules: {}".format(",".join(module_names)))
logger.info_rank0("Found linear modules: {}".format(",".join(module_names)))
return list(module_names)
......@@ -67,12 +69,12 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
if num_layers % num_layer_trainable != 0:
raise ValueError(
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(num_layers, num_layer_trainable)
f"`num_layers` {num_layers} should be divisible by `num_layer_trainable` {num_layer_trainable}."
)
stride = num_layers // num_layer_trainable
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
trainable_layers = [".{:d}.".format(idx) for idx in trainable_layer_ids]
trainable_layers = [f".{idx:d}." for idx in trainable_layer_ids]
module_names = []
for name, _ in model.named_modules():
if any(target_module in name for target_module in target_modules) and any(
......@@ -80,7 +82,7 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
):
module_names.append(name)
logger.info("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids))))
logger.info_rank0("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids))))
return module_names
......
......@@ -43,9 +43,9 @@ import torch
import torch.nn.functional as F
from transformers.utils.versions import require_version
from ...extras import logging
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
from ...extras.logging import get_logger
from ...extras.packages import is_transformers_version_greater_than_4_43
from ...extras.packages import is_transformers_version_greater_than
if TYPE_CHECKING:
......@@ -54,7 +54,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
......@@ -114,8 +114,8 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
def _patch_for_block_diag_attn(model_type: str) -> None:
require_version("transformers>=4.41.2,<=4.45.2", "To fix: pip install transformers>=4.41.2,<=4.45.2")
if is_transformers_version_greater_than_4_43():
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
if is_transformers_version_greater_than("4.43.0"):
import transformers.modeling_flash_attention_utils
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
......@@ -152,6 +152,6 @@ def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments",
model_type = getattr(config, "model_type", None)
if model_type in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN:
_patch_for_block_diag_attn(model_type)
logger.info("Using block diagonal attention for sequence packing without cross-attention.")
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")
else:
raise ValueError("Current model does not support block diagonal attention.")
......@@ -28,8 +28,8 @@ from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from transformers.utils.versions import require_version
from ...extras import logging
from ...extras.constants import FILEEXT2TYPE
from ...extras.logging import get_logger
from ...extras.misc import get_current_device
......@@ -39,7 +39,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
@unique
......@@ -109,7 +109,7 @@ def configure_quantization(
"""
if getattr(config, "quantization_config", None): # ptq
if model_args.quantization_bit is not None:
logger.warning("`quantization_bit` will not affect on the PTQ-quantized models.")
logger.warning_rank0("`quantization_bit` will not affect on the PTQ-quantized models.")
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")
......@@ -130,7 +130,7 @@ def configure_quantization(
quantization_config["bits"] = 2
quant_bits = quantization_config.get("bits", "?")
logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper()))
logger.info_rank0(f"Loading {quant_bits}-bit {quant_method.upper()}-quantized model.")
elif model_args.export_quantization_bit is not None: # auto-gptq
if model_args.export_quantization_bit not in [8, 4, 3, 2]:
......@@ -149,7 +149,7 @@ def configure_quantization(
)
init_kwargs["device_map"] = "auto"
init_kwargs["max_memory"] = get_max_memory()
logger.info("Quantizing model to {} bit with AutoGPTQ.".format(model_args.export_quantization_bit))
logger.info_rank0(f"Quantizing model to {model_args.export_quantization_bit} bit with AutoGPTQ.")
elif model_args.quantization_bit is not None: # on-the-fly
if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
......@@ -179,7 +179,7 @@ def configure_quantization(
else:
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
logger.info("Quantizing model to {} bit with bitsandbytes.".format(model_args.quantization_bit))
logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with bitsandbytes.")
elif model_args.quantization_method == QuantizationMethod.HQQ.value:
if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]:
raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.")
......@@ -191,7 +191,7 @@ def configure_quantization(
init_kwargs["quantization_config"] = HqqConfig(
nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0
) # use ATEN kernel (axis=0) for performance
logger.info("Quantizing model to {} bit with HQQ.".format(model_args.quantization_bit))
logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with HQQ.")
elif model_args.quantization_method == QuantizationMethod.EETQ.value:
if model_args.quantization_bit != 8:
raise ValueError("EETQ only accepts 8-bit quantization.")
......@@ -201,4 +201,4 @@ def configure_quantization(
require_version("eetq", "To fix: pip install eetq")
init_kwargs["quantization_config"] = EetqConfig()
logger.info("Quantizing model to {} bit with EETQ.".format(model_args.quantization_bit))
logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with EETQ.")
......@@ -19,7 +19,7 @@
import math
from typing import TYPE_CHECKING
from ...extras.logging import get_logger
from ...extras import logging
if TYPE_CHECKING:
......@@ -28,7 +28,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
......@@ -36,30 +36,28 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
return
if not hasattr(config, "rope_scaling"):
logger.warning("Current model does not support RoPE scaling.")
logger.warning_rank0("Current model does not support RoPE scaling.")
return
if model_args.model_max_length is not None:
if is_trainable and model_args.rope_scaling == "dynamic":
logger.warning(
logger.warning_rank0(
"Dynamic NTK scaling may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653"
)
current_max_length = getattr(config, "max_position_embeddings", None)
if current_max_length and model_args.model_max_length > current_max_length:
logger.info(
"Enlarge max model length from {} to {}.".format(current_max_length, model_args.model_max_length)
)
logger.info_rank0(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.")
setattr(config, "max_position_embeddings", model_args.model_max_length)
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
else:
logger.warning("Input length is smaller than max length. Consider increase input length.")
logger.warning_rank0("Input length is smaller than max length. Consider increase input length.")
scaling_factor = 1.0
else:
scaling_factor = 2.0
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
logger.info(
"Using {} scaling strategy and setting scaling factor to {}".format(model_args.rope_scaling, scaling_factor)
logger.info_rank0(
f"Using {model_args.rope_scaling} scaling strategy and setting scaling factor to {scaling_factor}"
)
......@@ -14,7 +14,7 @@
from typing import TYPE_CHECKING, Any, Dict, Optional
from ...extras.logging import get_logger
from ...extras import logging
from ...extras.misc import get_current_device
......@@ -24,7 +24,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def _get_unsloth_kwargs(
......@@ -56,7 +56,7 @@ def load_unsloth_pretrained_model(
try:
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
except NotImplementedError:
logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
logger.warning_rank0("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
model = None
model_args.use_unsloth = False
......
......@@ -17,8 +17,8 @@ from typing import TYPE_CHECKING, Dict
import torch
from transformers.utils import cached_file
from ...extras import logging
from ...extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ...extras.logging import get_logger
if TYPE_CHECKING:
......@@ -27,7 +27,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
......@@ -54,8 +54,8 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") ->
except Exception as err:
err_text = str(err)
logger.info("Provided path ({}) does not contain value head weights: {}.".format(path_or_repo_id, err_text))
logger.info("Ignore the above message if you are not resuming the training of a value head model.")
logger.info_rank0(f"Provided path ({path_or_repo_id}) does not contain value head weights: {err_text}.")
logger.info_rank0("Ignore the above message if you are not resuming the training of a value head model.")
return None
......
......@@ -18,21 +18,21 @@
from typing import TYPE_CHECKING, List, Sequence, Set, Tuple, Union
import torch
import transformers
import transformers.models
from transformers.activations import ACT2FN
from transformers.utils import logging
from ...extras.logging import get_logger
from ...extras import logging
if TYPE_CHECKING:
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel, ProcessorMixin
from ...hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__)
transformers_logger = logging.get_logger(__name__)
logger = logging.get_logger(__name__)
transformers_logger = transformers.utils.logging.get_logger(__name__)
class LlavaMultiModalProjectorForYiVL(torch.nn.Module):
......@@ -92,14 +92,14 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
if getattr(model, "quantization_method", None):
model_type = getattr(model.config, "model_type", None)
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]:
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "pixtral", "video_llava"]:
mm_projector: "torch.nn.Module" = getattr(model, "multi_modal_projector")
elif model_type == "qwen2_vl":
mm_projector: "torch.nn.Module" = getattr(getattr(model, "visual"), "merger")
else:
return
logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype))
logger.info_rank0(f"Casting multimodal projector outputs in {model_args.compute_dtype}.")
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
......@@ -113,12 +113,13 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
"llava_next",
"llava_next_video",
"paligemma",
"pixtral",
"video_llava",
]: # required for ds zero3 and valuehead models
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
if getattr(config, "is_yi_vl_derived_model", None):
logger.info("Detected Yi-VL model, applying projector patch.")
logger.info_rank0("Detected Yi-VL model, applying projector patch.")
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
......@@ -128,7 +129,7 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
"""
model_type = getattr(config, "model_type", None)
forbidden_modules = set()
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]:
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "pixtral", "video_llava"]:
if finetuning_args.freeze_vision_tower:
forbidden_modules.add("vision_tower")
......@@ -162,19 +163,21 @@ def get_image_seqlen(config: "PretrainedConfig") -> int:
return image_seqlen
def get_patch_size(config: "PretrainedConfig") -> int:
def get_patch_size(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
r"""
Computes the patch size of the vit.
"""
patch_size = getattr(config.vision_config, "patch_size", -1)
patch_size = getattr(config.vision_config, "patch_size", getattr(processor, "patch_size", -1))
return patch_size
def get_vision_feature_select_strategy(config: "PretrainedConfig") -> int:
def get_vision_feature_select_strategy(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
r"""
Get the vision_feature_select_strategy.
"""
vision_feature_select_strategy = getattr(config, "vision_feature_select_strategy", "default")
vision_feature_select_strategy = getattr(
config, "vision_feature_select_strategy", getattr(processor, "vision_feature_select_strategy", "default")
)
return vision_feature_select_strategy
......@@ -186,8 +189,10 @@ def patch_target_modules(
"""
model_type = getattr(config, "model_type", None)
if finetuning_args.freeze_vision_tower:
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]:
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "pixtral", "video_llava"]:
return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
elif model_type == "mllama":
return "^(?!.*vision_model).*(?:{}).*".format("|".join(target_modules))
elif model_type == "qwen2_vl":
return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules))
else:
......@@ -195,5 +200,7 @@ def patch_target_modules(
else:
if model_type == "qwen2_vl":
return "^(?!.*patch_embed).*(?:{}).*".format("|".join(target_modules))
elif model_type == "pixtral":
return "^(?!.*patch_conv).*(?:{}).*".format("|".join(target_modules))
else:
return target_modules
......@@ -22,7 +22,7 @@ from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from ..extras.logging import get_logger
from ..extras import logging
from ..extras.misc import infer_optim_dtype
from .model_utils.attention import configure_attn_implementation, print_attn_implementation
from .model_utils.checkpointing import prepare_model_for_training
......@@ -49,7 +49,7 @@ if TYPE_CHECKING:
from ..hparams import ModelArguments
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
......@@ -66,11 +66,11 @@ def patch_processor(
setattr(processor, "tokenizer", tokenizer)
setattr(processor, "image_seqlen", get_image_seqlen(config))
setattr(processor, "image_resolution", model_args.image_resolution)
setattr(processor, "patch_size", get_patch_size(config))
setattr(processor, "patch_size", get_patch_size(config, processor))
setattr(processor, "video_resolution", model_args.video_resolution)
setattr(processor, "video_fps", model_args.video_fps)
setattr(processor, "video_maxlen", model_args.video_maxlen)
setattr(processor, "vision_feature_select_strategy", get_vision_feature_select_strategy(config))
setattr(processor, "vision_feature_select_strategy", get_vision_feature_select_strategy(config, processor))
def patch_config(
......@@ -100,7 +100,7 @@ def patch_config(
if model_args.use_cache and not is_trainable:
setattr(config, "use_cache", True)
logger.info("Using KV cache for faster generation.")
logger.info_rank0("Using KV cache for faster generation.")
if getattr(config, "model_type", None) == "qwen":
setattr(config, "use_flash_attn", model_args.flash_attn == "fa2")
......@@ -165,7 +165,7 @@ def patch_model(
try:
model.add_model_tags(["llama-factory"])
except Exception:
logger.warning("Cannot properly tag the model.")
logger.warning_rank0("Cannot properly tag the model.")
def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
......
......@@ -13,7 +13,6 @@
# limitations under the License.
import json
import logging
import os
import signal
import sys
......@@ -34,8 +33,8 @@ from transformers.utils import (
)
from typing_extensions import override
from ..extras import logging
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.logging import LoggerHandler, get_logger
from ..extras.misc import get_peak_memory
......@@ -48,7 +47,7 @@ if TYPE_CHECKING:
from trl import AutoModelForCausalLMWithValueHead
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def fix_valuehead_checkpoint(
......@@ -92,7 +91,7 @@ def fix_valuehead_checkpoint(
else:
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
logger.info("Value head model saved at: {}".format(output_dir))
logger.info_rank0(f"Value head model saved at: {output_dir}")
class FixValueHeadModelCallback(TrainerCallback):
......@@ -106,7 +105,7 @@ class FixValueHeadModelCallback(TrainerCallback):
Event called after a checkpoint save.
"""
if args.should_save:
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
fix_valuehead_checkpoint(
model=kwargs.pop("model"), output_dir=output_dir, safe_serialization=args.save_safetensors
)
......@@ -123,13 +122,13 @@ class SaveProcessorCallback(TrainerCallback):
@override
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
if args.should_save:
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
getattr(self.processor, "image_processor").save_pretrained(output_dir)
output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
self.processor.save_pretrained(output_dir)
@override
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
if args.should_save:
getattr(self.processor, "image_processor").save_pretrained(args.output_dir)
self.processor.save_pretrained(args.output_dir)
class PissaConvertCallback(TrainerCallback):
......@@ -145,7 +144,7 @@ class PissaConvertCallback(TrainerCallback):
if args.should_save:
model = kwargs.pop("model")
pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
logger.info("Initial PiSSA adapter will be saved at: {}.".format(pissa_init_dir))
logger.info_rank0(f"Initial PiSSA adapter will be saved at: {pissa_init_dir}.")
if isinstance(model, PeftModel):
init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
setattr(model.peft_config["default"], "init_lora_weights", True)
......@@ -159,7 +158,7 @@ class PissaConvertCallback(TrainerCallback):
pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
pissa_backup_dir = os.path.join(args.output_dir, "pissa_backup")
pissa_convert_dir = os.path.join(args.output_dir, "pissa_converted")
logger.info("Converted PiSSA adapter will be saved at: {}.".format(pissa_convert_dir))
logger.info_rank0(f"Converted PiSSA adapter will be saved at: {pissa_convert_dir}.")
# 1. save a pissa backup with init_lora_weights: True
# 2. save a converted lora with init_lora_weights: pissa
# 3. load the pissa backup with init_lora_weights: True
......@@ -200,8 +199,8 @@ class LogCallback(TrainerCallback):
self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
if self.webui_mode:
signal.signal(signal.SIGABRT, self._set_abort)
self.logger_handler = LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR"))
logging.root.addHandler(self.logger_handler)
self.logger_handler = logging.LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR"))
logging.add_handler(self.logger_handler)
transformers.logging.add_handler(self.logger_handler)
def _set_abort(self, signum, frame) -> None:
......@@ -243,7 +242,7 @@ class LogCallback(TrainerCallback):
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
and args.overwrite_output_dir
):
logger.warning("Previous trainer log in this folder will be deleted.")
logger.warning_once("Previous trainer log in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, TRAINER_LOG))
@override
......@@ -288,13 +287,13 @@ class LogCallback(TrainerCallback):
logs = dict(
current_steps=self.cur_steps,
total_steps=self.max_steps,
loss=state.log_history[-1].get("loss", None),
eval_loss=state.log_history[-1].get("eval_loss", None),
predict_loss=state.log_history[-1].get("predict_loss", None),
reward=state.log_history[-1].get("reward", None),
accuracy=state.log_history[-1].get("rewards/accuracies", None),
learning_rate=state.log_history[-1].get("learning_rate", None),
epoch=state.log_history[-1].get("epoch", None),
loss=state.log_history[-1].get("loss"),
eval_loss=state.log_history[-1].get("eval_loss"),
predict_loss=state.log_history[-1].get("predict_loss"),
reward=state.log_history[-1].get("reward"),
accuracy=state.log_history[-1].get("rewards/accuracies"),
lr=state.log_history[-1].get("learning_rate"),
epoch=state.log_history[-1].get("epoch"),
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
elapsed_time=self.elapsed_time,
remaining_time=self.remaining_time,
......@@ -305,16 +304,17 @@ class LogCallback(TrainerCallback):
if os.environ.get("RECORD_VRAM", "0").lower() in ["true", "1"]:
vram_allocated, vram_reserved = get_peak_memory()
logs["vram_allocated"] = round(vram_allocated / 1024 / 1024 / 1024, 2)
logs["vram_reserved"] = round(vram_reserved / 1024 / 1024 / 1024, 2)
logs["vram_allocated"] = round(vram_allocated / (1024**3), 2)
logs["vram_reserved"] = round(vram_reserved / (1024**3), 2)
logs = {k: v for k, v in logs.items() if v is not None}
if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]):
logger.info(
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}, 'throughput': {}}}".format(
logs["loss"], logs["learning_rate"], logs["epoch"], logs.get("throughput", "N/A")
)
)
if self.webui_mode and all(key in logs for key in ("loss", "lr", "epoch")):
log_str = f"'loss': {logs['loss']:.4f}, 'learning_rate': {logs['lr']:2.4e}, 'epoch': {logs['epoch']:.2f}"
for extra_key in ("reward", "accuracy", "throughput"):
if logs.get(extra_key):
log_str += f", '{extra_key}': {logs[extra_key]:.2f}"
logger.info_rank0("{" + log_str + "}")
if self.thread_pool is not None:
self.thread_pool.submit(self._write_log, args.output_dir, logs)
......
......@@ -29,6 +29,7 @@ from trl.trainer import disable_dropout_in_model
from typing_extensions import override
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
......@@ -100,7 +101,7 @@ class CustomDPOTrainer(DPOTrainer):
self.callback_handler.add_callback(PissaConvertCallback)
if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)
......@@ -118,6 +119,13 @@ class CustomDPOTrainer(DPOTrainer):
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
@override
def get_batch_samples(self, epoch_iterator, num_batches):
r"""
Replaces the method of KTO Trainer with the one of the standard Trainer.
"""
return Trainer.get_batch_samples(self, epoch_iterator, num_batches)
def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
r"""
Computes ORPO's odds ratio (OR) loss for batched log probabilities of the policy model.
......@@ -156,7 +164,7 @@ class CustomDPOTrainer(DPOTrainer):
elif self.loss_type == "simpo":
losses = self.simpo_loss(policy_chosen_logps, policy_rejected_logps)
else:
raise NotImplementedError("Unknown loss type: {}.".format(self.loss_type))
raise NotImplementedError(f"Unknown loss type: {self.loss_type}.")
chosen_rewards = self.beta * policy_chosen_logps.to(self.accelerator.device).detach()
rejected_rewards = self.beta * policy_rejected_logps.to(self.accelerator.device).detach()
......@@ -242,19 +250,59 @@ class CustomDPOTrainer(DPOTrainer):
if self.ftx_gamma > 1e-6:
losses += self.ftx_gamma * sft_loss
reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else ""
metrics["{}rewards/chosen".format(prefix)] = chosen_rewards.mean().cpu()
metrics["{}rewards/rejected".format(prefix)] = rejected_rewards.mean().cpu()
metrics["{}rewards/accuracies".format(prefix)] = reward_accuracies.mean().cpu()
metrics["{}rewards/margins".format(prefix)] = (chosen_rewards - rejected_rewards).mean().cpu()
metrics["{}logps/rejected".format(prefix)] = policy_rejected_logps.detach().mean().cpu()
metrics["{}logps/chosen".format(prefix)] = policy_chosen_logps.detach().mean().cpu()
metrics["{}logits/rejected".format(prefix)] = policy_rejected_logits.detach().mean().cpu()
metrics["{}logits/chosen".format(prefix)] = policy_chosen_logits.detach().mean().cpu()
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().item()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().item()
metrics[f"{prefix}rewards/accuracies"] = (chosen_rewards > rejected_rewards).float().mean().item()
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().item()
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.mean().item()
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.mean().item()
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.mean().item()
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.mean().item()
if self.loss_type == "orpo":
metrics["{}sft_loss".format(prefix)] = sft_loss.detach().mean().cpu()
metrics["{}odds_ratio_loss".format(prefix)] = ((losses - sft_loss) / self.beta).detach().mean().cpu()
metrics[f"{prefix}sft_loss"] = sft_loss.mean().item()
metrics[f"{prefix}odds_ratio_loss"] = ((losses - sft_loss) / self.beta).mean().item()
return losses.mean(), metrics
@override
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
r"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
"""
loss = super().compute_loss(model, inputs, return_outputs)
if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False):
if return_outputs:
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
else:
return loss / self.args.gradient_accumulation_steps
return loss
@override
def log(self, logs: Dict[str, float]) -> None:
r"""
Log `logs` on the various objects watching training, including stored metrics.
"""
# logs either has "loss" or "eval_loss"
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
key_list, metric_list = [], []
for key, metrics in self._stored_metrics[train_eval].items():
key_list.append(key)
metric_list.append(torch.tensor(metrics, dtype=torch.float).to(self.accelerator.device).mean().item())
del self._stored_metrics[train_eval]
if len(metric_list) < 10: # pad to for all reduce
for i in range(10 - len(metric_list)):
key_list.append(f"dummy_{i}")
metric_list.append(0.0)
metric_list = torch.tensor(metric_list, dtype=torch.float).to(self.accelerator.device)
metric_list = self.accelerator.reduce(metric_list, "mean").tolist()
for key, metric in zip(key_list, metric_list): # add remaining items
if not key.startswith("dummy_"):
logs[key] = metric
return Trainer.log(self, logs)
......@@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, List, Optional
from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX
from ...extras.misc import cal_effective_tokens
from ...extras.ploting import plot_loss
from ...hparams import ModelArguments
from ...model import load_model, load_tokenizer
......@@ -64,6 +65,12 @@ def run_dpo(
# Update arguments
training_args.remove_unused_columns = False # important for multimodal and pairwise dataset
effective_token_num = 0.0
if finetuning_args.include_effective_tokens_per_second:
for data in dataset_module["train_dataset"]:
effective_token_num += len(data["chosen_input_ids"])
effective_token_num += len(data["rejected_input_ids"])
# Initialize our Trainer
trainer = CustomDPOTrainer(
model=model,
......@@ -79,6 +86,12 @@ def run_dpo(
# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
if finetuning_args.include_effective_tokens_per_second:
train_result.metrics["effective_tokens_per_sec"] = cal_effective_tokens(
effective_token_num, train_result.metrics["epoch"], train_result.metrics["train_runtime"]
)
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
......
......@@ -28,6 +28,7 @@ from trl.trainer import disable_dropout_in_model
from typing_extensions import override
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46
from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
......@@ -95,7 +96,7 @@ class CustomKTOTrainer(KTOTrainer):
self.add_callback(SaveProcessorCallback(processor))
if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)
......@@ -120,20 +121,27 @@ class CustomKTOTrainer(KTOTrainer):
"""
return Trainer._get_train_sampler(self)
@override
def get_batch_samples(self, epoch_iterator, num_batches):
r"""
Replaces the method of KTO Trainer with the one of the standard Trainer.
"""
return Trainer.get_batch_samples(self, epoch_iterator, num_batches)
@override
def forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
) -> Tuple["torch.Tensor", "torch.Tensor"]:
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""
Runs forward pass and computes the log probabilities.
"""
batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error
model_inputs = {
"input_ids": batch["{}input_ids".format(prefix)],
"attention_mask": batch["{}attention_mask".format(prefix)],
"input_ids": batch[f"{prefix}input_ids"],
"attention_mask": batch[f"{prefix}attention_mask"],
}
if "{}token_type_ids".format(prefix) in batch:
model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)]
if f"{prefix}token_type_ids" in batch:
model_inputs["token_type_ids"] = batch[f"{prefix}token_type_ids"]
if "pixel_values" in batch:
model_inputs["pixel_values"] = batch["pixel_values"]
......@@ -142,24 +150,26 @@ class CustomKTOTrainer(KTOTrainer):
model_inputs["image_grid_thw"] = batch["image_grid_thw"]
logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
logps, valid_length = get_batch_logps(logits=logits, labels=batch["{}labels".format(prefix)])
return logps, logps / valid_length
logps, valid_length = get_batch_logps(logits=logits, labels=batch[f"{prefix}labels"])
return logits, logps, logps / valid_length
@override
def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
target_logps, target_logps_avg = self.forward(model, batch)
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
target_logits, target_logps, target_logps_avg = self.forward(model, batch)
with torch.no_grad():
kl_logps, _ = self.forward(model, batch, prefix="kl_")
_, kl_logps, _ = self.forward(model, batch, prefix="kl_")
if len(target_logps) != len(batch["kto_tags"]):
raise ValueError("Mismatched shape of inputs and labels.")
chosen_logits = target_logits[batch["kto_tags"]]
chosen_logps = target_logps[batch["kto_tags"]]
rejected_logits = target_logits[~batch["kto_tags"]]
rejected_logps = target_logps[~batch["kto_tags"]]
chosen_logps_avg = target_logps_avg[batch["kto_tags"]]
return chosen_logps, rejected_logps, kl_logps, chosen_logps_avg
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, kl_logps, chosen_logps_avg
@override
def compute_reference_log_probs(
......@@ -176,7 +186,7 @@ class CustomKTOTrainer(KTOTrainer):
ref_context = nullcontext()
with torch.no_grad(), ref_context:
reference_chosen_logps, reference_rejected_logps, reference_kl_logps, _ = self.concatenated_forward(
reference_chosen_logps, reference_rejected_logps, _, _, reference_kl_logps, _ = self.concatenated_forward(
ref_model, batch
)
......@@ -192,9 +202,14 @@ class CustomKTOTrainer(KTOTrainer):
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
"""
metrics = {}
policy_chosen_logps, policy_rejected_logps, policy_kl_logps, policy_chosen_logps_avg = (
self.concatenated_forward(model, batch)
)
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_kl_logps,
policy_chosen_logps_avg,
) = self.concatenated_forward(model, batch)
reference_chosen_logps, reference_rejected_logps, reference_kl_logps = self.compute_reference_log_probs(
model, batch
)
......@@ -212,22 +227,73 @@ class CustomKTOTrainer(KTOTrainer):
sft_loss = -policy_chosen_logps_avg
losses += self.ftx_gamma * sft_loss.nanmean() / len(policy_chosen_logps) * len(batch["labels"])
num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)
num_chosen = len(chosen_rewards)
num_rejected = len(rejected_rewards)
if num_chosen > 0:
metrics["rewards/chosen_sum"] = chosen_rewards.nansum().item()
metrics["logps/chosen_sum"] = policy_chosen_logps.nansum().item()
metrics["logits/chosen_sum"] = policy_chosen_logits.nansum().item()
metrics["count/chosen"] = float(num_chosen)
all_num_chosen = self.accelerator.gather(num_chosen).sum().item()
all_num_rejected = self.accelerator.gather(num_rejected).sum().item()
if num_rejected > 0:
metrics["rewards/rejected_sum"] = rejected_rewards.nansum().item()
metrics["logps/rejected_sum"] = policy_rejected_logps.nansum().item()
metrics["logits/rejected_sum"] = policy_rejected_logits.nansum().item()
metrics["count/rejected"] = float(num_rejected)
if all_num_chosen > 0:
metrics["rewards/chosen_sum"] = self.accelerator.gather(chosen_rewards.nansum()).nansum().item()
metrics["logps/chosen_sum"] = self.accelerator.gather(policy_chosen_logps.nansum()).nansum().item()
metrics["count/chosen"] = all_num_chosen
metrics["kl"] = kl.item()
return losses, metrics
if all_num_rejected > 0:
metrics["rewards/rejected_sum"] = self.accelerator.gather(rejected_rewards.nansum()).nansum().item()
metrics["logps/rejected_sum"] = self.accelerator.gather(policy_rejected_logps.nansum()).nansum().item()
metrics["count/rejected"] = all_num_rejected
@override
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
r"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
"""
loss = super().compute_loss(model, inputs, return_outputs)
if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False):
if return_outputs:
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
else:
return loss / self.args.gradient_accumulation_steps
metrics["kl"] = kl.item()
return loss
return losses, metrics
@override
def log(self, logs: Dict[str, float]) -> None:
r"""
Log `logs` on the various objects watching training, including stored metrics.
"""
# logs either has "loss" or "eval_loss"
train_eval = "train" if "loss" in logs else "eval"
prefix = "eval_" if train_eval == "eval" else ""
# Add averaged stored metrics to logs
key_list, metric_list = [], []
for key, metrics in self._stored_metrics[train_eval].items():
key_list.append(key)
metric_list.append(torch.tensor(metrics, dtype=torch.float).to(self.accelerator.device).sum().item())
del self._stored_metrics[train_eval]
if len(metric_list) < 9: # pad to for all reduce
for i in range(9 - len(metric_list)):
key_list.append(f"dummy_{i}")
metric_list.append(0.0)
metric_list = torch.tensor(metric_list, dtype=torch.float).to(self.accelerator.device)
metric_list = self.accelerator.reduce(metric_list, "sum").tolist()
metric_dict: Dict[str, float] = dict(zip(key_list, metric_list))
for split in ["chosen", "rejected"]: # accumulate average metrics from sums and lengths
if f"count/{split}" in metric_dict:
for key in ("rewards", "logps", "logits"):
logs[f"{prefix}{key}/{split}"] = metric_dict[f"{key}/{split}_sum"] / metric_dict[f"count/{split}"]
del metric_dict[f"{key}/{split}_sum"]
del metric_dict[f"count/{split}"]
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs: # calculate reward margin
logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
for key, metric in metric_dict.items(): # add remaining items
if not key.startswith("dummy_"):
logs[key] = metric
return Trainer.log(self, logs)
......@@ -81,7 +81,7 @@ def run_kto(
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "train/rewards/chosen"])
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "rewards/chosen"])
# Evaluation
if training_args.do_eval:
......
......@@ -62,8 +62,8 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d
setattr(model, "default_head_bias", v_head_layer.bias.data.detach().clone())
device = v_head_layer.weight.device
v_head_layer.weight.data = model.get_buffer("{}_head_weight".format(target)).detach().clone().to(device)
v_head_layer.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone().to(device)
v_head_layer.weight.data = model.get_buffer(f"{target}_head_weight").detach().clone().to(device)
v_head_layer.bias.data = model.get_buffer(f"{target}_head_bias").detach().clone().to(device)
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, "torch.Tensor"]:
......
......@@ -37,7 +37,7 @@ from trl.core import PPODecorators, logprobs_from_logits
from trl.models.utils import unwrap_model_for_generation
from typing_extensions import override
from ...extras.logging import get_logger
from ...extras import logging
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
......@@ -58,7 +58,7 @@ if TYPE_CHECKING:
from ...hparams import FinetuningArguments, GeneratingArguments, ModelArguments
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
class CustomPPOTrainer(PPOTrainer, Trainer):
......@@ -112,7 +112,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
]
ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin
if ppo_config.log_with is not None:
logger.warning("PPOTrainer cannot use external logger when DeepSpeed is enabled.")
logger.warning_rank0("PPOTrainer cannot use external logger when DeepSpeed is enabled.")
ppo_config.log_with = None
# Create optimizer and scheduler
......@@ -160,7 +160,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
callbacks, self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler
)
if self.args.max_steps > 0:
logger.info("max_steps is given, it will override any value given in num_train_epochs")
logger.info_rank0("max_steps is given, it will override any value given in num_train_epochs")
self.amp_context = torch.autocast(self.current_device.type)
warnings.simplefilter("ignore") # remove gc warnings on ref model
......@@ -181,7 +181,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.add_callback(SaveProcessorCallback(processor))
if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)
......@@ -216,20 +216,19 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.state.is_local_process_zero = self.is_local_process_zero()
self.state.is_world_process_zero = self.is_world_process_zero()
if self.is_world_process_zero():
logger.info("***** Running training *****")
logger.info(" Num examples = {:,}".format(num_examples))
logger.info(" Num Epochs = {:,}".format(num_train_epochs))
logger.info(" Instantaneous batch size per device = {:,}".format(self.args.per_device_train_batch_size))
logger.info(
" Total train batch size (w. parallel, buffer, distributed & accumulation) = {:,}".format(
total_train_batch_size
)
logger.info_rank0("***** Running training *****")
logger.info_rank0(f" Num examples = {num_examples:,}")
logger.info_rank0(f" Num Epochs = {num_train_epochs:,}")
logger.info_rank0(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
logger.info_rank0(
" Total train batch size (w. parallel, buffer, distributed & accumulation) = {:,}".format(
total_train_batch_size
)
logger.info(" Gradient Accumulation steps = {:,}".format(self.args.gradient_accumulation_steps))
logger.info(" Num optimization epochs per batch = {:,}".format(self.finetuning_args.ppo_epochs))
logger.info(" Total training steps = {:,}".format(max_steps))
logger.info(" Number of trainable parameters = {:,}".format(count_parameters(self.model)[0]))
)
logger.info_rank0(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps:,}")
logger.info_rank0(f" Num optimization epochs per batch = {self.finetuning_args.ppo_epochs:,}")
logger.info_rank0(f" Total training steps = {max_steps:,}")
logger.info_rank0(f" Number of trainable parameters = {count_parameters(self.model)[0]:,}")
dataiter = iter(self.dataloader)
loss_meter = AverageMeter()
......@@ -269,7 +268,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True)
self.log_stats(stats, batch, rewards)
except Exception:
logger.warning("Failed to save stats due to unknown errors.")
logger.warning_rank0("Failed to save stats due to unknown errors.")
self.state.global_step += 1
self.callback_handler.on_step_end(self.args, self.state, self.control)
......@@ -290,7 +289,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if (step + 1) % self.args.save_steps == 0: # save checkpoint
self.save_model(
os.path.join(self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step))
os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
)
self.callback_handler.on_save(self.args, self.state, self.control)
......@@ -498,7 +497,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if self.args.should_save:
self._save(output_dir, state_dict=state_dict)
except ValueError:
logger.warning(
logger.warning_rank0(
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead,"
" use zero_to_fp32.py to recover weights"
)
......
......@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Optional
from transformers import Trainer
from typing_extensions import override
from ...extras.logging import get_logger
from ...extras.packages import is_transformers_version_equal_to_4_46
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
......@@ -30,9 +30,6 @@ if TYPE_CHECKING:
from ...hparams import FinetuningArguments
logger = get_logger(__name__)
class CustomTrainer(Trainer):
r"""
Inherits Trainer for custom optimizer.
......@@ -51,7 +48,7 @@ class CustomTrainer(Trainer):
self.add_callback(PissaConvertCallback)
if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)
......@@ -68,3 +65,19 @@ class CustomTrainer(Trainer):
) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
@override
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
r"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
"""
loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False):
# other model should not scale the loss
if return_outputs:
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
else:
return loss / self.args.gradient_accumulation_steps
return loss
......@@ -24,7 +24,8 @@ import torch
from transformers import Trainer
from typing_extensions import override
from ...extras.logging import get_logger
from ...extras import logging
from ...extras.packages import is_transformers_version_equal_to_4_46
from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
......@@ -36,7 +37,7 @@ if TYPE_CHECKING:
from ...hparams import FinetuningArguments
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
class PairwiseTrainer(Trainer):
......@@ -59,7 +60,7 @@ class PairwiseTrainer(Trainer):
self.add_callback(PissaConvertCallback)
if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)
......@@ -79,7 +80,7 @@ class PairwiseTrainer(Trainer):
@override
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r"""
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
......@@ -98,6 +99,10 @@ class PairwiseTrainer(Trainer):
chosen_scores, rejected_scores = chosen_scores.squeeze(), rejected_scores.squeeze()
loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean()
if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False):
loss /= self.args.gradient_accumulation_steps # fixes the loss value for transformers 4.46.0
if return_outputs:
return loss, (loss, chosen_scores, rejected_scores)
else:
......@@ -113,7 +118,7 @@ class PairwiseTrainer(Trainer):
return
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
logger.info(f"Saving prediction results to {output_prediction_file}")
logger.info_rank0(f"Saving prediction results to {output_prediction_file}")
chosen_scores, rejected_scores = predict_results.predictions
with open(output_prediction_file, "w", encoding="utf-8") as writer:
......
......@@ -25,8 +25,9 @@ import torch
from transformers import Seq2SeqTrainer
from typing_extensions import override
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from ...extras.packages import is_transformers_version_equal_to_4_46
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
......@@ -39,7 +40,7 @@ if TYPE_CHECKING:
from ...hparams import FinetuningArguments
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
class CustomSeq2SeqTrainer(Seq2SeqTrainer):
......@@ -60,7 +61,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
self.add_callback(PissaConvertCallback)
if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)
......@@ -78,6 +79,22 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
@override
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
r"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
"""
loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False):
# other model should not scale the loss
if return_outputs:
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
else:
return loss / self.args.gradient_accumulation_steps
return loss
@override
def prediction_step(
self,
......@@ -129,7 +146,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
return
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
logger.info(f"Saving prediction results to {output_prediction_file}")
logger.info_rank0(f"Saving prediction results to {output_prediction_file}")
labels = np.where(
predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment