Unverified Commit df2b548e authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Make safety_checker optional in more pipelines (#1796)



* Make safety_checker optional in more pipelines.

* Remove inappropriate comment in inpaint pipeline.

* InPaint Test: set feature_extractor to None.

* Remove import

* img2img test: set feature_extractor to None.

* inpaint sd2 test: set feature_extractor to None.
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
parent b6d47023
...@@ -92,7 +92,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -92,7 +92,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker"] _optional_components = ["safety_checker", "feature_extractor"]
def __init__( def __init__(
self, self,
......
...@@ -161,6 +161,8 @@ class PaintByExamplePipeline(DiffusionPipeline): ...@@ -161,6 +161,8 @@ class PaintByExamplePipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
# TODO: feature_extractor is required to encode initial images (if they are in PIL format),
# we should give a descriptive message if the pipeline doesn't have one.
_optional_components = ["safety_checker"] _optional_components = ["safety_checker"]
def __init__( def __init__(
......
...@@ -65,6 +65,8 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -65,6 +65,8 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
# TODO: feature_extractor is required to encode images (if they are in PIL format),
# we should give a descriptive message if the pipeline doesn't have one.
_optional_components = ["safety_checker"] _optional_components = ["safety_checker"]
def __init__( def __init__(
......
...@@ -90,7 +90,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -90,7 +90,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker"] _optional_components = ["safety_checker", "feature_extractor"]
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__
def __init__( def __init__(
......
...@@ -166,7 +166,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -166,7 +166,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]): feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
_optional_components = ["safety_checker"] _optional_components = ["safety_checker", "feature_extractor"]
def __init__( def __init__(
self, self,
......
...@@ -31,7 +31,7 @@ from diffusers import ( ...@@ -31,7 +31,7 @@ from diffusers import (
) )
from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu from diffusers.utils.testing_utils import require_torch_gpu
from transformers import CLIPImageProcessor, CLIPTextConfig, CLIPTextModel, CLIPTokenizer from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from ...test_pipelines_common import PipelineTesterMixin from ...test_pipelines_common import PipelineTesterMixin
...@@ -78,7 +78,6 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test ...@@ -78,7 +78,6 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
) )
text_encoder = CLIPTextModel(text_encoder_config) text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
feature_extractor = CLIPImageProcessor(crop_size=32, size=32)
components = { components = {
"unet": unet, "unet": unet,
...@@ -87,7 +86,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test ...@@ -87,7 +86,7 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
"text_encoder": text_encoder, "text_encoder": text_encoder,
"tokenizer": tokenizer, "tokenizer": tokenizer,
"safety_checker": None, "safety_checker": None,
"feature_extractor": feature_extractor, "feature_extractor": None,
} }
return components return components
......
...@@ -32,7 +32,7 @@ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint impo ...@@ -32,7 +32,7 @@ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint impo
from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu from diffusers.utils.testing_utils import require_torch_gpu
from PIL import Image from PIL import Image
from transformers import CLIPImageProcessor, CLIPTextConfig, CLIPTextModel, CLIPTokenizer from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from ...test_pipelines_common import PipelineTesterMixin from ...test_pipelines_common import PipelineTesterMixin
...@@ -79,7 +79,6 @@ class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test ...@@ -79,7 +79,6 @@ class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test
) )
text_encoder = CLIPTextModel(text_encoder_config) text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
feature_extractor = CLIPImageProcessor(crop_size=32, size=32)
components = { components = {
"unet": unet, "unet": unet,
...@@ -88,7 +87,7 @@ class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test ...@@ -88,7 +87,7 @@ class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test
"text_encoder": text_encoder, "text_encoder": text_encoder,
"tokenizer": tokenizer, "tokenizer": tokenizer,
"safety_checker": None, "safety_checker": None,
"feature_extractor": feature_extractor, "feature_extractor": None,
} }
return components return components
......
...@@ -24,7 +24,7 @@ from diffusers import AutoencoderKL, PNDMScheduler, StableDiffusionInpaintPipeli ...@@ -24,7 +24,7 @@ from diffusers import AutoencoderKL, PNDMScheduler, StableDiffusionInpaintPipeli
from diffusers.utils import floats_tensor, load_image, load_numpy, torch_device from diffusers.utils import floats_tensor, load_image, load_numpy, torch_device
from diffusers.utils.testing_utils import require_torch_gpu, slow from diffusers.utils.testing_utils import require_torch_gpu, slow
from PIL import Image from PIL import Image
from transformers import CLIPImageProcessor, CLIPTextConfig, CLIPTextModel, CLIPTokenizer from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from ...test_pipelines_common import PipelineTesterMixin from ...test_pipelines_common import PipelineTesterMixin
...@@ -78,7 +78,6 @@ class StableDiffusion2InpaintPipelineFastTests(PipelineTesterMixin, unittest.Tes ...@@ -78,7 +78,6 @@ class StableDiffusion2InpaintPipelineFastTests(PipelineTesterMixin, unittest.Tes
) )
text_encoder = CLIPTextModel(text_encoder_config) text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
feature_extractor = CLIPImageProcessor(crop_size=32, size=32)
components = { components = {
"unet": unet, "unet": unet,
...@@ -87,7 +86,7 @@ class StableDiffusion2InpaintPipelineFastTests(PipelineTesterMixin, unittest.Tes ...@@ -87,7 +86,7 @@ class StableDiffusion2InpaintPipelineFastTests(PipelineTesterMixin, unittest.Tes
"text_encoder": text_encoder, "text_encoder": text_encoder,
"tokenizer": tokenizer, "tokenizer": tokenizer,
"safety_checker": None, "safety_checker": None,
"feature_extractor": feature_extractor, "feature_extractor": None,
} }
return components return components
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment