Unverified Commit da843b3d authored by jquintanilla4's avatar jquintanilla4 Committed by GitHub
Browse files

.load_ip_adapter in StableDiffusionXLAdapterPipeline (#6246)



* Added testing notebook and .load_ip_adapter to XLAdapterPipeline

* Added annotations

* deleted testing notebook

* Update src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* code clean up

* Add feature_extractor and image_encoder to components

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 17cece07
...@@ -18,11 +18,22 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union ...@@ -18,11 +18,22 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import torch import torch
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers import (
CLIPImageProcessor,
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionModelWithProjection,
)
from ...image_processor import VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import (
from ...models import AutoencoderKL, MultiAdapter, T2IAdapter, UNet2DConditionModel FromSingleFileMixin,
IPAdapterMixin,
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
)
from ...models import AutoencoderKL, ImageProjection, MultiAdapter, T2IAdapter, UNet2DConditionModel
from ...models.attention_processor import ( from ...models.attention_processor import (
AttnProcessor2_0, AttnProcessor2_0,
LoRAAttnProcessor2_0, LoRAAttnProcessor2_0,
...@@ -169,7 +180,11 @@ def retrieve_timesteps( ...@@ -169,7 +180,11 @@ def retrieve_timesteps(
class StableDiffusionXLAdapterPipeline( class StableDiffusionXLAdapterPipeline(
DiffusionPipeline, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin DiffusionPipeline,
TextualInversionLoaderMixin,
StableDiffusionXLLoraLoaderMixin,
IPAdapterMixin,
FromSingleFileMixin,
): ):
r""" r"""
Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter
...@@ -183,6 +198,7 @@ class StableDiffusionXLAdapterPipeline( ...@@ -183,6 +198,7 @@ class StableDiffusionXLAdapterPipeline(
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args: Args:
adapter ([`T2IAdapter`] or [`MultiAdapter`] or `List[T2IAdapter]`): adapter ([`T2IAdapter`] or [`MultiAdapter`] or `List[T2IAdapter]`):
...@@ -211,8 +227,15 @@ class StableDiffusionXLAdapterPipeline( ...@@ -211,8 +227,15 @@ class StableDiffusionXLAdapterPipeline(
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] _optional_components = [
"tokenizer",
"tokenizer_2",
"text_encoder",
"text_encoder_2",
"feature_extractor",
"image_encoder",
]
def __init__( def __init__(
self, self,
...@@ -225,6 +248,8 @@ class StableDiffusionXLAdapterPipeline( ...@@ -225,6 +248,8 @@ class StableDiffusionXLAdapterPipeline(
adapter: Union[T2IAdapter, MultiAdapter, List[T2IAdapter]], adapter: Union[T2IAdapter, MultiAdapter, List[T2IAdapter]],
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
force_zeros_for_empty_prompt: bool = True, force_zeros_for_empty_prompt: bool = True,
feature_extractor: CLIPImageProcessor = None,
image_encoder: CLIPVisionModelWithProjection = None,
): ):
super().__init__() super().__init__()
...@@ -237,6 +262,8 @@ class StableDiffusionXLAdapterPipeline( ...@@ -237,6 +262,8 @@ class StableDiffusionXLAdapterPipeline(
unet=unet, unet=unet,
adapter=adapter, adapter=adapter,
scheduler=scheduler, scheduler=scheduler,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
) )
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
...@@ -511,6 +538,31 @@ class StableDiffusionXLAdapterPipeline( ...@@ -511,6 +538,31 @@ class StableDiffusionXLAdapterPipeline(
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
if output_hidden_states:
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta): def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
...@@ -768,7 +820,7 @@ class StableDiffusionXLAdapterPipeline( ...@@ -768,7 +820,7 @@ class StableDiffusionXLAdapterPipeline(
self, self,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None, prompt_2: Optional[Union[str, List[str]]] = None,
image: Union[torch.Tensor, PIL.Image.Image, List[PIL.Image.Image]] = None, image: PipelineImageInput = None,
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
num_inference_steps: int = 50, num_inference_steps: int = 50,
...@@ -785,6 +837,7 @@ class StableDiffusionXLAdapterPipeline( ...@@ -785,6 +837,7 @@ class StableDiffusionXLAdapterPipeline(
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
...@@ -876,6 +929,7 @@ class StableDiffusionXLAdapterPipeline( ...@@ -876,6 +929,7 @@ class StableDiffusionXLAdapterPipeline(
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument. input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
...@@ -991,7 +1045,7 @@ class StableDiffusionXLAdapterPipeline( ...@@ -991,7 +1045,7 @@ class StableDiffusionXLAdapterPipeline(
device = self._execution_device device = self._execution_device
# 3. Encode input prompt # 3.1 Encode input prompt
( (
prompt_embeds, prompt_embeds,
negative_prompt_embeds, negative_prompt_embeds,
...@@ -1012,6 +1066,15 @@ class StableDiffusionXLAdapterPipeline( ...@@ -1012,6 +1066,15 @@ class StableDiffusionXLAdapterPipeline(
clip_skip=clip_skip, clip_skip=clip_skip,
) )
# 3.2 Encode ip_adapter_image
if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
)
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. Prepare timesteps # 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
...@@ -1028,10 +1091,10 @@ class StableDiffusionXLAdapterPipeline( ...@@ -1028,10 +1091,10 @@ class StableDiffusionXLAdapterPipeline(
latents, latents,
) )
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 6.1 Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 6.5 Optionally get Guidance Scale Embedding # 6.2 Optionally get Guidance Scale Embedding
timestep_cond = None timestep_cond = None
if self.unet.config.time_cond_proj_dim is not None: if self.unet.config.time_cond_proj_dim is not None:
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
...@@ -1090,8 +1153,7 @@ class StableDiffusionXLAdapterPipeline( ...@@ -1090,8 +1153,7 @@ class StableDiffusionXLAdapterPipeline(
# 8. Denoising loop # 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
# Apply denoising_end
# 7.1 Apply denoising_end
if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
discrete_timestep_cutoff = int( discrete_timestep_cutoff = int(
round( round(
...@@ -1109,9 +1171,12 @@ class StableDiffusionXLAdapterPipeline( ...@@ -1109,9 +1171,12 @@ class StableDiffusionXLAdapterPipeline(
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
if ip_adapter_image is not None:
added_cond_kwargs["image_embeds"] = image_embeds
# predict the noise residual
if i < int(num_inference_steps * adapter_conditioning_factor): if i < int(num_inference_steps * adapter_conditioning_factor):
down_intrablock_additional_residuals = [state.clone() for state in adapter_state] down_intrablock_additional_residuals = [state.clone() for state in adapter_state]
else: else:
...@@ -1123,9 +1188,9 @@ class StableDiffusionXLAdapterPipeline( ...@@ -1123,9 +1188,9 @@ class StableDiffusionXLAdapterPipeline(
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond, timestep_cond=timestep_cond,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
added_cond_kwargs=added_cond_kwargs, added_cond_kwargs=added_cond_kwargs,
return_dict=False, return_dict=False,
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
)[0] )[0]
# perform guidance # perform guidance
......
...@@ -159,7 +159,8 @@ class StableDiffusionXLAdapterPipelineFastTests( ...@@ -159,7 +159,8 @@ class StableDiffusionXLAdapterPipelineFastTests(
"text_encoder_2": text_encoder_2, "text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2, "tokenizer_2": tokenizer_2,
# "safety_checker": None, # "safety_checker": None,
# "feature_extractor": None, "feature_extractor": None,
"image_encoder": None,
} }
return components return components
...@@ -265,7 +266,8 @@ class StableDiffusionXLAdapterPipelineFastTests( ...@@ -265,7 +266,8 @@ class StableDiffusionXLAdapterPipelineFastTests(
"text_encoder_2": text_encoder_2, "text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2, "tokenizer_2": tokenizer_2,
# "safety_checker": None, # "safety_checker": None,
# "feature_extractor": None, "feature_extractor": None,
"image_encoder": None,
} }
return components return components
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment