Unverified Commit ee7e141d authored by hlky's avatar hlky Committed by GitHub
Browse files

Use pipelines without vae (#10441)



* Use pipelines without vae

* getattr

* vqvae

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 01bd7964
...@@ -219,7 +219,7 @@ class StableDiffusionReferencePipeline( ...@@ -219,7 +219,7 @@ class StableDiffusionReferencePipeline(
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
......
...@@ -274,7 +274,7 @@ class StableDiffusionRepaintPipeline( ...@@ -274,7 +274,7 @@ class StableDiffusionRepaintPipeline(
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
......
...@@ -806,7 +806,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -806,7 +806,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline):
self.engine = {} # loaded in build_engines() self.engine = {} # loaded in build_engines()
self.vae.forward = self.vae.decode self.vae.forward = self.vae.decode
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
......
...@@ -810,7 +810,7 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -810,7 +810,7 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
self.engine = {} # loaded in build_engines() self.engine = {} # loaded in build_engines()
self.vae.forward = self.vae.decode self.vae.forward = self.vae.decode
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
......
...@@ -722,7 +722,7 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline): ...@@ -722,7 +722,7 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline):
self.engine = {} # loaded in build_engines() self.engine = {} # loaded in build_engines()
self.vae.forward = self.vae.decode self.vae.forward = self.vae.decode
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
......
...@@ -310,7 +310,7 @@ class PixArtAlphaControlnetPipeline(DiffusionPipeline): ...@@ -310,7 +310,7 @@ class PixArtAlphaControlnetPipeline(DiffusionPipeline):
controlnet=controlnet, controlnet=controlnet,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.control_image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) self.control_image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
......
...@@ -233,7 +233,7 @@ class PromptDiffusionPipeline( ...@@ -233,7 +233,7 @@ class PromptDiffusionPipeline(
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
image_encoder=image_encoder, image_encoder=image_encoder,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor( self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
......
...@@ -78,7 +78,7 @@ class RDMPipeline(DiffusionPipeline, StableDiffusionMixin): ...@@ -78,7 +78,7 @@ class RDMPipeline(DiffusionPipeline, StableDiffusionMixin):
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
# Copy from statement here and all the methods we take from stable_diffusion_pipeline # Copy from statement here and all the methods we take from stable_diffusion_pipeline
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.retriever = retriever self.retriever = retriever
......
...@@ -194,10 +194,10 @@ class AllegroPipeline(DiffusionPipeline): ...@@ -194,10 +194,10 @@ class AllegroPipeline(DiffusionPipeline):
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
) )
self.vae_scale_factor_spatial = ( self.vae_scale_factor_spatial = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
) )
self.vae_scale_factor_temporal = ( self.vae_scale_factor_temporal = (
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
) )
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
......
...@@ -66,7 +66,9 @@ class AmusedPipeline(DiffusionPipeline): ...@@ -66,7 +66,9 @@ class AmusedPipeline(DiffusionPipeline):
transformer=transformer, transformer=transformer,
scheduler=scheduler, scheduler=scheduler,
) )
self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) self.vae_scale_factor = (
2 ** (len(self.vqvae.config.block_out_channels) - 1) if getattr(self, "vqvae", None) else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
@torch.no_grad() @torch.no_grad()
......
...@@ -81,7 +81,9 @@ class AmusedImg2ImgPipeline(DiffusionPipeline): ...@@ -81,7 +81,9 @@ class AmusedImg2ImgPipeline(DiffusionPipeline):
transformer=transformer, transformer=transformer,
scheduler=scheduler, scheduler=scheduler,
) )
self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) self.vae_scale_factor = (
2 ** (len(self.vqvae.config.block_out_channels) - 1) if getattr(self, "vqvae", None) else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
@torch.no_grad() @torch.no_grad()
......
...@@ -89,7 +89,9 @@ class AmusedInpaintPipeline(DiffusionPipeline): ...@@ -89,7 +89,9 @@ class AmusedInpaintPipeline(DiffusionPipeline):
transformer=transformer, transformer=transformer,
scheduler=scheduler, scheduler=scheduler,
) )
self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) self.vae_scale_factor = (
2 ** (len(self.vqvae.config.block_out_channels) - 1) if getattr(self, "vqvae", None) else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
self.mask_processor = VaeImageProcessor( self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, vae_scale_factor=self.vae_scale_factor,
......
...@@ -139,7 +139,7 @@ class AnimateDiffPipeline( ...@@ -139,7 +139,7 @@ class AnimateDiffPipeline(
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
image_encoder=image_encoder, image_encoder=image_encoder,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
......
...@@ -180,7 +180,7 @@ class AnimateDiffControlNetPipeline( ...@@ -180,7 +180,7 @@ class AnimateDiffControlNetPipeline(
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
image_encoder=image_encoder, image_encoder=image_encoder,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
self.control_video_processor = VideoProcessor( self.control_video_processor = VideoProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
......
...@@ -307,7 +307,7 @@ class AnimateDiffSDXLPipeline( ...@@ -307,7 +307,7 @@ class AnimateDiffSDXLPipeline(
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = self.unet.config.sample_size self.default_sample_size = self.unet.config.sample_size
......
...@@ -188,7 +188,7 @@ class AnimateDiffSparseControlNetPipeline( ...@@ -188,7 +188,7 @@ class AnimateDiffSparseControlNetPipeline(
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
image_encoder=image_encoder, image_encoder=image_encoder,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
self.control_image_processor = VaeImageProcessor( self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
......
...@@ -243,7 +243,7 @@ class AnimateDiffVideoToVideoPipeline( ...@@ -243,7 +243,7 @@ class AnimateDiffVideoToVideoPipeline(
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
image_encoder=image_encoder, image_encoder=image_encoder,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
def encode_prompt( def encode_prompt(
......
...@@ -270,7 +270,7 @@ class AnimateDiffVideoToVideoControlNetPipeline( ...@@ -270,7 +270,7 @@ class AnimateDiffVideoToVideoControlNetPipeline(
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
image_encoder=image_encoder, image_encoder=image_encoder,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
self.control_video_processor = VideoProcessor( self.control_video_processor = VideoProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
......
...@@ -94,7 +94,7 @@ class AudioLDMPipeline(DiffusionPipeline, StableDiffusionMixin): ...@@ -94,7 +94,7 @@ class AudioLDMPipeline(DiffusionPipeline, StableDiffusionMixin):
scheduler=scheduler, scheduler=scheduler,
vocoder=vocoder, vocoder=vocoder,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
def _encode_prompt( def _encode_prompt(
self, self,
......
...@@ -207,7 +207,7 @@ class AudioLDM2Pipeline(DiffusionPipeline): ...@@ -207,7 +207,7 @@ class AudioLDM2Pipeline(DiffusionPipeline):
scheduler=scheduler, scheduler=scheduler,
vocoder=vocoder, vocoder=vocoder,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
# Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_slicing # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_slicing
def enable_vae_slicing(self): def enable_vae_slicing(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment