Unverified Commit 001b1402 authored by Álvaro Somoza's avatar Álvaro Somoza Committed by GitHub
Browse files

[ip-adapter] fix problem using embeds with the plus version of ip adapters (#7189)



* initial

* check_inputs fix to the rest of pipelines

* add fix for no cfg too

* use of variable

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent f55873b7
...@@ -400,15 +400,22 @@ class AnimateDiffPipeline( ...@@ -400,15 +400,22 @@ class AnimateDiffPipeline(
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
else: else:
repeat_dims = [1]
image_embeds = [] image_embeds = []
for single_image_embeds in ip_adapter_image_embeds: for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance: if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1) single_image_embeds = single_image_embeds.repeat(
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else: else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
return image_embeds return image_embeds
...@@ -509,9 +516,9 @@ class AnimateDiffPipeline( ...@@ -509,9 +516,9 @@ class AnimateDiffPipeline(
raise ValueError( raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
) )
elif ip_adapter_image_embeds[0].ndim != 3: elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError( raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D" f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
) )
# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
......
...@@ -478,15 +478,22 @@ class AnimateDiffVideoToVideoPipeline( ...@@ -478,15 +478,22 @@ class AnimateDiffVideoToVideoPipeline(
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
else: else:
repeat_dims = [1]
image_embeds = [] image_embeds = []
for single_image_embeds in ip_adapter_image_embeds: for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance: if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1) single_image_embeds = single_image_embeds.repeat(
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else: else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
return image_embeds return image_embeds
...@@ -589,9 +596,9 @@ class AnimateDiffVideoToVideoPipeline( ...@@ -589,9 +596,9 @@ class AnimateDiffVideoToVideoPipeline(
raise ValueError( raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
) )
elif ip_adapter_image_embeds[0].ndim != 3: elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError( raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D" f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
) )
def get_timesteps(self, num_inference_steps, timesteps, strength, device): def get_timesteps(self, num_inference_steps, timesteps, strength, device):
......
...@@ -510,15 +510,22 @@ class StableDiffusionControlNetPipeline( ...@@ -510,15 +510,22 @@ class StableDiffusionControlNetPipeline(
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
else: else:
repeat_dims = [1]
image_embeds = [] image_embeds = []
for single_image_embeds in ip_adapter_image_embeds: for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance: if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1) single_image_embeds = single_image_embeds.repeat(
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else: else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
return image_embeds return image_embeds
...@@ -726,9 +733,9 @@ class StableDiffusionControlNetPipeline( ...@@ -726,9 +733,9 @@ class StableDiffusionControlNetPipeline(
raise ValueError( raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
) )
elif ip_adapter_image_embeds[0].ndim != 3: elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError( raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D" f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
) )
def check_image(self, image, prompt, prompt_embeds): def check_image(self, image, prompt, prompt_embeds):
......
...@@ -503,15 +503,22 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -503,15 +503,22 @@ class StableDiffusionControlNetImg2ImgPipeline(
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
else: else:
repeat_dims = [1]
image_embeds = [] image_embeds = []
for single_image_embeds in ip_adapter_image_embeds: for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance: if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1) single_image_embeds = single_image_embeds.repeat(
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else: else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
return image_embeds return image_embeds
...@@ -713,9 +720,9 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -713,9 +720,9 @@ class StableDiffusionControlNetImg2ImgPipeline(
raise ValueError( raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
) )
elif ip_adapter_image_embeds[0].ndim != 3: elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError( raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D" f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
) )
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
......
...@@ -628,15 +628,22 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -628,15 +628,22 @@ class StableDiffusionControlNetInpaintPipeline(
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
else: else:
repeat_dims = [1]
image_embeds = [] image_embeds = []
for single_image_embeds in ip_adapter_image_embeds: for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance: if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1) single_image_embeds = single_image_embeds.repeat(
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else: else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
return image_embeds return image_embeds
...@@ -871,9 +878,9 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -871,9 +878,9 @@ class StableDiffusionControlNetInpaintPipeline(
raise ValueError( raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
) )
elif ip_adapter_image_embeds[0].ndim != 3: elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError( raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D" f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
) )
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
......
...@@ -537,15 +537,22 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -537,15 +537,22 @@ class StableDiffusionXLControlNetInpaintPipeline(
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
else: else:
repeat_dims = [1]
image_embeds = [] image_embeds = []
for single_image_embeds in ip_adapter_image_embeds: for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance: if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1) single_image_embeds = single_image_embeds.repeat(
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else: else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
return image_embeds return image_embeds
...@@ -817,9 +824,9 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -817,9 +824,9 @@ class StableDiffusionXLControlNetInpaintPipeline(
raise ValueError( raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
) )
elif ip_adapter_image_embeds[0].ndim != 3: elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError( raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D" f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
) )
def prepare_control_image( def prepare_control_image(
......
...@@ -515,15 +515,22 @@ class StableDiffusionXLControlNetPipeline( ...@@ -515,15 +515,22 @@ class StableDiffusionXLControlNetPipeline(
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
else: else:
repeat_dims = [1]
image_embeds = [] image_embeds = []
for single_image_embeds in ip_adapter_image_embeds: for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance: if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1) single_image_embeds = single_image_embeds.repeat(
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else: else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
return image_embeds return image_embeds
...@@ -730,9 +737,9 @@ class StableDiffusionXLControlNetPipeline( ...@@ -730,9 +737,9 @@ class StableDiffusionXLControlNetPipeline(
raise ValueError( raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
) )
elif ip_adapter_image_embeds[0].ndim != 3: elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError( raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D" f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
) )
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
......
...@@ -567,15 +567,22 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -567,15 +567,22 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
else: else:
repeat_dims = [1]
image_embeds = [] image_embeds = []
for single_image_embeds in ip_adapter_image_embeds: for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance: if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1) single_image_embeds = single_image_embeds.repeat(
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else: else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
return image_embeds return image_embeds
...@@ -794,9 +801,9 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -794,9 +801,9 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
raise ValueError( raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
) )
elif ip_adapter_image_embeds[0].ndim != 3: elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError( raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D" f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
) )
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_image # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_image
......
...@@ -453,15 +453,22 @@ class LatentConsistencyModelImg2ImgPipeline( ...@@ -453,15 +453,22 @@ class LatentConsistencyModelImg2ImgPipeline(
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
else: else:
repeat_dims = [1]
image_embeds = [] image_embeds = []
for single_image_embeds in ip_adapter_image_embeds: for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance: if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1) single_image_embeds = single_image_embeds.repeat(
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else: else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
return image_embeds return image_embeds
...@@ -647,9 +654,9 @@ class LatentConsistencyModelImg2ImgPipeline( ...@@ -647,9 +654,9 @@ class LatentConsistencyModelImg2ImgPipeline(
raise ValueError( raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
) )
elif ip_adapter_image_embeds[0].ndim != 3: elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError( raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D" f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
) )
@property @property
......
...@@ -437,15 +437,22 @@ class LatentConsistencyModelPipeline( ...@@ -437,15 +437,22 @@ class LatentConsistencyModelPipeline(
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
else: else:
repeat_dims = [1]
image_embeds = [] image_embeds = []
for single_image_embeds in ip_adapter_image_embeds: for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance: if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1) single_image_embeds = single_image_embeds.repeat(
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else: else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
return image_embeds return image_embeds
...@@ -579,9 +586,9 @@ class LatentConsistencyModelPipeline( ...@@ -579,9 +586,9 @@ class LatentConsistencyModelPipeline(
raise ValueError( raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
) )
elif ip_adapter_image_embeds[0].ndim != 3: elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError( raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D" f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
) )
@property @property
......
...@@ -582,9 +582,9 @@ class PIAPipeline( ...@@ -582,9 +582,9 @@ class PIAPipeline(
raise ValueError( raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
) )
elif ip_adapter_image_embeds[0].ndim != 3: elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError( raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D" f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
) )
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
...@@ -619,15 +619,22 @@ class PIAPipeline( ...@@ -619,15 +619,22 @@ class PIAPipeline(
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
else: else:
repeat_dims = [1]
image_embeds = [] image_embeds = []
for single_image_embeds in ip_adapter_image_embeds: for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance: if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1) single_image_embeds = single_image_embeds.repeat(
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else: else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
return image_embeds return image_embeds
......
...@@ -520,15 +520,22 @@ class StableDiffusionPipeline( ...@@ -520,15 +520,22 @@ class StableDiffusionPipeline(
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
else: else:
repeat_dims = [1]
image_embeds = [] image_embeds = []
for single_image_embeds in ip_adapter_image_embeds: for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance: if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1) single_image_embeds = single_image_embeds.repeat(
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else: else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
return image_embeds return image_embeds
...@@ -639,9 +646,9 @@ class StableDiffusionPipeline( ...@@ -639,9 +646,9 @@ class StableDiffusionPipeline(
raise ValueError( raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
) )
elif ip_adapter_image_embeds[0].ndim != 3: elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError( raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D" f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
) )
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
......
...@@ -564,15 +564,22 @@ class StableDiffusionImg2ImgPipeline( ...@@ -564,15 +564,22 @@ class StableDiffusionImg2ImgPipeline(
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
else: else:
repeat_dims = [1]
image_embeds = [] image_embeds = []
for single_image_embeds in ip_adapter_image_embeds: for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance: if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1) single_image_embeds = single_image_embeds.repeat(
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else: else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
return image_embeds return image_embeds
...@@ -685,9 +692,9 @@ class StableDiffusionImg2ImgPipeline( ...@@ -685,9 +692,9 @@ class StableDiffusionImg2ImgPipeline(
raise ValueError( raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
) )
elif ip_adapter_image_embeds[0].ndim != 3: elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError( raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D" f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
) )
def get_timesteps(self, num_inference_steps, strength, device): def get_timesteps(self, num_inference_steps, strength, device):
......
...@@ -636,15 +636,22 @@ class StableDiffusionInpaintPipeline( ...@@ -636,15 +636,22 @@ class StableDiffusionInpaintPipeline(
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
else: else:
repeat_dims = [1]
image_embeds = [] image_embeds = []
for single_image_embeds in ip_adapter_image_embeds: for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance: if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1) single_image_embeds = single_image_embeds.repeat(
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else: else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
return image_embeds return image_embeds
...@@ -767,9 +774,9 @@ class StableDiffusionInpaintPipeline( ...@@ -767,9 +774,9 @@ class StableDiffusionInpaintPipeline(
raise ValueError( raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
) )
elif ip_adapter_image_embeds[0].ndim != 3: elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError( raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D" f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
) )
def prepare_latents( def prepare_latents(
......
...@@ -442,15 +442,22 @@ class StableDiffusionLDM3DPipeline( ...@@ -442,15 +442,22 @@ class StableDiffusionLDM3DPipeline(
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
else: else:
repeat_dims = [1]
image_embeds = [] image_embeds = []
for single_image_embeds in ip_adapter_image_embeds: for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance: if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1) single_image_embeds = single_image_embeds.repeat(
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else: else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
return image_embeds return image_embeds
...@@ -553,9 +560,9 @@ class StableDiffusionLDM3DPipeline( ...@@ -553,9 +560,9 @@ class StableDiffusionLDM3DPipeline(
raise ValueError( raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
) )
elif ip_adapter_image_embeds[0].ndim != 3: elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError( raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D" f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
) )
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
......
...@@ -414,15 +414,22 @@ class StableDiffusionPanoramaPipeline( ...@@ -414,15 +414,22 @@ class StableDiffusionPanoramaPipeline(
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
else: else:
repeat_dims = [1]
image_embeds = [] image_embeds = []
for single_image_embeds in ip_adapter_image_embeds: for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance: if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1) single_image_embeds = single_image_embeds.repeat(
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else: else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
return image_embeds return image_embeds
...@@ -550,9 +557,9 @@ class StableDiffusionPanoramaPipeline( ...@@ -550,9 +557,9 @@ class StableDiffusionPanoramaPipeline(
raise ValueError( raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
) )
elif ip_adapter_image_embeds[0].ndim != 3: elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError( raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D" f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
) )
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
......
...@@ -549,15 +549,22 @@ class StableDiffusionXLPipeline( ...@@ -549,15 +549,22 @@ class StableDiffusionXLPipeline(
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
else: else:
repeat_dims = [1]
image_embeds = [] image_embeds = []
for single_image_embeds in ip_adapter_image_embeds: for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance: if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1) single_image_embeds = single_image_embeds.repeat(
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else: else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
return image_embeds return image_embeds
...@@ -671,9 +678,9 @@ class StableDiffusionXLPipeline( ...@@ -671,9 +678,9 @@ class StableDiffusionXLPipeline(
raise ValueError( raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
) )
elif ip_adapter_image_embeds[0].ndim != 3: elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError( raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D" f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
) )
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
......
...@@ -616,9 +616,9 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -616,9 +616,9 @@ class StableDiffusionXLImg2ImgPipeline(
raise ValueError( raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
) )
elif ip_adapter_image_embeds[0].ndim != 3: elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError( raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D" f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
) )
def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
...@@ -782,15 +782,22 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -782,15 +782,22 @@ class StableDiffusionXLImg2ImgPipeline(
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
else: else:
repeat_dims = [1]
image_embeds = [] image_embeds = []
for single_image_embeds in ip_adapter_image_embeds: for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance: if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1) single_image_embeds = single_image_embeds.repeat(
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else: else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
return image_embeds return image_embeds
......
...@@ -486,15 +486,22 @@ class StableDiffusionXLInpaintPipeline( ...@@ -486,15 +486,22 @@ class StableDiffusionXLInpaintPipeline(
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
else: else:
repeat_dims = [1]
image_embeds = [] image_embeds = []
for single_image_embeds in ip_adapter_image_embeds: for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance: if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1) single_image_embeds = single_image_embeds.repeat(
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else: else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
return image_embeds return image_embeds
...@@ -851,9 +858,9 @@ class StableDiffusionXLInpaintPipeline( ...@@ -851,9 +858,9 @@ class StableDiffusionXLInpaintPipeline(
raise ValueError( raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
) )
elif ip_adapter_image_embeds[0].ndim != 3: elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError( raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D" f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
) )
def prepare_latents( def prepare_latents(
......
...@@ -563,15 +563,22 @@ class StableDiffusionXLAdapterPipeline( ...@@ -563,15 +563,22 @@ class StableDiffusionXLAdapterPipeline(
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
else: else:
repeat_dims = [1]
image_embeds = [] image_embeds = []
for single_image_embeds in ip_adapter_image_embeds: for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance: if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1) single_image_embeds = single_image_embeds.repeat(
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else: else:
single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
return image_embeds return image_embeds
...@@ -686,9 +693,9 @@ class StableDiffusionXLAdapterPipeline( ...@@ -686,9 +693,9 @@ class StableDiffusionXLAdapterPipeline(
raise ValueError( raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
) )
elif ip_adapter_image_embeds[0].ndim != 3: elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError( raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D" f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
) )
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment