"tests/models/patchtsmixer/__init__.py" did not exist on "dd4df80f0b77c8f8e07e502298df0121cada9ce8"
Unverified Commit f6261d7d authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

FEAT / Optim: Add GaLore optimizer (#29588)



* add galore v1

* add import

* add tests and doc

* fix doctest

* forward contrib credits from discussions

* forward contrib credits from discussions

* Apply suggestions from code review
Co-authored-by: default avatarZach Mueller <muellerzr@gmail.com>

* fix failing tests'

* switch to `optim_target_modules` and clarify docs

* more clarification

* enhance lookup logic

* update a test to add peak memory

* add regex, all-linear and single string support

* add layer-wise optimization through DummyOptimizers and LRSchedulers

* forward contrib credits from discussions and original idea

* add a section about DDP not supported in layerwise

* Update src/transformers/trainer.py
Co-authored-by: default avatarZach Mueller <muellerzr@gmail.com>

* fix self

* check only if layer_wise

* Update src/transformers/training_args.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* oops

* make use of intervals

* clarify comment

* add matching tests

* GaLoRe -> GaLore

* move to `get_scheduler`

* add note on docs

* add a warning

* adapt a bit the docs

* update docstring

* support original API

* Update docs/source/en/trainer.md

* slightly refactor

* Update docs/source/en/trainer.md
Co-authored-by: default avatarMatthew Douglas <38992547+matthewdouglas@users.noreply.github.com>

* Update src/transformers/training_args.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* fix args parsing and add tests

* remove warning for regex

* fix type hint

* add note about extra args

* make `is_regex` return optional

---------

Co-authored-by: Maxime <maximegmd @users.noreply.github.com>
Co-authored-by: Wing Lian <winglian @users.noreply.github.com>
Co-authored-by: default avatarZach Mueller <muellerzr@gmail.com>
Co-authored-by: default avatarhiyouga <hiyouga@users.noreply.github.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: default avatarMatthew Douglas <38992547+matthewdouglas@users.noreply.github.com>
parent 484e10f7
......@@ -252,6 +252,136 @@ trainer = Trainer(..., args=training_args)
NEFTune is disabled after training to restore the original embedding layer to avoid any unexpected behavior.
## GaLore
Gradient Low-Rank Projection (GaLore) is a memory-efficient low-rank training strategy that allows full-parameter learning but is more memory-efficient than common low-rank adaptation methods, such as LoRA.
First make sure to install GaLore official repository:
```bash
pip install galore-torch
```
Then simply add one of `["galore_adamw", "galore_adafactor", "galore_adamw_8bit"]` in `optim` together with `optim_target_modules`, which can be a list of strings, regex or full path corresponding to the target module names you want to adapt. Below is an end-to-end example script (make sure to `pip install trl datasets`):
```python
import torch
import datasets
import trl
from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
train_dataset = datasets.load_dataset('imdb', split='train')
args = TrainingArguments(
output_dir="./test-galore",
max_steps=100,
per_device_train_batch_size=2,
optim="galore_adamw",
optim_target_modules=["attn", "mlp"]
)
model_id = "google/gemma-2b"
config = AutoConfig.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_config(config).to(0)
trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field='text',
max_seq_length=512,
)
trainer.train()
```
To pass extra arguments supports by GaLore, you should pass correctly `optim_args`, for example:
```python
import torch
import datasets
import trl
from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
train_dataset = datasets.load_dataset('imdb', split='train')
args = TrainingArguments(
output_dir="./test-galore",
max_steps=100,
per_device_train_batch_size=2,
optim="galore_adamw",
optim_target_modules=["attn", "mlp"],
optim_args="rank=64, update_proj_gap=100, scale=0.10",
)
model_id = "google/gemma-2b"
config = AutoConfig.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_config(config).to(0)
trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field='text',
max_seq_length=512,
)
trainer.train()
```
You can read more about the method in the [original repository](https://github.com/jiaweizzhao/GaLore) or the [paper](https://arxiv.org/abs/2403.03507).
Currently you can only train Linear layers that are considered as GaLore layers and will use low-rank decomposition to be trained while remaining layers will be optimized in the conventional manner.
Note it will take a bit of time before starting the training (~3 minutes for a 2B model on a NVIDIA A100), but training should go smoothly afterwards.
You can also perform layer-wise optimization by post-pending the optimizer name with `layerwise` like below:
```python
import torch
import datasets
import trl
from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
train_dataset = datasets.load_dataset('imdb', split='train')
args = TrainingArguments(
output_dir="./test-galore",
max_steps=100,
per_device_train_batch_size=2,
optim="galore_adamw_layerwise",
optim_target_modules=["attn", "mlp"]
)
model_id = "google/gemma-2b"
config = AutoConfig.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_config(config).to(0)
trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field='text',
max_seq_length=512,
)
trainer.train()
```
Note layerwise optimization is a bit experimental and does not support DDP (Distributed Data Parallel), thus you can run the training script only on a single GPU. Please see [this appropriate section](https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#train-7b-model-with-a-single-gpu-with-24gb-memory) for more details. Other features such as gradient clipping, DeepSpeed, etc might not be supported out of the box. Please [raise an issue on GitHub](https://github.com/huggingface/transformers/issues) if you encounter such issue.
## Accelerate and Trainer
The [`Trainer`] class is powered by [Accelerate](https://hf.co/docs/accelerate), a library for easily training PyTorch models in distributed environments with support for integrations such as [FullyShardedDataParallel (FSDP)](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) and [DeepSpeed](https://www.deepspeed.ai/).
......
......@@ -24,6 +24,7 @@ from torch import nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
from .trainer_pt_utils import LayerWiseDummyOptimizer, LayerWiseDummyScheduler
from .trainer_utils import SchedulerType
from .utils import logging
from .utils.versions import require_version
......@@ -362,6 +363,32 @@ def get_scheduler(
"""
name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
# If a `LayerWiseDummyOptimizer` is passed we extract the optimizer dict and
# recursively call `get_scheduler` to get the proper schedulers on each parameter
if optimizer is not None and isinstance(optimizer, LayerWiseDummyOptimizer):
optimizer_dict = optimizer.optimizer_dict
scheduler_dict = {}
for param in optimizer_dict.keys():
scheduler_dict[param] = get_scheduler(
name,
optimizer=optimizer_dict[param],
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
)
def scheduler_hook(param):
# Since the optimizer hook has been already attached we only need to
# attach the scheduler hook
if param.grad is not None:
scheduler_dict[param].step()
for param in optimizer_dict.keys():
param.register_post_accumulate_grad_hook(scheduler_hook)
return LayerWiseDummyScheduler()
if name == SchedulerType.CONSTANT:
return schedule_func(optimizer)
......
......@@ -70,6 +70,7 @@ from .utils import (
is_fsdp_available,
is_ftfy_available,
is_g2p_en_available,
is_galore_torch_available,
is_ipex_available,
is_jieba_available,
is_jinja_available,
......@@ -325,6 +326,14 @@ def require_bs4(test_case):
return unittest.skipUnless(is_bs4_available(), "test requires BeautifulSoup4")(test_case)
def require_galore_torch(test_case):
"""
Decorator marking a test that requires GaLore. These tests are skipped when GaLore isn't installed.
https://github.com/jiaweizzhao/GaLore
"""
return unittest.skipUnless(is_galore_torch_available(), "test requires GaLore")(test_case)
def require_cv2(test_case):
"""
Decorator marking a test that requires OpenCV.
......
......@@ -83,6 +83,7 @@ from .trainer_pt_utils import (
DistributedTensorGatherer,
IterableDatasetShard,
LabelSmoother,
LayerWiseDummyOptimizer,
LengthGroupedSampler,
SequentialDistributedSampler,
distributed_broadcast_scalars,
......@@ -111,6 +112,7 @@ from .trainer_utils import (
RemoveColumnsCollator,
TrainerMemoryTracker,
TrainOutput,
check_target_module_exists,
default_compute_objective,
denumpify_detensorize,
enable_full_determinism,
......@@ -141,6 +143,7 @@ from .utils import (
is_apex_available,
is_bitsandbytes_available,
is_datasets_available,
is_galore_torch_available,
is_in_notebook,
is_ipex_available,
is_peft_available,
......@@ -1010,7 +1013,17 @@ class Trainer:
},
]
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args, opt_model)
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for GaLore optimizer.
if "params" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
# to avoid arguments conflicts.
if "optimizer_dict" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict")
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
......@@ -1033,7 +1046,9 @@ class Trainer:
return self.optimizer
@staticmethod
def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
def get_optimizer_cls_and_kwargs(
args: TrainingArguments, model: Optional[PreTrainedModel] = None
) -> Tuple[Any, Any]:
"""
Returns the optimizer class and optimizer parameters based on the training arguments.
......@@ -1171,6 +1186,132 @@ class Trainer:
optimizer_cls = torch.optim.Adagrad
elif args.optim == OptimizerNames.RMSPROP:
optimizer_cls = torch.optim.RMSprop
elif args.optim in [
OptimizerNames.GALORE_ADAMW,
OptimizerNames.GALORE_ADAMW_8BIT,
OptimizerNames.GALORE_ADAFACTOR,
OptimizerNames.GALORE_ADAMW_LAYERWISE,
OptimizerNames.GALORE_ADAMW_8BIT_LAYERWISE,
OptimizerNames.GALORE_ADAFACTOR_LAYERWISE,
]:
if not is_galore_torch_available():
raise ImportError(
"You need to install `galore_torch` in order to use GaLore optimizers"
" install it with `pip install git+https://github.com/jiaweizzhao/GaLore`"
)
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit
is_layerwise = args.optim.lower().endswith("layerwise")
if is_layerwise and args.parallel_mode == ParallelMode.DISTRIBUTED:
raise NotImplementedError("Layer-wise GaLore does not support DDP at this time")
optimizer_mapping = {
OptimizerNames.GALORE_ADAMW: GaLoreAdamW,
OptimizerNames.GALORE_ADAMW_8BIT: GaLoreAdamW8bit,
OptimizerNames.GALORE_ADAFACTOR: GaLoreAdafactor,
OptimizerNames.GALORE_ADAMW_LAYERWISE: GaLoreAdamW,
OptimizerNames.GALORE_ADAMW_8BIT_LAYERWISE: GaLoreAdamW8bit,
OptimizerNames.GALORE_ADAFACTOR_LAYERWISE: GaLoreAdafactor,
}
optimizer_cls = optimizer_mapping[args.optim]
if args.optim_target_modules is None:
raise ValueError(
"You need to define a `optim_target_modules` in order to properly use GaLore optimizers"
)
if not isinstance(args.optim_target_modules, (list, str)):
raise ValueError(
f"`optim_target_modules` has to be a list of strings, a string corresponding to a regex, or a specific module or 'all-linear', you passed {args.optim_target_modules}"
)
if model is None:
raise ValueError("You need to pass a model in order to correctly initialize a GaLore optimizer.")
logger.warning(
"Activated GaLoRE fine-tuning, depending on your model size and hardware, the training might take a while before starting. Please be patient !"
)
all_linear = (
isinstance(args.optim_target_modules, str)
and args.optim_target_modules.replace("_", "-") == "all-linear"
)
galore_params = []
galore_params_names = []
for module_name, module in model.named_modules():
target_module_exists, is_regex = check_target_module_exists(
args.optim_target_modules, module_name, return_is_regex=True
)
if not isinstance(module, nn.Linear):
# Warn in case we match but it's not a linear layer
if target_module_exists and not is_regex:
logger.warning(
f"{module_name} has been matched but ignored as GaLore only supports linear layers. Please double check your `optim_target_modules`!"
)
continue
if not target_module_exists and not all_linear:
continue
galore_params.append(module.weight)
galore_params_names.append(module_name + ".weight")
if len(galore_params) == 0:
raise ValueError(
f"None of the target modules were found! ({args.optim_target_modules}). Please make sure to pass a valid `target_modules`."
)
non_galore_params = [p for n, p in model.named_parameters() if n not in galore_params_names]
galore_optim_kwargs = {
"rank": int(optim_args.pop("rank", 128)),
"update_proj_gap": int(optim_args.pop("update_proj_gap", 200)),
"scale": float(optim_args.pop("scale", 0.25)),
"proj_type": optim_args.pop("proj_type", "std"),
}
# The default args are from the official repository: https://github.com/jiaweizzhao/GaLore
param_groups = [
{"params": non_galore_params},
{"params": galore_params, **galore_optim_kwargs},
]
if is_layerwise:
# For layer-wise optimizers, the optimization step is done through post accumulation
# gradient hooks. The trick is to first attach these hooks to the model parameters then
# create a dummy optimizer that will perform no-ops in the Trainer.
# See the original implementation or the nice implementation from @hiyouga
# here: https://github.com/hiyouga/LLaMA-Factory/commit/8664262cde3919e10eaecbd66e8c5d356856362e#diff-ebe08ab14496dfb9e06075f0fdd36799ef6d1535cc4dd4715b74c4e3e06fe3ba
if args.gradient_accumulation_steps != 1:
raise ValueError("Layerwise GaLoRE optimizer do not support gradient accumulation !")
optimizer_dict = {}
for param in non_galore_params:
param_groups = [{"params": [param]}]
optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs)
for param in galore_params:
param_groups = [{"params": [param], **galore_optim_kwargs}]
optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs)
def optimizer_hook(param):
if param.grad is not None:
optimizer_dict[param].step()
optimizer_dict[param].zero_grad()
for param in model.parameters():
param.register_post_accumulate_grad_hook(optimizer_hook)
optimizer_cls = LayerWiseDummyOptimizer
optimizer_kwargs.update({"optimizer_dict": optimizer_dict})
optimizer_kwargs.update({"params": param_groups})
if args.optim == OptimizerNames.GALORE_ADAFACTOR:
optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
else:
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
return optimizer_cls, optimizer_kwargs
......
......@@ -34,6 +34,7 @@ import numpy as np
import torch
import torch.distributed as dist
from torch import nn
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import Dataset, IterableDataset, RandomSampler, Sampler
from torch.utils.data.distributed import DistributedSampler
......@@ -1226,3 +1227,47 @@ class AcceleratorConfig:
def to_dict(self):
return copy.deepcopy(self.__dict__)
class LayerWiseDummyOptimizer(torch.optim.Optimizer):
"""
For Layer-wise optimizers such as GaLoRE optimizer, the optimization
step is already done through the post gradient hooks. Therefore
the trick is to create a dummy optimizer that can take arbitrary
args and kwargs and return a no-op during training.
Initial idea from @hiyouga in LLaMA-Factory:
https://github.com/hiyouga/LLaMA-Factory/commit/8664262cde3919e10eaecbd66e8c5d356856362e#diff-ebe08ab14496dfb9e06075f0fdd36799ef6d1535cc4dd4715b74c4e3e06fe3ba
"""
def __init__(self, optimizer_dict=None, *args, **kwargs):
dummy_tensor = torch.randn(1, 1)
self.optimizer_dict = optimizer_dict
super().__init__([dummy_tensor], {"lr": 1e-03})
def zero_grad(self, set_to_none: bool = True) -> None:
pass
def step(self, closure=None) -> Optional[float]:
pass
class LayerWiseDummyScheduler(LRScheduler):
"""
For Layer-wise optimizers such as GaLoRE optimizer, the optimization and scheduling step
are already done through the post gradient hooks. Therefore
the trick is to create a dummy scheduler that can take arbitrary
args and kwargs and return a no-op during training.
"""
def __init__(self, *args, **kwargs):
optimizer = LayerWiseDummyOptimizer()
last_epoch = -1
verbose = False
super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
return [group["lr"] for group in self.optimizer.param_groups]
def _get_closed_form_lr(self):
return self.base_lrs
......@@ -785,3 +785,42 @@ class RemoveColumnsCollator:
def __call__(self, features: List[dict]):
features = [self._remove_columns(feature) for feature in features]
return self.data_collator(features)
def check_target_module_exists(optim_target_modules, key: str, return_is_regex: bool = False):
"""A helper method to check if the passed module's key name matches any of the target modules in the optim_target_modules.
Args:
optim_target_modules (`Union[str, List[str]]`):
A list of strings to try to match. Can be also a full string.
key (`str`):
A key to search any matches in optim_target_modules
return_is_regex (`bool`):
If set to `True`, the method will return whether the passed `optim_target_modules`
is a regex or not.
Returns:
`bool` : True of match object if key matches any target modules from config, False or
None if no match found
`bool` : If the matched target module is a regex to silence out the warnings in Trainer
for extra modules being found (only if `target_module_found=True` for an array of regex).
"""
target_module_found = False
is_regex = False
if isinstance(optim_target_modules, str):
target_module_found = bool(re.fullmatch(optim_target_modules, key))
is_regex = True if not optim_target_modules == key else False
elif key in optim_target_modules: # from here, target_module_found must be a list of str
# this module is specified directly in target_modules
target_module_found = True
elif any(target_key in key for target_key in optim_target_modules):
target_module_found = True
elif any(bool(re.fullmatch(optim_target_module, key)) for optim_target_module in optim_target_modules):
target_module_found = True
is_regex = True
if return_is_regex:
return target_module_found, is_regex
return target_module_found
......@@ -164,6 +164,12 @@ class OptimizerNames(ExplicitEnum):
RMSPROP_BNB = "rmsprop_bnb"
RMSPROP_8BIT = "rmsprop_bnb_8bit"
RMSPROP_32BIT = "rmsprop_bnb_32bit"
GALORE_ADAMW = "galore_adamw"
GALORE_ADAMW_8BIT = "galore_adamw_8bit"
GALORE_ADAFACTOR = "galore_adafactor"
GALORE_ADAMW_LAYERWISE = "galore_adamw_layerwise"
GALORE_ADAMW_8BIT_LAYERWISE = "galore_adamw_8bit_layerwise"
GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise"
# TODO: `TrainingArguments` users rely on it being fully mutable. In the future see if we can narrow this to a few keys: https://github.com/huggingface/transformers/pull/25903
......@@ -696,6 +702,12 @@ class TrainingArguments:
for instruction fine-tuning. Check out the [original paper](https://arxiv.org/abs/2310.05914) and the
[original code](https://github.com/neelsjain/NEFTune). Support transformers `PreTrainedModel` and also
`PeftModel` from peft.
optim_target_modules (`Union[str, List[str]]`, *optional*):
The target modules to optimize, i.e. the module names that you would like to train, right now this is used only for GaLore algorithm
https://arxiv.org/abs/2403.03507
See: https://github.com/jiaweizzhao/GaLore for more details. You need to make sure to pass a valid GaloRe
optimizer, e.g. one of: "galore_adamw", "galore_adamw_8bit", "galore_adafactor" and make sure that the target modules are `nn.Linear` modules
only.
"""
framework = "pt"
......@@ -1354,6 +1366,13 @@ class TrainingArguments:
},
)
optim_target_modules: Union[None, str, List[str]] = field(
default=None,
metadata={
"help": "Target modules for the optimizer defined in the `optim` argument. Only used for the GaLore optimizer at the moment."
},
)
def __post_init__(self):
# expand paths, if not os.makedirs("~/bar") will make directory
# in the current directory instead of the actual home
......
......@@ -125,6 +125,7 @@ from .import_utils import (
is_fsdp_available,
is_ftfy_available,
is_g2p_en_available,
is_galore_torch_available,
is_in_notebook,
is_ipex_available,
is_jieba_available,
......
......@@ -95,6 +95,7 @@ _accelerate_available, _accelerate_version = _is_package_available("accelerate",
_apex_available = _is_package_available("apex")
_aqlm_available = _is_package_available("aqlm")
_bitsandbytes_available = _is_package_available("bitsandbytes")
_galore_torch_available = _is_package_available("galore_torch")
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
_bs4_available = importlib.util.find_spec("bs4") is not None
_coloredlogs_available = _is_package_available("coloredlogs")
......@@ -309,6 +310,10 @@ def is_torchvision_available():
return _torchvision_available
def is_galore_torch_available():
return _galore_torch_available
def is_pyctcdecode_available():
return _pyctcdecode_available
......
......@@ -60,6 +60,7 @@ from transformers.testing_utils import (
require_accelerate,
require_bitsandbytes,
require_deepspeed,
require_galore_torch,
require_intel_extension_for_pytorch,
require_optuna,
require_peft,
......@@ -84,7 +85,7 @@ from transformers.testing_utils import (
slow,
torch_device,
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend, check_target_module_exists
from transformers.training_args import OptimizerNames
from transformers.utils import (
SAFE_WEIGHTS_INDEX_NAME,
......@@ -114,6 +115,8 @@ if is_torch_available():
GPT2Config,
GPT2LMHeadModel,
LineByLineTextDataset,
LlamaConfig,
LlamaForCausalLM,
PreTrainedModel,
Trainer,
TrainerState,
......@@ -146,6 +149,31 @@ class RegressionDataset:
return result
# Converting Bytes to Megabytes
def bytes2megabytes(x):
return int(x / 2**20)
# Copied from acclerate: https://github.com/huggingface/accelerate/blob/ee163b66fb7848892519e804688cb4ae981aacbe/src/accelerate/test_utils/scripts/external_deps/test_peak_memory_usage.py#L40C1-L73C68
class TorchTracemalloc:
def __enter__(self):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero
self.begin = torch.cuda.memory_allocated()
return self
def __exit__(self, *exc):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
self.end = torch.cuda.memory_allocated()
self.peak = torch.cuda.max_memory_allocated()
self.used = bytes2megabytes(self.end - self.begin)
self.peaked = bytes2megabytes(self.peak - self.begin)
@dataclasses.dataclass
class RegressionTrainingArguments(TrainingArguments):
a: float = 0.0
......@@ -1069,6 +1097,293 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
trainer.train()
trainer.evaluate()
def test_galore_matched_modules(self):
regex_patterns = [r".*.attn.*", r".*.mlp.*"]
module_names = [
"model.transformer.h.0.ln_1",
"model.transformer.h.0.attn.q_proj",
"model.lm_head",
"model.transformer.h.0.mlp.up_proj",
]
expected_values = [False, True, False, True]
for expected_value, module_name in zip(expected_values, module_names):
is_module_matched, is_regex = check_target_module_exists(regex_patterns, module_name, return_is_regex=True)
self.assertTrue(is_module_matched == expected_value)
if is_module_matched:
self.assertTrue(is_regex)
exact_patterns = ["q_proj", "up_proj"]
module_names = [
"model.transformer.h.0.ln_1",
"model.transformer.h.0.attn.q_proj",
"model.lm_head",
"model.transformer.h.0.mlp.up_proj",
]
expected_values = [False, True, False, True]
for expected_value, module_name in zip(expected_values, module_names):
is_module_matched, is_regex = check_target_module_exists(exact_patterns, module_name, return_is_regex=True)
self.assertTrue(is_module_matched == expected_value)
if is_module_matched:
self.assertFalse(is_regex)
simple_regex = r".*.attn.*"
module_names = [
"model.transformer.h.0.ln_1",
"model.transformer.h.0.attn.q_proj",
"model.lm_head",
"model.transformer.h.0.mlp.up_proj",
]
expected_values = [False, True, False, False]
for expected_value, module_name in zip(expected_values, module_names):
is_module_matched, is_regex = check_target_module_exists(simple_regex, module_name, return_is_regex=True)
self.assertTrue(is_module_matched == expected_value)
if is_module_matched:
self.assertTrue(is_regex)
simple_regex = "model.transformer.h.0.attn.q_proj"
module_names = [
"model.transformer.h.0.ln_1",
"model.transformer.h.0.attn.q_proj",
"model.lm_head",
"model.transformer.h.0.mlp.up_proj",
]
expected_values = [False, True, False, False]
for expected_value, module_name in zip(expected_values, module_names):
is_module_matched, is_regex = check_target_module_exists(simple_regex, module_name, return_is_regex=True)
self.assertTrue(is_module_matched == expected_value)
if is_module_matched:
self.assertFalse(is_regex)
target_modules = ["attn", "mlp"]
module_names = [
"model.transformer.h.0.ln_1",
"model.transformer.h.0.attn.q_proj",
"model.lm_head",
"model.transformer.h.0.mlp.up_proj",
]
expected_values = [False, True, False, True]
for expected_value, module_name in zip(expected_values, module_names):
is_module_matched, is_regex = check_target_module_exists(target_modules, module_name, return_is_regex=True)
self.assertTrue(is_module_matched == expected_value)
if is_module_matched:
self.assertFalse(is_regex)
@require_galore_torch
@require_torch_gpu
def test_galore(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
optim="galore_adamw",
optim_target_modules=[r".*attn.*", r".*mlp.*"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()
@require_galore_torch
@require_torch_gpu
def test_galore_extra_args(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
optim="galore_adamw",
optim_args="rank=64, update_proj_gap=100, scale=0.10",
optim_target_modules=[r".*attn.*", r".*mlp.*"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()
@require_galore_torch
@require_torch_gpu
def test_galore_layerwise(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
optim="galore_adamw_layerwise",
optim_target_modules=[r".*attn.*", r".*mlp.*"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()
@require_galore_torch
@require_torch_gpu
def test_galore_layerwise_with_scheduler(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
optim="galore_adamw_layerwise",
lr_scheduler_type="cosine",
optim_target_modules=[r".*attn.*", r".*mlp.*"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()
@require_galore_torch
@require_torch_gpu
def test_galore_adamw_8bit(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
optim="galore_adamw_8bit",
optim_target_modules=[r".*attn.*", r".*mlp.*"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()
@require_galore_torch
@require_torch_gpu
def test_galore_adafactor(self):
# These are the intervals of the peak memory usage of training such a tiny model
# if the peak memory goes outside that range, then we know there might be a bug somewhere
upper_bound_pm = 700
lower_bound_pm = 650
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir, TorchTracemalloc() as tracemalloc:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
optim="galore_adafactor",
optim_target_modules=[r".*attn.*", r".*mlp.*"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()
galore_peak_memory = tracemalloc.peaked + bytes2megabytes(tracemalloc.begin)
self.assertTrue(galore_peak_memory < upper_bound_pm)
self.assertTrue(lower_bound_pm < galore_peak_memory)
@require_galore_torch
@require_torch_gpu
def test_galore_adafactor_attention_only(self):
# These are the intervals of the peak memory usage of training such a tiny model
# if the peak memory goes outside that range, then we know there might be a bug somewhere
upper_bound_pm = 700
lower_bound_pm = 650
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir, TorchTracemalloc() as tracemalloc:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
optim="galore_adafactor",
optim_target_modules=["q_proj", "k_proj", "v_proj"],
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()
galore_peak_memory = tracemalloc.peaked + bytes2megabytes(tracemalloc.begin)
self.assertTrue(galore_peak_memory < upper_bound_pm)
self.assertTrue(lower_bound_pm < galore_peak_memory)
@require_galore_torch
@require_torch_gpu
def test_galore_adafactor_all_linear(self):
# These are the intervals of the peak memory usage of training such a tiny model
# if the peak memory goes outside that range, then we know there might be a bug somewhere
upper_bound_pm = 700
lower_bound_pm = 650
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir, TorchTracemalloc() as tracemalloc:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir,
learning_rate=1e-9,
logging_steps=5,
optim="galore_adafactor",
optim_target_modules="all-linear",
)
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
# Check this works
_ = trainer.train()
galore_peak_memory = tracemalloc.peaked + bytes2megabytes(tracemalloc.begin)
self.assertTrue(galore_peak_memory < upper_bound_pm)
self.assertTrue(lower_bound_pm < galore_peak_memory)
@require_torch_multi_accelerator
def test_data_is_not_parallelized_when_model_is_parallel(self):
model = RegressionModel()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment