"vscode:/vscode.git/clone" did not exist on "55bcae7f25678bd8f40cda1a076ef9773047c1fe"
Unverified Commit 571fa585 authored by raghavanone's avatar raghavanone Committed by GitHub
Browse files

Add limit_all_gathers option to fsdp_config and fix forward_prefetch bug (#21489)

* Add limit_all_gathers option to fsdp_config and fix forward_prefetch bug

* Fix black issue

* Fix ruff failure

* Incorporate PR feedbacks

* Incorporate PR feedbacks

* Incorporate PR feedbacks
parent 479322bf
...@@ -441,9 +441,13 @@ class Trainer: ...@@ -441,9 +441,13 @@ class Trainer:
self.backward_prefetch = BackwardPrefetch.BACKWARD_POST self.backward_prefetch = BackwardPrefetch.BACKWARD_POST
self.forword_prefetch = False self.forword_prefetch = False
if "forword_prefetch" in self.args.fsdp_config and self.backward_prefetch: if self.args.fsdp_config.get("forword_prefect", False):
self.forword_prefetch = True self.forword_prefetch = True
self.limit_all_gathers = False
if self.args.fsdp_config.get("limit_all_gathers", False):
self.limit_all_gathers = True
# one place to sort out whether to place the model on device or not # one place to sort out whether to place the model on device or not
# postpone switching model to cuda when: # postpone switching model to cuda when:
# 1. MP - since we are trying to fit a much bigger than 1 gpu model # 1. MP - since we are trying to fit a much bigger than 1 gpu model
...@@ -1462,6 +1466,7 @@ class Trainer: ...@@ -1462,6 +1466,7 @@ class Trainer:
device_id=self.args.device, device_id=self.args.device,
backward_prefetch=self.backward_prefetch, backward_prefetch=self.backward_prefetch,
forward_prefetch=self.forword_prefetch, forward_prefetch=self.forword_prefetch,
limit_all_gathers=self.limit_all_gathers,
) )
elif is_sagemaker_dp_enabled(): elif is_sagemaker_dp_enabled():
model = nn.parallel.DistributedDataParallel( model = nn.parallel.DistributedDataParallel(
......
...@@ -432,6 +432,10 @@ class TrainingArguments: ...@@ -432,6 +432,10 @@ class TrainingArguments:
FSDP's forward prefetch mode (useful only when `fsdp` field is passed). FSDP's forward prefetch mode (useful only when `fsdp` field is passed).
If `"True"`, then FSDP explicitly prefetches the next upcoming all-gather while executing in the If `"True"`, then FSDP explicitly prefetches the next upcoming all-gather while executing in the
forward pass. forward pass.
- limit_all_gathers (`bool`, *optional*, defaults to `False`)
FSDP's limit_all_gathers (useful only when `fsdp` field is passed).
If `"True"`, FSDP explicitly synchronizes the CPU thread to prevent too many in-flight
all-gathers.
deepspeed (`str` or `dict`, *optional*): deepspeed (`str` or `dict`, *optional*):
Use [Deepspeed](https://github.com/microsoft/deepspeed). This is an experimental feature and its API may Use [Deepspeed](https://github.com/microsoft/deepspeed). This is an experimental feature and its API may
evolve in the future. The value is either the location of DeepSpeed json config file (e.g., evolve in the future. The value is either the location of DeepSpeed json config file (e.g.,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment