Commit 8293100a authored by luopl's avatar luopl
Browse files

update to 0.9.2.dev0

parent 2778a3d0
......@@ -17,7 +17,8 @@ from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments
from .generating_args import GeneratingArguments
from .model_args import ModelArguments
from .parser import get_eval_args, get_infer_args, get_train_args
from .parser import get_eval_args, get_infer_args, get_ray_args, get_train_args, read_args
from .training_args import RayArguments, TrainingArguments
__all__ = [
......@@ -26,7 +27,11 @@ __all__ = [
"FinetuningArguments",
"GeneratingArguments",
"ModelArguments",
"RayArguments",
"TrainingArguments",
"get_eval_args",
"get_infer_args",
"get_ray_args",
"get_train_args",
"read_args",
]
......@@ -15,8 +15,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Literal, Optional
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Literal, Optional
@dataclass
......@@ -99,7 +99,7 @@ class DataArguments:
)
val_size: float = field(
default=0.0,
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."},
metadata={"help": "Size of the validation set, should be an integer or a float in range `[0,1)`."},
)
packing: Optional[bool] = field(
default=None,
......@@ -161,3 +161,6 @@ class DataArguments:
if self.mask_history and self.train_on_prompt:
raise ValueError("`mask_history` is incompatible with `train_on_prompt`.")
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
......@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import List, Literal, Optional
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, List, Literal, Optional
@dataclass
......@@ -251,6 +251,59 @@ class GaloreArguments:
)
@dataclass
class ApolloArguments:
r"""
Arguments pertaining to the APOLLO algorithm.
"""
use_apollo: bool = field(
default=False,
metadata={"help": "Whether or not to use the APOLLO optimizer."},
)
apollo_target: str = field(
default="all",
metadata={
"help": (
"Name(s) of modules to apply APOLLO. Use commas to separate multiple modules. "
"Use `all` to specify all the linear modules."
)
},
)
apollo_rank: int = field(
default=16,
metadata={"help": "The rank of APOLLO gradients."},
)
apollo_update_interval: int = field(
default=200,
metadata={"help": "Number of steps to update the APOLLO projection."},
)
apollo_scale: float = field(
default=1.0,
metadata={"help": "APOLLO scaling coefficient."},
)
apollo_proj: Literal["svd", "random"] = field(
default="random",
metadata={"help": "Type of APOLLO low-rank projection algorithm (svd or random)."},
)
apollo_proj_type: Literal["std", "right", "left"] = field(
default="std",
metadata={"help": "Type of APOLLO projection."},
)
apollo_scale_type: Literal["channel", "tensor"] = field(
default="channel",
metadata={"help": "Type of APOLLO scaling (channel or tensor)."},
)
apollo_layerwise: bool = field(
default=False,
metadata={"help": "Whether or not to enable layer-wise update to further save memory."},
)
apollo_scale_front: bool = field(
default=False,
metadata={"help": "Whether or not to use the norm-growth limiter in front of gradient scaling."},
)
@dataclass
class BAdamArgument:
r"""
......@@ -305,7 +358,37 @@ class BAdamArgument:
@dataclass
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument):
class SwanLabArguments:
use_swanlab: bool = field(
default=False,
metadata={"help": "Whether or not to use the SwanLab (an experiment tracking and visualization tool)."},
)
swanlab_project: str = field(
default="llamafactory",
metadata={"help": "The project name in SwanLab."},
)
swanlab_workspace: str = field(
default=None,
metadata={"help": "The workspace name in SwanLab."},
)
swanlab_run_name: str = field(
default=None,
metadata={"help": "The experiment name in SwanLab."},
)
swanlab_mode: Literal["cloud", "local"] = field(
default="cloud",
metadata={"help": "The mode of SwanLab."},
)
swanlab_api_key: str = field(
default=None,
metadata={"help": "The API key for SwanLab."},
)
@dataclass
class FinetuningArguments(
FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, ApolloArguments, BAdamArgument, SwanLabArguments
):
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
......@@ -334,6 +417,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
default=True,
metadata={"help": "Whether ot not to freeze vision tower in MLLM training."},
)
freeze_multi_modal_projector: bool = field(
default=True,
metadata={"help": "Whether or not to freeze the multi modal projector in MLLM training."},
)
train_mm_proj_only: bool = field(
default=False,
metadata={"help": "Whether or not to train the multimodal projector for MLLM only."},
......@@ -342,6 +429,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
default=False,
metadata={"help": "Whether or not to compute the token-level accuracy at evaluation."},
)
disable_shuffling: bool = field(
default=False,
metadata={"help": "Whether or not to disable the shuffling of the training set."},
)
plot_loss: bool = field(
default=False,
metadata={"help": "Whether or not to save the training loss curves."},
......@@ -363,7 +454,9 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
self.lora_target: List[str] = split_arg(self.lora_target)
self.additional_target: Optional[List[str]] = split_arg(self.additional_target)
self.galore_target: List[str] = split_arg(self.galore_target)
self.apollo_target: List[str] = split_arg(self.apollo_target)
self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only
self.freeze_multi_modal_projector = self.freeze_multi_modal_projector and not self.train_mm_proj_only
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
......@@ -382,11 +475,11 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
if self.use_llama_pro and self.finetuning_type == "full":
raise ValueError("`use_llama_pro` is only valid for Freeze or LoRA training.")
if self.finetuning_type == "lora" and (self.use_galore or self.use_badam):
raise ValueError("Cannot use LoRA with GaLore or BAdam together.")
if self.finetuning_type == "lora" and (self.use_galore or self.use_apollo or self.use_badam):
raise ValueError("Cannot use LoRA with GaLore, APOLLO or BAdam together.")
if self.use_galore and self.use_badam:
raise ValueError("Cannot use GaLore with BAdam together.")
if int(self.use_galore) + int(self.use_apollo) + (self.use_badam) > 1:
raise ValueError("Cannot use GaLore, APOLLO or BAdam together.")
if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model):
raise ValueError("Cannot use PiSSA for current training stage.")
......@@ -406,3 +499,8 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
if self.pissa_init:
raise ValueError("`pissa_init` is only valid for LoRA training.")
def to_dict(self) -> Dict[str, Any]:
args = asdict(self)
args = {k: f"<{k.upper()}>" if k.endswith("api_key") else v for k, v in args.items()}
return args
......@@ -15,6 +15,8 @@
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Optional
from transformers import GenerationConfig
@dataclass
class GeneratingArguments:
......@@ -64,11 +66,22 @@ class GeneratingArguments:
default=None,
metadata={"help": "Default system message to use in chat completion."},
)
skip_special_tokens: bool = field(
default=True,
metadata={"help": "Whether or not to remove special tokens in the decoding."},
)
def to_dict(self) -> Dict[str, Any]:
def to_dict(self, obey_generation_config: bool = False) -> Dict[str, Any]:
args = asdict(self)
if args.get("max_new_tokens", -1) > 0:
args.pop("max_length", None)
else:
args.pop("max_new_tokens", None)
if obey_generation_config:
generation_config = GenerationConfig()
for key in list(args.keys()):
if not hasattr(generation_config, key):
args.pop(key)
return args
......@@ -16,7 +16,7 @@
# limitations under the License.
import json
from dataclasses import dataclass, field, fields
from dataclasses import asdict, dataclass, field, fields
from typing import Any, Dict, Literal, Optional, Union
import torch
......@@ -237,6 +237,10 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
default=False,
metadata={"help": "Whether or not to disable gradient checkpointing."},
)
use_reentrant_gc: bool = field(
default=True,
metadata={"help": "Whether or not to use reentrant gradient checkpointing."},
)
upcast_layernorm: bool = field(
default=False,
metadata={"help": "Whether or not to upcast the layernorm weights in fp32."},
......@@ -281,6 +285,10 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
default=False,
metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
)
trust_remote_code: bool = field(
default=False,
metadata={"help": "Whether to trust the execution of code from datasets/models defined on the Hub or not."},
)
compute_dtype: Optional[torch.dtype] = field(
default=None,
init=False,
......@@ -336,3 +344,8 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
setattr(result, name, value)
return result
def to_dict(self) -> Dict[str, Any]:
args = asdict(self)
args = {k: f"<{k.upper()}>" if k.endswith("token") else v for k, v in args.items()}
return args
......@@ -15,56 +15,67 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import sys
from typing import Any, Dict, Optional, Tuple
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import transformers
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
import yaml
from transformers import HfArgumentParser
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.trainer_utils import get_last_checkpoint
from transformers.training_args import ParallelMode
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
from transformers.utils.versions import require_version
from ..extras import logging
from ..extras.constants import CHECKPOINT_NAMES
from ..extras.misc import check_dependencies, get_current_device
from ..extras.misc import check_dependencies, check_version, get_current_device
from .data_args import DataArguments
from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments
from .generating_args import GeneratingArguments
from .model_args import ModelArguments
from .training_args import RayArguments, TrainingArguments
logger = logging.get_logger(__name__)
check_dependencies()
_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
def read_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> Union[Dict[str, Any], List[str]]:
if args is not None:
return parser.parse_dict(args)
return args
if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")):
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
return yaml.safe_load(Path(sys.argv[1]).absolute().read_text())
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
return json.loads(Path(sys.argv[1]).absolute().read_text())
else:
return sys.argv[1:]
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
def _parse_args(
parser: "HfArgumentParser", args: Optional[Union[Dict[str, Any], List[str]]] = None, allow_extra_keys: bool = False
) -> Tuple[Any]:
args = read_args(args)
if isinstance(args, dict):
return parser.parse_dict(args, allow_extra_keys=allow_extra_keys)
if unknown_args:
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args=args, return_remaining_strings=True)
if unknown_args and not allow_extra_keys:
print(parser.format_help())
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
......@@ -110,54 +121,64 @@ def _verify_model_args(
def _check_extra_dependencies(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
training_args: Optional["Seq2SeqTrainingArguments"] = None,
training_args: Optional["TrainingArguments"] = None,
) -> None:
if model_args.use_unsloth:
require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth")
check_version("unsloth", mandatory=True)
if model_args.enable_liger_kernel:
require_version("liger-kernel", "To fix: pip install liger-kernel")
check_version("liger-kernel", mandatory=True)
if model_args.mixture_of_depths is not None:
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
check_version("mixture-of-depth>=1.1.6", mandatory=True)
if model_args.infer_backend == "vllm":
require_version("vllm>=0.4.3,<0.6.4", "To fix: pip install vllm>=0.4.3,<0.6.4")
check_version("vllm>=0.4.3,<=0.6.5")
check_version("vllm", mandatory=True)
if finetuning_args.use_galore:
require_version("galore_torch", "To fix: pip install galore_torch")
check_version("galore_torch", mandatory=True)
if finetuning_args.use_apollo:
check_version("apollo_torch", mandatory=True)
if finetuning_args.use_badam:
require_version("badam>=1.2.1", "To fix: pip install badam>=1.2.1")
check_version("badam>=1.2.1", mandatory=True)
if finetuning_args.use_adam_mini:
require_version("adam-mini", "To fix: pip install adam-mini")
check_version("adam-mini", mandatory=True)
if finetuning_args.plot_loss:
require_version("matplotlib", "To fix: pip install matplotlib")
check_version("matplotlib", mandatory=True)
if training_args is not None and training_args.predict_with_generate:
require_version("jieba", "To fix: pip install jieba")
require_version("nltk", "To fix: pip install nltk")
require_version("rouge_chinese", "To fix: pip install rouge-chinese")
check_version("jieba", mandatory=True)
check_version("nltk", mandatory=True)
check_version("rouge_chinese", mandatory=True)
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
def _parse_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS)
return _parse_args(parser, args)
def _parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
def _parse_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS:
parser = HfArgumentParser(_INFER_ARGS)
return _parse_args(parser, args)
def _parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
def _parse_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS:
parser = HfArgumentParser(_EVAL_ARGS)
return _parse_args(parser, args)
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
def get_ray_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> RayArguments:
parser = HfArgumentParser(RayArguments)
(ray_args,) = _parse_args(parser, args, allow_extra_keys=True)
return ray_args
def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS:
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
# Setup logging
......@@ -237,21 +258,21 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if is_deepspeed_zero3_enabled():
raise ValueError("`pure_bf16` is incompatible with DeepSpeed ZeRO-3.")
if (
finetuning_args.use_galore
and finetuning_args.galore_layerwise
and training_args.parallel_mode == ParallelMode.DISTRIBUTED
):
if training_args.parallel_mode == ParallelMode.DISTRIBUTED:
if finetuning_args.use_galore and finetuning_args.galore_layerwise:
raise ValueError("Distributed training does not support layer-wise GaLore.")
if finetuning_args.use_badam and training_args.parallel_mode == ParallelMode.DISTRIBUTED:
if finetuning_args.use_apollo and finetuning_args.apollo_layerwise:
raise ValueError("Distributed training does not support layer-wise APOLLO.")
if finetuning_args.use_badam:
if finetuning_args.badam_mode == "ratio":
raise ValueError("Radio-based BAdam does not yet support distributed training, use layer-wise BAdam.")
elif not is_deepspeed_zero3_enabled():
raise ValueError("Layer-wise BAdam only supports DeepSpeed ZeRO-3 training.")
if finetuning_args.use_galore and training_args.deepspeed is not None:
raise ValueError("GaLore is incompatible with DeepSpeed yet.")
if training_args.deepspeed is not None and (finetuning_args.use_galore or finetuning_args.use_apollo):
raise ValueError("GaLore and APOLLO are incompatible with DeepSpeed yet.")
if model_args.infer_backend == "vllm":
raise ValueError("vLLM backend is only available for API, CLI and Web.")
......@@ -283,9 +304,13 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
logger.warning_rank0("We recommend enable mixed precision training.")
if training_args.do_train and finetuning_args.use_galore and not finetuning_args.pure_bf16:
if (
training_args.do_train
and (finetuning_args.use_galore or finetuning_args.use_apollo)
and not finetuning_args.pure_bf16
):
logger.warning_rank0(
"Using GaLore with mixed precision training may significantly increases GPU memory usage."
"Using GaLore or APOLLO with mixed precision training may significantly increases GPU memory usage."
)
if (not training_args.do_train) and model_args.quantization_bit is not None:
......@@ -361,13 +386,12 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
str(model_args.compute_dtype),
)
)
transformers.set_seed(training_args.seed)
return model_args, data_args, training_args, finetuning_args, generating_args
def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS:
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
_set_transformers_logging()
......@@ -400,7 +424,7 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
return model_args, data_args, finetuning_args, generating_args
def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
def get_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS:
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
_set_transformers_logging()
......
import json
from dataclasses import dataclass, field
from typing import Literal, Optional, Union
from transformers import Seq2SeqTrainingArguments
from transformers.training_args import _convert_str_dict
from ..extras.misc import use_ray
@dataclass
class RayArguments:
r"""
Arguments pertaining to the Ray training.
"""
ray_run_name: Optional[str] = field(
default=None,
metadata={"help": "The training results will be saved at `saves/ray_run_name`."},
)
ray_num_workers: int = field(
default=1,
metadata={"help": "The number of workers for Ray training. Default is 1 worker."},
)
resources_per_worker: Union[dict, str] = field(
default_factory=lambda: {"GPU": 1},
metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
)
placement_strategy: Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"] = field(
default="PACK",
metadata={"help": "The placement strategy for Ray training. Default is PACK."},
)
def __post_init__(self):
self.use_ray = use_ray()
if isinstance(self.resources_per_worker, str) and self.resources_per_worker.startswith("{"):
self.resources_per_worker = _convert_str_dict(json.loads(self.resources_per_worker))
@dataclass
class TrainingArguments(RayArguments, Seq2SeqTrainingArguments):
r"""
Arguments pertaining to the trainer.
"""
def __post_init__(self):
Seq2SeqTrainingArguments.__post_init__(self)
RayArguments.__post_init__(self)
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict
import torch
......@@ -52,7 +53,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
skip_check_imports()
model_args.model_name_or_path = try_download_model_from_other_hub(model_args)
return {
"trust_remote_code": True,
"trust_remote_code": model_args.trust_remote_code,
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"token": model_args.hf_hub_token,
......@@ -85,6 +86,9 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
except Exception as e:
raise OSError("Failed to load tokenizer.") from e
if model_args.model_max_length is not None and tokenizer.model_max_length != model_args.model_max_length:
tokenizer.model_max_length = model_args.model_max_length
if model_args.new_special_tokens is not None:
num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=model_args.new_special_tokens),
......@@ -155,7 +159,7 @@ def load_model(
load_class = AutoModelForCausalLM
if model_args.train_from_scratch:
model = load_class.from_config(config, trust_remote_code=True)
model = load_class.from_config(config, trust_remote_code=model_args.trust_remote_code)
else:
model = load_class.from_pretrained(**init_kwargs)
......@@ -202,12 +206,8 @@ def load_model(
logger.info_rank0(param_stats)
if model_args.print_param_status:
if model_args.print_param_status and int(os.getenv("LOCAL_RANK", "0")) == 0:
for name, param in model.named_parameters():
print(
"name: {}, dtype: {}, device: {}, trainable: {}".format(
name, param.dtype, param.device, param.requires_grad
)
)
print(f"name: {name}, dtype: {param.dtype}, device: {param.device}, trainable: {param.requires_grad}")
return model
......@@ -15,9 +15,9 @@
from typing import TYPE_CHECKING
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
from transformers.utils.versions import require_version
from ...extras import logging
from ...extras.misc import check_version
if TYPE_CHECKING:
......@@ -35,8 +35,8 @@ def configure_attn_implementation(
if getattr(config, "model_type", None) == "gemma2" and is_trainable:
if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2":
if is_flash_attn_2_available():
require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4")
require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3")
check_version("transformers>=4.42.4")
check_version("flash_attn>=2.6.3")
if model_args.flash_attn != "fa2":
logger.warning_rank0("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
model_args.flash_attn = "fa2"
......
......@@ -122,7 +122,7 @@ def _gradient_checkpointing_enable(
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
self.apply(partial(self._set_gradient_checkpointing, value=True))
self.enable_input_require_grads()
logger.warning_once("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
logger.warning_rank0_once("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
else: # have already enabled input require gradients
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
......@@ -156,7 +156,9 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
_gradient_checkpointing_enable, use_unsloth_gc=model_args.use_unsloth_gc
)
model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": model_args.use_reentrant_gc}
)
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
logger.info_rank0("Gradient checkpointing enabled.")
......
......@@ -23,21 +23,18 @@ from typing import TYPE_CHECKING, Optional, Tuple
import torch
import torch.nn as nn
import transformers
from transformers.models.llama.modeling_llama import (
Cache,
LlamaAttention,
LlamaFlashAttention2,
LlamaSdpaAttention,
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.utils.versions import require_version
from transformers.models.llama.modeling_llama import Cache, apply_rotary_pos_emb, repeat_kv
from ...extras import logging
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
from ...extras.misc import check_version
from ...extras.packages import is_transformers_version_greater_than
if not is_transformers_version_greater_than("4.48.0"):
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaFlashAttention2, LlamaSdpaAttention
if TYPE_CHECKING:
from transformers import PretrainedConfig
......@@ -353,7 +350,7 @@ def llama_sdpa_attention_forward(
def _apply_llama_patch() -> None:
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
check_version("transformers>=4.41.2,<=4.46.1")
LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
......
......@@ -15,6 +15,7 @@
from typing import TYPE_CHECKING, List
from ...extras import logging
from .visual import COMPOSITE_MODELS
if TYPE_CHECKING:
......@@ -26,7 +27,7 @@ logger = logging.get_logger(__name__)
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]:
r"""
Finds all available modules to apply lora or galore.
Finds all available modules to apply LoRA, GaLore or APOLLO.
"""
model_type = getattr(model.config, "model_type", None)
forbidden_modules = {"lm_head"}
......@@ -34,18 +35,12 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
forbidden_modules.add("output_layer")
elif model_type == "internlm2":
forbidden_modules.add("output")
elif model_type in ["llava", "llava_next", "llava_next_video", "mllama", "paligemma", "video_llava"]:
forbidden_modules.add("multi_modal_projector")
elif model_type == "qwen2_vl":
forbidden_modules.add("merger")
if freeze_vision_tower:
if model_type == "mllama":
forbidden_modules.add("vision_model")
elif model_type == "qwen2_vl":
forbidden_modules.add("visual")
else:
forbidden_modules.add("vision_tower")
if model_type in COMPOSITE_MODELS:
forbidden_modules.add(COMPOSITE_MODELS[model_type].projector_key)
if freeze_vision_tower and model_type in COMPOSITE_MODELS:
forbidden_modules.update(COMPOSITE_MODELS[model_type].vision_model_keys)
module_names = set()
for name, module in model.named_modules():
......
......@@ -16,7 +16,8 @@ from typing import TYPE_CHECKING, Sequence
import torch
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
from ...extras.misc import check_version
if TYPE_CHECKING:
......@@ -26,7 +27,7 @@ if TYPE_CHECKING:
def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch.nn.Module"]) -> None:
require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
check_version("deepspeed>=0.13.0")
from deepspeed.utils import set_z3_leaf_modules # type: ignore
set_z3_leaf_modules(model, leaf_modules)
......
......@@ -41,16 +41,17 @@ from typing import TYPE_CHECKING, Tuple
import torch
import torch.nn.functional as F
from transformers.utils.versions import require_version
from ...extras import logging
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
from ...extras.misc import check_version
from ...extras.packages import is_transformers_version_greater_than
if TYPE_CHECKING:
from transformers import PretrainedConfig
if is_transformers_version_greater_than("4.43.0"):
import transformers.modeling_flash_attention_utils
if TYPE_CHECKING:
from ...hparams import ModelArguments
......@@ -113,45 +114,10 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
return indices, cu_seqlens, max_seqlen_in_batch
def _patch_for_block_diag_attn(model_type: str) -> None:
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
if is_transformers_version_greater_than("4.43.0"):
import transformers.modeling_flash_attention_utils
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
return
import transformers.models
if model_type == "cohere":
transformers.models.cohere.modeling_cohere._get_unpad_data = get_unpad_data
elif model_type == "falcon":
transformers.models.falcon.modeling_falcon._get_unpad_data = get_unpad_data
elif model_type == "gemma":
transformers.models.gemma.modeling_gemma._get_unpad_data = get_unpad_data
elif model_type == "gemma2":
transformers.models.gemma2.modeling_gemma2._get_unpad_data = get_unpad_data
elif model_type == "llama":
transformers.models.llama.modeling_llama._get_unpad_data = get_unpad_data
elif model_type == "mistral":
transformers.models.mistral.modeling_mistral._get_unpad_data = get_unpad_data
elif model_type == "phi":
transformers.models.phi.modeling_phi._get_unpad_data = get_unpad_data
elif model_type == "phi3":
transformers.models.phi3.modeling_phi3._get_unpad_data = get_unpad_data
elif model_type == "qwen2":
transformers.models.qwen2.modeling_qwen2._get_unpad_data = get_unpad_data
elif model_type == "starcoder2":
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = get_unpad_data
def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
def configure_packing(model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.block_diag_attn:
return
model_type = getattr(config, "model_type", None)
if model_type in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN:
_patch_for_block_diag_attn(model_type)
check_version("transformers>=4.43.0,<=4.46.1")
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")
else:
raise ValueError("Current model does not support block diagonal attention.")
......@@ -26,11 +26,10 @@ from datasets import load_dataset
from transformers import BitsAndBytesConfig, EetqConfig, GPTQConfig, HqqConfig
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from transformers.utils.versions import require_version
from ...extras import logging
from ...extras.constants import FILEEXT2TYPE
from ...extras.misc import get_current_device
from ...extras.misc import check_version, get_current_device
if TYPE_CHECKING:
......@@ -118,15 +117,15 @@ def configure_quantization(
quant_method = quantization_config.get("quant_method", "")
if quant_method == QuantizationMethod.GPTQ:
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
check_version("auto_gptq>=0.5.0", mandatory=True)
quantization_config.pop("disable_exllama", None) # remove deprecated args
quantization_config["use_exllama"] = False # disable exllama
if quant_method == QuantizationMethod.AWQ:
require_version("autoawq", "To fix: pip install autoawq")
check_version("autoawq", mandatory=True)
if quant_method == QuantizationMethod.AQLM:
require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0")
check_version("aqlm>=1.1.0", mandatory=True)
quantization_config["bits"] = 2
quant_bits = quantization_config.get("bits", "?")
......@@ -136,8 +135,8 @@ def configure_quantization(
if model_args.export_quantization_bit not in [8, 4, 3, 2]:
raise ValueError("AutoGPTQ only accepts 2/3/4/8-bit quantization.")
require_version("optimum>=1.17.0", "To fix: pip install optimum>=1.17.0")
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
check_version("optimum>=1.17.0", mandatory=True)
check_version("auto_gptq>=0.5.0", mandatory=True)
from accelerate.utils import get_max_memory
if getattr(config, "model_type", None) == "chatglm":
......@@ -154,10 +153,10 @@ def configure_quantization(
elif model_args.quantization_bit is not None: # on-the-fly
if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
check_version("bitsandbytes>=0.37.0", mandatory=True)
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
elif model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
check_version("bitsandbytes>=0.39.0", mandatory=True)
init_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=model_args.compute_dtype,
......@@ -175,7 +174,7 @@ def configure_quantization(
if model_args.quantization_bit != 4:
raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.")
require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
check_version("bitsandbytes>=0.43.0", mandatory=True)
else:
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
......@@ -187,7 +186,7 @@ def configure_quantization(
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
raise ValueError("HQQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.")
require_version("hqq", "To fix: pip install hqq")
check_version("hqq", mandatory=True)
init_kwargs["quantization_config"] = HqqConfig(
nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0
) # use ATEN kernel (axis=0) for performance
......@@ -199,6 +198,6 @@ def configure_quantization(
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
raise ValueError("EETQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.")
require_version("eetq", "To fix: pip install eetq")
check_version("eetq", mandatory=True)
init_kwargs["quantization_config"] = EetqConfig()
logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with EETQ.")
......@@ -39,7 +39,7 @@ def _get_unsloth_kwargs(
"device_map": {"": get_current_device()},
"rope_scaling": getattr(config, "rope_scaling", None),
"fix_tokenizer": False,
"trust_remote_code": True,
"trust_remote_code": model_args.trust_remote_code,
"use_gradient_checkpointing": "unsloth",
}
......
......@@ -15,7 +15,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, List, Sequence, Set, Tuple, Union
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple, Union
import torch
import transformers
......@@ -35,6 +36,40 @@ logger = logging.get_logger(__name__)
transformers_logger = transformers.utils.logging.get_logger(__name__)
@dataclass
class CompositeModel:
model_type: str
projector_key: str
vision_model_keys: List[str]
language_model_keys: List[str]
def get_projector(self, module: "torch.nn.Module") -> "torch.nn.Module":
for key in self.projector_key.split("."):
module = getattr(module, key)
return module
COMPOSITE_MODELS: Dict[str, "CompositeModel"] = {}
def _register_composite_model(
model_type: str,
projector_key: Optional[str] = None,
vision_model_keys: Optional[List[str]] = None,
language_model_keys: Optional[List[str]] = None,
):
projector_key = projector_key or "multi_modal_projector"
vision_model_keys = vision_model_keys or ["vision_tower"]
language_model_keys = language_model_keys or ["language_model"]
COMPOSITE_MODELS[model_type] = CompositeModel(
model_type=model_type,
projector_key=projector_key,
vision_model_keys=vision_model_keys,
language_model_keys=language_model_keys,
)
class LlavaMultiModalProjectorForYiVL(torch.nn.Module):
def __init__(self, config: "LlavaConfig") -> None:
super().__init__()
......@@ -92,10 +127,8 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
if getattr(model, "quantization_method", None):
model_type = getattr(model.config, "model_type", None)
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "pixtral", "video_llava"]:
mm_projector: "torch.nn.Module" = getattr(model, "multi_modal_projector")
elif model_type == "qwen2_vl":
mm_projector: "torch.nn.Module" = getattr(getattr(model, "visual"), "merger")
if model_type in COMPOSITE_MODELS:
mm_projector = COMPOSITE_MODELS[model_type].get_projector(model)
else:
return
......@@ -107,15 +140,8 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
r"""
Patches VLMs before loading them.
"""
model_type = getattr(config, "model_type", None)
if model_type in [
"llava",
"llava_next",
"llava_next_video",
"paligemma",
"pixtral",
"video_llava",
]: # required for ds zero3 and valuehead models
if getattr(config, "text_config", None) and not getattr(config, "hidden_size", None):
# required for ds zero3 and valuehead models
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
if getattr(config, "is_yi_vl_derived_model", None):
......@@ -129,19 +155,21 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
"""
model_type = getattr(config, "model_type", None)
forbidden_modules = set()
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "pixtral", "video_llava"]:
if model_type in COMPOSITE_MODELS:
if finetuning_args.freeze_vision_tower:
forbidden_modules.add("vision_tower")
if finetuning_args.train_mm_proj_only:
forbidden_modules.add("language_model")
vision_model_keys = COMPOSITE_MODELS[model_type].vision_model_keys
logger.info_rank0(f"Set vision model not trainable: {vision_model_keys}.")
forbidden_modules.update(vision_model_keys)
elif model_type == "qwen2_vl":
if finetuning_args.freeze_vision_tower:
forbidden_modules.add("visual")
if finetuning_args.freeze_multi_modal_projector:
projector_key = COMPOSITE_MODELS[model_type].projector_key
logger.info_rank0(f"Set multi model projector not trainable: {projector_key}.")
forbidden_modules.add(projector_key)
if finetuning_args.train_mm_proj_only:
raise ValueError("Qwen2-VL models do not support `train_mm_proj_only`.")
language_model_keys = COMPOSITE_MODELS[model_type].language_model_keys
logger.info_rank0(f"Set language model not trainable: {language_model_keys}.")
forbidden_modules.update(language_model_keys)
return forbidden_modules
......@@ -188,19 +216,73 @@ def patch_target_modules(
Freezes vision tower for VLM LoRA tuning.
"""
model_type = getattr(config, "model_type", None)
vit_model_type = getattr(getattr(config, "vision_config", None), "model_type", None)
if finetuning_args.freeze_vision_tower:
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "pixtral", "video_llava"]:
return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
elif model_type == "mllama":
return "^(?!.*vision_model).*(?:{}).*".format("|".join(target_modules))
elif model_type == "qwen2_vl":
return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules))
if model_type in COMPOSITE_MODELS:
vision_model_keys = COMPOSITE_MODELS[model_type].vision_model_keys
logger.info_rank0(f"Set vision model not trainable: {vision_model_keys}.")
vision_model_keys = "|".join(vision_model_keys)
target_modules = "|".join(target_modules)
return f"^(?!.*{vision_model_keys}).*(?:{target_modules}).*"
else:
return target_modules
else:
if model_type == "qwen2_vl":
if model_type == "qwen2_vl": # avoid attaching lora to Conv3D layer
return "^(?!.*patch_embed).*(?:{}).*".format("|".join(target_modules))
elif model_type == "pixtral":
elif vit_model_type == "pixtral":
return "^(?!.*patch_conv).*(?:{}).*".format("|".join(target_modules))
else:
return target_modules
_register_composite_model(
model_type="llava",
)
_register_composite_model(
model_type="llava_next",
)
_register_composite_model(
model_type="llava_next_video",
)
_register_composite_model(
model_type="minicpmv",
vision_model_keys=["vpm"],
language_model_keys=["llm"],
)
_register_composite_model(
model_type="minicpmo",
vision_model_keys=["vpm", "apm", "resampler", "tts"],
language_model_keys=["llm"],
)
_register_composite_model(
model_type="paligemma",
)
_register_composite_model(
model_type="video_llava",
)
_register_composite_model(
model_type="mllama",
vision_model_keys=["vision_model"],
)
_register_composite_model(
model_type="qwen2_vl",
projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["model", "lm_head"],
)
......@@ -24,6 +24,7 @@ from transformers.modeling_utils import is_fsdp_enabled
from ..extras import logging
from ..extras.misc import infer_optim_dtype
from ..extras.packages import is_transformers_version_greater_than
from .model_utils.attention import configure_attn_implementation, print_attn_implementation
from .model_utils.checkpointing import prepare_model_for_training
from .model_utils.embedding import resize_embedding_layer
......@@ -96,7 +97,7 @@ def patch_config(
configure_quantization(config, tokenizer, model_args, init_kwargs)
configure_moe(config, model_args, is_trainable)
configure_visual_model(config)
configure_packing(config, model_args, is_trainable)
configure_packing(model_args, is_trainable)
if model_args.use_cache and not is_trainable:
setattr(config, "use_cache", True)
......@@ -110,9 +111,16 @@ def patch_config(
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2":
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn
if getattr(config, "model_type", None) == "minicpmo":
setattr(config, "init_audio", False)
setattr(config, "init_tts", False)
if "LlavaLlamaForCausalLM" in getattr(config, "architectures", []):
raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")
if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"):
raise RuntimeError("InternLM3 model requires transformers>=4.47.1, please upgrade it.")
# deepspeed zero3 is not compatible with low_cpu_mem_usage
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
......@@ -145,7 +153,9 @@ def patch_model(
):
gen_config.do_sample = True
if "GenerationMixin" not in str(model.generate.__func__):
if getattr(model.config, "model_type", None) not in ["minicpmv", "minicpmo"] and "GenerationMixin" not in str(
model.generate.__func__
):
model.generate = MethodType(PreTrainedModel.generate, model)
if add_valuehead:
......
......@@ -35,17 +35,20 @@ from typing_extensions import override
from ..extras import logging
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.misc import get_peak_memory
from ..extras.misc import get_peak_memory, use_ray
if is_safetensors_available():
from safetensors import safe_open
from safetensors.torch import save_file
if TYPE_CHECKING:
from transformers import TrainerControl, TrainerState, TrainingArguments
from trl import AutoModelForCausalLMWithValueHead
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
logger = logging.get_logger(__name__)
......@@ -101,9 +104,6 @@ class FixValueHeadModelCallback(TrainerCallback):
@override
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after a checkpoint save.
"""
if args.should_save:
output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
fix_valuehead_checkpoint(
......@@ -138,9 +138,6 @@ class PissaConvertCallback(TrainerCallback):
@override
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the beginning of training.
"""
if args.should_save:
model = kwargs.pop("model")
pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
......@@ -197,7 +194,7 @@ class LogCallback(TrainerCallback):
self.do_train = False
# Web UI
self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
if self.webui_mode:
if self.webui_mode and not use_ray():
signal.signal(signal.SIGABRT, self._set_abort)
self.logger_handler = logging.LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR"))
logging.add_handler(self.logger_handler)
......@@ -242,7 +239,7 @@ class LogCallback(TrainerCallback):
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
and args.overwrite_output_dir
):
logger.warning_once("Previous trainer log in this folder will be deleted.")
logger.warning_rank0_once("Previous trainer log in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, TRAINER_LOG))
@override
......@@ -348,3 +345,51 @@ class LogCallback(TrainerCallback):
remaining_time=self.remaining_time,
)
self.thread_pool.submit(self._write_log, args.output_dir, logs)
class ReporterCallback(TrainerCallback):
r"""
A callback for reporting training status to external logger.
"""
def __init__(
self,
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
) -> None:
self.model_args = model_args
self.data_args = data_args
self.finetuning_args = finetuning_args
self.generating_args = generating_args
os.environ["WANDB_PROJECT"] = os.getenv("WANDB_PROJECT", "llamafactory")
@override
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
if not state.is_world_process_zero:
return
if "wandb" in args.report_to:
import wandb
wandb.config.update(
{
"model_args": self.model_args.to_dict(),
"data_args": self.data_args.to_dict(),
"finetuning_args": self.finetuning_args.to_dict(),
"generating_args": self.generating_args.to_dict(),
}
)
if self.finetuning_args.use_swanlab:
import swanlab # type: ignore
swanlab.config.update(
{
"model_args": self.model_args.to_dict(),
"data_args": self.data_args.to_dict(),
"finetuning_args": self.finetuning_args.to_dict(),
"generating_args": self.generating_args.to_dict(),
}
)
......@@ -19,7 +19,7 @@ import warnings
from collections import defaultdict
from contextlib import nullcontext
from types import MethodType
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
import torch
import torch.nn.functional as F
......@@ -29,9 +29,9 @@ from trl.trainer import disable_dropout_in_model
from typing_extensions import override
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, nested_detach
if TYPE_CHECKING:
......@@ -50,6 +50,9 @@ class CustomDPOTrainer(DPOTrainer):
disable_dropout: bool = True,
**kwargs,
):
if is_transformers_version_greater_than("4.46"):
kwargs["processing_class"] = kwargs.pop("tokenizer")
if disable_dropout:
disable_dropout_in_model(model)
if ref_model is not None:
......@@ -79,6 +82,7 @@ class CustomDPOTrainer(DPOTrainer):
self.simpo_gamma = finetuning_args.simpo_gamma
Trainer.__init__(self, model=model, **kwargs)
self.model_accepts_loss_kwargs = False # overwrite trainer's default behavior
if not hasattr(self, "accelerator"):
raise AttributeError("Please update `transformers`.")
......@@ -97,9 +101,6 @@ class CustomDPOTrainer(DPOTrainer):
if processor is not None:
self.add_callback(SaveProcessorCallback(processor))
if finetuning_args.pissa_convert:
self.callback_handler.add_callback(PissaConvertCallback)
if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
......@@ -119,6 +120,13 @@ class CustomDPOTrainer(DPOTrainer):
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
@override
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset)
return super()._get_train_sampler()
@override
def get_batch_samples(self, epoch_iterator, num_batches):
r"""
......@@ -185,7 +193,7 @@ class CustomDPOTrainer(DPOTrainer):
Otherwise the average log probabilities.
"""
if self.finetuning_args.use_ref_model:
batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error
batch = nested_detach(batch, clone=True) # avoid error
all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"])
......@@ -266,17 +274,18 @@ class CustomDPOTrainer(DPOTrainer):
return losses.mean(), metrics
@override
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
"""
loss = super().compute_loss(model, inputs, return_outputs)
if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False):
if is_transformers_version_equal_to_4_46() and kwargs.get("num_items_in_batch"):
if return_outputs:
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
loss = (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
else:
return loss / self.args.gradient_accumulation_steps
loss = loss / self.args.gradient_accumulation_steps
return loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment