Unverified Commit 43ffb785 authored by Billy Cao's avatar Billy Cao Committed by GitHub
Browse files

Add torch_empty_cache_steps to TrainingArguments (#31546)

* Add torch_empty_cache_steps to TrainingArguments

* Fix formatting

* Add torch_empty_cache_steps to docs on single gpu training

* Remove check for torch_empty_cache_steps <= max_steps

* Captalize Tip

* Be device agnostic

* Fix linting
parent cee768d9
......@@ -42,11 +42,12 @@ hyperparameter tuning, you should determine which batch size yields the best res
The methods and tools covered in this guide can be classified based on the effect they have on the training process:
| Method/tool | Improves training speed | Optimizes memory utilization |
|:-----------------------------------------------------------|:------------------------|:-----------------------------|
|:--------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------|:-----------------------------|
| [Batch size choice](#batch-size-choice) | Yes | Yes |
| [Gradient accumulation](#gradient-accumulation) | No | Yes |
| [Gradient checkpointing](#gradient-checkpointing) | No | Yes |
| [Mixed precision training](#mixed-precision-training) | Yes | (No) |
| [Mixed precision training](#mixed-precision-training) | Yes | Maybe* |
| [torch_empty_cache_steps](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments.torch_empty_cache_steps) | No | Yes |
| [Optimizer choice](#optimizer-choice) | Yes | Yes |
| [Data preloading](#data-preloading) | Yes | No |
| [DeepSpeed Zero](#deepspeed-zero) | No | Yes |
......@@ -55,7 +56,7 @@ The methods and tools covered in this guide can be classified based on the effec
<Tip>
Note: when using mixed precision with a small model and a large batch size, there will be some memory savings but with a
*Note: when using mixed precision with a small model and a large batch size, there will be some memory savings but with a
large model and a small batch size, the memory use will be larger.
</Tip>
......
......@@ -221,6 +221,11 @@ if is_accelerate_available():
DistributedDataParallelKwargs,
DistributedType,
GradientAccumulationPlugin,
is_mlu_available,
is_mps_available,
is_npu_available,
is_torch_version,
is_xpu_available,
load_fsdp_model,
load_fsdp_optimizer,
save_fsdp_model,
......@@ -3307,6 +3312,20 @@ class Trainer:
loss = self.compute_loss(model, inputs)
del inputs
if (
self.args.torch_empty_cache_steps is not None
and self.state.global_step % self.args.torch_empty_cache_steps == 0
):
if is_xpu_available():
torch.xpu.empty_cache()
elif is_mlu_available():
torch.mlu.empty_cache()
elif is_npu_available():
torch.npu.empty_cache()
elif is_torch_version(">=", "2.0") and is_mps_available():
torch.mps.empty_cache()
else:
torch.cuda.empty_cache()
kwargs = {}
......
......@@ -267,6 +267,15 @@ class TrainingArguments:
eval_delay (`float`, *optional*):
Number of epochs or steps to wait for before the first evaluation can be performed, depending on the
eval_strategy.
torch_empty_cache_steps (`int`, *optional*):
Number of steps to wait before calling `torch.<device>.empty_cache()`. If left unset or set to None, cache will not be emptied.
<Tip>
This can help avoid CUDA out-of-memory errors by lowering peak VRAM usage at a cost of about [10% slower performance](https://github.com/huggingface/transformers/issues/31372).
</Tip>
learning_rate (`float`, *optional*, defaults to 5e-5):
The initial learning rate for [`AdamW`] optimizer.
weight_decay (`float`, *optional*, defaults to 0):
......@@ -851,6 +860,15 @@ class TrainingArguments:
},
)
torch_empty_cache_steps: Optional[int] = field(
default=None,
metadata={
"help": "Number of steps to wait before calling `torch.<device>.empty_cache()`."
"This can help avoid CUDA out-of-memory errors by lowering peak VRAM usage at a cost of about [10% slower performance](https://github.com/huggingface/transformers/issues/31372)."
"If left unset or set to None, cache will not be emptied."
},
)
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
......@@ -1532,6 +1550,12 @@ class TrainingArguments:
if self.do_eval is False and self.eval_strategy != IntervalStrategy.NO:
self.do_eval = True
if self.torch_empty_cache_steps is not None:
if not (isinstance(self.torch_empty_cache_steps, int) or self.torch_empty_cache_steps > 0):
raise ValueError(
f"`torch_empty_cache_steps` must be an integer bigger than 0, got {self.torch_empty_cache_steps}."
)
# eval_steps has to be defined and non-zero, fallbacks to logging_steps if the latter is non-zero
if self.eval_strategy == IntervalStrategy.STEPS and (self.eval_steps is None or self.eval_steps == 0):
if self.logging_steps > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment