".github/workflows/build-conda-windows.yml" did not exist on "9135b544f343007025db7db007ecd13666de896b"
Commit 8293100a authored by luopl's avatar luopl
Browse files

update to 0.9.2.dev0

parent 2778a3d0
......@@ -17,7 +17,8 @@ from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments
from .generating_args import GeneratingArguments
from .model_args import ModelArguments
from .parser import get_eval_args, get_infer_args, get_train_args
from .parser import get_eval_args, get_infer_args, get_ray_args, get_train_args, read_args
from .training_args import RayArguments, TrainingArguments
__all__ = [
......@@ -26,7 +27,11 @@ __all__ = [
"FinetuningArguments",
"GeneratingArguments",
"ModelArguments",
"RayArguments",
"TrainingArguments",
"get_eval_args",
"get_infer_args",
"get_ray_args",
"get_train_args",
"read_args",
]
......@@ -15,8 +15,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Literal, Optional
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Literal, Optional
@dataclass
......@@ -99,7 +99,7 @@ class DataArguments:
)
val_size: float = field(
default=0.0,
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."},
metadata={"help": "Size of the validation set, should be an integer or a float in range `[0,1)`."},
)
packing: Optional[bool] = field(
default=None,
......@@ -161,3 +161,6 @@ class DataArguments:
if self.mask_history and self.train_on_prompt:
raise ValueError("`mask_history` is incompatible with `train_on_prompt`.")
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
......@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import List, Literal, Optional
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, List, Literal, Optional
@dataclass
......@@ -251,6 +251,59 @@ class GaloreArguments:
)
@dataclass
class ApolloArguments:
r"""
Arguments pertaining to the APOLLO algorithm.
"""
use_apollo: bool = field(
default=False,
metadata={"help": "Whether or not to use the APOLLO optimizer."},
)
apollo_target: str = field(
default="all",
metadata={
"help": (
"Name(s) of modules to apply APOLLO. Use commas to separate multiple modules. "
"Use `all` to specify all the linear modules."
)
},
)
apollo_rank: int = field(
default=16,
metadata={"help": "The rank of APOLLO gradients."},
)
apollo_update_interval: int = field(
default=200,
metadata={"help": "Number of steps to update the APOLLO projection."},
)
apollo_scale: float = field(
default=1.0,
metadata={"help": "APOLLO scaling coefficient."},
)
apollo_proj: Literal["svd", "random"] = field(
default="random",
metadata={"help": "Type of APOLLO low-rank projection algorithm (svd or random)."},
)
apollo_proj_type: Literal["std", "right", "left"] = field(
default="std",
metadata={"help": "Type of APOLLO projection."},
)
apollo_scale_type: Literal["channel", "tensor"] = field(
default="channel",
metadata={"help": "Type of APOLLO scaling (channel or tensor)."},
)
apollo_layerwise: bool = field(
default=False,
metadata={"help": "Whether or not to enable layer-wise update to further save memory."},
)
apollo_scale_front: bool = field(
default=False,
metadata={"help": "Whether or not to use the norm-growth limiter in front of gradient scaling."},
)
@dataclass
class BAdamArgument:
r"""
......@@ -305,7 +358,37 @@ class BAdamArgument:
@dataclass
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument):
class SwanLabArguments:
use_swanlab: bool = field(
default=False,
metadata={"help": "Whether or not to use the SwanLab (an experiment tracking and visualization tool)."},
)
swanlab_project: str = field(
default="llamafactory",
metadata={"help": "The project name in SwanLab."},
)
swanlab_workspace: str = field(
default=None,
metadata={"help": "The workspace name in SwanLab."},
)
swanlab_run_name: str = field(
default=None,
metadata={"help": "The experiment name in SwanLab."},
)
swanlab_mode: Literal["cloud", "local"] = field(
default="cloud",
metadata={"help": "The mode of SwanLab."},
)
swanlab_api_key: str = field(
default=None,
metadata={"help": "The API key for SwanLab."},
)
@dataclass
class FinetuningArguments(
FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, ApolloArguments, BAdamArgument, SwanLabArguments
):
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
......@@ -334,6 +417,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
default=True,
metadata={"help": "Whether ot not to freeze vision tower in MLLM training."},
)
freeze_multi_modal_projector: bool = field(
default=True,
metadata={"help": "Whether or not to freeze the multi modal projector in MLLM training."},
)
train_mm_proj_only: bool = field(
default=False,
metadata={"help": "Whether or not to train the multimodal projector for MLLM only."},
......@@ -342,6 +429,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
default=False,
metadata={"help": "Whether or not to compute the token-level accuracy at evaluation."},
)
disable_shuffling: bool = field(
default=False,
metadata={"help": "Whether or not to disable the shuffling of the training set."},
)
plot_loss: bool = field(
default=False,
metadata={"help": "Whether or not to save the training loss curves."},
......@@ -363,7 +454,9 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
self.lora_target: List[str] = split_arg(self.lora_target)
self.additional_target: Optional[List[str]] = split_arg(self.additional_target)
self.galore_target: List[str] = split_arg(self.galore_target)
self.apollo_target: List[str] = split_arg(self.apollo_target)
self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only
self.freeze_multi_modal_projector = self.freeze_multi_modal_projector and not self.train_mm_proj_only
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
......@@ -382,11 +475,11 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
if self.use_llama_pro and self.finetuning_type == "full":
raise ValueError("`use_llama_pro` is only valid for Freeze or LoRA training.")
if self.finetuning_type == "lora" and (self.use_galore or self.use_badam):
raise ValueError("Cannot use LoRA with GaLore or BAdam together.")
if self.finetuning_type == "lora" and (self.use_galore or self.use_apollo or self.use_badam):
raise ValueError("Cannot use LoRA with GaLore, APOLLO or BAdam together.")
if self.use_galore and self.use_badam:
raise ValueError("Cannot use GaLore with BAdam together.")
if int(self.use_galore) + int(self.use_apollo) + (self.use_badam) > 1:
raise ValueError("Cannot use GaLore, APOLLO or BAdam together.")
if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model):
raise ValueError("Cannot use PiSSA for current training stage.")
......@@ -406,3 +499,8 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
if self.pissa_init:
raise ValueError("`pissa_init` is only valid for LoRA training.")
def to_dict(self) -> Dict[str, Any]:
args = asdict(self)
args = {k: f"<{k.upper()}>" if k.endswith("api_key") else v for k, v in args.items()}
return args
......@@ -15,6 +15,8 @@
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Optional
from transformers import GenerationConfig
@dataclass
class GeneratingArguments:
......@@ -64,11 +66,22 @@ class GeneratingArguments:
default=None,
metadata={"help": "Default system message to use in chat completion."},
)
skip_special_tokens: bool = field(
default=True,
metadata={"help": "Whether or not to remove special tokens in the decoding."},
)
def to_dict(self) -> Dict[str, Any]:
def to_dict(self, obey_generation_config: bool = False) -> Dict[str, Any]:
args = asdict(self)
if args.get("max_new_tokens", -1) > 0:
args.pop("max_length", None)
else:
args.pop("max_new_tokens", None)
if obey_generation_config:
generation_config = GenerationConfig()
for key in list(args.keys()):
if not hasattr(generation_config, key):
args.pop(key)
return args
......@@ -16,7 +16,7 @@
# limitations under the License.
import json
from dataclasses import dataclass, field, fields
from dataclasses import asdict, dataclass, field, fields
from typing import Any, Dict, Literal, Optional, Union
import torch
......@@ -237,6 +237,10 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
default=False,
metadata={"help": "Whether or not to disable gradient checkpointing."},
)
use_reentrant_gc: bool = field(
default=True,
metadata={"help": "Whether or not to use reentrant gradient checkpointing."},
)
upcast_layernorm: bool = field(
default=False,
metadata={"help": "Whether or not to upcast the layernorm weights in fp32."},
......@@ -281,6 +285,10 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
default=False,
metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
)
trust_remote_code: bool = field(
default=False,
metadata={"help": "Whether to trust the execution of code from datasets/models defined on the Hub or not."},
)
compute_dtype: Optional[torch.dtype] = field(
default=None,
init=False,
......@@ -336,3 +344,8 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
setattr(result, name, value)
return result
def to_dict(self) -> Dict[str, Any]:
args = asdict(self)
args = {k: f"<{k.upper()}>" if k.endswith("token") else v for k, v in args.items()}
return args
......@@ -15,56 +15,67 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import sys
from typing import Any, Dict, Optional, Tuple
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import transformers
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
import yaml
from transformers import HfArgumentParser
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.trainer_utils import get_last_checkpoint
from transformers.training_args import ParallelMode
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
from transformers.utils.versions import require_version
from ..extras import logging
from ..extras.constants import CHECKPOINT_NAMES
from ..extras.misc import check_dependencies, get_current_device
from ..extras.misc import check_dependencies, check_version, get_current_device
from .data_args import DataArguments
from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments
from .generating_args import GeneratingArguments
from .model_args import ModelArguments
from .training_args import RayArguments, TrainingArguments
logger = logging.get_logger(__name__)
check_dependencies()
_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_ARGS = [ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, TrainingArguments, FinetuningArguments, GeneratingArguments]
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
def read_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> Union[Dict[str, Any], List[str]]:
if args is not None:
return parser.parse_dict(args)
return args
if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")):
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
return yaml.safe_load(Path(sys.argv[1]).absolute().read_text())
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
return json.loads(Path(sys.argv[1]).absolute().read_text())
else:
return sys.argv[1:]
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
def _parse_args(
parser: "HfArgumentParser", args: Optional[Union[Dict[str, Any], List[str]]] = None, allow_extra_keys: bool = False
) -> Tuple[Any]:
args = read_args(args)
if isinstance(args, dict):
return parser.parse_dict(args, allow_extra_keys=allow_extra_keys)
if unknown_args:
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(args=args, return_remaining_strings=True)
if unknown_args and not allow_extra_keys:
print(parser.format_help())
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
......@@ -110,54 +121,64 @@ def _verify_model_args(
def _check_extra_dependencies(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
training_args: Optional["Seq2SeqTrainingArguments"] = None,
training_args: Optional["TrainingArguments"] = None,
) -> None:
if model_args.use_unsloth:
require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth")
check_version("unsloth", mandatory=True)
if model_args.enable_liger_kernel:
require_version("liger-kernel", "To fix: pip install liger-kernel")
check_version("liger-kernel", mandatory=True)
if model_args.mixture_of_depths is not None:
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
check_version("mixture-of-depth>=1.1.6", mandatory=True)
if model_args.infer_backend == "vllm":
require_version("vllm>=0.4.3,<0.6.4", "To fix: pip install vllm>=0.4.3,<0.6.4")
check_version("vllm>=0.4.3,<=0.6.5")
check_version("vllm", mandatory=True)
if finetuning_args.use_galore:
require_version("galore_torch", "To fix: pip install galore_torch")
check_version("galore_torch", mandatory=True)
if finetuning_args.use_apollo:
check_version("apollo_torch", mandatory=True)
if finetuning_args.use_badam:
require_version("badam>=1.2.1", "To fix: pip install badam>=1.2.1")
check_version("badam>=1.2.1", mandatory=True)
if finetuning_args.use_adam_mini:
require_version("adam-mini", "To fix: pip install adam-mini")
check_version("adam-mini", mandatory=True)
if finetuning_args.plot_loss:
require_version("matplotlib", "To fix: pip install matplotlib")
check_version("matplotlib", mandatory=True)
if training_args is not None and training_args.predict_with_generate:
require_version("jieba", "To fix: pip install jieba")
require_version("nltk", "To fix: pip install nltk")
require_version("rouge_chinese", "To fix: pip install rouge-chinese")
check_version("jieba", mandatory=True)
check_version("nltk", mandatory=True)
check_version("rouge_chinese", mandatory=True)
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
def _parse_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS)
return _parse_args(parser, args)
def _parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
def _parse_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS:
parser = HfArgumentParser(_INFER_ARGS)
return _parse_args(parser, args)
def _parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
def _parse_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS:
parser = HfArgumentParser(_EVAL_ARGS)
return _parse_args(parser, args)
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
def get_ray_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> RayArguments:
parser = HfArgumentParser(RayArguments)
(ray_args,) = _parse_args(parser, args, allow_extra_keys=True)
return ray_args
def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS:
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
# Setup logging
......@@ -237,21 +258,21 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if is_deepspeed_zero3_enabled():
raise ValueError("`pure_bf16` is incompatible with DeepSpeed ZeRO-3.")
if (
finetuning_args.use_galore
and finetuning_args.galore_layerwise
and training_args.parallel_mode == ParallelMode.DISTRIBUTED
):
if training_args.parallel_mode == ParallelMode.DISTRIBUTED:
if finetuning_args.use_galore and finetuning_args.galore_layerwise:
raise ValueError("Distributed training does not support layer-wise GaLore.")
if finetuning_args.use_badam and training_args.parallel_mode == ParallelMode.DISTRIBUTED:
if finetuning_args.use_apollo and finetuning_args.apollo_layerwise:
raise ValueError("Distributed training does not support layer-wise APOLLO.")
if finetuning_args.use_badam:
if finetuning_args.badam_mode == "ratio":
raise ValueError("Radio-based BAdam does not yet support distributed training, use layer-wise BAdam.")
elif not is_deepspeed_zero3_enabled():
raise ValueError("Layer-wise BAdam only supports DeepSpeed ZeRO-3 training.")
if finetuning_args.use_galore and training_args.deepspeed is not None:
raise ValueError("GaLore is incompatible with DeepSpeed yet.")
if training_args.deepspeed is not None and (finetuning_args.use_galore or finetuning_args.use_apollo):
raise ValueError("GaLore and APOLLO are incompatible with DeepSpeed yet.")
if model_args.infer_backend == "vllm":
raise ValueError("vLLM backend is only available for API, CLI and Web.")
......@@ -283,9 +304,13 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
logger.warning_rank0("We recommend enable mixed precision training.")
if training_args.do_train and finetuning_args.use_galore and not finetuning_args.pure_bf16:
if (
training_args.do_train
and (finetuning_args.use_galore or finetuning_args.use_apollo)
and not finetuning_args.pure_bf16
):
logger.warning_rank0(
"Using GaLore with mixed precision training may significantly increases GPU memory usage."
"Using GaLore or APOLLO with mixed precision training may significantly increases GPU memory usage."
)
if (not training_args.do_train) and model_args.quantization_bit is not None:
......@@ -361,13 +386,12 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
str(model_args.compute_dtype),
)
)
transformers.set_seed(training_args.seed)
return model_args, data_args, training_args, finetuning_args, generating_args
def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _INFER_CLS:
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
_set_transformers_logging()
......@@ -400,7 +424,7 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
return model_args, data_args, finetuning_args, generating_args
def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
def get_eval_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _EVAL_CLS:
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
_set_transformers_logging()
......
import json
from dataclasses import dataclass, field
from typing import Literal, Optional, Union
from transformers import Seq2SeqTrainingArguments
from transformers.training_args import _convert_str_dict
from ..extras.misc import use_ray
@dataclass
class RayArguments:
r"""
Arguments pertaining to the Ray training.
"""
ray_run_name: Optional[str] = field(
default=None,
metadata={"help": "The training results will be saved at `saves/ray_run_name`."},
)
ray_num_workers: int = field(
default=1,
metadata={"help": "The number of workers for Ray training. Default is 1 worker."},
)
resources_per_worker: Union[dict, str] = field(
default_factory=lambda: {"GPU": 1},
metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
)
placement_strategy: Literal["SPREAD", "PACK", "STRICT_SPREAD", "STRICT_PACK"] = field(
default="PACK",
metadata={"help": "The placement strategy for Ray training. Default is PACK."},
)
def __post_init__(self):
self.use_ray = use_ray()
if isinstance(self.resources_per_worker, str) and self.resources_per_worker.startswith("{"):
self.resources_per_worker = _convert_str_dict(json.loads(self.resources_per_worker))
@dataclass
class TrainingArguments(RayArguments, Seq2SeqTrainingArguments):
r"""
Arguments pertaining to the trainer.
"""
def __post_init__(self):
Seq2SeqTrainingArguments.__post_init__(self)
RayArguments.__post_init__(self)
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict
import torch
......@@ -52,7 +53,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
skip_check_imports()
model_args.model_name_or_path = try_download_model_from_other_hub(model_args)
return {
"trust_remote_code": True,
"trust_remote_code": model_args.trust_remote_code,
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"token": model_args.hf_hub_token,
......@@ -85,6 +86,9 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
except Exception as e:
raise OSError("Failed to load tokenizer.") from e
if model_args.model_max_length is not None and tokenizer.model_max_length != model_args.model_max_length:
tokenizer.model_max_length = model_args.model_max_length
if model_args.new_special_tokens is not None:
num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=model_args.new_special_tokens),
......@@ -155,7 +159,7 @@ def load_model(
load_class = AutoModelForCausalLM
if model_args.train_from_scratch:
model = load_class.from_config(config, trust_remote_code=True)
model = load_class.from_config(config, trust_remote_code=model_args.trust_remote_code)
else:
model = load_class.from_pretrained(**init_kwargs)
......@@ -202,12 +206,8 @@ def load_model(
logger.info_rank0(param_stats)
if model_args.print_param_status:
if model_args.print_param_status and int(os.getenv("LOCAL_RANK", "0")) == 0:
for name, param in model.named_parameters():
print(
"name: {}, dtype: {}, device: {}, trainable: {}".format(
name, param.dtype, param.device, param.requires_grad
)
)
print(f"name: {name}, dtype: {param.dtype}, device: {param.device}, trainable: {param.requires_grad}")
return model
......@@ -15,9 +15,9 @@
from typing import TYPE_CHECKING
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
from transformers.utils.versions import require_version
from ...extras import logging
from ...extras.misc import check_version
if TYPE_CHECKING:
......@@ -35,8 +35,8 @@ def configure_attn_implementation(
if getattr(config, "model_type", None) == "gemma2" and is_trainable:
if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2":
if is_flash_attn_2_available():
require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4")
require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3")
check_version("transformers>=4.42.4")
check_version("flash_attn>=2.6.3")
if model_args.flash_attn != "fa2":
logger.warning_rank0("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
model_args.flash_attn = "fa2"
......
......@@ -122,7 +122,7 @@ def _gradient_checkpointing_enable(
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
self.apply(partial(self._set_gradient_checkpointing, value=True))
self.enable_input_require_grads()
logger.warning_once("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
logger.warning_rank0_once("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
else: # have already enabled input require gradients
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
......@@ -156,7 +156,9 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
_gradient_checkpointing_enable, use_unsloth_gc=model_args.use_unsloth_gc
)
model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": model_args.use_reentrant_gc}
)
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
logger.info_rank0("Gradient checkpointing enabled.")
......
......@@ -23,21 +23,18 @@ from typing import TYPE_CHECKING, Optional, Tuple
import torch
import torch.nn as nn
import transformers
from transformers.models.llama.modeling_llama import (
Cache,
LlamaAttention,
LlamaFlashAttention2,
LlamaSdpaAttention,
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.utils.versions import require_version
from transformers.models.llama.modeling_llama import Cache, apply_rotary_pos_emb, repeat_kv
from ...extras import logging
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
from ...extras.misc import check_version
from ...extras.packages import is_transformers_version_greater_than
if not is_transformers_version_greater_than("4.48.0"):
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaFlashAttention2, LlamaSdpaAttention
if TYPE_CHECKING:
from transformers import PretrainedConfig
......@@ -353,7 +350,7 @@ def llama_sdpa_attention_forward(
def _apply_llama_patch() -> None:
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
check_version("transformers>=4.41.2,<=4.46.1")
LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
......
......@@ -15,6 +15,7 @@
from typing import TYPE_CHECKING, List
from ...extras import logging
from .visual import COMPOSITE_MODELS
if TYPE_CHECKING:
......@@ -26,7 +27,7 @@ logger = logging.get_logger(__name__)
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]:
r"""
Finds all available modules to apply lora or galore.
Finds all available modules to apply LoRA, GaLore or APOLLO.
"""
model_type = getattr(model.config, "model_type", None)
forbidden_modules = {"lm_head"}
......@@ -34,18 +35,12 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
forbidden_modules.add("output_layer")
elif model_type == "internlm2":
forbidden_modules.add("output")
elif model_type in ["llava", "llava_next", "llava_next_video", "mllama", "paligemma", "video_llava"]:
forbidden_modules.add("multi_modal_projector")
elif model_type == "qwen2_vl":
forbidden_modules.add("merger")
if freeze_vision_tower:
if model_type == "mllama":
forbidden_modules.add("vision_model")
elif model_type == "qwen2_vl":
forbidden_modules.add("visual")
else:
forbidden_modules.add("vision_tower")
if model_type in COMPOSITE_MODELS:
forbidden_modules.add(COMPOSITE_MODELS[model_type].projector_key)
if freeze_vision_tower and model_type in COMPOSITE_MODELS:
forbidden_modules.update(COMPOSITE_MODELS[model_type].vision_model_keys)
module_names = set()
for name, module in model.named_modules():
......
......@@ -16,7 +16,8 @@ from typing import TYPE_CHECKING, Sequence
import torch
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
from ...extras.misc import check_version
if TYPE_CHECKING:
......@@ -26,7 +27,7 @@ if TYPE_CHECKING:
def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch.nn.Module"]) -> None:
require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
check_version("deepspeed>=0.13.0")
from deepspeed.utils import set_z3_leaf_modules # type: ignore
set_z3_leaf_modules(model, leaf_modules)
......
......@@ -39,7 +39,7 @@ def _get_unsloth_kwargs(
"device_map": {"": get_current_device()},
"rope_scaling": getattr(config, "rope_scaling", None),
"fix_tokenizer": False,
"trust_remote_code": True,
"trust_remote_code": model_args.trust_remote_code,
"use_gradient_checkpointing": "unsloth",
}
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment