Unverified Commit 8bf046b7 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Add single file and IP Adapter support to PIA Pipeline (#6851)

update
parent bb99623d
...@@ -24,7 +24,7 @@ import torch.fft as fft ...@@ -24,7 +24,7 @@ import torch.fft as fft
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...image_processor import PipelineImageInput, VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...models.unets.unet_motion_model import MotionAdapter from ...models.unets.unet_motion_model import MotionAdapter
...@@ -209,7 +209,9 @@ class PIAPipelineOutput(BaseOutput): ...@@ -209,7 +209,9 @@ class PIAPipelineOutput(BaseOutput):
frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]] frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]]
class PIAPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin): class PIAPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin
):
r""" r"""
Pipeline for text-to-video generation. Pipeline for text-to-video generation.
...@@ -685,6 +687,35 @@ class PIAPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin ...@@ -685,6 +687,35 @@ class PIAPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin
f" {negative_prompt_embeds.shape}." f" {negative_prompt_embeds.shape}."
) )
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
if self.do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
return image_embeds
# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
def prepare_latents( def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
...@@ -1107,12 +1138,9 @@ class PIAPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin ...@@ -1107,12 +1138,9 @@ class PIAPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None: if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True image_embeds = self.prepare_ip_adapter_image_embeds(
image_embeds, negative_image_embeds = self.encode_image( ip_adapter_image, device, batch_size * num_videos_per_prompt
ip_adapter_image, device, num_videos_per_prompt, output_hidden_state
) )
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. Prepare timesteps # 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment