Unverified Commit 12fbe3f7 authored by hlky's avatar hlky Committed by GitHub
Browse files

Use Pipelines without unet (#10440)



* Use Pipelines without unet

* unet.config.in_channels

* default_sample_size

* is_unet_version_less_0_9_0

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 83ba01a3
...@@ -416,10 +416,14 @@ class AdaptiveMaskInpaintPipeline( ...@@ -416,10 +416,14 @@ class AdaptiveMaskInpaintPipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( is_unet_version_less_0_9_0 = (
version.parse(unet.config._diffusers_version).base_version unet is not None
) < version.parse("0.9.0.dev0") and hasattr(unet.config, "_diffusers_version")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = ( deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than" "The configuration file of the unet has set the default `sample_size` to smaller than"
...@@ -438,7 +442,7 @@ class AdaptiveMaskInpaintPipeline( ...@@ -438,7 +442,7 @@ class AdaptiveMaskInpaintPipeline(
unet._internal_dict = FrozenDict(new_config) unet._internal_dict = FrozenDict(new_config)
# Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4 # Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4
if unet.config.in_channels != 9: if unet is not None and unet.config.in_channels != 9:
logger.info(f"You have loaded a UNet with {unet.config.in_channels} input channels which.") logger.info(f"You have loaded a UNet with {unet.config.in_channels} input channels which.")
self.register_modules( self.register_modules(
......
...@@ -132,10 +132,14 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin) ...@@ -132,10 +132,14 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin)
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( is_unet_version_less_0_9_0 = (
version.parse(unet.config._diffusers_version).base_version unet is not None
) < version.parse("0.9.0.dev0") and hasattr(unet.config, "_diffusers_version")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = ( deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than" "The configuration file of the unet has set the default `sample_size` to smaller than"
......
...@@ -152,10 +152,14 @@ class InstaFlowPipeline( ...@@ -152,10 +152,14 @@ class InstaFlowPipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( is_unet_version_less_0_9_0 = (
version.parse(unet.config._diffusers_version).base_version unet is not None
) < version.parse("0.9.0.dev0") and hasattr(unet.config, "_diffusers_version")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = ( deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than" "The configuration file of the unet has set the default `sample_size` to smaller than"
......
...@@ -234,10 +234,14 @@ class IPAdapterFaceIDStableDiffusionPipeline( ...@@ -234,10 +234,14 @@ class IPAdapterFaceIDStableDiffusionPipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( is_unet_version_less_0_9_0 = (
version.parse(unet.config._diffusers_version).base_version unet is not None
) < version.parse("0.9.0.dev0") and hasattr(unet.config, "_diffusers_version")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = ( deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than" "The configuration file of the unet has set the default `sample_size` to smaller than"
......
...@@ -379,10 +379,14 @@ class LLMGroundedDiffusionPipeline( ...@@ -379,10 +379,14 @@ class LLMGroundedDiffusionPipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( is_unet_version_less_0_9_0 = (
version.parse(unet.config._diffusers_version).base_version unet is not None
) < version.parse("0.9.0.dev0") and hasattr(unet.config, "_diffusers_version")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = ( deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than" "The configuration file of the unet has set the default `sample_size` to smaller than"
......
...@@ -539,10 +539,14 @@ class StableDiffusionLongPromptWeightingPipeline( ...@@ -539,10 +539,14 @@ class StableDiffusionLongPromptWeightingPipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( is_unet_version_less_0_9_0 = (
version.parse(unet.config._diffusers_version).base_version unet is not None
) < version.parse("0.9.0.dev0") and hasattr(unet.config, "_diffusers_version")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = ( deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than" "The configuration file of the unet has set the default `sample_size` to smaller than"
......
...@@ -678,7 +678,11 @@ class SDXLLongPromptWeightingPipeline( ...@@ -678,7 +678,11 @@ class SDXLLongPromptWeightingPipeline(
self.mask_processor = VaeImageProcessor( self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
) )
self.default_sample_size = self.unet.config.sample_size self.default_sample_size = (
self.unet.config.sample_size
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
else 128
)
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
......
...@@ -3793,10 +3793,14 @@ class MatryoshkaPipeline( ...@@ -3793,10 +3793,14 @@ class MatryoshkaPipeline(
# new_config["clip_sample"] = False # new_config["clip_sample"] = False
# scheduler._internal_dict = FrozenDict(new_config) # scheduler._internal_dict = FrozenDict(new_config)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( is_unet_version_less_0_9_0 = (
version.parse(unet.config._diffusers_version).base_version unet is not None
) < version.parse("0.9.0.dev0") and hasattr(unet.config, "_diffusers_version")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = ( deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than" "The configuration file of the unet has set the default `sample_size` to smaller than"
......
...@@ -168,7 +168,11 @@ class DemoFusionSDXLPipeline( ...@@ -168,7 +168,11 @@ class DemoFusionSDXLPipeline(
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = self.unet.config.sample_size self.default_sample_size = (
self.unet.config.sample_size
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
else 128
)
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
......
...@@ -150,10 +150,14 @@ class FabricPipeline(DiffusionPipeline): ...@@ -150,10 +150,14 @@ class FabricPipeline(DiffusionPipeline):
): ):
super().__init__() super().__init__()
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( is_unet_version_less_0_9_0 = (
version.parse(unet.config._diffusers_version).base_version unet is not None
) < version.parse("0.9.0.dev0") and hasattr(unet.config, "_diffusers_version")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = ( deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than" "The configuration file of the unet has set the default `sample_size` to smaller than"
......
...@@ -216,7 +216,11 @@ class KolorsDifferentialImg2ImgPipeline( ...@@ -216,7 +216,11 @@ class KolorsDifferentialImg2ImgPipeline(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_convert_grayscale=True vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_convert_grayscale=True
) )
self.default_sample_size = self.unet.config.sample_size self.default_sample_size = (
self.unet.config.sample_size
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
else 128
)
# Copied from diffusers.pipelines.kolors.pipeline_kolors.KolorsPipeline.encode_prompt # Copied from diffusers.pipelines.kolors.pipeline_kolors.KolorsPipeline.encode_prompt
def encode_prompt( def encode_prompt(
......
...@@ -174,10 +174,14 @@ class Prompt2PromptPipeline( ...@@ -174,10 +174,14 @@ class Prompt2PromptPipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( is_unet_version_less_0_9_0 = (
version.parse(unet.config._diffusers_version).base_version unet is not None
) < version.parse("0.9.0.dev0") and hasattr(unet.config, "_diffusers_version")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = ( deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than" "The configuration file of the unet has set the default `sample_size` to smaller than"
......
...@@ -494,7 +494,11 @@ class StyleAlignedSDXLPipeline( ...@@ -494,7 +494,11 @@ class StyleAlignedSDXLPipeline(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
) )
self.default_sample_size = self.unet.config.sample_size self.default_sample_size = (
self.unet.config.sample_size
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
else 128
)
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
......
...@@ -460,10 +460,14 @@ class StableDiffusionBoxDiffPipeline( ...@@ -460,10 +460,14 @@ class StableDiffusionBoxDiffPipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( is_unet_version_less_0_9_0 = (
version.parse(unet.config._diffusers_version).base_version unet is not None
) < version.parse("0.9.0.dev0") and hasattr(unet.config, "_diffusers_version")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = ( deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than" "The configuration file of the unet has set the default `sample_size` to smaller than"
......
...@@ -427,10 +427,14 @@ class StableDiffusionPAGPipeline( ...@@ -427,10 +427,14 @@ class StableDiffusionPAGPipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( is_unet_version_less_0_9_0 = (
version.parse(unet.config._diffusers_version).base_version unet is not None
) < version.parse("0.9.0.dev0") and hasattr(unet.config, "_diffusers_version")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = ( deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than" "The configuration file of the unet has set the default `sample_size` to smaller than"
......
...@@ -231,7 +231,11 @@ class StableDiffusionXLControlNetAdapterPipeline( ...@@ -231,7 +231,11 @@ class StableDiffusionXLControlNetAdapterPipeline(
self.control_image_processor = VaeImageProcessor( self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
) )
self.default_sample_size = self.unet.config.sample_size self.default_sample_size = (
self.unet.config.sample_size
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
else 128
)
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
def encode_prompt( def encode_prompt(
......
...@@ -379,7 +379,11 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline( ...@@ -379,7 +379,11 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(
self.control_image_processor = VaeImageProcessor( self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
) )
self.default_sample_size = self.unet.config.sample_size self.default_sample_size = (
self.unet.config.sample_size
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
else 128
)
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
def encode_prompt( def encode_prompt(
......
...@@ -256,7 +256,11 @@ class StableDiffusionXLPipelineIpex( ...@@ -256,7 +256,11 @@ class StableDiffusionXLPipelineIpex(
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = self.unet.config.sample_size self.default_sample_size = (
self.unet.config.sample_size
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
else 128
)
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
......
...@@ -151,10 +151,14 @@ class Zero1to3StableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin): ...@@ -151,10 +151,14 @@ class Zero1to3StableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( is_unet_version_less_0_9_0 = (
version.parse(unet.config._diffusers_version).base_version unet is not None
) < version.parse("0.9.0.dev0") and hasattr(unet.config, "_diffusers_version")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = ( deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than" "The configuration file of the unet has set the default `sample_size` to smaller than"
......
...@@ -148,10 +148,14 @@ class StableDiffusionIPEXPipeline( ...@@ -148,10 +148,14 @@ class StableDiffusionIPEXPipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( is_unet_version_less_0_9_0 = (
version.parse(unet.config._diffusers_version).base_version unet is not None
) < version.parse("0.9.0.dev0") and hasattr(unet.config, "_diffusers_version")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = ( deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than" "The configuration file of the unet has set the default `sample_size` to smaller than"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment