Unverified Commit 12fbe3f7 authored by hlky's avatar hlky Committed by GitHub
Browse files

Use Pipelines without unet (#10440)



* Use Pipelines without unet

* unet.config.in_channels

* default_sample_size

* is_unet_version_less_0_9_0

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 83ba01a3
...@@ -181,10 +181,14 @@ class StableDiffusionReferencePipeline( ...@@ -181,10 +181,14 @@ class StableDiffusionReferencePipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( is_unet_version_less_0_9_0 = (
version.parse(unet.config._diffusers_version).base_version unet is not None
) < version.parse("0.9.0.dev0") and hasattr(unet.config, "_diffusers_version")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = ( deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than" "The configuration file of the unet has set the default `sample_size` to smaller than"
...@@ -202,7 +206,7 @@ class StableDiffusionReferencePipeline( ...@@ -202,7 +206,7 @@ class StableDiffusionReferencePipeline(
new_config["sample_size"] = 64 new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config) unet._internal_dict = FrozenDict(new_config)
# Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4 # Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4
if unet.config.in_channels != 4: if unet is not None and unet.config.in_channels != 4:
logger.warning( logger.warning(
f"You have loaded a UNet with {unet.config.in_channels} input channels, whereas by default," f"You have loaded a UNet with {unet.config.in_channels} input channels, whereas by default,"
f" {self.__class__} assumes that `pipeline.unet` has 4 input channels: 4 for `num_channels_latents`," f" {self.__class__} assumes that `pipeline.unet` has 4 input channels: 4 for `num_channels_latents`,"
......
...@@ -236,10 +236,14 @@ class StableDiffusionRepaintPipeline( ...@@ -236,10 +236,14 @@ class StableDiffusionRepaintPipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( is_unet_version_less_0_9_0 = (
version.parse(unet.config._diffusers_version).base_version unet is not None
) < version.parse("0.9.0.dev0") and hasattr(unet.config, "_diffusers_version")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = ( deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than" "The configuration file of the unet has set the default `sample_size` to smaller than"
...@@ -257,7 +261,7 @@ class StableDiffusionRepaintPipeline( ...@@ -257,7 +261,7 @@ class StableDiffusionRepaintPipeline(
new_config["sample_size"] = 64 new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config) unet._internal_dict = FrozenDict(new_config)
# Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4 # Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4
if unet.config.in_channels != 4: if unet is not None and unet.config.in_channels != 4:
logger.warning( logger.warning(
f"You have loaded a UNet with {unet.config.in_channels} input channels, whereas by default," f"You have loaded a UNet with {unet.config.in_channels} input channels, whereas by default,"
f" {self.__class__} assumes that `pipeline.unet` has 4 input channels: 4 for `num_channels_latents`," f" {self.__class__} assumes that `pipeline.unet` has 4 input channels: 4 for `num_channels_latents`,"
......
...@@ -753,10 +753,14 @@ class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -753,10 +753,14 @@ class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( is_unet_version_less_0_9_0 = (
version.parse(unet.config._diffusers_version).base_version unet is not None
) < version.parse("0.9.0.dev0") and hasattr(unet.config, "_diffusers_version")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = ( deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than" "The configuration file of the unet has set the default `sample_size` to smaller than"
......
...@@ -757,10 +757,14 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -757,10 +757,14 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( is_unet_version_less_0_9_0 = (
version.parse(unet.config._diffusers_version).base_version unet is not None
) < version.parse("0.9.0.dev0") and hasattr(unet.config, "_diffusers_version")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = ( deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than" "The configuration file of the unet has set the default `sample_size` to smaller than"
......
...@@ -669,10 +669,14 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline): ...@@ -669,10 +669,14 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( is_unet_version_less_0_9_0 = (
version.parse(unet.config._diffusers_version).base_version unet is not None
) < version.parse("0.9.0.dev0") and hasattr(unet.config, "_diffusers_version")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = ( deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than" "The configuration file of the unet has set the default `sample_size` to smaller than"
......
...@@ -319,7 +319,11 @@ class AnimateDiffSDXLPipeline( ...@@ -319,7 +319,11 @@ class AnimateDiffSDXLPipeline(
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = self.unet.config.sample_size self.default_sample_size = (
self.unet.config.sample_size
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
else 128
)
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt with num_images_per_prompt->num_videos_per_prompt # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt with num_images_per_prompt->num_videos_per_prompt
def encode_prompt( def encode_prompt(
......
...@@ -184,7 +184,7 @@ class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline, StableDiffusionLoraLoa ...@@ -184,7 +184,7 @@ class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline, StableDiffusionLoraLoa
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
if unet.config.in_channels != 6: if unet is not None and unet.config.in_channels != 6:
logger.warning( logger.warning(
"It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`." "It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`."
) )
......
...@@ -186,7 +186,7 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, StableDiffusionLora ...@@ -186,7 +186,7 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, StableDiffusionLora
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
if unet.config.in_channels != 6: if unet is not None and unet.config.in_channels != 6:
logger.warning( logger.warning(
"It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`." "It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`."
) )
......
...@@ -142,7 +142,7 @@ class IFSuperResolutionPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixi ...@@ -142,7 +142,7 @@ class IFSuperResolutionPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixi
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
if unet.config.in_channels != 6: if unet is not None and unet.config.in_channels != 6:
logger.warning( logger.warning(
"It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`." "It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`."
) )
......
...@@ -253,10 +253,14 @@ class AltDiffusionPipeline( ...@@ -253,10 +253,14 @@ class AltDiffusionPipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( is_unet_version_less_0_9_0 = (
version.parse(unet.config._diffusers_version).base_version unet is not None
) < version.parse("0.9.0.dev0") and hasattr(unet.config, "_diffusers_version")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = ( deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than" "The configuration file of the unet has set the default `sample_size` to smaller than"
......
...@@ -281,10 +281,14 @@ class AltDiffusionImg2ImgPipeline( ...@@ -281,10 +281,14 @@ class AltDiffusionImg2ImgPipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( is_unet_version_less_0_9_0 = (
version.parse(unet.config._diffusers_version).base_version unet is not None
) < version.parse("0.9.0.dev0") and hasattr(unet.config, "_diffusers_version")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = ( deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than" "The configuration file of the unet has set the default `sample_size` to smaller than"
......
...@@ -213,10 +213,14 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Sta ...@@ -213,10 +213,14 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Sta
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( is_unet_version_less_0_9_0 = (
version.parse(unet.config._diffusers_version).base_version unet is not None
) < version.parse("0.9.0.dev0") and hasattr(unet.config, "_diffusers_version")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = ( deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than" "The configuration file of the unet has set the default `sample_size` to smaller than"
......
...@@ -183,10 +183,14 @@ class StableDiffusionInpaintPipelineLegacy( ...@@ -183,10 +183,14 @@ class StableDiffusionInpaintPipelineLegacy(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( is_unet_version_less_0_9_0 = (
version.parse(unet.config._diffusers_version).base_version unet is not None
) < version.parse("0.9.0.dev0") and hasattr(unet.config, "_diffusers_version")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = ( deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than" "The configuration file of the unet has set the default `sample_size` to smaller than"
......
...@@ -191,7 +191,11 @@ class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLL ...@@ -191,7 +191,11 @@ class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLL
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = self.unet.config.sample_size self.default_sample_size = (
self.unet.config.sample_size
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
else 128
)
def encode_prompt( def encode_prompt(
self, self,
......
...@@ -210,7 +210,11 @@ class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffu ...@@ -210,7 +210,11 @@ class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffu
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = self.unet.config.sample_size self.default_sample_size = (
self.unet.config.sample_size
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
else 128
)
# Copied from diffusers.pipelines.kolors.pipeline_kolors.KolorsPipeline.encode_prompt # Copied from diffusers.pipelines.kolors.pipeline_kolors.KolorsPipeline.encode_prompt
def encode_prompt( def encode_prompt(
......
...@@ -368,10 +368,14 @@ class LEditsPPPipelineStableDiffusion( ...@@ -368,10 +368,14 @@ class LEditsPPPipelineStableDiffusion(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( is_unet_version_less_0_9_0 = (
version.parse(unet.config._diffusers_version).base_version unet is not None
) < version.parse("0.9.0.dev0") and hasattr(unet.config, "_diffusers_version")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = ( deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than" "The configuration file of the unet has set the default `sample_size` to smaller than"
......
...@@ -384,7 +384,11 @@ class LEditsPPPipelineStableDiffusionXL( ...@@ -384,7 +384,11 @@ class LEditsPPPipelineStableDiffusionXL(
"The scheduler has been changed to DPMSolverMultistepScheduler." "The scheduler has been changed to DPMSolverMultistepScheduler."
) )
self.default_sample_size = self.unet.config.sample_size self.default_sample_size = (
self.unet.config.sample_size
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
else 128
)
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
......
...@@ -205,7 +205,11 @@ class KolorsPAGPipeline( ...@@ -205,7 +205,11 @@ class KolorsPAGPipeline(
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = self.unet.config.sample_size self.default_sample_size = (
self.unet.config.sample_size
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
else 128
)
self.set_pag_applied_layers(pag_applied_layers) self.set_pag_applied_layers(pag_applied_layers)
......
...@@ -259,10 +259,14 @@ class StableDiffusionPAGPipeline( ...@@ -259,10 +259,14 @@ class StableDiffusionPAGPipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( is_unet_version_less_0_9_0 = (
version.parse(unet.config._diffusers_version).base_version unet is not None
) < version.parse("0.9.0.dev0") and hasattr(unet.config, "_diffusers_version")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = ( deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than" "The configuration file of the unet has set the default `sample_size` to smaller than"
......
...@@ -254,10 +254,14 @@ class StableDiffusionPAGImg2ImgPipeline( ...@@ -254,10 +254,14 @@ class StableDiffusionPAGImg2ImgPipeline(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
) )
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( is_unet_version_less_0_9_0 = (
version.parse(unet.config._diffusers_version).base_version unet is not None
) < version.parse("0.9.0.dev0") and hasattr(unet.config, "_diffusers_version")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64 = (
unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = ( deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than" "The configuration file of the unet has set the default `sample_size` to smaller than"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment