"docs/source/en/model_doc/perceiver.md" did not exist on "074490b2c24d4d8f11f05599016ee8d50823f06a"
Unverified Commit b89fcccd authored by Sourab Mangrulkar's avatar Sourab Mangrulkar Committed by GitHub
Browse files

update FSDP save and load logic (#24249)

* update fsdp save and load logic

* fix

* see if this resolves the failing tests
parent e0603d89
......@@ -216,6 +216,14 @@ if is_accelerate_available():
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
if version.parse(accelerate_version) > version.parse("0.20.3"):
from accelerate.utils import (
load_fsdp_model,
load_fsdp_optimizer,
save_fsdp_model,
save_fsdp_optimizer,
)
if TYPE_CHECKING:
import optuna
......@@ -2035,7 +2043,7 @@ class Trainer:
# release memory
del state_dict
elif self.is_fsdp_enabled:
self.accelerator.state.fsdp_plugin.load_model(self.accelerator, model, resume_from_checkpoint)
load_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, model, resume_from_checkpoint)
else:
# We load the model state dict on the CPU to avoid an OOM error.
if self.args.save_safetensors and os.path.isfile(safe_weights_file):
......@@ -2096,8 +2104,8 @@ class Trainer:
state_dict["_smp_is_partial"] = False
load_result = model.load_state_dict(state_dict, strict=True)
elif self.is_fsdp_enabled:
self.accelerator.state.fsdp_plugin.load_model(
self.accelerator, model, self.state.best_model_checkpoint
load_fsdp_model(
self.accelerator.state.fsdp_plugin, self.accelerator, model, self.state.best_model_checkpoint
)
else:
if is_peft_available() and isinstance(model, PeftModel):
......@@ -2269,7 +2277,12 @@ class Trainer:
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
self.optimizer.consolidate_state_dict()
if self.fsdp:
if self.fsdp or self.is_fsdp_enabled:
if self.is_fsdp_enabled:
save_fsdp_optimizer(
self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir
)
else:
# FSDP has a different interface for saving optimizer states.
# Needs to be called on all ranks to gather all states.
# full_optim_state_dict will be deprecated after Pytorch 2.2!
......@@ -2413,7 +2426,16 @@ class Trainer:
# In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more
# likely to get OOM on CPU (since we load num_gpu times the optimizer state
map_location = self.args.device if self.args.world_size > 1 else "cpu"
if self.fsdp:
if self.fsdp or self.is_fsdp_enabled:
if self.is_fsdp_enabled:
load_fsdp_optimizer(
self.accelerator.state.fsdp_plugin,
self.accelerator,
self.optimizer,
self.model,
checkpoint,
)
else:
full_osd = None
# In FSDP, we need to load the full optimizer state dict on rank 0 and then shard it
if self.args.process_index == 0:
......@@ -2724,7 +2746,7 @@ class Trainer:
):
if self.is_fsdp_enabled:
os.makedirs(output_dir, exist_ok=True)
self.accelerator.state.fsdp_plugin.save_model(self.accelerator, self.model, output_dir)
save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir)
else:
state_dict = self.model.state_dict()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment