Unverified Commit 183f442b authored by Llohann Dallagnol Speranca's avatar Llohann Dallagnol Speranca Committed by GitHub
Browse files

Fix resuming PeftModel checkpoints in Trainer (#24274)



* Fix resuming checkpoints for PeftModels

Fix an error occurred when resuming a PeftModel from a training checkpoint. That was caused since PeftModel.pre_trained saves only adapter-related data while _load_from_checkpoint was expecting a torch sved model. This PR fix this issue and allows the adapter checkpoint to be loaded.

Resolves: #24252

* fix last comment

* fix nits

---------
Co-authored-by: default avataryounesbelkada <younesbelkada@gmail.com>
parent 0875b250
......@@ -1979,14 +1979,23 @@ class Trainer:
model = self.model
config_file = os.path.join(resume_from_checkpoint, CONFIG_NAME)
adapter_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_WEIGHTS_NAME)
adapter_safe_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME)
weights_file = os.path.join(resume_from_checkpoint, WEIGHTS_NAME)
weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME)
safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME)
if not any(
os.path.isfile(f) for f in [weights_file, safe_weights_file, weights_index_file, safe_weights_index_file]
os.path.isfile(f)
for f in [
weights_file,
safe_weights_file,
weights_index_file,
safe_weights_index_file,
adapter_weights_file,
adapter_safe_weights_file,
]
):
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
......@@ -2039,6 +2048,21 @@ class Trainer:
# release memory
del state_dict
self._issue_warnings_after_load(load_result)
# Load adapters following PR # 24096
elif is_peft_available() and isinstance(model, PeftModel):
# If train a model using PEFT & LoRA, assume that adapter have been saved properly.
if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
if os.path.exists(resume_from_checkpoint):
model.load_adapter(resume_from_checkpoint, model.active_adapter)
else:
logger.warning(
"The intermediate checkpoints of PEFT may not be saved correctly, "
f"consider using a custom callback to save {ADAPTER_WEIGHTS_NAME} in corresponding saving folders. "
"Check some examples here: https://github.com/huggingface/peft/issues/96"
)
else:
logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed")
else:
# We load the sharded checkpoint
load_result = load_sharded_checkpoint(
......@@ -2102,8 +2126,8 @@ class Trainer:
else:
logger.warning(
"The intermediate checkpoints of PEFT may not be saved correctly, "
f"using `TrainerCallback` to save {ADAPTER_WEIGHTS_NAME} in corresponding folders, "
"here are some examples https://github.com/huggingface/peft/issues/96"
f"consider using a custom callback to save {ADAPTER_WEIGHTS_NAME} in corresponding saving folders. "
"Check some examples here: https://github.com/huggingface/peft/issues/96"
)
has_been_loaded = False
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment