Unverified Commit b89fcccd authored by Sourab Mangrulkar's avatar Sourab Mangrulkar Committed by GitHub
Browse files

update FSDP save and load logic (#24249)

* update fsdp save and load logic

* fix

* see if this resolves the failing tests
parent e0603d89
...@@ -216,6 +216,14 @@ if is_accelerate_available(): ...@@ -216,6 +216,14 @@ if is_accelerate_available():
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs from accelerate.utils import DistributedDataParallelKwargs
if version.parse(accelerate_version) > version.parse("0.20.3"):
from accelerate.utils import (
load_fsdp_model,
load_fsdp_optimizer,
save_fsdp_model,
save_fsdp_optimizer,
)
if TYPE_CHECKING: if TYPE_CHECKING:
import optuna import optuna
...@@ -2035,7 +2043,7 @@ class Trainer: ...@@ -2035,7 +2043,7 @@ class Trainer:
# release memory # release memory
del state_dict del state_dict
elif self.is_fsdp_enabled: elif self.is_fsdp_enabled:
self.accelerator.state.fsdp_plugin.load_model(self.accelerator, model, resume_from_checkpoint) load_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, model, resume_from_checkpoint)
else: else:
# We load the model state dict on the CPU to avoid an OOM error. # We load the model state dict on the CPU to avoid an OOM error.
if self.args.save_safetensors and os.path.isfile(safe_weights_file): if self.args.save_safetensors and os.path.isfile(safe_weights_file):
...@@ -2096,8 +2104,8 @@ class Trainer: ...@@ -2096,8 +2104,8 @@ class Trainer:
state_dict["_smp_is_partial"] = False state_dict["_smp_is_partial"] = False
load_result = model.load_state_dict(state_dict, strict=True) load_result = model.load_state_dict(state_dict, strict=True)
elif self.is_fsdp_enabled: elif self.is_fsdp_enabled:
self.accelerator.state.fsdp_plugin.load_model( load_fsdp_model(
self.accelerator, model, self.state.best_model_checkpoint self.accelerator.state.fsdp_plugin, self.accelerator, model, self.state.best_model_checkpoint
) )
else: else:
if is_peft_available() and isinstance(model, PeftModel): if is_peft_available() and isinstance(model, PeftModel):
...@@ -2269,11 +2277,16 @@ class Trainer: ...@@ -2269,11 +2277,16 @@ class Trainer:
if self.sharded_ddp == ShardedDDPOption.SIMPLE: if self.sharded_ddp == ShardedDDPOption.SIMPLE:
self.optimizer.consolidate_state_dict() self.optimizer.consolidate_state_dict()
if self.fsdp: if self.fsdp or self.is_fsdp_enabled:
# FSDP has a different interface for saving optimizer states. if self.is_fsdp_enabled:
# Needs to be called on all ranks to gather all states. save_fsdp_optimizer(
# full_optim_state_dict will be deprecated after Pytorch 2.2! self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir
full_osd = self.model.__class__.full_optim_state_dict(self.model, self.optimizer) )
else:
# FSDP has a different interface for saving optimizer states.
# Needs to be called on all ranks to gather all states.
# full_optim_state_dict will be deprecated after Pytorch 2.2!
full_osd = self.model.__class__.full_optim_state_dict(self.model, self.optimizer)
if is_torch_tpu_available(): if is_torch_tpu_available():
xm.rendezvous("saving_optimizer_states") xm.rendezvous("saving_optimizer_states")
...@@ -2413,14 +2426,23 @@ class Trainer: ...@@ -2413,14 +2426,23 @@ class Trainer:
# In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more # In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more
# likely to get OOM on CPU (since we load num_gpu times the optimizer state # likely to get OOM on CPU (since we load num_gpu times the optimizer state
map_location = self.args.device if self.args.world_size > 1 else "cpu" map_location = self.args.device if self.args.world_size > 1 else "cpu"
if self.fsdp: if self.fsdp or self.is_fsdp_enabled:
full_osd = None if self.is_fsdp_enabled:
# In FSDP, we need to load the full optimizer state dict on rank 0 and then shard it load_fsdp_optimizer(
if self.args.process_index == 0: self.accelerator.state.fsdp_plugin,
full_osd = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME)) self.accelerator,
# call scatter_full_optim_state_dict on all ranks self.optimizer,
sharded_osd = self.model.__class__.scatter_full_optim_state_dict(full_osd, self.model) self.model,
self.optimizer.load_state_dict(sharded_osd) checkpoint,
)
else:
full_osd = None
# In FSDP, we need to load the full optimizer state dict on rank 0 and then shard it
if self.args.process_index == 0:
full_osd = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME))
# call scatter_full_optim_state_dict on all ranks
sharded_osd = self.model.__class__.scatter_full_optim_state_dict(full_osd, self.model)
self.optimizer.load_state_dict(sharded_osd)
else: else:
self.optimizer.load_state_dict( self.optimizer.load_state_dict(
torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location) torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
...@@ -2724,7 +2746,7 @@ class Trainer: ...@@ -2724,7 +2746,7 @@ class Trainer:
): ):
if self.is_fsdp_enabled: if self.is_fsdp_enabled:
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
self.accelerator.state.fsdp_plugin.save_model(self.accelerator, self.model, output_dir) save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir)
else: else:
state_dict = self.model.state_dict() state_dict = self.model.state_dict()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment