Unverified Commit ad35309a authored by Yun Dai's avatar Yun Dai Committed by GitHub
Browse files

add warning when using gradient_checkpointing with FSDP full shard (#31578)



* add warning when using  with FSDP full shard

* fix style

* Update src/transformers/training_args.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/training_args.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* add hybrid shard warn

* fix style

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 6176d8f5
......@@ -1820,7 +1820,7 @@ class TrainingArguments:
raise ValueError("warmup_steps must be either 0 or > 1")
if isinstance(self.fsdp, bool):
self.fsdp = "full_shard" if self.fsdp else ""
self.fsdp = [FSDPOption.FULL_SHARD] if self.fsdp else ""
if isinstance(self.fsdp, str):
self.fsdp = [FSDPOption(s) for s in self.fsdp.split()]
if self.fsdp == [FSDPOption.OFFLOAD]:
......@@ -1831,6 +1831,15 @@ class TrainingArguments:
elif FSDPOption.FULL_SHARD in self.fsdp and FSDPOption.SHARD_GRAD_OP in self.fsdp:
raise ValueError("`--fsdp full_shard` is not compatible with `--fsdp shard_grad_op`.")
if self.gradient_checkpointing and (
FSDPOption.FULL_SHARD in self.fsdp or FSDPOption.HYBRID_SHARD in self.fsdp
):
logger.warning(
"When using FSDP full shard, instead of using `gradient_checkpointing` in TrainingArguments, please"
" use `activation_checkpointing` in `fsdp_config`. The former introduces a redundant AllGather"
" operation in backward pass. Reference: https://github.com/huggingface/transformers/issues/30404"
)
if self.fsdp_config is None:
self.fsdp_config = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment