Unverified Commit 4a38166a authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

Allow saving `None` pipeline components (#1118)

* Allow saving `None` pipeline components

* support flax as well

* style
parent 0edf9ca0
......@@ -161,6 +161,10 @@ class FlaxDiffusionPipeline(ConfigMixin):
for pipeline_component_name in model_index_dict.keys():
sub_model = getattr(self, pipeline_component_name)
if sub_model is None:
# edge case for saving a pipeline with safety_checker=None
continue
model_cls = sub_model.__class__
save_method_name = None
......@@ -367,6 +371,11 @@ class FlaxDiffusionPipeline(ConfigMixin):
# 3. Load each module in the pipeline
for name, (library_name, class_name) in init_dict.items():
if class_name is None:
# edge case for when the pipeline was saved with safety_checker=None
init_kwargs[name] = None
continue
is_pipeline_module = hasattr(pipelines, library_name)
loaded_sub_model = None
sub_model_should_be_defined = True
......
......@@ -176,6 +176,10 @@ class DiffusionPipeline(ConfigMixin):
for pipeline_component_name in model_index_dict.keys():
sub_model = getattr(self, pipeline_component_name)
if sub_model is None:
# edge case for saving a pipeline with safety_checker=None
continue
model_cls = sub_model.__class__
save_method_name = None
......@@ -477,6 +481,11 @@ class DiffusionPipeline(ConfigMixin):
# 3. Load each module in the pipeline
for name, (library_name, class_name) in init_dict.items():
if class_name is None:
# edge case for when the pipeline was saved with safety_checker=None
init_kwargs[name] = None
continue
# 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
if class_name.startswith("Flax"):
class_name = class_name[4:]
......
......@@ -15,6 +15,7 @@
import gc
import random
import tempfile
import time
import unittest
......@@ -318,6 +319,16 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
image = pipe("example prompt", num_inference_steps=2).images[0]
assert image is not None
# check that there's no error when saving a pipeline with one of the models being None
with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname)
pipe = StableDiffusionPipeline.from_pretrained(tmpdirname)
# sanity check that the pipeline still works
assert pipe.safety_checker is None
image = pipe("example prompt", num_inference_steps=2).images[0]
assert image is not None
def test_stable_diffusion_k_lms(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment