Unverified Commit 88bdd97c authored by Aryan V S's avatar Aryan V S Committed by GitHub
Browse files

IP adapter support for most pipelines (#5900)



* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py

* update tests

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py

* support ip-adapter in src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py

* support ip-adapter in src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py

* support ip-adapter in src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py

* revert changes to sd_attend_and_excite and sd_upscale

* make style

* fix broken tests

* update ip-adapter implementation to latest

* apply suggestions from review

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 08b453e3
...@@ -20,11 +20,11 @@ from typing import Any, Callable, Dict, List, Optional, Union ...@@ -20,11 +20,11 @@ from typing import Any, Callable, Dict, List, Optional, Union
import PIL.Image import PIL.Image
import torch import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import LCMScheduler from ...schedulers import LCMScheduler
from ...utils import ( from ...utils import (
...@@ -129,7 +129,7 @@ EXAMPLE_DOC_STRING = """ ...@@ -129,7 +129,7 @@ EXAMPLE_DOC_STRING = """
class LatentConsistencyModelImg2ImgPipeline( class LatentConsistencyModelImg2ImgPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin
): ):
r""" r"""
Pipeline for image-to-image generation using a latent consistency model. Pipeline for image-to-image generation using a latent consistency model.
...@@ -142,6 +142,7 @@ class LatentConsistencyModelImg2ImgPipeline( ...@@ -142,6 +142,7 @@ class LatentConsistencyModelImg2ImgPipeline(
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args: Args:
vae ([`AutoencoderKL`]): vae ([`AutoencoderKL`]):
...@@ -166,7 +167,7 @@ class LatentConsistencyModelImg2ImgPipeline( ...@@ -166,7 +167,7 @@ class LatentConsistencyModelImg2ImgPipeline(
""" """
model_cpu_offload_seq = "text_encoder->unet->vae" model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
_exclude_from_cpu_offload = ["safety_checker"] _exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "denoised", "prompt_embeds", "w_embedding"] _callback_tensor_inputs = ["latents", "denoised", "prompt_embeds", "w_embedding"]
...@@ -179,6 +180,7 @@ class LatentConsistencyModelImg2ImgPipeline( ...@@ -179,6 +180,7 @@ class LatentConsistencyModelImg2ImgPipeline(
scheduler: LCMScheduler, scheduler: LCMScheduler,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor, feature_extractor: CLIPImageProcessor,
image_encoder: Optional[CLIPVisionModelWithProjection] = None,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
...@@ -191,6 +193,7 @@ class LatentConsistencyModelImg2ImgPipeline( ...@@ -191,6 +193,7 @@ class LatentConsistencyModelImg2ImgPipeline(
scheduler=scheduler, scheduler=scheduler,
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
image_encoder=image_encoder,
) )
if safety_checker is None and requires_safety_checker: if safety_checker is None and requires_safety_checker:
...@@ -449,6 +452,31 @@ class LatentConsistencyModelImg2ImgPipeline( ...@@ -449,6 +452,31 @@ class LatentConsistencyModelImg2ImgPipeline(
return prompt_embeds, negative_prompt_embeds return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
if output_hidden_states:
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None: if self.safety_checker is None:
...@@ -647,6 +675,7 @@ class LatentConsistencyModelImg2ImgPipeline( ...@@ -647,6 +675,7 @@ class LatentConsistencyModelImg2ImgPipeline(
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
...@@ -695,6 +724,8 @@ class LatentConsistencyModelImg2ImgPipeline( ...@@ -695,6 +724,8 @@ class LatentConsistencyModelImg2ImgPipeline(
prompt_embeds (`torch.FloatTensor`, *optional*): prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument. provided, text embeddings are generated from the `prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`. The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
...@@ -758,6 +789,12 @@ class LatentConsistencyModelImg2ImgPipeline( ...@@ -758,6 +789,12 @@ class LatentConsistencyModelImg2ImgPipeline(
device = self._execution_device device = self._execution_device
# do_classifier_free_guidance = guidance_scale > 1.0 # do_classifier_free_guidance = guidance_scale > 1.0
if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
)
# 3. Encode input prompt # 3. Encode input prompt
lora_scale = ( lora_scale = (
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
...@@ -815,6 +852,9 @@ class LatentConsistencyModelImg2ImgPipeline( ...@@ -815,6 +852,9 @@ class LatentConsistencyModelImg2ImgPipeline(
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, None) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, None)
# 7.1 Add image embeds for IP-Adapter
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
# 8. LCM Multistep Sampling Loop # 8. LCM Multistep Sampling Loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
...@@ -829,6 +869,7 @@ class LatentConsistencyModelImg2ImgPipeline( ...@@ -829,6 +869,7 @@ class LatentConsistencyModelImg2ImgPipeline(
timestep_cond=w_embedding, timestep_cond=w_embedding,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=self.cross_attention_kwargs, cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
......
...@@ -19,11 +19,11 @@ import inspect ...@@ -19,11 +19,11 @@ import inspect
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...image_processor import VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import LCMScheduler from ...schedulers import LCMScheduler
from ...utils import ( from ...utils import (
...@@ -107,7 +107,7 @@ def retrieve_timesteps( ...@@ -107,7 +107,7 @@ def retrieve_timesteps(
class LatentConsistencyModelPipeline( class LatentConsistencyModelPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin
): ):
r""" r"""
Pipeline for text-to-image generation using a latent consistency model. Pipeline for text-to-image generation using a latent consistency model.
...@@ -120,6 +120,7 @@ class LatentConsistencyModelPipeline( ...@@ -120,6 +120,7 @@ class LatentConsistencyModelPipeline(
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args: Args:
vae ([`AutoencoderKL`]): vae ([`AutoencoderKL`]):
...@@ -144,7 +145,7 @@ class LatentConsistencyModelPipeline( ...@@ -144,7 +145,7 @@ class LatentConsistencyModelPipeline(
""" """
model_cpu_offload_seq = "text_encoder->unet->vae" model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
_exclude_from_cpu_offload = ["safety_checker"] _exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "denoised", "prompt_embeds", "w_embedding"] _callback_tensor_inputs = ["latents", "denoised", "prompt_embeds", "w_embedding"]
...@@ -157,6 +158,7 @@ class LatentConsistencyModelPipeline( ...@@ -157,6 +158,7 @@ class LatentConsistencyModelPipeline(
scheduler: LCMScheduler, scheduler: LCMScheduler,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor, feature_extractor: CLIPImageProcessor,
image_encoder: Optional[CLIPVisionModelWithProjection] = None,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
...@@ -185,6 +187,7 @@ class LatentConsistencyModelPipeline( ...@@ -185,6 +187,7 @@ class LatentConsistencyModelPipeline(
scheduler=scheduler, scheduler=scheduler,
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
image_encoder=image_encoder,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
...@@ -433,6 +436,31 @@ class LatentConsistencyModelPipeline( ...@@ -433,6 +436,31 @@ class LatentConsistencyModelPipeline(
return prompt_embeds, negative_prompt_embeds return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
if output_hidden_states:
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None: if self.safety_checker is None:
...@@ -581,6 +609,7 @@ class LatentConsistencyModelPipeline( ...@@ -581,6 +609,7 @@ class LatentConsistencyModelPipeline(
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
...@@ -629,6 +658,8 @@ class LatentConsistencyModelPipeline( ...@@ -629,6 +658,8 @@ class LatentConsistencyModelPipeline(
prompt_embeds (`torch.FloatTensor`, *optional*): prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument. provided, text embeddings are generated from the `prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`. The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
...@@ -697,6 +728,12 @@ class LatentConsistencyModelPipeline( ...@@ -697,6 +728,12 @@ class LatentConsistencyModelPipeline(
device = self._execution_device device = self._execution_device
# do_classifier_free_guidance = guidance_scale > 1.0 # do_classifier_free_guidance = guidance_scale > 1.0
if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
)
# 3. Encode input prompt # 3. Encode input prompt
lora_scale = ( lora_scale = (
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
...@@ -748,6 +785,9 @@ class LatentConsistencyModelPipeline( ...@@ -748,6 +785,9 @@ class LatentConsistencyModelPipeline(
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, None) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, None)
# 7.1 Add image embeds for IP-Adapter
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
# 8. LCM MultiStep Sampling Loop: # 8. LCM MultiStep Sampling Loop:
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
...@@ -762,6 +802,7 @@ class LatentConsistencyModelPipeline( ...@@ -762,6 +802,7 @@ class LatentConsistencyModelPipeline(
timestep_cond=w_embedding, timestep_cond=w_embedding,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=self.cross_attention_kwargs, cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
......
...@@ -18,11 +18,11 @@ from typing import Callable, Dict, List, Optional, Union ...@@ -18,11 +18,11 @@ from typing import Callable, Dict, List, Optional, Union
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import torch import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import PIL_INTERPOLATION, deprecate, logging from ...utils import PIL_INTERPOLATION, deprecate, logging
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
...@@ -72,7 +72,9 @@ def retrieve_latents( ...@@ -72,7 +72,9 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output") raise AttributeError("Could not access latents of provided encoder_output")
class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): class StableDiffusionInstructPix2PixPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin
):
r""" r"""
Pipeline for pixel-level image editing by following text instructions (based on Stable Diffusion). Pipeline for pixel-level image editing by following text instructions (based on Stable Diffusion).
...@@ -83,6 +85,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion ...@@ -83,6 +85,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args: Args:
vae ([`AutoencoderKL`]): vae ([`AutoencoderKL`]):
...@@ -105,7 +108,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion ...@@ -105,7 +108,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
""" """
model_cpu_offload_seq = "text_encoder->unet->vae" model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
_exclude_from_cpu_offload = ["safety_checker"] _exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "image_latents"] _callback_tensor_inputs = ["latents", "prompt_embeds", "image_latents"]
...@@ -118,6 +121,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion ...@@ -118,6 +121,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor, feature_extractor: CLIPImageProcessor,
image_encoder: Optional[CLIPVisionModelWithProjection] = None,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
...@@ -146,6 +150,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion ...@@ -146,6 +150,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
scheduler=scheduler, scheduler=scheduler,
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
image_encoder=image_encoder,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
...@@ -166,6 +171,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion ...@@ -166,6 +171,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
...@@ -213,6 +219,8 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion ...@@ -213,6 +219,8 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
negative_prompt_embeds (`torch.FloatTensor`, *optional*): negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`. The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
...@@ -293,6 +301,16 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion ...@@ -293,6 +301,16 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
self._image_guidance_scale = image_guidance_scale self._image_guidance_scale = image_guidance_scale
device = self._execution_device
if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
)
if self.do_classifier_free_guidance:
image_embeds = torch.cat([image_embeds, negative_image_embeds, negative_image_embeds])
if image is None: if image is None:
raise ValueError("`image` input cannot be undefined.") raise ValueError("`image` input cannot be undefined.")
...@@ -367,6 +385,9 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion ...@@ -367,6 +385,9 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 8.1 Add image embeds for IP-Adapter
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
# 9. Denoising loop # 9. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
...@@ -383,7 +404,11 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion ...@@ -383,7 +404,11 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
# predict the noise residual # predict the noise residual
noise_pred = self.unet( noise_pred = self.unet(
scaled_latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False scaled_latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0] )[0]
# Hack: # Hack:
...@@ -603,6 +628,31 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion ...@@ -603,6 +628,31 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
return prompt_embeds return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
if output_hidden_states:
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None: if self.safety_checker is None:
......
...@@ -19,11 +19,11 @@ from typing import Any, Callable, Dict, List, Optional, Union ...@@ -19,11 +19,11 @@ from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import torch import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...image_processor import VaeImageProcessorLDM3D from ...image_processor import PipelineImageInput, VaeImageProcessorLDM3D
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
...@@ -82,7 +82,7 @@ class LDM3DPipelineOutput(BaseOutput): ...@@ -82,7 +82,7 @@ class LDM3DPipelineOutput(BaseOutput):
class StableDiffusionLDM3DPipeline( class StableDiffusionLDM3DPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin
): ):
r""" r"""
Pipeline for text-to-image and 3D generation using LDM3D. Pipeline for text-to-image and 3D generation using LDM3D.
...@@ -95,6 +95,7 @@ class StableDiffusionLDM3DPipeline( ...@@ -95,6 +95,7 @@ class StableDiffusionLDM3DPipeline(
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args: Args:
vae ([`AutoencoderKL`]): vae ([`AutoencoderKL`]):
...@@ -117,7 +118,7 @@ class StableDiffusionLDM3DPipeline( ...@@ -117,7 +118,7 @@ class StableDiffusionLDM3DPipeline(
""" """
model_cpu_offload_seq = "text_encoder->unet->vae" model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
_exclude_from_cpu_offload = ["safety_checker"] _exclude_from_cpu_offload = ["safety_checker"]
def __init__( def __init__(
...@@ -129,6 +130,7 @@ class StableDiffusionLDM3DPipeline( ...@@ -129,6 +130,7 @@ class StableDiffusionLDM3DPipeline(
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor, feature_extractor: CLIPImageProcessor,
image_encoder: Optional[CLIPVisionModelWithProjection],
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
...@@ -157,6 +159,7 @@ class StableDiffusionLDM3DPipeline( ...@@ -157,6 +159,7 @@ class StableDiffusionLDM3DPipeline(
scheduler=scheduler, scheduler=scheduler,
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
image_encoder=image_encoder,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessorLDM3D(vae_scale_factor=self.vae_scale_factor) self.image_processor = VaeImageProcessorLDM3D(vae_scale_factor=self.vae_scale_factor)
...@@ -410,6 +413,31 @@ class StableDiffusionLDM3DPipeline( ...@@ -410,6 +413,31 @@ class StableDiffusionLDM3DPipeline(
return prompt_embeds, negative_prompt_embeds return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
if output_hidden_states:
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None: if self.safety_checker is None:
has_nsfw_concept = None has_nsfw_concept = None
...@@ -529,6 +557,7 @@ class StableDiffusionLDM3DPipeline( ...@@ -529,6 +557,7 @@ class StableDiffusionLDM3DPipeline(
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
...@@ -573,6 +602,8 @@ class StableDiffusionLDM3DPipeline( ...@@ -573,6 +602,8 @@ class StableDiffusionLDM3DPipeline(
negative_prompt_embeds (`torch.FloatTensor`, *optional*): negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`. The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
...@@ -622,6 +653,14 @@ class StableDiffusionLDM3DPipeline( ...@@ -622,6 +653,14 @@ class StableDiffusionLDM3DPipeline(
# corresponds to doing no classifier free guidance. # corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
)
if do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 3. Encode input prompt # 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt, prompt,
...@@ -659,6 +698,9 @@ class StableDiffusionLDM3DPipeline( ...@@ -659,6 +698,9 @@ class StableDiffusionLDM3DPipeline(
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 6.1 Add image embeds for IP-Adapter
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
# 7. Denoising loop # 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
...@@ -673,6 +715,7 @@ class StableDiffusionLDM3DPipeline( ...@@ -673,6 +715,7 @@ class StableDiffusionLDM3DPipeline(
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
......
...@@ -16,11 +16,11 @@ import inspect ...@@ -16,11 +16,11 @@ import inspect
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...image_processor import VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import DDIMScheduler from ...schedulers import DDIMScheduler
from ...utils import ( from ...utils import (
...@@ -59,13 +59,19 @@ EXAMPLE_DOC_STRING = """ ...@@ -59,13 +59,19 @@ EXAMPLE_DOC_STRING = """
""" """
class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin):
r""" r"""
Pipeline for text-to-image generation using MultiDiffusion. Pipeline for text-to-image generation using MultiDiffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.). implemented for all pipelines (downloading, saving, running on a particular device, etc.).
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args: Args:
vae ([`AutoencoderKL`]): vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
...@@ -87,7 +93,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -87,7 +93,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
""" """
model_cpu_offload_seq = "text_encoder->unet->vae" model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
_exclude_from_cpu_offload = ["safety_checker"] _exclude_from_cpu_offload = ["safety_checker"]
def __init__( def __init__(
...@@ -99,6 +105,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -99,6 +105,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
scheduler: DDIMScheduler, scheduler: DDIMScheduler,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor, feature_extractor: CLIPImageProcessor,
image_encoder: Optional[CLIPVisionModelWithProjection] = None,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
...@@ -127,6 +134,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -127,6 +134,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
scheduler=scheduler, scheduler=scheduler,
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
image_encoder=image_encoder,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
...@@ -363,6 +371,31 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -363,6 +371,31 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
return prompt_embeds, negative_prompt_embeds return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
if output_hidden_states:
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None: if self.safety_checker is None:
...@@ -529,6 +562,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -529,6 +562,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
...@@ -578,6 +612,8 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -578,6 +612,8 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
negative_prompt_embeds (`torch.FloatTensor`, *optional*): negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`. The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
...@@ -632,6 +668,14 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -632,6 +668,14 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
# corresponds to doing no classifier free guidance. # corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
)
if do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 3. Encode input prompt # 3. Encode input prompt
text_encoder_lora_scale = ( text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
...@@ -681,6 +725,9 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -681,6 +725,9 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7.1 Add image embeds for IP-Adapter
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
# 8. Denoising loop # 8. Denoising loop
# Each denoising step also includes refinement of the latents with respect to the # Each denoising step also includes refinement of the latents with respect to the
# views. # views.
...@@ -743,6 +790,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -743,6 +790,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
t, t,
encoder_hidden_states=prompt_embeds_input, encoder_hidden_states=prompt_embeds_input,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
).sample ).sample
# perform guidance # perform guidance
......
...@@ -17,11 +17,11 @@ from typing import Any, Callable, Dict, List, Optional, Union ...@@ -17,11 +17,11 @@ from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...image_processor import VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
...@@ -98,13 +98,17 @@ class CrossAttnStoreProcessor: ...@@ -98,13 +98,17 @@ class CrossAttnStoreProcessor:
# Modified to get self-attention guidance scale in this paper (https://arxiv.org/pdf/2210.00939.pdf) as an input # Modified to get self-attention guidance scale in this paper (https://arxiv.org/pdf/2210.00939.pdf) as an input
class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin): class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin):
r""" r"""
Pipeline for text-to-image generation using Stable Diffusion. Pipeline for text-to-image generation using Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.). implemented for all pipelines (downloading, saving, running on a particular device, etc.).
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args: Args:
vae ([`AutoencoderKL`]): vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
...@@ -126,7 +130,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin) ...@@ -126,7 +130,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
""" """
model_cpu_offload_seq = "text_encoder->unet->vae" model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
_exclude_from_cpu_offload = ["safety_checker"] _exclude_from_cpu_offload = ["safety_checker"]
def __init__( def __init__(
...@@ -138,6 +142,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin) ...@@ -138,6 +142,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor, feature_extractor: CLIPImageProcessor,
image_encoder: Optional[CLIPVisionModelWithProjection] = None,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
...@@ -150,6 +155,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin) ...@@ -150,6 +155,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
scheduler=scheduler, scheduler=scheduler,
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
image_encoder=image_encoder,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
...@@ -386,6 +392,31 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin) ...@@ -386,6 +392,31 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
return prompt_embeds, negative_prompt_embeds return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
if output_hidden_states:
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None: if self.safety_checker is None:
...@@ -519,6 +550,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin) ...@@ -519,6 +550,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
...@@ -565,6 +597,8 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin) ...@@ -565,6 +597,8 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
negative_prompt_embeds (`torch.FloatTensor`, *optional*): negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`. The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
...@@ -618,6 +652,14 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin) ...@@ -618,6 +652,14 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
# `sag_scale = 0` means no self-attention guidance # `sag_scale = 0` means no self-attention guidance
do_self_attention_guidance = sag_scale > 0.0 do_self_attention_guidance = sag_scale > 0.0
if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
)
if do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 3. Encode input prompt # 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt, prompt,
...@@ -655,6 +697,10 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin) ...@@ -655,6 +697,10 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 6.1 Add image embeds for IP-Adapter
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
added_uncond_kwargs = {"image_embeds": negative_image_embeds} if ip_adapter_image is not None else None
# 7. Denoising loop # 7. Denoising loop
store_processor = CrossAttnStoreProcessor() store_processor = CrossAttnStoreProcessor()
self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor
...@@ -680,6 +726,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin) ...@@ -680,6 +726,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
).sample ).sample
# perform guidance # perform guidance
...@@ -703,7 +750,12 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin) ...@@ -703,7 +750,12 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
) )
uncond_emb, _ = prompt_embeds.chunk(2) uncond_emb, _ = prompt_embeds.chunk(2)
# forward and give guidance # forward and give guidance
degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=uncond_emb).sample degraded_pred = self.unet(
degraded_latents,
t,
encoder_hidden_states=uncond_emb,
added_cond_kwargs=added_uncond_kwargs,
).sample
noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) noise_pred += sag_scale * (noise_pred_uncond - degraded_pred)
else: else:
# DDIM-like prediction of x0 # DDIM-like prediction of x0
...@@ -715,7 +767,12 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin) ...@@ -715,7 +767,12 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
pred_x0, cond_attn, map_size, t, self.pred_epsilon(latents, noise_pred, t) pred_x0, cond_attn, map_size, t, self.pred_epsilon(latents, noise_pred, t)
) )
# forward and give guidance # forward and give guidance
degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=prompt_embeds).sample degraded_pred = self.unet(
degraded_latents,
t,
encoder_hidden_states=prompt_embeds,
added_cond_kwargs=added_cond_kwargs,
).sample
noise_pred += sag_scale * (noise_pred - degraded_pred) noise_pred += sag_scale * (noise_pred - degraded_pred)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
......
...@@ -5,10 +5,12 @@ from typing import Callable, List, Optional, Union ...@@ -5,10 +5,12 @@ from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
from packaging import version from packaging import version
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel from ...image_processor import PipelineImageInput
from ...loaders import IPAdapterMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging from ...utils import deprecate, logging
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
...@@ -20,13 +22,16 @@ from .safety_checker import SafeStableDiffusionSafetyChecker ...@@ -20,13 +22,16 @@ from .safety_checker import SafeStableDiffusionSafetyChecker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class StableDiffusionPipelineSafe(DiffusionPipeline): class StableDiffusionPipelineSafe(DiffusionPipeline, IPAdapterMixin):
r""" r"""
Pipeline based on the [`StableDiffusionPipeline`] for text-to-image generation using Safe Latent Diffusion. Pipeline based on the [`StableDiffusionPipeline`] for text-to-image generation using Safe Latent Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.). implemented for all pipelines (downloading, saving, running on a particular device, etc.).
The pipeline also inherits the following loading methods:
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args: Args:
vae ([`AutoencoderKL`]): vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
...@@ -48,7 +53,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): ...@@ -48,7 +53,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
""" """
model_cpu_offload_seq = "text_encoder->unet->vae" model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
def __init__( def __init__(
self, self,
...@@ -59,6 +64,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): ...@@ -59,6 +64,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
safety_checker: SafeStableDiffusionSafetyChecker, safety_checker: SafeStableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor, feature_extractor: CLIPImageProcessor,
image_encoder: Optional[CLIPVisionModelWithProjection] = None,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
...@@ -140,6 +146,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): ...@@ -140,6 +146,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
scheduler=scheduler, scheduler=scheduler,
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
image_encoder=image_encoder,
) )
self._safety_text_concept = safety_concept self._safety_text_concept = safety_concept
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
...@@ -467,6 +474,31 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): ...@@ -467,6 +474,31 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
noise_guidance = noise_guidance - noise_guidance_safety noise_guidance = noise_guidance - noise_guidance_safety
return noise_guidance, safety_momentum return noise_guidance, safety_momentum
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
if output_hidden_states:
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,
...@@ -480,6 +512,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): ...@@ -480,6 +512,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
eta: float = 0.0, eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
...@@ -521,6 +554,8 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): ...@@ -521,6 +554,8 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`. tensor is generated by sampling using the supplied random `generator`.
ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`. The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
...@@ -588,6 +623,17 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): ...@@ -588,6 +623,17 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
if not enable_safety_guidance: if not enable_safety_guidance:
warnings.warn("Safety checker disabled!") warnings.warn("Safety checker disabled!")
if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
)
if do_classifier_free_guidance:
if enable_safety_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds, image_embeds])
else:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 3. Encode input prompt # 3. Encode input prompt
prompt_embeds = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, enable_safety_guidance prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, enable_safety_guidance
...@@ -613,6 +659,9 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): ...@@ -613,6 +659,9 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
# 6. Prepare extra step kwargs. # 6. Prepare extra step kwargs.
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 6.1 Add image embeds for IP-Adapter
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
safety_momentum = None safety_momentum = None
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
...@@ -627,7 +676,9 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): ...@@ -627,7 +676,9 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual # predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample noise_pred = self.unet(
latent_model_input, t, encoder_hidden_states=prompt_embeds, added_cond_kwargs=added_cond_kwargs
).sample
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
......
...@@ -87,6 +87,7 @@ class LatentConsistencyModelPipelineFastTests(PipelineLatentTesterMixin, Pipelin ...@@ -87,6 +87,7 @@ class LatentConsistencyModelPipelineFastTests(PipelineLatentTesterMixin, Pipelin
"tokenizer": tokenizer, "tokenizer": tokenizer,
"safety_checker": None, "safety_checker": None,
"feature_extractor": None, "feature_extractor": None,
"image_encoder": None,
"requires_safety_checker": False, "requires_safety_checker": False,
} }
return components return components
......
...@@ -97,6 +97,7 @@ class LatentConsistencyModelImg2ImgPipelineFastTests( ...@@ -97,6 +97,7 @@ class LatentConsistencyModelImg2ImgPipelineFastTests(
"tokenizer": tokenizer, "tokenizer": tokenizer,
"safety_checker": None, "safety_checker": None,
"feature_extractor": None, "feature_extractor": None,
"image_encoder": None,
"requires_safety_checker": False, "requires_safety_checker": False,
} }
return components return components
......
...@@ -108,6 +108,7 @@ class StableDiffusionInstructPix2PixPipelineFastTests( ...@@ -108,6 +108,7 @@ class StableDiffusionInstructPix2PixPipelineFastTests(
"tokenizer": tokenizer, "tokenizer": tokenizer,
"safety_checker": None, "safety_checker": None,
"feature_extractor": None, "feature_extractor": None,
"image_encoder": None,
} }
return components return components
......
...@@ -93,6 +93,7 @@ class StableDiffusionLDM3DPipelineFastTests(unittest.TestCase): ...@@ -93,6 +93,7 @@ class StableDiffusionLDM3DPipelineFastTests(unittest.TestCase):
"tokenizer": tokenizer, "tokenizer": tokenizer,
"safety_checker": None, "safety_checker": None,
"feature_extractor": None, "feature_extractor": None,
"image_encoder": None,
} }
return components return components
......
...@@ -91,6 +91,7 @@ class StableDiffusionPanoramaPipelineFastTests(PipelineLatentTesterMixin, Pipeli ...@@ -91,6 +91,7 @@ class StableDiffusionPanoramaPipelineFastTests(PipelineLatentTesterMixin, Pipeli
"tokenizer": tokenizer, "tokenizer": tokenizer,
"safety_checker": None, "safety_checker": None,
"feature_extractor": None, "feature_extractor": None,
"image_encoder": None,
} }
return components return components
......
...@@ -93,6 +93,7 @@ class StableDiffusionSAGPipelineFastTests(PipelineLatentTesterMixin, PipelineTes ...@@ -93,6 +93,7 @@ class StableDiffusionSAGPipelineFastTests(PipelineLatentTesterMixin, PipelineTes
"tokenizer": tokenizer, "tokenizer": tokenizer,
"safety_checker": None, "safety_checker": None,
"feature_extractor": None, "feature_extractor": None,
"image_encoder": None,
} }
return components return components
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment