Unverified Commit 2be8a909 authored by raghavanone's avatar raghavanone Committed by GitHub
Browse files

Save image_processor while saving pipeline (ImageSegmentationPipeline) (#25884)

* Save image_processor while saving pipeline (ImageSegmentationPipeline)

* Fix black issues
parent a39ebbf8
...@@ -872,6 +872,9 @@ class Pipeline(_ScikitCompat): ...@@ -872,6 +872,9 @@ class Pipeline(_ScikitCompat):
if self.feature_extractor is not None: if self.feature_extractor is not None:
self.feature_extractor.save_pretrained(save_directory) self.feature_extractor.save_pretrained(save_directory)
if self.image_processor is not None:
self.image_processor.save_pretrained(save_directory)
if self.modelcard is not None: if self.modelcard is not None:
self.modelcard.save_pretrained(save_directory) self.modelcard.save_pretrained(save_directory)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import hashlib import hashlib
import tempfile
import unittest import unittest
from typing import Dict from typing import Dict
...@@ -714,3 +715,17 @@ class ImageSegmentationPipelineTests(unittest.TestCase): ...@@ -714,3 +715,17 @@ class ImageSegmentationPipelineTests(unittest.TestCase):
}, },
], ],
) )
def test_save_load(self):
model_id = "hf-internal-testing/tiny-detr-mobilenetsv3-panoptic"
model = AutoModelForImageSegmentation.from_pretrained(model_id)
image_processor = AutoImageProcessor.from_pretrained(model_id)
image_segmenter = pipeline(
task="image-segmentation",
model=model,
image_processor=image_processor,
)
with tempfile.TemporaryDirectory() as tmpdirname:
image_segmenter.save_pretrained(tmpdirname)
pipeline(task="image-segmentation", model=tmpdirname)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment