Commit 032b90a1 authored by luopl's avatar luopl
Browse files

init commit

parents
Pipeline #1684 canceled with stages
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
from typing_extensions import Self
if TYPE_CHECKING:
import torch
@dataclass
class ModelArguments:
r"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
"""
model_name_or_path: str = field(
metadata={
"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
},
)
adapter_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": (
"Path to the adapter weight or identifier from huggingface.co/models. "
"Use commas to separate multiple adapters."
)
},
)
adapter_folder: Optional[str] = field(
default=None,
metadata={"help": "The folder containing the adapter weights to load."},
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
)
use_fast_tokenizer: bool = field(
default=True,
metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."},
)
resize_vocab: bool = field(
default=False,
metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."},
)
split_special_tokens: bool = field(
default=False,
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
)
new_special_tokens: Optional[str] = field(
default=None,
metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
low_cpu_mem_usage: bool = field(
default=True,
metadata={"help": "Whether or not to use memory-efficient model loading."},
)
quantization_method: Literal["bitsandbytes", "hqq", "eetq"] = field(
default="bitsandbytes",
metadata={"help": "Quantization method to use for on-the-fly quantization."},
)
quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the model using bitsandbytes."},
)
quantization_type: Literal["fp4", "nf4"] = field(
default="nf4",
metadata={"help": "Quantization data type to use in int4 training."},
)
double_quantization: bool = field(
default=True,
metadata={"help": "Whether or not to use double quantization in int4 training."},
)
quantization_device_map: Optional[Literal["auto"]] = field(
default=None,
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
)
rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
default=None,
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
)
flash_attn: Literal["auto", "disabled", "sdpa", "fa2"] = field(
default="auto",
metadata={"help": "Enable FlashAttention for faster training and inference."},
)
shift_attn: bool = field(
default=False,
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
)
mixture_of_depths: Optional[Literal["convert", "load"]] = field(
default=None,
metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."},
)
use_unsloth: bool = field(
default=False,
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
)
visual_inputs: bool = field(
default=False,
metadata={"help": "Whethor or not to use multimodal LLM that accepts visual inputs."},
)
moe_aux_loss_coef: Optional[float] = field(
default=None,
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
)
disable_gradient_checkpointing: bool = field(
default=False,
metadata={"help": "Whether or not to disable gradient checkpointing."},
)
upcast_layernorm: bool = field(
default=False,
metadata={"help": "Whether or not to upcast the layernorm weights in fp32."},
)
upcast_lmhead_output: bool = field(
default=False,
metadata={"help": "Whether or not to upcast the output of lm_head in fp32."},
)
train_from_scratch: bool = field(
default=False,
metadata={"help": "Whether or not to randomly initialize the model weights."},
)
infer_backend: Literal["huggingface", "vllm"] = field(
default="huggingface",
metadata={"help": "Backend engine used at inference."},
)
vllm_maxlen: int = field(
default=2048,
metadata={"help": "Maximum sequence (prompt + response) length of the vLLM engine."},
)
vllm_gpu_util: float = field(
default=0.9,
metadata={"help": "The fraction of GPU memory in (0,1) to be used for the vLLM engine."},
)
vllm_enforce_eager: bool = field(
default=False,
metadata={"help": "Whether or not to disable CUDA graph in the vLLM engine."},
)
vllm_max_lora_rank: int = field(
default=32,
metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
)
offload_folder: str = field(
default="offload",
metadata={"help": "Path to offload model weights."},
)
use_cache: bool = field(
default=True,
metadata={"help": "Whether or not to use KV cache in generation."},
)
infer_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field(
default="auto",
metadata={"help": "Data type for model weights and activations at inference."},
)
hf_hub_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."},
)
ms_hub_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with ModelScope Hub."},
)
export_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory to save the exported model."},
)
export_size: int = field(
default=1,
metadata={"help": "The file shard size (in GB) of the exported model."},
)
export_device: Literal["cpu", "auto"] = field(
default="cpu",
metadata={"help": "The device used in model export, use `auto` to accelerate exporting."},
)
export_quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the exported model."},
)
export_quantization_dataset: Optional[str] = field(
default=None,
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."},
)
export_quantization_nsamples: int = field(
default=128,
metadata={"help": "The number of samples used for quantization."},
)
export_quantization_maxlen: int = field(
default=1024,
metadata={"help": "The maximum length of the model inputs used for quantization."},
)
export_legacy_format: bool = field(
default=False,
metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."},
)
export_hub_model_id: Optional[str] = field(
default=None,
metadata={"help": "The name of the repository if push the model to the Hugging Face hub."},
)
print_param_status: bool = field(
default=False,
metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
)
def __post_init__(self):
self.compute_dtype: Optional["torch.dtype"] = None
self.device_map: Optional[Union[str, Dict[str, Any]]] = None
self.model_max_length: Optional[int] = None
self.block_diag_attn: bool = False
if self.split_special_tokens and self.use_fast_tokenizer:
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
if self.visual_inputs and self.use_unsloth:
raise ValueError("Unsloth does not support MLLM yet. Stay tuned.")
if self.adapter_name_or_path is not None: # support merging multiple lora weights
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
if self.new_special_tokens is not None: # support multiple special tokens
self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(",")]
if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
raise ValueError("Quantization dataset is necessary for exporting.")
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
@classmethod
def copyfrom(cls, old_arg: Self, **kwargs) -> Self:
arg_dict = old_arg.to_dict()
arg_dict.update(**kwargs)
new_arg = cls(**arg_dict)
new_arg.compute_dtype = old_arg.compute_dtype
new_arg.device_map = old_arg.device_map
new_arg.model_max_length = old_arg.model_max_length
new_arg.block_diag_attn = old_arg.block_diag_attn
return new_arg
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
from typing import Any, Dict, Optional, Tuple
import torch
import transformers
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.trainer_utils import get_last_checkpoint
from transformers.training_args import ParallelMode
from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.versions import require_version
from ..extras.constants import CHECKPOINT_NAMES
from ..extras.logging import get_logger
from ..extras.misc import check_dependencies, get_current_device
from .data_args import DataArguments
from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments
from .generating_args import GeneratingArguments
from .model_args import ModelArguments
logger = get_logger(__name__)
check_dependencies()
_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
if args is not None:
return parser.parse_dict(args)
if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
if unknown_args:
print(parser.format_help())
print("Got unknown args, potentially deprecated arguments: {}".format(unknown_args))
raise ValueError("Some specified arguments are not used by the HfArgumentParser: {}".format(unknown_args))
return (*parsed_args,)
def _set_transformers_logging(log_level: Optional[int] = logging.INFO) -> None:
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
def _verify_model_args(
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments",
) -> None:
if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Adapter is only valid for the LoRA method.")
if model_args.quantization_bit is not None:
if finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
if finetuning_args.pissa_init:
raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA for a quantized model.")
if model_args.resize_vocab:
raise ValueError("Cannot resize embedding layers of a quantized model.")
if model_args.adapter_name_or_path is not None and finetuning_args.create_new_adapter:
raise ValueError("Cannot create new adapter upon a quantized model.")
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
if data_args.template == "yi" and model_args.use_fast_tokenizer:
logger.warning("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
model_args.use_fast_tokenizer = False
def _check_extra_dependencies(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
training_args: Optional["Seq2SeqTrainingArguments"] = None,
) -> None:
if model_args.use_unsloth:
require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth")
if model_args.mixture_of_depths is not None:
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
if model_args.infer_backend == "vllm":
require_version("vllm>=0.4.3", "To fix: pip install vllm>=0.4.3")
if finetuning_args.use_galore:
require_version("galore_torch", "To fix: pip install galore_torch")
if finetuning_args.use_badam:
require_version("badam>=1.2.1", "To fix: pip install badam>=1.2.1")
if finetuning_args.plot_loss:
require_version("matplotlib", "To fix: pip install matplotlib")
if training_args is not None and training_args.predict_with_generate:
require_version("jieba", "To fix: pip install jieba")
require_version("nltk", "To fix: pip install nltk")
require_version("rouge_chinese", "To fix: pip install rouge-chinese")
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS)
return _parse_args(parser, args)
def _parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
parser = HfArgumentParser(_INFER_ARGS)
return _parse_args(parser, args)
def _parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
parser = HfArgumentParser(_EVAL_ARGS)
return _parse_args(parser, args)
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
# Setup logging
if training_args.should_log:
_set_transformers_logging()
# Check arguments
if finetuning_args.stage != "pt" and data_args.template is None:
raise ValueError("Please specify which `template` to use.")
if finetuning_args.stage != "sft" and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
if finetuning_args.stage != "sft" and data_args.neat_packing:
raise ValueError("`neat_packing` cannot be set as True except SFT.")
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
if finetuning_args.stage in ["rm", "ppo"] and training_args.load_best_model_at_end:
raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.")
if finetuning_args.stage == "ppo" and not training_args.do_train:
raise ValueError("PPO training does not support evaluation, use the SFT stage to evaluate models.")
if finetuning_args.stage == "ppo" and model_args.shift_attn:
raise ValueError("PPO training is incompatible with S^2-Attn.")
if finetuning_args.stage == "ppo" and finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
raise ValueError("Unsloth does not support lora reward model.")
if (
finetuning_args.stage == "ppo"
and training_args.report_to
and training_args.report_to[0] not in ["wandb", "tensorboard"]
):
raise ValueError("PPO only accepts wandb or tensorboard logger.")
if training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.")
if training_args.deepspeed and training_args.parallel_mode != ParallelMode.DISTRIBUTED:
raise ValueError("Please use `FORCE_TORCHRUN=1` to launch DeepSpeed training.")
if training_args.max_steps == -1 and data_args.streaming:
raise ValueError("Please specify `max_steps` in streaming mode.")
if training_args.do_train and data_args.dataset is None:
raise ValueError("Please specify dataset for training.")
if (training_args.do_eval or training_args.do_predict) and (
data_args.eval_dataset is None and data_args.val_size < 1e-6
):
raise ValueError("Please specify dataset for evaluation.")
if training_args.predict_with_generate and data_args.eval_dataset is None:
raise ValueError("Cannot use `predict_with_generate` if `eval_dataset` is None.")
if training_args.predict_with_generate and finetuning_args.compute_accuracy:
raise ValueError("Cannot use `predict_with_generate` and `compute_accuracy` together.")
if training_args.do_train and model_args.quantization_device_map == "auto":
raise ValueError("Cannot use device map for quantized models in training.")
if finetuning_args.pissa_init and is_deepspeed_zero3_enabled():
raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA in DeepSpeed ZeRO-3.")
if finetuning_args.pure_bf16:
if not is_torch_bf16_gpu_available():
raise ValueError("This device does not support `pure_bf16`.")
if is_deepspeed_zero3_enabled():
raise ValueError("`pure_bf16` is incompatible with DeepSpeed ZeRO-3.")
if (
finetuning_args.use_galore
and finetuning_args.galore_layerwise
and training_args.parallel_mode == ParallelMode.DISTRIBUTED
):
raise ValueError("Distributed training does not support layer-wise GaLore.")
if finetuning_args.use_badam and training_args.parallel_mode == ParallelMode.DISTRIBUTED:
if finetuning_args.badam_mode == "ratio":
raise ValueError("Radio-based BAdam does not yet support distributed training, use layer-wise BAdam.")
elif not is_deepspeed_zero3_enabled():
raise ValueError("Layer-wise BAdam only supports DeepSpeed ZeRO-3 training.")
if finetuning_args.use_galore and training_args.deepspeed is not None:
raise ValueError("GaLore is incompatible with DeepSpeed yet.")
if model_args.infer_backend == "vllm":
raise ValueError("vLLM backend is only available for API, CLI and Web.")
if model_args.visual_inputs and data_args.packing:
raise ValueError("Cannot use packing in MLLM fine-tuning.")
if model_args.use_unsloth and is_deepspeed_zero3_enabled():
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
if data_args.neat_packing and not data_args.packing:
logger.warning("`neat_packing` requires `packing` is True. Change `packing` to True.")
data_args.packing = True
_verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args, training_args)
if (
training_args.do_train
and finetuning_args.finetuning_type == "lora"
and model_args.quantization_bit is None
and model_args.resize_vocab
and finetuning_args.additional_target is None
):
logger.warning("Remember to add embedding layers to `additional_target` to make the added tokens trainable.")
if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
logger.warning("We recommend enable mixed precision training.")
if training_args.do_train and finetuning_args.use_galore and not finetuning_args.pure_bf16:
logger.warning("Using GaLore with mixed precision training may significantly increases GPU memory usage.")
if (not training_args.do_train) and model_args.quantization_bit is not None:
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
logger.warning("Specify `ref_model` for computing rewards at evaluation.")
# Post-process training arguments
if (
training_args.parallel_mode == ParallelMode.DISTRIBUTED
and training_args.ddp_find_unused_parameters is None
and finetuning_args.finetuning_type == "lora"
):
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
training_args.ddp_find_unused_parameters = False
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
can_resume_from_checkpoint = False
if training_args.resume_from_checkpoint is not None:
logger.warning("Cannot resume from checkpoint in current stage.")
training_args.resume_from_checkpoint = None
else:
can_resume_from_checkpoint = True
if (
training_args.resume_from_checkpoint is None
and training_args.do_train
and os.path.isdir(training_args.output_dir)
and not training_args.overwrite_output_dir
and can_resume_from_checkpoint
):
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and any(
os.path.isfile(os.path.join(training_args.output_dir, name)) for name in CHECKPOINT_NAMES
):
raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")
if last_checkpoint is not None:
training_args.resume_from_checkpoint = last_checkpoint
logger.info("Resuming training from {}.".format(training_args.resume_from_checkpoint))
logger.info("Change `output_dir` or use `overwrite_output_dir` to avoid.")
if (
finetuning_args.stage in ["rm", "ppo"]
and finetuning_args.finetuning_type == "lora"
and training_args.resume_from_checkpoint is not None
):
logger.warning(
"Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
training_args.resume_from_checkpoint
)
)
# Post-process model arguments
if training_args.bf16 or finetuning_args.pure_bf16:
model_args.compute_dtype = torch.bfloat16
elif training_args.fp16:
model_args.compute_dtype = torch.float16
model_args.device_map = {"": get_current_device()}
model_args.model_max_length = data_args.cutoff_len
model_args.block_diag_attn = data_args.neat_packing
data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt"
# Log on each process the small summary
logger.info(
"Process rank: {}, device: {}, n_gpu: {}, distributed training: {}, compute dtype: {}".format(
training_args.local_rank,
training_args.device,
training_args.n_gpu,
training_args.parallel_mode == ParallelMode.DISTRIBUTED,
str(model_args.compute_dtype),
)
)
transformers.set_seed(training_args.seed)
return model_args, data_args, training_args, finetuning_args, generating_args
def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
_set_transformers_logging()
if data_args.template is None:
raise ValueError("Please specify which `template` to use.")
if model_args.infer_backend == "vllm":
if finetuning_args.stage != "sft":
raise ValueError("vLLM engine only supports auto-regressive models.")
if model_args.quantization_bit is not None:
raise ValueError("vLLM engine does not support bnb quantization (GPTQ and AWQ are supported).")
if model_args.rope_scaling is not None:
raise ValueError("vLLM engine does not support RoPE scaling.")
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
raise ValueError("vLLM only accepts a single adapter. Merge them first.")
if finetuning_args.stage == "rm" and model_args.visual_inputs:
raise ValueError("Reward server does not support MLLM yet. Stay tuned.")
_verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args)
if model_args.export_dir is not None and model_args.export_device == "cpu":
model_args.device_map = {"": torch.device("cpu")}
model_args.model_max_length = data_args.cutoff_len
else:
model_args.device_map = "auto"
return model_args, data_args, finetuning_args, generating_args
def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
_set_transformers_logging()
if data_args.template is None:
raise ValueError("Please specify which `template` to use.")
if model_args.infer_backend == "vllm":
raise ValueError("vLLM backend is only available for API, CLI and Web.")
_verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args)
model_args.device_map = "auto"
transformers.set_seed(eval_args.seed)
return model_args, data_args, eval_args, finetuning_args
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from llamafactory.train.tuner import run_exp
def launch():
run_exp()
if __name__ == "__main__":
launch()
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .loader import load_config, load_model, load_tokenizer
from .model_utils.misc import find_all_linear_modules
from .model_utils.quantization import QuantizationMethod
from .model_utils.valuehead import load_valuehead_params
__all__ = [
"QuantizationMethod",
"load_config",
"load_model",
"load_tokenizer",
"find_all_linear_modules",
"load_valuehead_params",
]
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import TYPE_CHECKING
import torch
from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from ..extras.logging import get_logger
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
from .model_utils.quantization import QuantizationMethod
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel
from ..hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__)
def _setup_full_tuning(
model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool,
cast_trainable_params_to_fp32: bool,
) -> None:
if not is_trainable:
return
logger.info("Fine-tuning method: Full")
forbidden_modules = set()
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
forbidden_modules.add("vision_tower")
if model_args.visual_inputs and finetuning_args.train_mm_proj_only:
forbidden_modules.add("language_model")
for name, param in model.named_parameters():
if not any(forbidden_module in name for forbidden_module in forbidden_modules):
if cast_trainable_params_to_fp32:
param.data = param.data.to(torch.float32)
else:
param.requires_grad_(False)
def _setup_freeze_tuning(
model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool,
cast_trainable_params_to_fp32: bool,
) -> None:
if not is_trainable:
return
logger.info("Fine-tuning method: Freeze")
if model_args.visual_inputs:
config = model.config.text_config
else:
config = model.config
num_layers = (
getattr(config, "num_hidden_layers", None)
or getattr(config, "num_layers", None)
or getattr(config, "n_layer", None)
)
if not num_layers:
raise ValueError("Current model does not support freeze tuning.")
if finetuning_args.use_llama_pro:
if num_layers % finetuning_args.freeze_trainable_layers != 0:
raise ValueError(
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(
num_layers, finetuning_args.freeze_trainable_layers
)
)
stride = num_layers // finetuning_args.freeze_trainable_layers
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
elif finetuning_args.freeze_trainable_layers > 0: # fine-tuning the last n layers if num_layer_trainable > 0
trainable_layer_ids = range(max(0, num_layers - finetuning_args.freeze_trainable_layers), num_layers)
else: # fine-tuning the first n layers if num_layer_trainable < 0
trainable_layer_ids = range(min(-finetuning_args.freeze_trainable_layers, num_layers))
hidden_modules = set()
non_hidden_modules = set()
for name, _ in model.named_parameters():
if ".0." in name:
hidden_modules.add(name.split(".0.")[-1].split(".")[0])
elif ".1." in name: # MoD starts from layer 1
hidden_modules.add(name.split(".1.")[-1].split(".")[0])
if re.search(r"\.\d+\.", name) is None:
non_hidden_modules.add(name.split(".")[-2])
trainable_layers = []
for module_name in finetuning_args.freeze_trainable_modules:
if module_name != "all" and module_name not in hidden_modules:
raise ValueError(
"Module {} is not found, please choose from {}".format(module_name, ", ".join(hidden_modules))
)
for idx in trainable_layer_ids:
trainable_layers.append(".{:d}.{}".format(idx, module_name if module_name != "all" else ""))
if finetuning_args.freeze_extra_modules:
for module_name in finetuning_args.freeze_extra_modules:
if module_name not in non_hidden_modules:
raise ValueError(
"Module {} is not found, please choose from {}".format(module_name, ", ".join(non_hidden_modules))
)
trainable_layers.append(module_name)
forbidden_modules = set()
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
forbidden_modules.add("vision_tower")
for name, param in model.named_parameters():
if any(trainable_layer in name for trainable_layer in trainable_layers) and not any(
forbidden_module in name for forbidden_module in forbidden_modules
):
if cast_trainable_params_to_fp32:
param.data = param.data.to(torch.float32)
else:
param.requires_grad_(False)
logger.info("Set trainable layers: {}".format(",".join(trainable_layers)))
def _setup_lora_tuning(
config: "PretrainedConfig",
model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool,
cast_trainable_params_to_fp32: bool,
) -> "PeftModel":
if is_trainable:
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
adapter_to_resume = None
if model_args.adapter_name_or_path is not None:
is_mergeable = True
if getattr(model, "quantization_method", None): # merge lora in quantized model is unstable
assert len(model_args.adapter_name_or_path) == 1, "Quantized model only accepts a single adapter."
is_mergeable = False
if is_deepspeed_zero3_enabled():
assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3."
is_mergeable = False
if model_args.use_unsloth:
assert len(model_args.adapter_name_or_path) == 1, "Unsloth model only accepts a single adapter."
is_mergeable = False
if (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable):
adapter_to_merge = model_args.adapter_name_or_path[:-1]
adapter_to_resume = model_args.adapter_name_or_path[-1]
else:
adapter_to_merge = model_args.adapter_name_or_path
init_kwargs = {
"subfolder": model_args.adapter_folder,
"offload_folder": model_args.offload_folder,
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"token": model_args.hf_hub_token,
}
for adapter in adapter_to_merge:
model: "LoraModel" = PeftModel.from_pretrained(model, adapter, **init_kwargs)
model = model.merge_and_unload()
if len(adapter_to_merge) > 0:
logger.info("Merged {} adapter(s).".format(len(adapter_to_merge)))
if adapter_to_resume is not None: # resume lora training
if model_args.use_unsloth:
model = load_unsloth_peft_model(config, model_args, is_trainable=is_trainable)
else:
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
if is_trainable and adapter_to_resume is None: # create new lora weights while training
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
target_modules = find_all_linear_modules(model, finetuning_args.freeze_vision_tower)
else:
target_modules = finetuning_args.lora_target
if finetuning_args.use_llama_pro:
target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers)
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
target_modules = "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
if (
finetuning_args.use_dora
and getattr(model, "quantization_method", None) is not None
and getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES
):
raise ValueError("DoRA is not compatible with PTQ-quantized models.")
if model_args.resize_vocab and finetuning_args.additional_target is None:
input_embeddings = model.get_input_embeddings()
output_embeddings = model.get_output_embeddings()
module_names = set()
for name, module in model.named_modules():
if module in [input_embeddings, output_embeddings]:
module_names.add(name.split(".")[-1])
finetuning_args.additional_target = module_names
logger.warning("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
peft_kwargs = {
"r": finetuning_args.lora_rank,
"target_modules": target_modules,
"lora_alpha": finetuning_args.lora_alpha,
"lora_dropout": finetuning_args.lora_dropout,
"use_rslora": finetuning_args.use_rslora,
"use_dora": finetuning_args.use_dora,
"modules_to_save": finetuning_args.additional_target,
}
if model_args.use_unsloth:
model = get_unsloth_peft_model(model, model_args, peft_kwargs)
else:
if finetuning_args.pissa_init:
if finetuning_args.pissa_iter == -1:
logger.info("Using PiSSA initialization.")
peft_kwargs["init_lora_weights"] = "pissa"
else:
logger.info("Using PiSSA initialization with FSVD steps {}.".format(finetuning_args.pissa_iter))
peft_kwargs["init_lora_weights"] = "pissa_niter_{}".format(finetuning_args.pissa_iter)
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
**peft_kwargs,
)
model = get_peft_model(model, lora_config)
if is_trainable and cast_trainable_params_to_fp32:
for param in filter(lambda p: p.requires_grad, model.parameters()):
param.data = param.data.to(torch.float32)
return model
def init_adapter(
config: "PretrainedConfig",
model: "PreTrainedModel",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool,
) -> "PreTrainedModel":
r"""
Initializes the adapters.
Support full-parameter, freeze and LoRA training.
Note that the trainable parameters must be cast to float32.
"""
if is_trainable and getattr(model, "quantization_method", None) is not None:
if finetuning_args.finetuning_type != "lora":
raise ValueError("Quantized models can only be used for the LoRA tuning.")
if finetuning_args.pissa_init:
raise ValueError("Cannot initialize PiSSA adapter on quantized models.")
# cast trainable parameters to float32 if:
# 1. is_trainable and not pure_bf16 and not badam and quantization_bit is not None (qlora)
# 2. is_trainable and not pure_bf16 and not badam and not zero3 and not fsdp (zero3 or fsdp already in fp32)
cast_trainable_params_to_fp32 = False
if not is_trainable:
pass
elif finetuning_args.pure_bf16 or finetuning_args.use_badam:
logger.info("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()):
logger.info("ZeRO3 / FSDP detected, remaining trainable params in float32.")
else:
logger.info("Upcasting trainable params to float32.")
cast_trainable_params_to_fp32 = True
if finetuning_args.finetuning_type == "full":
_setup_full_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
elif finetuning_args.finetuning_type == "freeze":
_setup_freeze_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
elif finetuning_args.finetuning_type == "lora":
model = _setup_lora_tuning(
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
)
else:
raise NotImplementedError("Unknown finetuning type: {}.".format(finetuning_args.finetuning_type))
return model
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead
from ..extras.logging import get_logger
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_ms
from .adapter import init_adapter
from .model_utils.misc import register_autoclass
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
from .model_utils.unsloth import load_unsloth_pretrained_model
from .model_utils.valuehead import load_valuehead_params
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from ..hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__)
class TokenizerModule(TypedDict):
tokenizer: "PreTrainedTokenizer"
processor: Optional["ProcessorMixin"]
def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
r"""
Gets arguments to load config/tokenizer/model.
Note: including inplace operation of model_args.
"""
skip_check_imports()
model_args.model_name_or_path = try_download_model_from_ms(model_args)
return {
"trust_remote_code": True,
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"token": model_args.hf_hub_token,
}
def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
r"""
Loads pretrained tokenizer.
Note: including inplace operation of model_args.
"""
init_kwargs = _get_init_kwargs(model_args)
try:
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=model_args.use_fast_tokenizer,
split_special_tokens=model_args.split_special_tokens,
padding_side="right",
**init_kwargs,
)
except ValueError: # try the fast one
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=True,
padding_side="right",
**init_kwargs,
)
if model_args.new_special_tokens is not None:
num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=model_args.new_special_tokens),
replace_additional_special_tokens=False,
)
logger.info("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
if num_added_tokens > 0 and not model_args.resize_vocab:
model_args.resize_vocab = True
logger.warning("New tokens have been added, changed `resize_vocab` to True.")
patch_tokenizer(tokenizer)
if model_args.visual_inputs:
try:
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
setattr(processor, "tokenizer", tokenizer)
except Exception:
raise ValueError(
"This multimodal LLM is not supported.\n"
"Download LLaVA-1.5 models from: https://huggingface.co/llava-hf\n"
"Download Yi-VL models from: https://huggingface.co/BUAADreamer"
)
else:
processor = None
return {"tokenizer": tokenizer, "processor": processor}
def load_config(model_args: "ModelArguments") -> "PretrainedConfig":
r"""
Loads model config.
"""
init_kwargs = _get_init_kwargs(model_args)
return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)
def load_model(
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: bool = False,
add_valuehead: bool = False,
) -> "PreTrainedModel":
r"""
Loads pretrained model.
"""
init_kwargs = _get_init_kwargs(model_args)
config = load_config(model_args)
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
model = None
lazy_load = False
if model_args.use_unsloth:
if model_args.adapter_name_or_path is not None:
lazy_load = True
elif is_trainable:
model = load_unsloth_pretrained_model(config, model_args)
if model is None and not lazy_load:
init_kwargs["config"] = config
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
if model_args.mixture_of_depths == "load":
model = load_mod_pretrained_model(**init_kwargs)
elif model_args.visual_inputs:
model = AutoModelForVision2Seq.from_pretrained(**init_kwargs)
elif model_args.train_from_scratch:
model = AutoModelForCausalLM.from_config(config)
else:
model = AutoModelForCausalLM.from_pretrained(**init_kwargs)
if model_args.mixture_of_depths == "convert":
model = convert_pretrained_model_to_mod(model, config, model_args)
if not lazy_load:
patch_model(model, tokenizer, model_args, is_trainable, add_valuehead)
register_autoclass(config, model, tokenizer)
model = init_adapter(config, model, model_args, finetuning_args, is_trainable)
if add_valuehead:
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
patch_valuehead_model(model)
if model_args.adapter_name_or_path is not None:
vhead_path = model_args.adapter_name_or_path[-1]
else:
vhead_path = model_args.model_name_or_path
vhead_params = load_valuehead_params(vhead_path, model_args)
if vhead_params is not None:
model.load_state_dict(vhead_params, strict=False)
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))
if not is_trainable:
model.requires_grad_(False)
for param in model.parameters():
if param.data.dtype == torch.float32 and model_args.compute_dtype != torch.float32:
param.data = param.data.to(model_args.compute_dtype)
model.eval()
else:
model.train()
trainable_params, all_param = count_parameters(model)
if is_trainable:
param_stats = "trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format(
trainable_params, all_param, 100 * trainable_params / all_param
)
else:
param_stats = "all params: {:,}".format(all_param)
logger.info(param_stats)
if model_args.print_param_status:
for name, param in model.named_parameters():
print(
"name: {}, dtype: {}, device: {}, trainable: {}".format(
name, param.dtype, param.device, param.requires_grad
)
)
return model
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
from transformers.utils.versions import require_version
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PretrainedConfig
from ...hparams import ModelArguments
logger = get_logger(__name__)
def configure_attn_implementation(
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
) -> None:
if getattr(config, "model_type", None) == "gemma2" and is_trainable:
if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2":
if is_flash_attn_2_available():
require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4")
require_version("flash_attn>=2.6.0", "To fix: pip install flash_attn>=2.6.0")
logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
model_args.flash_attn = "fa2"
else:
logger.warning("Gemma-2 should use eager attention, change `flash_attn` to disabled.")
model_args.flash_attn = "disabled"
elif model_args.flash_attn == "sdpa":
logger.warning("Gemma-2 should use soft-capping attention, while the SDPA attention does not support it.")
if model_args.flash_attn == "auto":
return
elif model_args.flash_attn == "disabled":
requested_attn_implementation = "eager"
elif model_args.flash_attn == "sdpa":
if not is_torch_sdpa_available():
logger.warning("torch>=2.1.1 is required for SDPA attention.")
return
requested_attn_implementation = "sdpa"
elif model_args.flash_attn == "fa2":
if not is_flash_attn_2_available():
logger.warning("FlashAttention-2 is not installed.")
return
requested_attn_implementation = "flash_attention_2"
else:
raise NotImplementedError("Unknown attention type: {}".format(model_args.flash_attn))
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
setattr(config, "attn_implementation", requested_attn_implementation)
else:
setattr(config, "_attn_implementation", requested_attn_implementation)
def print_attn_implementation(config: "PretrainedConfig") -> None:
if getattr(config, "model_type", None) == "internlm2": # special case for custom models
attn_implementation = getattr(config, "attn_implementation", None)
else:
attn_implementation = getattr(config, "_attn_implementation", None)
if attn_implementation == "flash_attention_2":
logger.info("Using FlashAttention-2 for faster training and inference.")
elif attn_implementation == "sdpa":
logger.info("Using torch SDPA for faster training and inference.")
else:
logger.info("Using vanilla attention implementation.")
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's Transformers and PEFT library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py
# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/utils/other.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from functools import partial
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
import torch
from ...extras.constants import LAYERNORM_NAMES
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PreTrainedModel
from ...hparams import ModelArguments
logger = get_logger(__name__)
def _gradient_checkpointing_enable(
self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
) -> None:
r"""
Activates gradient checkpointing for the current model.
Modification of the original method to enable gradient checkpointing for block-wise optimizer.
"""
from torch.utils.checkpoint import checkpoint
if not self.supports_gradient_checkpointing:
raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__))
if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {"use_reentrant": True}
gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)
def custom_gradient_checkpointing_func(func, *args, **kwargs):
module: "torch.nn.Module" = func.__self__
if any(param.requires_grad for param in module.parameters()):
for arg in args:
if torch.is_tensor(arg) and torch.is_floating_point(arg):
arg.requires_grad_(True)
return gradient_checkpointing_func(func, *args, **kwargs)
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
self.apply(partial(self._set_gradient_checkpointing, value=True))
self.enable_input_require_grads()
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
else: # have already enabled input require gradients
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func)
def _fp32_forward_post_hook(
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
) -> "torch.Tensor":
return output.to(torch.float32)
def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
r"""
Includes:
(1) cast the layernorm in fp32
(2) make output embedding layer require grads
(3) add the upcasting of the lm_head in fp32
"""
if model_args.upcast_layernorm:
logger.info("Upcasting layernorm weights in float32.")
for name, param in model.named_parameters():
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
param.data = param.data.to(torch.float32)
if not model_args.disable_gradient_checkpointing:
if not getattr(model, "supports_gradient_checkpointing", False):
logger.warning("Current model does not support gradient checkpointing.")
else:
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
# According to: https://github.com/huggingface/transformers/issues/28339
model.gradient_checkpointing_enable = MethodType(_gradient_checkpointing_enable, model)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.")
if model_args.upcast_lmhead_output:
output_layer = model.get_output_embeddings()
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
logger.info("Upcasting lm_head outputs in float32.")
output_layer.register_forward_hook(_fp32_forward_post_hook)
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from contextlib import nullcontext
from typing import TYPE_CHECKING
import torch
from transformers.integrations import is_deepspeed_zero3_enabled
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
logger = get_logger(__name__)
def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None:
embedding_dim = embed_weight.size(1)
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
r"""
Resize token embeddings.
"""
if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore
params = [model.get_input_embeddings().weight]
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
params.append(model.get_output_embeddings().weight)
context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
else:
context_maybe_zero3 = nullcontext()
with context_maybe_zero3:
current_embedding_size = model.get_input_embeddings().weight.size(0)
if len(tokenizer) > current_embedding_size:
if getattr(model, "quantization_method", None):
raise ValueError("Cannot resize embedding layers of a quantized model.")
if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
raise ValueError("Current model does not support resizing embedding layers.")
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
with context_maybe_zero3:
new_embedding_size = model.get_input_embeddings().weight.size(0)
num_new_tokens = new_embedding_size - current_embedding_size
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size))
# Copyright 2024 EleutherAI, HuggingFace Inc., Yukang Chen, and the LlamaFactory team.
#
# This code is based on the EleutherAI's GPT-NeoX and the HuggingFace's Transformers libraries.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
# This code is also inspired by the original LongLoRA implementation.
# https://github.com/dvlab-research/LongLoRA/blob/main/llama_attn_replace.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import TYPE_CHECKING, Optional, Tuple
import torch
import torch.nn as nn
from transformers.models.llama.modeling_llama import (
Cache,
LlamaAttention,
LlamaFlashAttention2,
LlamaSdpaAttention,
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.utils import logging
from transformers.utils.versions import require_version
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PretrainedConfig
from ...hparams import ModelArguments
transformers_logger = logging.get_logger(__name__)
# Modified from:
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
def llama_attention_forward(
self: "LlamaAttention",
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states: "torch.Tensor" = self.q_proj(hidden_states)
key_states: "torch.Tensor" = self.k_proj(hidden_states)
value_states: "torch.Tensor" = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
num_groups = q_len // groupsz
def shift(state: "torch.Tensor") -> "torch.Tensor":
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
state = torch.cat(
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
dim=2,
)
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
if attention_mask is not None:
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz * n_group, :, groupsz, :)
attn_output = attn_output.transpose(1, 2).contiguous()
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
attn_output = torch.cat(
(
attn_output[:, :, : self.num_heads // 2],
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
),
dim=2,
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# Modified from:
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
def llama_flash_attention_2_forward(
self: "LlamaFlashAttention2",
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# LlamaFlashAttention2 attention does not support output_attentions
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states: "torch.Tensor" = self.q_proj(hidden_states)
key_states: "torch.Tensor" = self.k_proj(hidden_states)
value_states: "torch.Tensor" = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.attention_dropout if self.training else 0.0
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
transformers_logger.warning_once("The input hidden states seems to be silently casted in float32.")
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
num_groups = q_len // groupsz
def shift(state: "torch.Tensor") -> "torch.Tensor":
state = torch.cat(
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
dim=2,
)
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim)
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
if attention_mask is not None:
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
attn_output: "torch.Tensor" = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
attn_output = torch.cat(
(
attn_output[:, :, : self.num_heads // 2],
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
),
dim=2,
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# Modified from:
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
def llama_sdpa_attention_forward(
self: "LlamaSdpaAttention",
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
transformers_logger.warning_once(
"SDPA does not support `output_attentions=True`. Falling back to the vanilla attention"
)
return llama_attention_forward(
self,
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
cache_position=cache_position,
**kwargs,
)
bsz, q_len, _ = hidden_states.size()
query_states: "torch.Tensor" = self.q_proj(hidden_states)
key_states: "torch.Tensor" = self.k_proj(hidden_states)
value_states: "torch.Tensor" = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
num_groups = q_len // groupsz
def shift(state: "torch.Tensor") -> "torch.Tensor":
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
state = torch.cat(
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
dim=2,
)
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
if attention_mask is not None:
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
if query_states.device.type == "cuda" and causal_mask is not None: # avoid pytorch bug
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
attn_output = torch.cat(
(
attn_output[:, :, : self.num_heads // 2],
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
),
dim=2,
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
def _apply_llama_patch() -> None:
require_version("transformers>=4.41.2,<=4.42.4", "To fix: pip install transformers>=4.41.2,<=4.42.4")
LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.shift_attn:
return
logger = get_logger(__name__)
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
setattr(config, "group_size_ratio", 0.25)
_apply_llama_patch()
logger.info("Using shift short attention with group_size_ratio=1/4.")
else:
logger.warning("Current model does not support shift short attention.")
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, List
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
logger = get_logger(__name__)
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]:
r"""
Finds all available modules to apply lora or galore.
"""
forbidden_modules = {"lm_head"}
if model.config.model_type == "chatglm":
forbidden_modules.add("output_layer")
elif model.config.model_type == "internlm2":
forbidden_modules.add("output")
elif model.config.model_type in ["llava", "paligemma"]:
forbidden_modules.add("multi_modal_projector")
if freeze_vision_tower:
forbidden_modules.add("vision_tower")
module_names = set()
for name, module in model.named_modules():
if any(forbidden_module in name for forbidden_module in forbidden_modules):
continue
if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__:
module_names.add(name.split(".")[-1])
logger.info("Found linear modules: {}".format(",".join(module_names)))
return list(module_names)
def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], num_layer_trainable: int) -> List[str]:
r"""
Finds the modules in the expanded blocks to apply lora.
"""
num_layers = getattr(model.config, "num_hidden_layers", None)
if not num_layers:
raise ValueError("Model was not supported.")
if num_layers % num_layer_trainable != 0:
raise ValueError(
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(num_layers, num_layer_trainable)
)
stride = num_layers // num_layer_trainable
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
trainable_layers = [".{:d}.".format(idx) for idx in trainable_layer_ids]
module_names = []
for name, _ in model.named_modules():
if any(target_module in name for target_module in target_modules) and any(
trainable_layer in name for trainable_layer in trainable_layers
):
module_names.append(name)
logger.info("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids))))
return module_names
def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer"):
if "AutoConfig" in getattr(config, "auto_map", {}):
config.__class__.register_for_auto_class()
if "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
model.__class__.register_for_auto_class()
if "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
tokenizer.__class__.register_for_auto_class()
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...extras.constants import MOD_SUPPORTED_MODELS
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel
from ...hparams import ModelArguments
def load_mod_pretrained_model(**init_kwargs) -> "PreTrainedModel":
from MoD import AutoMoDModelForCausalLM
return AutoMoDModelForCausalLM.from_pretrained(**init_kwargs)
def convert_pretrained_model_to_mod(
model: "PreTrainedModel", config: "PretrainedConfig", model_args: "ModelArguments"
) -> "PreTrainedModel":
from MoD import apply_mod_to_hf
if getattr(config, "model_type", None) not in MOD_SUPPORTED_MODELS:
raise ValueError("Current model is not supported by mixture-of-depth.")
model = apply_mod_to_hf(model)
model = model.to(model_args.compute_dtype)
return model
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Sequence
import torch
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel
from ...hparams import ModelArguments
def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch.nn.Module"]) -> None:
require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
from deepspeed.utils import set_z3_leaf_modules # type: ignore
set_z3_leaf_modules(model, leaf_modules)
def add_z3_leaf_module(model: "PreTrainedModel") -> None:
r"""
Sets module as a leaf module to skip partitioning in deepspeed zero3.
"""
if not is_deepspeed_zero3_enabled():
return
if getattr(model.config, "model_type", None) == "dbrx":
from transformers.models.dbrx.modeling_dbrx import DbrxFFN
_set_z3_leaf_modules(model, [DbrxFFN])
if getattr(model.config, "model_type", None) == "jamba":
from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock
_set_z3_leaf_modules(model, [JambaSparseMoeBlock])
if getattr(model.config, "model_type", None) == "jetmoe":
from transformers.models.jetmoe.modeling_jetmoe import JetMoeMoA, JetMoeMoE
_set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE])
if getattr(model.config, "model_type", None) == "mixtral":
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
_set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
if getattr(model.config, "model_type", None) == "qwen2moe":
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
_set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if model_args.moe_aux_loss_coef is not None:
if getattr(config, "model_type", None) in ["jamba", "mixtral", "qwen2_moe"]:
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)
elif getattr(config, "model_type", None) == "deepseek":
setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef)
elif getattr(config, "model_type", None) == "jetmoe":
setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef)
if getattr(config, "model_type", None) in ["dbrx", "jamba", "jetmoe", "mixtral", "qwen2_moe"]:
setattr(config, "output_router_logits", is_trainable)
# Copyright 2024 Musab Gultekin and the LlamaFactory team.
#
# This code is based on the Musab Gultekin's functionary library.
# https://github.com/MeetKai/functionary/blob/main/functionary/train/packing/monkey_patch_packing.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# MIT License
#
# Copyright (c) 2023 Musab Gultekin
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from typing import TYPE_CHECKING, Tuple
import torch
import torch.nn.functional as F
import transformers.models
from transformers.utils.versions import require_version
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PretrainedConfig
from ...hparams import ModelArguments
logger = get_logger(__name__)
def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
r"""
Gets the sequnce lengths in the current batch.
e.g.
```python
# input
[
[1, 1, 2, 2, 2, 0],
[1, 2, 2, 3, 3, 3],
]
# output
[2, 3, 1, 2, 3]
```
"""
bsz = attention_mask.size(0)
dtype, device = attention_mask.dtype, attention_mask.device
max_num = torch.max(attention_mask).item()
counts: "torch.Tensor" = torch.zeros((bsz, max_num), dtype=dtype, device=device)
for i in range(max_num):
counts[:, i] = torch.sum(attention_mask == (i + 1), dim=-1)
counts = counts.flatten()
seqlens = counts[counts.nonzero().squeeze(dim=-1)]
return seqlens
def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "torch.Tensor", int]:
r"""
Prepares the indices and seqlens for flash attn varlen function.
Returns:
indices: indices of non-masked tokens from the flattened sequence.
cu_seqlens: the cumulative sequence lengths in the current batch, always starts from 0.
max_seqlen_in_batch: the largest seqlen in the current batch.
e.g.
```python
# input
[
[1, 1, 2, 2, 2, 0],
[1, 2, 2, 3, 3, 3],
]
# output
[0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11]
[0, 2, 5, 6, 8, 11]
3
```
"""
seqlens_in_batch = get_seqlens_in_batch(attention_mask)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return indices, cu_seqlens, max_seqlen_in_batch
def _patch_for_block_diag_attn(model_type: str) -> None:
require_version("transformers>=4.41.2,<=4.42.4", "To fix: pip install transformers>=4.41.2,<=4.42.4")
if model_type == "cohere":
transformers.models.cohere.modeling_cohere._get_unpad_data = get_unpad_data
elif model_type == "falcon":
transformers.models.falcon.modeling_falcon._get_unpad_data = get_unpad_data
elif model_type == "gemma":
transformers.models.gemma.modeling_gemma._get_unpad_data = get_unpad_data
elif model_type == "gemma2":
transformers.models.gemma2.modeling_gemma2._get_unpad_data = get_unpad_data
elif model_type == "llama":
transformers.models.llama.modeling_llama._get_unpad_data = get_unpad_data
elif model_type == "mistral":
transformers.models.mistral.modeling_mistral._get_unpad_data = get_unpad_data
elif model_type == "phi":
transformers.models.phi.modeling_phi._get_unpad_data = get_unpad_data
elif model_type == "phi3":
transformers.models.phi3.modeling_phi3._get_unpad_data = get_unpad_data
elif model_type == "qwen2":
transformers.models.qwen2.modeling_qwen2._get_unpad_data = get_unpad_data
elif model_type == "starcoder2":
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = get_unpad_data
def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.block_diag_attn:
return
model_type = getattr(config, "model_type", None)
if model_type in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN:
_patch_for_block_diag_attn(model_type)
logger.info("Using block diagonal attention for sequence packing without cross-attention.")
else:
raise ValueError("Current model does not support block diagonal attention.")
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's Transformers and Optimum library.
# https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/utils/quantization_config.py
# https://github.com/huggingface/optimum/blob/v1.20.0/optimum/gptq/data.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
from enum import Enum, unique
from typing import TYPE_CHECKING, Any, Dict, List
import torch
from datasets import load_dataset
from transformers import BitsAndBytesConfig, EetqConfig, GPTQConfig, HqqConfig
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from transformers.utils.versions import require_version
from ...extras.constants import FILEEXT2TYPE
from ...extras.logging import get_logger
from ...extras.misc import get_current_device
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer
from ...hparams import ModelArguments
logger = get_logger(__name__)
@unique
class QuantizationMethod(str, Enum):
r"""
Borrowed from `transformers.utils.quantization_config.QuantizationMethod`.
"""
BITS_AND_BYTES = "bitsandbytes"
GPTQ = "gptq"
AWQ = "awq"
AQLM = "aqlm"
QUANTO = "quanto"
EETQ = "eetq"
HQQ = "hqq"
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[Dict[str, Any]]:
r"""
Prepares the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization.
"""
if os.path.isfile(model_args.export_quantization_dataset):
data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None)
data_files = model_args.export_quantization_dataset
else:
data_path = model_args.export_quantization_dataset
data_files = None
dataset = load_dataset(
path=data_path,
data_files=data_files,
split="train",
cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token,
)
samples = []
maxlen = model_args.export_quantization_maxlen
for _ in range(model_args.export_quantization_nsamples):
n_try = 0
while True:
if n_try > 100:
raise ValueError("Cannot find satisfying example, considering decrease `export_quantization_maxlen`.")
sample_idx = random.randint(0, len(dataset) - 1)
sample: Dict[str, "torch.Tensor"] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
n_try += 1
if sample["input_ids"].size(1) > maxlen:
break # TODO: fix large maxlen
word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1)
input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen]
attention_mask = sample["attention_mask"][:, word_idx : word_idx + maxlen]
samples.append({"input_ids": input_ids.tolist(), "attention_mask": attention_mask.tolist()})
return samples
def configure_quantization(
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
init_kwargs: Dict[str, Any],
) -> None:
r"""
Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer)
"""
if getattr(config, "quantization_config", None): # ptq
if model_args.quantization_bit is not None:
logger.warning("`quantization_bit` will not affect on the PTQ-quantized models.")
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "")
if quant_method == QuantizationMethod.GPTQ:
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
quantization_config.pop("disable_exllama", None) # remove deprecated args
quantization_config["use_exllama"] = False # disable exllama
if quant_method == QuantizationMethod.AWQ:
require_version("autoawq", "To fix: pip install autoawq")
if quant_method == QuantizationMethod.AQLM:
require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0")
quantization_config["bits"] = 2
quant_bits = quantization_config.get("bits", "?")
logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper()))
elif model_args.export_quantization_bit is not None: # auto-gptq
if model_args.export_quantization_bit not in [8, 4, 3, 2]:
raise ValueError("AutoGPTQ only accepts 2/3/4/8-bit quantization.")
require_version("optimum>=1.17.0", "To fix: pip install optimum>=1.17.0")
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
from accelerate.utils import get_max_memory
if getattr(config, "model_type", None) == "chatglm":
raise ValueError("ChatGLM model is not supported yet.")
init_kwargs["quantization_config"] = GPTQConfig(
bits=model_args.export_quantization_bit,
dataset=_get_quantization_dataset(tokenizer, model_args),
)
init_kwargs["device_map"] = "auto"
init_kwargs["max_memory"] = get_max_memory()
logger.info("Quantizing model to {} bit with AutoGPTQ.".format(model_args.export_quantization_bit))
elif model_args.quantization_bit is not None: # on-the-fly
if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
elif model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
init_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=model_args.compute_dtype,
bnb_4bit_use_double_quant=model_args.double_quantization,
bnb_4bit_quant_type=model_args.quantization_type,
bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp+qlora
)
else:
raise ValueError("Bitsandbytes only accepts 4-bit or 8-bit quantization.")
# Do not assign device map if:
# 1. deepspeed zero3 or fsdp (train)
# 2. auto quantization device map (inference)
if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or model_args.quantization_device_map == "auto":
if model_args.quantization_bit != 4:
raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.")
require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
else:
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
logger.info("Quantizing model to {} bit with bitsandbytes.".format(model_args.quantization_bit))
elif model_args.quantization_method == QuantizationMethod.HQQ.value:
if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]:
raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.")
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
raise ValueError("HQQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.")
require_version("hqq", "To fix: pip install hqq")
init_kwargs["quantization_config"] = HqqConfig(
nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0
) # use ATEN kernel (axis=0) for performance
logger.info("Quantizing model to {} bit with HQQ.".format(model_args.quantization_bit))
elif model_args.quantization_method == QuantizationMethod.EETQ.value:
if model_args.quantization_bit != 8:
raise ValueError("EETQ only accepts 8-bit quantization.")
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
raise ValueError("EETQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.")
require_version("eetq", "To fix: pip install eetq")
init_kwargs["quantization_config"] = EetqConfig()
logger.info("Quantizing model to {} bit with EETQ.".format(model_args.quantization_bit))
# Copyright 2024 LMSYS and the LlamaFactory team.
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# This code is inspired by the LMSYS's FastChat library.
# https://github.com/lm-sys/FastChat/blob/v0.2.30/fastchat/train/train.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import TYPE_CHECKING
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PretrainedConfig
from ...hparams import ModelArguments
logger = get_logger(__name__)
def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if model_args.rope_scaling is None:
return
if not hasattr(config, "rope_scaling"):
logger.warning("Current model does not support RoPE scaling.")
return
if model_args.model_max_length is not None:
if is_trainable and model_args.rope_scaling == "dynamic":
logger.warning(
"Dynamic NTK scaling may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653"
)
current_max_length = getattr(config, "max_position_embeddings", None)
if current_max_length and model_args.model_max_length > current_max_length:
logger.info(
"Enlarge max model length from {} to {}.".format(current_max_length, model_args.model_max_length)
)
setattr(config, "max_position_embeddings", model_args.model_max_length)
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
else:
logger.warning("Input length is smaller than max length. Consider increase input length.")
scaling_factor = 1.0
else:
scaling_factor = 2.0
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
logger.info(
"Using {} scaling strategy and setting scaling factor to {}".format(model_args.rope_scaling, scaling_factor)
)
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, Optional
from ...extras.logging import get_logger
from ...extras.misc import get_current_device
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel
from ...hparams import ModelArguments
logger = get_logger(__name__)
def _get_unsloth_kwargs(
config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments"
) -> Dict[str, Any]:
return {
"model_name": model_name_or_path,
"max_seq_length": model_args.model_max_length or 4096,
"dtype": model_args.compute_dtype,
"load_in_4bit": model_args.quantization_bit == 4,
"token": model_args.hf_hub_token,
"device_map": {"": get_current_device()},
"rope_scaling": getattr(config, "rope_scaling", None),
"fix_tokenizer": False,
"trust_remote_code": True,
"use_gradient_checkpointing": "unsloth",
}
def load_unsloth_pretrained_model(
config: "PretrainedConfig", model_args: "ModelArguments"
) -> Optional["PreTrainedModel"]:
r"""
Optionally loads pretrained model with unsloth. Used in training.
"""
from unsloth import FastLanguageModel
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.model_name_or_path, model_args)
try:
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
except NotImplementedError:
logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
model = None
model_args.use_unsloth = False
return model
def get_unsloth_peft_model(
model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: Dict[str, Any]
) -> "PreTrainedModel":
r"""
Gets the peft model for the pretrained model with unsloth. Used in training.
"""
from unsloth import FastLanguageModel
unsloth_peft_kwargs = {
"model": model,
"max_seq_length": model_args.model_max_length,
"use_gradient_checkpointing": "unsloth",
}
return FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
def load_unsloth_peft_model(
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
) -> "PreTrainedModel":
r"""
Loads peft model with unsloth. Used in both training and inference.
"""
from unsloth import FastLanguageModel
unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args)
try:
if not is_trainable:
unsloth_kwargs["use_gradient_checkpointing"] = False
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
except NotImplementedError:
raise ValueError("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
if not is_trainable:
FastLanguageModel.for_inference(model)
return model
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict
import torch
from transformers.utils import cached_file
from ...extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PreTrainedModel
from ...hparams import ModelArguments
logger = get_logger(__name__)
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
r"""
Loads value head parameters from Hugging Face Hub or local disk.
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
"""
kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token}
err_text = ""
try:
from safetensors import safe_open
vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs)
with safe_open(vhead_file, framework="pt", device="cpu") as f:
return {key: f.get_tensor(key) for key in f.keys()}
except Exception as err:
err_text = str(err)
try:
vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs)
return torch.load(vhead_file, map_location="cpu")
except Exception as err:
err_text = str(err)
logger.info("Provided path ({}) does not contain value head weights: {}.".format(path_or_repo_id, err_text))
logger.info("Ignore the above message if you are not resuming the training of a value head model.")
return None
def prepare_valuehead_model(model: "PreTrainedModel") -> None:
if getattr(model.config, "model_type", None) == "llava":
setattr(model, "lm_head", model.language_model.get_output_embeddings())
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
if getattr(model.config, "model_type", None) == "chatglm":
setattr(model, "lm_head", model.transformer.output_layer)
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
if getattr(model.config, "model_type", None) == "internlm2":
setattr(model, "lm_head", model.output)
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's Transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/modeling_llava.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Tuple
import torch
import transformers.models
from transformers.activations import ACT2FN
from transformers.utils import logging
from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel
from ...hparams import ModelArguments
logger = get_logger(__name__)
transformers_logger = logging.get_logger(__name__)
class LlavaMultiModalProjectorForYiVL(torch.nn.Module):
def __init__(self, config: "LlavaConfig") -> None:
super().__init__()
self.config = config
if config is None:
return
self.linear_1 = torch.nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
self.linear_2 = torch.nn.LayerNorm(config.text_config.hidden_size, bias=True)
self.linear_3 = torch.nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
self.linear_4 = torch.nn.LayerNorm(config.text_config.hidden_size, bias=True)
self.act = ACT2FN[config.projector_hidden_act]
def forward(self, image_features: "torch.Tensor") -> "torch.Tensor":
hidden_states = self.linear_1(image_features)
hidden_states = self.linear_2(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_3(hidden_states)
hidden_states = self.linear_4(hidden_states)
if hidden_states.dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.linear_1.weight.dtype
transformers_logger.warning_once("The hidden states seems to be silently casted in float32.")
hidden_states = hidden_states.to(target_dtype)
return hidden_states
class LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL):
def __init__(self, vision_hidden_size: int, text_hidden_size: int, projector_hidden_act: str) -> None:
super().__init__(config=None)
self.linear_1 = torch.nn.Linear(vision_hidden_size, text_hidden_size, bias=True)
self.linear_2 = torch.nn.LayerNorm(text_hidden_size, bias=True)
self.linear_3 = torch.nn.Linear(text_hidden_size, text_hidden_size, bias=True)
self.linear_4 = torch.nn.LayerNorm(text_hidden_size, bias=True)
self.act = ACT2FN[projector_hidden_act]
def autocast_projector_dtype(
model: "PreTrainedModel", model_args: "ModelArguments", mm_projector_name: str = "multi_modal_projector"
) -> None:
def _mm_projector_forward_post_hook(
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
) -> "torch.Tensor":
return output.to(model_args.compute_dtype)
if hasattr(model, mm_projector_name) and getattr(model, "quantization_method", None):
logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype))
mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name)
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
def configure_visual_model(config: "PretrainedConfig") -> None:
if getattr(config, "model_type", None) == "llava": # required for ds zero3 and valuehead models
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
if getattr(config, "is_yi_vl_derived_model", None):
logger.info("Detected Yi-VL model, applying projector patch.")
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment