Unverified Commit 7824fa43 authored by Yessen Kanapin's avatar Yessen Kanapin Committed by GitHub
Browse files

expose safe_serialization argument in the pipeline API (#23775)



expose safe_serialization argument of PreTrainedModel and TFPreTrainedModel in the save_pretrained of the pipeline api
Co-authored-by: default avatarYessen Kanapin <yessen@deepinfra.com>
parent b4919cb5
......@@ -821,13 +821,15 @@ class Pipeline(_ScikitCompat):
# then we should keep working
self.image_processor = self.feature_extractor
def save_pretrained(self, save_directory: str):
def save_pretrained(self, save_directory: str, safe_serialization: bool = False):
"""
Save the pipeline's model and tokenizer.
Args:
save_directory (`str`):
A path to the directory where to saved. It will be created if it doesn't exist.
safe_serialization (`str`):
Whether to save the model using `safetensors` or the traditional way for PyTorch or Tensorflow
"""
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
......@@ -855,7 +857,7 @@ class Pipeline(_ScikitCompat):
# Save the pipeline custom code
custom_object_save(self, save_directory)
self.model.save_pretrained(save_directory)
self.model.save_pretrained(save_directory, safe_serialization=safe_serialization)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(save_directory)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment