Unverified Commit 12fbe3f7 authored by hlky's avatar hlky Committed by GitHub
Browse files

Use Pipelines without unet (#10440)



* Use Pipelines without unet

* unet.config.in_channels

* default_sample_size

* is_unet_version_less_0_9_0

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 83ba01a3
......@@ -286,10 +286,14 @@ class StableDiffusionPAGInpaintPipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
......
......@@ -278,7 +278,11 @@ class StableDiffusionXLPAGPipeline(
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = self.unet.config.sample_size
self.default_sample_size = (
self.unet.config.sample_size
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
else 128
)
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
......
......@@ -132,10 +132,14 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
......
......@@ -159,10 +159,14 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
......
......@@ -254,12 +254,15 @@ class StableDiffusionPipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
self._is_unet_config_sample_size_int = isinstance(unet.config.sample_size, int)
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
self._is_unet_config_sample_size_int = unet is not None and isinstance(unet.config.sample_size, int)
is_unet_sample_size_less_64 = (
hasattr(unet.config, "sample_size")
unet is not None
and hasattr(unet.config, "sample_size")
and self._is_unet_config_sample_size_int
and unet.config.sample_size < 64
)
......
......@@ -130,10 +130,14 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
):
super().__init__()
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
......
......@@ -104,10 +104,14 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline, StableDiffusionMi
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
......
......@@ -282,10 +282,14 @@ class StableDiffusionImg2ImgPipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
......
......@@ -229,10 +229,14 @@ class StableDiffusionInpaintPipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
......@@ -251,7 +255,7 @@ class StableDiffusionInpaintPipeline(
unet._internal_dict = FrozenDict(new_config)
# Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4
if unet.config.in_channels != 9:
if unet is not None and unet.config.in_channels != 9:
logger.info(f"You have loaded a UNet with {unet.config.in_channels} input channels which.")
self.register_modules(
......
......@@ -344,10 +344,14 @@ class StableDiffusionDiffEditPipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
......
......@@ -173,7 +173,11 @@ class StableDiffusionXLKDiffusionPipeline(
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = self.unet.config.sample_size
self.default_sample_size = (
self.unet.config.sample_size
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
else 128
)
model = ModelWrapper(unet, scheduler.alphas_cumprod)
if scheduler.config.prediction_type == "v_prediction":
......
......@@ -124,10 +124,14 @@ class StableDiffusionPipelineSafe(DiffusionPipeline, StableDiffusionMixin, IPAda
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
is_unet_version_less_0_9_0 = (
unet is not None
and hasattr(unet.config, "_diffusers_version")
and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
......
......@@ -272,7 +272,11 @@ class StableDiffusionXLPipeline(
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = self.unet.config.sample_size
self.default_sample_size = (
self.unet.config.sample_size
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
else 128
)
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
......
......@@ -201,7 +201,11 @@ class StableDiffusionXLInstructPix2PixPipeline(
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = self.unet.config.sample_size
self.default_sample_size = (
self.unet.config.sample_size
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
else 128
)
self.is_cosxl_edit = is_cosxl_edit
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
......
......@@ -304,7 +304,11 @@ class StableDiffusionXLAdapterPipeline(
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = self.unet.config.sample_size
self.default_sample_size = (
self.unet.config.sample_size
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
else 128
)
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
def encode_prompt(
......
......@@ -422,7 +422,11 @@ class TextToVideoZeroSDXLPipeline(
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = self.unet.config.sample_size
self.default_sample_size = (
self.unet.config.sample_size
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
else 128
)
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment