Unverified Commit 2200bf7a authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`Trainer`] Correct behavior of `_load_best_model` for PEFT models (#24103)

* v1

* some refactor

- add ST format as well

* fix

* add `ADAPTER_WEIGHTS_NAME` & `ADAPTER_SAFE_WEIGHTS_NAME`
parent 0f236050
...@@ -134,6 +134,8 @@ from .trainer_utils import ( ...@@ -134,6 +134,8 @@ from .trainer_utils import (
) )
from .training_args import OptimizerNames, ParallelMode, TrainingArguments from .training_args import OptimizerNames, ParallelMode, TrainingArguments
from .utils import ( from .utils import (
ADAPTER_SAFE_WEIGHTS_NAME,
ADAPTER_WEIGHTS_NAME,
CONFIG_NAME, CONFIG_NAME,
SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_NAME,
...@@ -2177,11 +2179,20 @@ class Trainer: ...@@ -2177,11 +2179,20 @@ class Trainer:
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).") logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME) best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME) best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME)
best_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_WEIGHTS_NAME)
best_safe_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME)
model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if os.path.exists(best_model_path) or os.path.exists(best_safe_model_path): if (
os.path.exists(best_model_path)
or os.path.exists(best_safe_model_path)
or os.path.exists(best_adapter_model_path)
or os.path.exists(best_safe_adapter_model_path)
):
if self.is_deepspeed_enabled: if self.is_deepspeed_enabled:
deepspeed_load_checkpoint(self.model_wrapped, self.state.best_model_checkpoint) deepspeed_load_checkpoint(self.model_wrapped, self.state.best_model_checkpoint)
else: else:
has_been_loaded = True
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")): if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")):
# If the 'user_content.pt' file exists, load with the new smp api. # If the 'user_content.pt' file exists, load with the new smp api.
...@@ -2207,10 +2218,10 @@ class Trainer: ...@@ -2207,10 +2218,10 @@ class Trainer:
self.accelerator, model, self.state.best_model_checkpoint self.accelerator, model, self.state.best_model_checkpoint
) )
else: else:
if hasattr(model, "base_model") and getattr(model.base_model, "is_8bit_serializable", False): if is_peft_available() and isinstance(model, PeftModel):
# If train base_8_bit_models using PEFT & LoRA, assume that adapter have been saved properly. # If train a model using PEFT & LoRA, assume that adapter have been saved properly.
if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"): if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
if os.path.exists(os.path.join(self.state.best_model_checkpoint, "adapter_model.bin")): if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path):
model.load_adapter(self.state.best_model_checkpoint, model.active_adapter) model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
# Load_adapter has no return value present, modify it when appropriate. # Load_adapter has no return value present, modify it when appropriate.
from torch.nn.modules.module import _IncompatibleKeys from torch.nn.modules.module import _IncompatibleKeys
...@@ -2219,12 +2230,13 @@ class Trainer: ...@@ -2219,12 +2230,13 @@ class Trainer:
else: else:
logger.warning( logger.warning(
"The intermediate checkpoints of PEFT may not be saved correctly, " "The intermediate checkpoints of PEFT may not be saved correctly, "
"using `TrainerCallback` to save adapter_model.bin in corresponding folders, " f"using `TrainerCallback` to save {ADAPTER_WEIGHTS_NAME} in corresponding folders, "
"here are some examples https://github.com/huggingface/peft/issues/96" "here are some examples https://github.com/huggingface/peft/issues/96"
) )
has_been_loaded = False
else: else:
# We can't do pure 8bit training using transformers. logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed")
logger.warning("Could not loading a quantized checkpoint.") has_been_loaded = False
else: else:
# We load the model state dict on the CPU to avoid an OOM error. # We load the model state dict on the CPU to avoid an OOM error.
if self.args.save_safetensors and os.path.isfile(best_safe_model_path): if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
...@@ -2236,7 +2248,7 @@ class Trainer: ...@@ -2236,7 +2248,7 @@ class Trainer:
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963 # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
# which takes *args instead of **kwargs # which takes *args instead of **kwargs
load_result = model.load_state_dict(state_dict, False) load_result = model.load_state_dict(state_dict, False)
if not is_sagemaker_mp_enabled(): if not is_sagemaker_mp_enabled() and has_been_loaded:
self._issue_warnings_after_load(load_result) self._issue_warnings_after_load(load_result)
elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)): elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)):
load_result = load_sharded_checkpoint( load_result = load_sharded_checkpoint(
......
...@@ -177,6 +177,8 @@ from .import_utils import ( ...@@ -177,6 +177,8 @@ from .import_utils import (
WEIGHTS_NAME = "pytorch_model.bin" WEIGHTS_NAME = "pytorch_model.bin"
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json" WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
ADAPTER_WEIGHTS_NAME = "adapter_model.bin"
ADAPTER_SAFE_WEIGHTS_NAME = "adapter_model.safetensors"
TF2_WEIGHTS_NAME = "tf_model.h5" TF2_WEIGHTS_NAME = "tf_model.h5"
TF2_WEIGHTS_INDEX_NAME = "tf_model.h5.index.json" TF2_WEIGHTS_INDEX_NAME = "tf_model.h5.index.json"
TF_WEIGHTS_NAME = "model.ckpt" TF_WEIGHTS_NAME = "model.ckpt"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment