Unverified Commit c9e3c0b4 authored by Shubham Krishna's avatar Shubham Krishna Committed by GitHub
Browse files

[`PEFT`] Fix `save_pretrained` to make sure adapters weights are also saved on TPU (#29388)

* Fix for saving ad
apter weights when using PEFT

* Change supported-classes to PushToHubMixin
parent b4b96251
...@@ -134,6 +134,7 @@ from .utils import ( ...@@ -134,6 +134,7 @@ from .utils import (
WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
PushInProgress, PushInProgress,
PushToHubMixin,
can_return_loss, can_return_loss,
find_labels, find_labels,
is_accelerate_available, is_accelerate_available,
...@@ -3019,9 +3020,10 @@ class Trainer: ...@@ -3019,9 +3020,10 @@ class Trainer:
# Save a trained model and configuration using `save_pretrained()`. # Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()` # They can then be reloaded using `from_pretrained()`
supported_classes = (PushToHubMixin,)
xm.rendezvous("saving_checkpoint") xm.rendezvous("saving_checkpoint")
if not isinstance(model, PreTrainedModel): if not isinstance(model, supported_classes):
if isinstance(unwrap_model(model), PreTrainedModel): if isinstance(unwrap_model(model), supported_classes):
unwrap_model(model).save_pretrained( unwrap_model(model).save_pretrained(
output_dir, output_dir,
is_main_process=self.args.should_save, is_main_process=self.args.should_save,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment