Unverified Commit c9e3c0b4 authored by Shubham Krishna's avatar Shubham Krishna Committed by GitHub
Browse files

[`PEFT`] Fix `save_pretrained` to make sure adapters weights are also saved on TPU (#29388)

* Fix for saving ad
apter weights when using PEFT

* Change supported-classes to PushToHubMixin
parent b4b96251
......@@ -134,6 +134,7 @@ from .utils import (
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
PushInProgress,
PushToHubMixin,
can_return_loss,
find_labels,
is_accelerate_available,
......@@ -3019,9 +3020,10 @@ class Trainer:
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
supported_classes = (PushToHubMixin,)
xm.rendezvous("saving_checkpoint")
if not isinstance(model, PreTrainedModel):
if isinstance(unwrap_model(model), PreTrainedModel):
if not isinstance(model, supported_classes):
if isinstance(unwrap_model(model), supported_classes):
unwrap_model(model).save_pretrained(
output_dir,
is_main_process=self.args.should_save,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment