"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "bdb3f50771e268b028af7e6d1fb3d9652a2821f7"
Unverified Commit 6daf7c31 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

Support PEFT models when saving the model using trainer (#24073)

* support PEFT models when saving the model using trainer

* fixup
parent 1e4a7737
......@@ -147,6 +147,7 @@ from .utils import (
is_datasets_available,
is_in_notebook,
is_ipex_available,
is_peft_available,
is_safetensors_available,
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
......@@ -205,6 +206,10 @@ if is_safetensors_available():
import safetensors.torch
if is_peft_available():
from peft import PeftModel
skip_first_batches = None
if is_accelerate_available():
from accelerate import __version__ as accelerate_version
......@@ -2897,13 +2902,15 @@ class Trainer:
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}")
supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, PreTrainedModel):
if not isinstance(self.model, supported_classes):
if state_dict is None:
state_dict = self.model.state_dict()
if isinstance(unwrap_model(self.model), PreTrainedModel):
if isinstance(unwrap_model(self.model), supported_classes):
unwrap_model(self.model).save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment