Unverified Commit 738ecd17 authored by Arup De's avatar Arup De Committed by GitHub
Browse files

Arde/fsdp activation checkpointing (#25771)

* add FSDP config option to enable activation-checkpointing

* update docs

* add checks and remove redundant code

* fix formatting error
parent 50573c64
...@@ -456,6 +456,10 @@ as the model saving with FSDP activated is only available with recent fixes. ...@@ -456,6 +456,10 @@ as the model saving with FSDP activated is only available with recent fixes.
If `"True"`, FSDP explicitly prefetches the next upcoming all-gather while executing in the forward pass. If `"True"`, FSDP explicitly prefetches the next upcoming all-gather while executing in the forward pass.
- `limit_all_gathers` can be specified in the config file. - `limit_all_gathers` can be specified in the config file.
If `"True"`, FSDP explicitly synchronizes the CPU thread to prevent too many in-flight all-gathers. If `"True"`, FSDP explicitly synchronizes the CPU thread to prevent too many in-flight all-gathers.
- `activation_checkpointing` can be specified in the config file.
If `"True"`, FSDP activation checkpointing is a technique to reduce memory usage by clearing activations of
certain layers and recomputing them during a backward pass. Effectively, this trades extra computation time
for reduced memory usage.
**Few caveats to be aware of** **Few caveats to be aware of**
- it is incompatible with `generate`, thus is incompatible with `--predict_with_generate` - it is incompatible with `generate`, thus is incompatible with `--predict_with_generate`
......
...@@ -3896,6 +3896,15 @@ class Trainer: ...@@ -3896,6 +3896,15 @@ class Trainer:
fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get( fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get(
"limit_all_gathers", fsdp_plugin.limit_all_gathers "limit_all_gathers", fsdp_plugin.limit_all_gathers
) )
fsdp_plugin.activation_checkpointing = self.args.fsdp_config.get(
"activation_checkpointing", fsdp_plugin.activation_checkpointing
)
if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing:
raise ValueError(
"The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg "
"can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic "
"when using FSDP."
)
if self.is_deepspeed_enabled: if self.is_deepspeed_enabled:
if getattr(self.args, "hf_deepspeed_config", None) is None: if getattr(self.args, "hf_deepspeed_config", None) is None:
......
...@@ -482,6 +482,10 @@ class TrainingArguments: ...@@ -482,6 +482,10 @@ class TrainingArguments:
Will use gradient checkpointing over each nested XLA FSDP wrapped layer. This setting can only be Will use gradient checkpointing over each nested XLA FSDP wrapped layer. This setting can only be
used when the xla flag is set to true, and an auto wrapping policy is specified through used when the xla flag is set to true, and an auto wrapping policy is specified through
fsdp_min_num_params or fsdp_transformer_layer_cls_to_wrap. fsdp_min_num_params or fsdp_transformer_layer_cls_to_wrap.
- activation_checkpointing (`bool`, *optional*, defaults to `False`):
If True, activation checkpointing is a technique to reduce memory usage by clearing activations of
certain layers and recomputing them during a backward pass. Effectively, this trades extra
computation time for reduced memory usage.
deepspeed (`str` or `dict`, *optional*): deepspeed (`str` or `dict`, *optional*):
Use [Deepspeed](https://github.com/microsoft/deepspeed). This is an experimental feature and its API may Use [Deepspeed](https://github.com/microsoft/deepspeed). This is an experimental feature and its API may
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment