"vscode:/vscode.git/clone" did not exist on "cc2205832443176fb4c1a9b02f21929b67846fbe"
Unverified Commit 001b1402 authored by Álvaro Somoza's avatar Álvaro Somoza Committed by GitHub
Browse files

[ip-adapter] fix problem using embeds with the plus version of ip adapters (#7189)



* initial

* check_inputs fix to the rest of pipelines

* add fix for no cfg too

* use of variable

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent f55873b7
......@@ -400,15 +400,22 @@ class AnimateDiffPipeline(
image_embeds.append(single_image_embeds)
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds)
return image_embeds
......@@ -509,9 +516,9 @@ class AnimateDiffPipeline(
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
elif ip_adapter_image_embeds[0].ndim != 3:
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)
# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
......
......@@ -478,15 +478,22 @@ class AnimateDiffVideoToVideoPipeline(
image_embeds.append(single_image_embeds)
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds)
return image_embeds
......@@ -589,9 +596,9 @@ class AnimateDiffVideoToVideoPipeline(
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
elif ip_adapter_image_embeds[0].ndim != 3:
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)
def get_timesteps(self, num_inference_steps, timesteps, strength, device):
......
......@@ -510,15 +510,22 @@ class StableDiffusionControlNetPipeline(
image_embeds.append(single_image_embeds)
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds)
return image_embeds
......@@ -726,9 +733,9 @@ class StableDiffusionControlNetPipeline(
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
elif ip_adapter_image_embeds[0].ndim != 3:
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)
def check_image(self, image, prompt, prompt_embeds):
......
......@@ -503,15 +503,22 @@ class StableDiffusionControlNetImg2ImgPipeline(
image_embeds.append(single_image_embeds)
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds)
return image_embeds
......@@ -713,9 +720,9 @@ class StableDiffusionControlNetImg2ImgPipeline(
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
elif ip_adapter_image_embeds[0].ndim != 3:
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
......
......@@ -628,15 +628,22 @@ class StableDiffusionControlNetInpaintPipeline(
image_embeds.append(single_image_embeds)
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds)
return image_embeds
......@@ -871,9 +878,9 @@ class StableDiffusionControlNetInpaintPipeline(
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
elif ip_adapter_image_embeds[0].ndim != 3:
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
......
......@@ -537,15 +537,22 @@ class StableDiffusionXLControlNetInpaintPipeline(
image_embeds.append(single_image_embeds)
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds)
return image_embeds
......@@ -817,9 +824,9 @@ class StableDiffusionXLControlNetInpaintPipeline(
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
elif ip_adapter_image_embeds[0].ndim != 3:
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)
def prepare_control_image(
......
......@@ -515,15 +515,22 @@ class StableDiffusionXLControlNetPipeline(
image_embeds.append(single_image_embeds)
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds)
return image_embeds
......@@ -730,9 +737,9 @@ class StableDiffusionXLControlNetPipeline(
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
elif ip_adapter_image_embeds[0].ndim != 3:
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
......
......@@ -567,15 +567,22 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
image_embeds.append(single_image_embeds)
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds)
return image_embeds
......@@ -794,9 +801,9 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
elif ip_adapter_image_embeds[0].ndim != 3:
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_image
......
......@@ -453,15 +453,22 @@ class LatentConsistencyModelImg2ImgPipeline(
image_embeds.append(single_image_embeds)
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds)
return image_embeds
......@@ -647,9 +654,9 @@ class LatentConsistencyModelImg2ImgPipeline(
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
elif ip_adapter_image_embeds[0].ndim != 3:
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)
@property
......
......@@ -437,15 +437,22 @@ class LatentConsistencyModelPipeline(
image_embeds.append(single_image_embeds)
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds)
return image_embeds
......@@ -579,9 +586,9 @@ class LatentConsistencyModelPipeline(
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
elif ip_adapter_image_embeds[0].ndim != 3:
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)
@property
......
......@@ -582,9 +582,9 @@ class PIAPipeline(
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
elif ip_adapter_image_embeds[0].ndim != 3:
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
......@@ -619,15 +619,22 @@ class PIAPipeline(
image_embeds.append(single_image_embeds)
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds)
return image_embeds
......
......@@ -520,15 +520,22 @@ class StableDiffusionPipeline(
image_embeds.append(single_image_embeds)
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds)
return image_embeds
......@@ -639,9 +646,9 @@ class StableDiffusionPipeline(
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
elif ip_adapter_image_embeds[0].ndim != 3:
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
......
......@@ -564,15 +564,22 @@ class StableDiffusionImg2ImgPipeline(
image_embeds.append(single_image_embeds)
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds)
return image_embeds
......@@ -685,9 +692,9 @@ class StableDiffusionImg2ImgPipeline(
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
elif ip_adapter_image_embeds[0].ndim != 3:
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)
def get_timesteps(self, num_inference_steps, strength, device):
......
......@@ -636,15 +636,22 @@ class StableDiffusionInpaintPipeline(
image_embeds.append(single_image_embeds)
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds)
return image_embeds
......@@ -767,9 +774,9 @@ class StableDiffusionInpaintPipeline(
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
elif ip_adapter_image_embeds[0].ndim != 3:
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)
def prepare_latents(
......
......@@ -442,15 +442,22 @@ class StableDiffusionLDM3DPipeline(
image_embeds.append(single_image_embeds)
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds)
return image_embeds
......@@ -553,9 +560,9 @@ class StableDiffusionLDM3DPipeline(
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
elif ip_adapter_image_embeds[0].ndim != 3:
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
......
......@@ -414,15 +414,22 @@ class StableDiffusionPanoramaPipeline(
image_embeds.append(single_image_embeds)
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds)
return image_embeds
......@@ -550,9 +557,9 @@ class StableDiffusionPanoramaPipeline(
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
elif ip_adapter_image_embeds[0].ndim != 3:
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
......
......@@ -549,15 +549,22 @@ class StableDiffusionXLPipeline(
image_embeds.append(single_image_embeds)
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds)
return image_embeds
......@@ -671,9 +678,9 @@ class StableDiffusionXLPipeline(
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
elif ip_adapter_image_embeds[0].ndim != 3:
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
......
......@@ -616,9 +616,9 @@ class StableDiffusionXLImg2ImgPipeline(
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
elif ip_adapter_image_embeds[0].ndim != 3:
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)
def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
......@@ -782,15 +782,22 @@ class StableDiffusionXLImg2ImgPipeline(
image_embeds.append(single_image_embeds)
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds)
return image_embeds
......
......@@ -486,15 +486,22 @@ class StableDiffusionXLInpaintPipeline(
image_embeds.append(single_image_embeds)
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds)
return image_embeds
......@@ -851,9 +858,9 @@ class StableDiffusionXLInpaintPipeline(
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
elif ip_adapter_image_embeds[0].ndim != 3:
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)
def prepare_latents(
......
......@@ -563,15 +563,22 @@ class StableDiffusionXLAdapterPipeline(
image_embeds.append(single_image_embeds)
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds)
return image_embeds
......@@ -686,9 +693,9 @@ class StableDiffusionXLAdapterPipeline(
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
elif ip_adapter_image_embeds[0].ndim != 3:
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D"
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment