"test/git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "7de5f98c18bcf358178e8cae090b82ae3f76a338"
Unverified Commit 6daf7c31 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

Support PEFT models when saving the model using trainer (#24073)

* support PEFT models when saving the model using trainer

* fixup
parent 1e4a7737
...@@ -147,6 +147,7 @@ from .utils import ( ...@@ -147,6 +147,7 @@ from .utils import (
is_datasets_available, is_datasets_available,
is_in_notebook, is_in_notebook,
is_ipex_available, is_ipex_available,
is_peft_available,
is_safetensors_available, is_safetensors_available,
is_sagemaker_dp_enabled, is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled, is_sagemaker_mp_enabled,
...@@ -205,6 +206,10 @@ if is_safetensors_available(): ...@@ -205,6 +206,10 @@ if is_safetensors_available():
import safetensors.torch import safetensors.torch
if is_peft_available():
from peft import PeftModel
skip_first_batches = None skip_first_batches = None
if is_accelerate_available(): if is_accelerate_available():
from accelerate import __version__ as accelerate_version from accelerate import __version__ as accelerate_version
...@@ -2897,13 +2902,15 @@ class Trainer: ...@@ -2897,13 +2902,15 @@ class Trainer:
output_dir = output_dir if output_dir is not None else self.args.output_dir output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}") logger.info(f"Saving model checkpoint to {output_dir}")
supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
# Save a trained model and configuration using `save_pretrained()`. # Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()` # They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, PreTrainedModel): if not isinstance(self.model, supported_classes):
if state_dict is None: if state_dict is None:
state_dict = self.model.state_dict() state_dict = self.model.state_dict()
if isinstance(unwrap_model(self.model), PreTrainedModel): if isinstance(unwrap_model(self.model), supported_classes):
unwrap_model(self.model).save_pretrained( unwrap_model(self.model).save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment