Unverified Commit f6261d7d authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

FEAT / Optim: Add GaLore optimizer (#29588)



* add galore v1

* add import

* add tests and doc

* fix doctest

* forward contrib credits from discussions

* forward contrib credits from discussions

* Apply suggestions from code review
Co-authored-by: default avatarZach Mueller <muellerzr@gmail.com>

* fix failing tests'

* switch to `optim_target_modules` and clarify docs

* more clarification

* enhance lookup logic

* update a test to add peak memory

* add regex, all-linear and single string support

* add layer-wise optimization through DummyOptimizers and LRSchedulers

* forward contrib credits from discussions and original idea

* add a section about DDP not supported in layerwise

* Update src/transformers/trainer.py
Co-authored-by: default avatarZach Mueller <muellerzr@gmail.com>

* fix self

* check only if layer_wise

* Update src/transformers/training_args.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* oops

* make use of intervals

* clarify comment

* add matching tests

* GaLoRe -> GaLore

* move to `get_scheduler`

* add note on docs

* add a warning

* adapt a bit the docs

* update docstring

* support original API

* Update docs/source/en/trainer.md

* slightly refactor

* Update docs/source/en/trainer.md
Co-authored-by: default avatarMatthew Douglas <38992547+matthewdouglas@users.noreply.github.com>

* Update src/transformers/training_args.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* fix args parsing and add tests

* remove warning for regex

* fix type hint

* add note about extra args

* make `is_regex` return optional

---------

Co-authored-by: Maxime <maximegmd @users.noreply.github.com>
Co-authored-by: Wing Lian <winglian @users.noreply.github.com>
Co-authored-by: default avatarZach Mueller <muellerzr@gmail.com>
Co-authored-by: default avatarhiyouga <hiyouga@users.noreply.github.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: default avatarMatthew Douglas <38992547+matthewdouglas@users.noreply.github.com>
parent 484e10f7
...@@ -252,6 +252,136 @@ trainer = Trainer(..., args=training_args) ...@@ -252,6 +252,136 @@ trainer = Trainer(..., args=training_args)
NEFTune is disabled after training to restore the original embedding layer to avoid any unexpected behavior. NEFTune is disabled after training to restore the original embedding layer to avoid any unexpected behavior.
## GaLore
Gradient Low-Rank Projection (GaLore) is a memory-efficient low-rank training strategy that allows full-parameter learning but is more memory-efficient than common low-rank adaptation methods, such as LoRA.
First make sure to install GaLore official repository:
```bash
pip install galore-torch
```
Then simply add one of `["galore_adamw", "galore_adafactor", "galore_adamw_8bit"]` in `optim` together with `optim_target_modules`, which can be a list of strings, regex or full path corresponding to the target module names you want to adapt. Below is an end-to-end example script (make sure to `pip install trl datasets`):
```python
import torch
import datasets
import trl
from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
train_dataset = datasets.load_dataset('imdb', split='train')
args = TrainingArguments(
output_dir="./test-galore",
max_steps=100,
per_device_train_batch_size=2,
optim="galore_adamw",
optim_target_modules=["attn", "mlp"]
)
model_id = "google/gemma-2b"
config = AutoConfig.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_config(config).to(0)
trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field='text',
max_seq_length=512,
)
trainer.train()
```
To pass extra arguments supports by GaLore, you should pass correctly `optim_args`, for example:
```python
import torch
import datasets
import trl
from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
train_dataset = datasets.load_dataset('imdb', split='train')
args = TrainingArguments(
output_dir="./test-galore",
max_steps=100,
per_device_train_batch_size=2,
optim="galore_adamw",
optim_target_modules=["attn", "mlp"],
optim_args="rank=64, update_proj_gap=100, scale=0.10",
)
model_id = "google/gemma-2b"
config = AutoConfig.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_config(config).to(0)
trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field='text',
max_seq_length=512,
)
trainer.train()
```
You can read more about the method in the [original repository](https://github.com/jiaweizzhao/GaLore) or the [paper](https://arxiv.org/abs/2403.03507).
Currently you can only train Linear layers that are considered as GaLore layers and will use low-rank decomposition to be trained while remaining layers will be optimized in the conventional manner.
Note it will take a bit of time before starting the training (~3 minutes for a 2B model on a NVIDIA A100), but training should go smoothly afterwards.
You can also perform layer-wise optimization by post-pending the optimizer name with `layerwise` like below:
```python
import torch
import datasets
import trl
from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
train_dataset = datasets.load_dataset('imdb', split='train')
args = TrainingArguments(
output_dir="./test-galore",
max_steps=100,
per_device_train_batch_size=2,
optim="galore_adamw_layerwise",
optim_target_modules=["attn", "mlp"]
)
model_id = "google/gemma-2b"
config = AutoConfig.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_config(config).to(0)
trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field='text',
max_seq_length=512,
)
trainer.train()
```
Note layerwise optimization is a bit experimental and does not support DDP (Distributed Data Parallel), thus you can run the training script only on a single GPU. Please see [this appropriate section](https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#train-7b-model-with-a-single-gpu-with-24gb-memory) for more details. Other features such as gradient clipping, DeepSpeed, etc might not be supported out of the box. Please [raise an issue on GitHub](https://github.com/huggingface/transformers/issues) if you encounter such issue.
## Accelerate and Trainer ## Accelerate and Trainer
The [`Trainer`] class is powered by [Accelerate](https://hf.co/docs/accelerate), a library for easily training PyTorch models in distributed environments with support for integrations such as [FullyShardedDataParallel (FSDP)](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) and [DeepSpeed](https://www.deepspeed.ai/). The [`Trainer`] class is powered by [Accelerate](https://hf.co/docs/accelerate), a library for easily training PyTorch models in distributed environments with support for integrations such as [FullyShardedDataParallel (FSDP)](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) and [DeepSpeed](https://www.deepspeed.ai/).
......
...@@ -24,6 +24,7 @@ from torch import nn ...@@ -24,6 +24,7 @@ from torch import nn
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
from .trainer_pt_utils import LayerWiseDummyOptimizer, LayerWiseDummyScheduler
from .trainer_utils import SchedulerType from .trainer_utils import SchedulerType
from .utils import logging from .utils import logging
from .utils.versions import require_version from .utils.versions import require_version
...@@ -362,6 +363,32 @@ def get_scheduler( ...@@ -362,6 +363,32 @@ def get_scheduler(
""" """
name = SchedulerType(name) name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
# If a `LayerWiseDummyOptimizer` is passed we extract the optimizer dict and
# recursively call `get_scheduler` to get the proper schedulers on each parameter
if optimizer is not None and isinstance(optimizer, LayerWiseDummyOptimizer):
optimizer_dict = optimizer.optimizer_dict
scheduler_dict = {}
for param in optimizer_dict.keys():
scheduler_dict[param] = get_scheduler(
name,
optimizer=optimizer_dict[param],
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
)
def scheduler_hook(param):
# Since the optimizer hook has been already attached we only need to
# attach the scheduler hook
if param.grad is not None:
scheduler_dict[param].step()
for param in optimizer_dict.keys():
param.register_post_accumulate_grad_hook(scheduler_hook)
return LayerWiseDummyScheduler()
if name == SchedulerType.CONSTANT: if name == SchedulerType.CONSTANT:
return schedule_func(optimizer) return schedule_func(optimizer)
......
...@@ -70,6 +70,7 @@ from .utils import ( ...@@ -70,6 +70,7 @@ from .utils import (
is_fsdp_available, is_fsdp_available,
is_ftfy_available, is_ftfy_available,
is_g2p_en_available, is_g2p_en_available,
is_galore_torch_available,
is_ipex_available, is_ipex_available,
is_jieba_available, is_jieba_available,
is_jinja_available, is_jinja_available,
...@@ -325,6 +326,14 @@ def require_bs4(test_case): ...@@ -325,6 +326,14 @@ def require_bs4(test_case):
return unittest.skipUnless(is_bs4_available(), "test requires BeautifulSoup4")(test_case) return unittest.skipUnless(is_bs4_available(), "test requires BeautifulSoup4")(test_case)
def require_galore_torch(test_case):
"""
Decorator marking a test that requires GaLore. These tests are skipped when GaLore isn't installed.
https://github.com/jiaweizzhao/GaLore
"""
return unittest.skipUnless(is_galore_torch_available(), "test requires GaLore")(test_case)
def require_cv2(test_case): def require_cv2(test_case):
""" """
Decorator marking a test that requires OpenCV. Decorator marking a test that requires OpenCV.
......
...@@ -83,6 +83,7 @@ from .trainer_pt_utils import ( ...@@ -83,6 +83,7 @@ from .trainer_pt_utils import (
DistributedTensorGatherer, DistributedTensorGatherer,
IterableDatasetShard, IterableDatasetShard,
LabelSmoother, LabelSmoother,
LayerWiseDummyOptimizer,
LengthGroupedSampler, LengthGroupedSampler,
SequentialDistributedSampler, SequentialDistributedSampler,
distributed_broadcast_scalars, distributed_broadcast_scalars,
...@@ -111,6 +112,7 @@ from .trainer_utils import ( ...@@ -111,6 +112,7 @@ from .trainer_utils import (
RemoveColumnsCollator, RemoveColumnsCollator,
TrainerMemoryTracker, TrainerMemoryTracker,
TrainOutput, TrainOutput,
check_target_module_exists,
default_compute_objective, default_compute_objective,
denumpify_detensorize, denumpify_detensorize,
enable_full_determinism, enable_full_determinism,
...@@ -141,6 +143,7 @@ from .utils import ( ...@@ -141,6 +143,7 @@ from .utils import (
is_apex_available, is_apex_available,
is_bitsandbytes_available, is_bitsandbytes_available,
is_datasets_available, is_datasets_available,
is_galore_torch_available,
is_in_notebook, is_in_notebook,
is_ipex_available, is_ipex_available,
is_peft_available, is_peft_available,
...@@ -1010,7 +1013,17 @@ class Trainer: ...@@ -1010,7 +1013,17 @@ class Trainer:
}, },
] ]
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args, opt_model)
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for GaLore optimizer.
if "params" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
# to avoid arguments conflicts.
if "optimizer_dict" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict")
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit": if optimizer_cls.__name__ == "Adam8bit":
...@@ -1033,7 +1046,9 @@ class Trainer: ...@@ -1033,7 +1046,9 @@ class Trainer:
return self.optimizer return self.optimizer
@staticmethod @staticmethod
def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]: def get_optimizer_cls_and_kwargs(
args: TrainingArguments, model: Optional[PreTrainedModel] = None
) -> Tuple[Any, Any]:
""" """
Returns the optimizer class and optimizer parameters based on the training arguments. Returns the optimizer class and optimizer parameters based on the training arguments.
...@@ -1171,6 +1186,132 @@ class Trainer: ...@@ -1171,6 +1186,132 @@ class Trainer:
optimizer_cls = torch.optim.Adagrad optimizer_cls = torch.optim.Adagrad
elif args.optim == OptimizerNames.RMSPROP: elif args.optim == OptimizerNames.RMSPROP:
optimizer_cls = torch.optim.RMSprop optimizer_cls = torch.optim.RMSprop
elif args.optim in [
OptimizerNames.GALORE_ADAMW,
OptimizerNames.GALORE_ADAMW_8BIT,
OptimizerNames.GALORE_ADAFACTOR,
OptimizerNames.GALORE_ADAMW_LAYERWISE,
OptimizerNames.GALORE_ADAMW_8BIT_LAYERWISE,
OptimizerNames.GALORE_ADAFACTOR_LAYERWISE,
]:
if not is_galore_torch_available():
raise ImportError(
"You need to install `galore_torch` in order to use GaLore optimizers"
" install it with `pip install git+https://github.com/jiaweizzhao/GaLore`"
)
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit
is_layerwise = args.optim.lower().endswith("layerwise")
if is_layerwise and args.parallel_mode == ParallelMode.DISTRIBUTED:
raise NotImplementedError("Layer-wise GaLore does not support DDP at this time")
optimizer_mapping = {
OptimizerNames.GALORE_ADAMW: GaLoreAdamW,
OptimizerNames.GALORE_ADAMW_8BIT: GaLoreAdamW8bit,
OptimizerNames.GALORE_ADAFACTOR: GaLoreAdafactor,
OptimizerNames.GALORE_ADAMW_LAYERWISE: GaLoreAdamW,
OptimizerNames.GALORE_ADAMW_8BIT_LAYERWISE: GaLoreAdamW8bit,
OptimizerNames.GALORE_ADAFACTOR_LAYERWISE: GaLoreAdafactor,
}
optimizer_cls = optimizer_mapping[args.optim]
if args.optim_target_modules is None:
raise ValueError(
"You need to define a `optim_target_modules` in order to properly use GaLore optimizers"
)
if not isinstance(args.optim_target_modules, (list, str)):
raise ValueError(
f"`optim_target_modules` has to be a list of strings, a string corresponding to a regex, or a specific module or 'all-linear', you passed {args.optim_target_modules}"
)
if model is None:
raise ValueError("You need to pass a model in order to correctly initialize a GaLore optimizer.")
logger.warning(
"Activated GaLoRE fine-tuning, depending on your model size and hardware, the training might take a while before starting. Please be patient !"
)
all_linear = (
isinstance(args.optim_target_modules, str)
and args.optim_target_modules.replace("_", "-") == "all-linear"
)
galore_params = []
galore_params_names = []
for module_name, module in model.named_modules():
target_module_exists, is_regex = check_target_module_exists(
args.optim_target_modules, module_name, return_is_regex=True
)
if not isinstance(module, nn.Linear):
# Warn in case we match but it's not a linear layer
if target_module_exists and not is_regex:
logger.warning(
f"{module_name} has been matched but ignored as GaLore only supports linear layers. Please double check your `optim_target_modules`!"
)
continue
if not target_module_exists and not all_linear:
continue
galore_params.append(module.weight)
galore_params_names.append(module_name + ".weight")
if len(galore_params) == 0:
raise ValueError(
f"None of the target modules were found! ({args.optim_target_modules}). Please make sure to pass a valid `target_modules`."
)
non_galore_params = [p for n, p in model.named_parameters() if n not in galore_params_names]
galore_optim_kwargs = {
"rank": int(optim_args.pop("rank", 128)),
"update_proj_gap": int(optim_args.pop("update_proj_gap", 200)),
"scale": float(optim_args.pop("scale", 0.25)),
"proj_type": optim_args.pop("proj_type", "std"),
}
# The default args are from the official repository: https://github.com/jiaweizzhao/GaLore
param_groups = [
{"params": non_galore_params},
{"params": galore_params, **galore_optim_kwargs},
]
if is_layerwise:
# For layer-wise optimizers, the optimization step is done through post accumulation
# gradient hooks. The trick is to first attach these hooks to the model parameters then
# create a dummy optimizer that will perform no-ops in the Trainer.
# See the original implementation or the nice implementation from @hiyouga
# here: https://github.com/hiyouga/LLaMA-Factory/commit/8664262cde3919e10eaecbd66e8c5d356856362e#diff-ebe08ab14496dfb9e06075f0fdd36799ef6d1535cc4dd4715b74c4e3e06fe3ba
if args.gradient_accumulation_steps != 1:
raise ValueError("Layerwise GaLoRE optimizer do not support gradient accumulation !")
optimizer_dict = {}
for param in non_galore_params:
param_groups = [{"params": [param]}]
optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs)
for param in galore_params:
param_groups = [{"params": [param], **galore_optim_kwargs}]
optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs)
def optimizer_hook(param):
if param.grad is not None:
optimizer_dict[param].step()
optimizer_dict[param].zero_grad()
for param in model.parameters():
param.register_post_accumulate_grad_hook(optimizer_hook)
optimizer_cls = LayerWiseDummyOptimizer
optimizer_kwargs.update({"optimizer_dict": optimizer_dict})
optimizer_kwargs.update({"params": param_groups})
if args.optim == OptimizerNames.GALORE_ADAFACTOR:
optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
else: else:
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}") raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
return optimizer_cls, optimizer_kwargs return optimizer_cls, optimizer_kwargs
......
...@@ -34,6 +34,7 @@ import numpy as np ...@@ -34,6 +34,7 @@ import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch import nn from torch import nn
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import Dataset, IterableDataset, RandomSampler, Sampler from torch.utils.data import Dataset, IterableDataset, RandomSampler, Sampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
...@@ -1226,3 +1227,47 @@ class AcceleratorConfig: ...@@ -1226,3 +1227,47 @@ class AcceleratorConfig:
def to_dict(self): def to_dict(self):
return copy.deepcopy(self.__dict__) return copy.deepcopy(self.__dict__)
class LayerWiseDummyOptimizer(torch.optim.Optimizer):
"""
For Layer-wise optimizers such as GaLoRE optimizer, the optimization
step is already done through the post gradient hooks. Therefore
the trick is to create a dummy optimizer that can take arbitrary
args and kwargs and return a no-op during training.
Initial idea from @hiyouga in LLaMA-Factory:
https://github.com/hiyouga/LLaMA-Factory/commit/8664262cde3919e10eaecbd66e8c5d356856362e#diff-ebe08ab14496dfb9e06075f0fdd36799ef6d1535cc4dd4715b74c4e3e06fe3ba
"""
def __init__(self, optimizer_dict=None, *args, **kwargs):
dummy_tensor = torch.randn(1, 1)
self.optimizer_dict = optimizer_dict
super().__init__([dummy_tensor], {"lr": 1e-03})
def zero_grad(self, set_to_none: bool = True) -> None:
pass
def step(self, closure=None) -> Optional[float]:
pass
class LayerWiseDummyScheduler(LRScheduler):
"""
For Layer-wise optimizers such as GaLoRE optimizer, the optimization and scheduling step
are already done through the post gradient hooks. Therefore
the trick is to create a dummy scheduler that can take arbitrary
args and kwargs and return a no-op during training.
"""
def __init__(self, *args, **kwargs):
optimizer = LayerWiseDummyOptimizer()
last_epoch = -1
verbose = False
super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
return [group["lr"] for group in self.optimizer.param_groups]
def _get_closed_form_lr(self):
return self.base_lrs
...@@ -785,3 +785,42 @@ class RemoveColumnsCollator: ...@@ -785,3 +785,42 @@ class RemoveColumnsCollator:
def __call__(self, features: List[dict]): def __call__(self, features: List[dict]):
features = [self._remove_columns(feature) for feature in features] features = [self._remove_columns(feature) for feature in features]
return self.data_collator(features) return self.data_collator(features)
def check_target_module_exists(optim_target_modules, key: str, return_is_regex: bool = False):
"""A helper method to check if the passed module's key name matches any of the target modules in the optim_target_modules.
Args:
optim_target_modules (`Union[str, List[str]]`):
A list of strings to try to match. Can be also a full string.
key (`str`):
A key to search any matches in optim_target_modules
return_is_regex (`bool`):
If set to `True`, the method will return whether the passed `optim_target_modules`
is a regex or not.
Returns:
`bool` : True of match object if key matches any target modules from config, False or
None if no match found
`bool` : If the matched target module is a regex to silence out the warnings in Trainer
for extra modules being found (only if `target_module_found=True` for an array of regex).
"""
target_module_found = False
is_regex = False
if isinstance(optim_target_modules, str):
target_module_found = bool(re.fullmatch(optim_target_modules, key))
is_regex = True if not optim_target_modules == key else False
elif key in optim_target_modules: # from here, target_module_found must be a list of str
# this module is specified directly in target_modules
target_module_found = True
elif any(target_key in key for target_key in optim_target_modules):
target_module_found = True
elif any(bool(re.fullmatch(optim_target_module, key)) for optim_target_module in optim_target_modules):
target_module_found = True
is_regex = True
if return_is_regex:
return target_module_found, is_regex
return target_module_found
...@@ -164,6 +164,12 @@ class OptimizerNames(ExplicitEnum): ...@@ -164,6 +164,12 @@ class OptimizerNames(ExplicitEnum):
RMSPROP_BNB = "rmsprop_bnb" RMSPROP_BNB = "rmsprop_bnb"
RMSPROP_8BIT = "rmsprop_bnb_8bit" RMSPROP_8BIT = "rmsprop_bnb_8bit"
RMSPROP_32BIT = "rmsprop_bnb_32bit" RMSPROP_32BIT = "rmsprop_bnb_32bit"
GALORE_ADAMW = "galore_adamw"
GALORE_ADAMW_8BIT = "galore_adamw_8bit"
GALORE_ADAFACTOR = "galore_adafactor"
GALORE_ADAMW_LAYERWISE = "galore_adamw_layerwise"
GALORE_ADAMW_8BIT_LAYERWISE = "galore_adamw_8bit_layerwise"
GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise"
# TODO: `TrainingArguments` users rely on it being fully mutable. In the future see if we can narrow this to a few keys: https://github.com/huggingface/transformers/pull/25903 # TODO: `TrainingArguments` users rely on it being fully mutable. In the future see if we can narrow this to a few keys: https://github.com/huggingface/transformers/pull/25903
...@@ -696,6 +702,12 @@ class TrainingArguments: ...@@ -696,6 +702,12 @@ class TrainingArguments:
for instruction fine-tuning. Check out the [original paper](https://arxiv.org/abs/2310.05914) and the for instruction fine-tuning. Check out the [original paper](https://arxiv.org/abs/2310.05914) and the
[original code](https://github.com/neelsjain/NEFTune). Support transformers `PreTrainedModel` and also [original code](https://github.com/neelsjain/NEFTune). Support transformers `PreTrainedModel` and also
`PeftModel` from peft. `PeftModel` from peft.
optim_target_modules (`Union[str, List[str]]`, *optional*):
The target modules to optimize, i.e. the module names that you would like to train, right now this is used only for GaLore algorithm
https://arxiv.org/abs/2403.03507
See: https://github.com/jiaweizzhao/GaLore for more details. You need to make sure to pass a valid GaloRe
optimizer, e.g. one of: "galore_adamw", "galore_adamw_8bit", "galore_adafactor" and make sure that the target modules are `nn.Linear` modules
only.
""" """
framework = "pt" framework = "pt"
...@@ -1354,6 +1366,13 @@ class TrainingArguments: ...@@ -1354,6 +1366,13 @@ class TrainingArguments:
}, },
) )
optim_target_modules: Union[None, str, List[str]] = field(
default=None,
metadata={
"help": "Target modules for the optimizer defined in the `optim` argument. Only used for the GaLore optimizer at the moment."
},
)
def __post_init__(self): def __post_init__(self):
# expand paths, if not os.makedirs("~/bar") will make directory # expand paths, if not os.makedirs("~/bar") will make directory
# in the current directory instead of the actual home # in the current directory instead of the actual home
......
...@@ -125,6 +125,7 @@ from .import_utils import ( ...@@ -125,6 +125,7 @@ from .import_utils import (
is_fsdp_available, is_fsdp_available,
is_ftfy_available, is_ftfy_available,
is_g2p_en_available, is_g2p_en_available,
is_galore_torch_available,
is_in_notebook, is_in_notebook,
is_ipex_available, is_ipex_available,
is_jieba_available, is_jieba_available,
......
...@@ -95,6 +95,7 @@ _accelerate_available, _accelerate_version = _is_package_available("accelerate", ...@@ -95,6 +95,7 @@ _accelerate_available, _accelerate_version = _is_package_available("accelerate",
_apex_available = _is_package_available("apex") _apex_available = _is_package_available("apex")
_aqlm_available = _is_package_available("aqlm") _aqlm_available = _is_package_available("aqlm")
_bitsandbytes_available = _is_package_available("bitsandbytes") _bitsandbytes_available = _is_package_available("bitsandbytes")
_galore_torch_available = _is_package_available("galore_torch")
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed. # `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
_bs4_available = importlib.util.find_spec("bs4") is not None _bs4_available = importlib.util.find_spec("bs4") is not None
_coloredlogs_available = _is_package_available("coloredlogs") _coloredlogs_available = _is_package_available("coloredlogs")
...@@ -309,6 +310,10 @@ def is_torchvision_available(): ...@@ -309,6 +310,10 @@ def is_torchvision_available():
return _torchvision_available return _torchvision_available
def is_galore_torch_available():
return _galore_torch_available
def is_pyctcdecode_available(): def is_pyctcdecode_available():
return _pyctcdecode_available return _pyctcdecode_available
......
...@@ -60,6 +60,7 @@ from transformers.testing_utils import ( ...@@ -60,6 +60,7 @@ from transformers.testing_utils import (
require_accelerate, require_accelerate,
require_bitsandbytes, require_bitsandbytes,
require_deepspeed, require_deepspeed,
require_galore_torch,
require_intel_extension_for_pytorch, require_intel_extension_for_pytorch,
require_optuna, require_optuna,
require_peft, require_peft,
...@@ -84,7 +85,7 @@ from transformers.testing_utils import ( ...@@ -84,7 +85,7 @@ from transformers.testing_utils import (
slow, slow,
torch_device, torch_device,
) )
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend, check_target_module_exists
from transformers.training_args import OptimizerNames from transformers.training_args import OptimizerNames
from transformers.utils import ( from transformers.utils import (
SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME,
...@@ -114,6 +115,8 @@ if is_torch_available(): ...@@ -114,6 +115,8 @@ if is_torch_available():
GPT2Config, GPT2Config,
GPT2LMHeadModel, GPT2LMHeadModel,
LineByLineTextDataset, LineByLineTextDataset,
LlamaConfig,
LlamaForCausalLM,
PreTrainedModel, PreTrainedModel,
Trainer, Trainer,
TrainerState, TrainerState,
...@@ -146,6 +149,31 @@ class RegressionDataset: ...@@ -146,6 +149,31 @@ class RegressionDataset:
return result return result
# Converting Bytes to Megabytes
def bytes2megabytes(x):
return int(x / 2**20)
# Copied from acclerate: https://github.com/huggingface/accelerate/blob/ee163b66fb7848892519e804688cb4ae981aacbe/src/accelerate/test_utils/scripts/external_deps/test_peak_memory_usage.py#L40C1-L73C68
class TorchTracemalloc:
def __enter__(self):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero
self.begin = torch.cuda.memory_allocated()
return self
def __exit__(self, *exc):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
self.end = torch.cuda.memory_allocated()
self.peak = torch.cuda.max_memory_allocated()
self.used = bytes2megabytes(self.end - self.begin)
self.peaked = bytes2megabytes(self.peak - self.begin)
@dataclasses.dataclass @dataclasses.dataclass
class RegressionTrainingArguments(TrainingArguments): class RegressionTrainingArguments(TrainingArguments):
a: float = 0.0 a: float = 0.0
...@@ -1069,6 +1097,293 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -1069,6 +1097,293 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
trainer.train() trainer.train()
trainer.evaluate() trainer.evaluate()
def test_galore_matched_modules(self):
regex_patterns = [r".*.attn.*", r".*.mlp.*"]
module_names = [
"model.transformer.h.0.ln_1",
"model.transformer.h.0.attn.q_proj",
"model.lm_head",
"model.transformer.h.0.mlp.up_proj",
]
expected_values = [False, True, False, True]
for expected_value, module_name in zip(expected_values, module_names):
is_module_matched, is_regex = check_target_module_exists(regex_patterns, module_name, return_is_regex=True)
self.assertTrue(is_module_matched == expected_value)
if is_module_matched:
self.assertTrue(is_regex)
exact_patterns = ["q_proj", "up_proj"]
module_names = [
"model.transformer.h.0.ln_1",
"model.transformer.h.0.attn.q_proj",
"model.lm_head",
"model.transformer.h.0.mlp.up_proj",
]
expected_values = [False, True, False, True]
for expected_value, module_name in zip(expected_values, module_names):
is_module_matched, is_regex = check_target_module_exists(exact_patterns, module_name, return_is_regex=True)
self.assertTrue(is_module_matched == expected_value)
if is_module_matched:
self.assertFalse(is_regex)
simple_regex = r".*.attn.*"
module_names = [
"model.transformer.h.0.ln_1",
"model.transformer.h.0.attn.q_proj",
"model.lm_head",
"model.transformer.h.0.mlp.up_proj",
]
expected_values = [False, True, False, False]
for expected_value, module_name in zip(expected_values, module_names):
is_module_matched, is_regex = check_target_module_exists(simple_regex, module_name, return_is_regex=True)
self.assertTrue(is_module_matched == expected_value)
if is_module_matched:
self.assertTrue(is_regex)
simple_regex = "model.transformer.h.0.attn.q_proj"
module_names = [
"model.transformer.h.0.ln_1",
"model.transformer.h.0.attn.q_proj",
"model.lm_head",
"model.transformer.h.0.mlp.up_proj",
]
expected_values = [False, True, False, False]
for expected_value, module_name in zip(expected_values, module_names):
is_module_matched, is_regex = check_target_module_exists(simple_regex, module_name, return_is_regex=True)
self.assertTrue(is_module_matched == expected_value)
if is_module_matched:
self.assertFalse(is_regex)
target_modules = ["attn", "mlp"]
module_names = [
"model.transformer.h.0.ln_1",
"model.transformer.h.0.attn.q_proj",
"model.lm_head",
"model.transformer.h.0.mlp.up_proj",
]
expected_values = [False, True, False, True]
for expected_value, module_name in zip(expected_values, module_names):
is_module_matched, is_regex = check_target_module_exists(target_modules, module_name, return_is_regex=True)
self.assertTrue(is_module_matched == expected_value)
if is_module_matched:
self.assertFalse(is_regex)
@require_galore_torch
@require_torch_gpu
def test_galore(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
optim="galore_adamw",
optim_target_modules=[r".*attn.*", r".*mlp.*"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()
@require_galore_torch
@require_torch_gpu
def test_galore_extra_args(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
optim="galore_adamw",
optim_args="rank=64, update_proj_gap=100, scale=0.10",
optim_target_modules=[r".*attn.*", r".*mlp.*"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()
@require_galore_torch
@require_torch_gpu
def test_galore_layerwise(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
optim="galore_adamw_layerwise",
optim_target_modules=[r".*attn.*", r".*mlp.*"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()
@require_galore_torch
@require_torch_gpu
def test_galore_layerwise_with_scheduler(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
optim="galore_adamw_layerwise",
lr_scheduler_type="cosine",
optim_target_modules=[r".*attn.*", r".*mlp.*"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()
@require_galore_torch
@require_torch_gpu
def test_galore_adamw_8bit(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
optim="galore_adamw_8bit",
optim_target_modules=[r".*attn.*", r".*mlp.*"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()
@require_galore_torch
@require_torch_gpu
def test_galore_adafactor(self):
# These are the intervals of the peak memory usage of training such a tiny model
# if the peak memory goes outside that range, then we know there might be a bug somewhere
upper_bound_pm = 700
lower_bound_pm = 650
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir, TorchTracemalloc() as tracemalloc:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
optim="galore_adafactor",
optim_target_modules=[r".*attn.*", r".*mlp.*"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()
galore_peak_memory = tracemalloc.peaked + bytes2megabytes(tracemalloc.begin)
self.assertTrue(galore_peak_memory < upper_bound_pm)
self.assertTrue(lower_bound_pm < galore_peak_memory)
@require_galore_torch
@require_torch_gpu
def test_galore_adafactor_attention_only(self):
# These are the intervals of the peak memory usage of training such a tiny model
# if the peak memory goes outside that range, then we know there might be a bug somewhere
upper_bound_pm = 700
lower_bound_pm = 650
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir, TorchTracemalloc() as tracemalloc:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
optim="galore_adafactor",
optim_target_modules=["q_proj", "k_proj", "v_proj"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()
galore_peak_memory = tracemalloc.peaked + bytes2megabytes(tracemalloc.begin)
self.assertTrue(galore_peak_memory < upper_bound_pm)
self.assertTrue(lower_bound_pm < galore_peak_memory)
@require_galore_torch
@require_torch_gpu
def test_galore_adafactor_all_linear(self):
# These are the intervals of the peak memory usage of training such a tiny model
# if the peak memory goes outside that range, then we know there might be a bug somewhere
upper_bound_pm = 700
lower_bound_pm = 650
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir, TorchTracemalloc() as tracemalloc:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
optim="galore_adafactor",
optim_target_modules="all-linear",
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()
galore_peak_memory = tracemalloc.peaked + bytes2megabytes(tracemalloc.begin)
self.assertTrue(galore_peak_memory < upper_bound_pm)
self.assertTrue(lower_bound_pm < galore_peak_memory)
@require_torch_multi_accelerator @require_torch_multi_accelerator
def test_data_is_not_parallelized_when_model_is_parallel(self): def test_data_is_not_parallelized_when_model_is_parallel(self):
model = RegressionModel() model = RegressionModel()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment