Unverified Commit 06b01ea8 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

[ip-adapter] refactor `prepare_ip_adapter_image_embeds` and skip load `image_encoder` (#7016)



* add
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent f4fc7503
...@@ -519,7 +519,7 @@ class StableDiffusionXLPipeline( ...@@ -519,7 +519,7 @@ class StableDiffusionXLPipeline(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds( def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
): ):
if ip_adapter_image_embeds is None: if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list): if not isinstance(ip_adapter_image, list):
...@@ -543,13 +543,23 @@ class StableDiffusionXLPipeline( ...@@ -543,13 +543,23 @@ class StableDiffusionXLPipeline(
[single_negative_image_embeds] * num_images_per_prompt, dim=0 [single_negative_image_embeds] * num_images_per_prompt, dim=0
) )
if self.do_classifier_free_guidance: if do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device) single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
else: else:
image_embeds = ip_adapter_image_embeds image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
image_embeds.append(single_image_embeds)
return image_embeds return image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
...@@ -656,6 +666,16 @@ class StableDiffusionXLPipeline( ...@@ -656,6 +666,16 @@ class StableDiffusionXLPipeline(
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
) )
if ip_adapter_image_embeds is not None:
if not isinstance(ip_adapter_image_embeds, list):
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
elif ip_adapter_image_embeds[0].ndim != 3:
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
...@@ -890,8 +910,10 @@ class StableDiffusionXLPipeline( ...@@ -890,8 +910,10 @@ class StableDiffusionXLPipeline(
input argument. input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*): ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. If not Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
provided, embeddings are computed from the `ip_adapter_image` input argument. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
if `do_classifier_free_guidance` is set to `True`.
If not provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
...@@ -1093,7 +1115,11 @@ class StableDiffusionXLPipeline( ...@@ -1093,7 +1115,11 @@ class StableDiffusionXLPipeline(
if ip_adapter_image is not None or ip_adapter_image_embeds is not None: if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds( image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt ip_adapter_image,
ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
self.do_classifier_free_guidance,
) )
# 8. Denoising loop # 8. Denoising loop
......
...@@ -611,6 +611,16 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -611,6 +611,16 @@ class StableDiffusionXLImg2ImgPipeline(
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
) )
if ip_adapter_image_embeds is not None:
if not isinstance(ip_adapter_image_embeds, list):
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
elif ip_adapter_image_embeds[0].ndim != 3:
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)
def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
# get the original timestep using init_timestep # get the original timestep using init_timestep
if denoising_start is None: if denoising_start is None:
...@@ -742,7 +752,7 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -742,7 +752,7 @@ class StableDiffusionXLImg2ImgPipeline(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds( def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
): ):
if ip_adapter_image_embeds is None: if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list): if not isinstance(ip_adapter_image, list):
...@@ -766,13 +776,23 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -766,13 +776,23 @@ class StableDiffusionXLImg2ImgPipeline(
[single_negative_image_embeds] * num_images_per_prompt, dim=0 [single_negative_image_embeds] * num_images_per_prompt, dim=0
) )
if self.do_classifier_free_guidance: if do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device) single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
else: else:
image_embeds = ip_adapter_image_embeds image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
image_embeds.append(single_image_embeds)
return image_embeds return image_embeds
def _get_add_time_ids( def _get_add_time_ids(
...@@ -1038,8 +1058,10 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -1038,8 +1058,10 @@ class StableDiffusionXLImg2ImgPipeline(
input argument. input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*): ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. If not Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
provided, embeddings are computed from the `ip_adapter_image` input argument. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
if `do_classifier_free_guidance` is set to `True`.
If not provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
...@@ -1264,7 +1286,11 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -1264,7 +1286,11 @@ class StableDiffusionXLImg2ImgPipeline(
if ip_adapter_image is not None or ip_adapter_image_embeds is not None: if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds( image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt ip_adapter_image,
ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
self.do_classifier_free_guidance,
) )
# 9. Denoising loop # 9. Denoising loop
......
...@@ -456,7 +456,7 @@ class StableDiffusionXLInpaintPipeline( ...@@ -456,7 +456,7 @@ class StableDiffusionXLInpaintPipeline(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds( def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
): ):
if ip_adapter_image_embeds is None: if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list): if not isinstance(ip_adapter_image, list):
...@@ -480,13 +480,23 @@ class StableDiffusionXLInpaintPipeline( ...@@ -480,13 +480,23 @@ class StableDiffusionXLInpaintPipeline(
[single_negative_image_embeds] * num_images_per_prompt, dim=0 [single_negative_image_embeds] * num_images_per_prompt, dim=0
) )
if self.do_classifier_free_guidance: if do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device) single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
else: else:
image_embeds = ip_adapter_image_embeds image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
image_embeds.append(single_image_embeds)
return image_embeds return image_embeds
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
...@@ -836,6 +846,16 @@ class StableDiffusionXLInpaintPipeline( ...@@ -836,6 +846,16 @@ class StableDiffusionXLInpaintPipeline(
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
) )
if ip_adapter_image_embeds is not None:
if not isinstance(ip_adapter_image_embeds, list):
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
elif ip_adapter_image_embeds[0].ndim != 3:
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)
def prepare_latents( def prepare_latents(
self, self,
batch_size, batch_size,
...@@ -1290,8 +1310,10 @@ class StableDiffusionXLInpaintPipeline( ...@@ -1290,8 +1310,10 @@ class StableDiffusionXLInpaintPipeline(
input argument. input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*): ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. If not Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
provided, embeddings are computed from the `ip_adapter_image` input argument. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
if `do_classifier_free_guidance` is set to `True`.
If not provided, embeddings are computed from the `ip_adapter_image` input argument.
num_images_per_prompt (`int`, *optional*, defaults to 1): num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt. The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0): eta (`float`, *optional*, defaults to 0.0):
...@@ -1612,7 +1634,11 @@ class StableDiffusionXLInpaintPipeline( ...@@ -1612,7 +1634,11 @@ class StableDiffusionXLInpaintPipeline(
if ip_adapter_image is not None or ip_adapter_image_embeds is not None: if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds( image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt ip_adapter_image,
ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
self.do_classifier_free_guidance,
) )
# 11. Denoising loop # 11. Denoising loop
......
...@@ -533,7 +533,7 @@ class StableDiffusionXLAdapterPipeline( ...@@ -533,7 +533,7 @@ class StableDiffusionXLAdapterPipeline(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds( def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
): ):
if ip_adapter_image_embeds is None: if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list): if not isinstance(ip_adapter_image, list):
...@@ -557,13 +557,23 @@ class StableDiffusionXLAdapterPipeline( ...@@ -557,13 +557,23 @@ class StableDiffusionXLAdapterPipeline(
[single_negative_image_embeds] * num_images_per_prompt, dim=0 [single_negative_image_embeds] * num_images_per_prompt, dim=0
) )
if self.do_classifier_free_guidance: if do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device) single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
else: else:
image_embeds = ip_adapter_image_embeds image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
image_embeds.append(single_image_embeds)
return image_embeds return image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
...@@ -671,6 +681,16 @@ class StableDiffusionXLAdapterPipeline( ...@@ -671,6 +681,16 @@ class StableDiffusionXLAdapterPipeline(
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
) )
if ip_adapter_image_embeds is not None:
if not isinstance(ip_adapter_image_embeds, list):
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
elif ip_adapter_image_embeds[0].ndim != 3:
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
...@@ -914,8 +934,10 @@ class StableDiffusionXLAdapterPipeline( ...@@ -914,8 +934,10 @@ class StableDiffusionXLAdapterPipeline(
input argument. input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*): ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. If not Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
provided, embeddings are computed from the `ip_adapter_image` input argument. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
if `do_classifier_free_guidance` is set to `True`.
If not provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
...@@ -1057,7 +1079,11 @@ class StableDiffusionXLAdapterPipeline( ...@@ -1057,7 +1079,11 @@ class StableDiffusionXLAdapterPipeline(
# 3.2 Encode ip_adapter_image # 3.2 Encode ip_adapter_image
if ip_adapter_image is not None or ip_adapter_image_embeds is not None: if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds( image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt ip_adapter_image,
ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
self.do_classifier_free_guidance,
) )
# 4. Prepare timesteps # 4. Prepare timesteps
......
...@@ -319,6 +319,35 @@ class IPAdapterTesterMixin: ...@@ -319,6 +319,35 @@ class IPAdapterTesterMixin:
"Output with multi-ip-adapter scale must be different from normal inference", "Output with multi-ip-adapter scale must be different from normal inference",
) )
def test_ip_adapter_cfg(self, expected_max_diff: float = 1e-4):
parameters = inspect.signature(self.pipeline_class.__call__).parameters
if "guidance_scale" not in parameters:
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components).to(torch_device)
pipe.set_progress_bar_config(disable=None)
cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32)
adapter_state_dict = create_ip_adapter_state_dict(pipe.unet)
pipe.unet._load_ip_adapter_weights(adapter_state_dict)
pipe.set_ip_adapter_scale(1.0)
# forward pass with CFG not applied
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)[0].unsqueeze(0)]
inputs["guidance_scale"] = 1.0
out_no_cfg = pipe(**inputs)[0]
# forward pass with CFG applied
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
inputs["guidance_scale"] = 7.5
out_cfg = pipe(**inputs)[0]
assert out_cfg.shape == out_no_cfg.shape
class PipelineLatentTesterMixin: class PipelineLatentTesterMixin:
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment