Unverified Commit 5a287d3f authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[SDXL] Make sure multi batch prompt embeds works (#5073)

* [SDXL] Make sure multi batch prompt embeds works

* [SDXL] Make sure multi batch prompt embeds works

* improve more

* improve more

* Apply suggestions from code review
parent 65c162a5
...@@ -314,9 +314,9 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -314,9 +314,9 @@ class StableDiffusionXLControlNetInpaintPipeline(
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
if prompt is not None and isinstance(prompt, str): prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = 1
elif prompt is not None and isinstance(prompt, list): if prompt is not None:
batch_size = len(prompt) batch_size = len(prompt)
else: else:
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
...@@ -329,6 +329,8 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -329,6 +329,8 @@ class StableDiffusionXLControlNetInpaintPipeline(
if prompt_embeds is None: if prompt_embeds is None:
prompt_2 = prompt_2 or prompt prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
# textual inversion: procecss multi-vector tokens if necessary # textual inversion: procecss multi-vector tokens if necessary
prompt_embeds_list = [] prompt_embeds_list = []
prompts = [prompt, prompt_2] prompts = [prompt, prompt_2]
...@@ -378,14 +380,18 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -378,14 +380,18 @@ class StableDiffusionXLControlNetInpaintPipeline(
negative_prompt = negative_prompt or "" negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt_2 or negative_prompt negative_prompt_2 = negative_prompt_2 or negative_prompt
# normalize str to list
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
negative_prompt_2 = (
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
)
uncond_tokens: List[str] uncond_tokens: List[str]
if prompt is not None and type(prompt) is not type(negative_prompt): if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError( raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}." f" {type(prompt)}."
) )
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt, negative_prompt_2]
elif batch_size != len(negative_prompt): elif batch_size != len(negative_prompt):
raise ValueError( raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
......
...@@ -287,9 +287,9 @@ class StableDiffusionXLControlNetPipeline( ...@@ -287,9 +287,9 @@ class StableDiffusionXLControlNetPipeline(
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
if prompt is not None and isinstance(prompt, str): prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = 1
elif prompt is not None and isinstance(prompt, list): if prompt is not None:
batch_size = len(prompt) batch_size = len(prompt)
else: else:
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
...@@ -302,6 +302,8 @@ class StableDiffusionXLControlNetPipeline( ...@@ -302,6 +302,8 @@ class StableDiffusionXLControlNetPipeline(
if prompt_embeds is None: if prompt_embeds is None:
prompt_2 = prompt_2 or prompt prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
# textual inversion: procecss multi-vector tokens if necessary # textual inversion: procecss multi-vector tokens if necessary
prompt_embeds_list = [] prompt_embeds_list = []
prompts = [prompt, prompt_2] prompts = [prompt, prompt_2]
...@@ -351,14 +353,18 @@ class StableDiffusionXLControlNetPipeline( ...@@ -351,14 +353,18 @@ class StableDiffusionXLControlNetPipeline(
negative_prompt = negative_prompt or "" negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt_2 or negative_prompt negative_prompt_2 = negative_prompt_2 or negative_prompt
# normalize str to list
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
negative_prompt_2 = (
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
)
uncond_tokens: List[str] uncond_tokens: List[str]
if prompt is not None and type(prompt) is not type(negative_prompt): if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError( raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}." f" {type(prompt)}."
) )
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt, negative_prompt_2]
elif batch_size != len(negative_prompt): elif batch_size != len(negative_prompt):
raise ValueError( raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
......
...@@ -325,9 +325,9 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -325,9 +325,9 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
if prompt is not None and isinstance(prompt, str): prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = 1
elif prompt is not None and isinstance(prompt, list): if prompt is not None:
batch_size = len(prompt) batch_size = len(prompt)
else: else:
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
...@@ -340,6 +340,8 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -340,6 +340,8 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
if prompt_embeds is None: if prompt_embeds is None:
prompt_2 = prompt_2 or prompt prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
# textual inversion: procecss multi-vector tokens if necessary # textual inversion: procecss multi-vector tokens if necessary
prompt_embeds_list = [] prompt_embeds_list = []
prompts = [prompt, prompt_2] prompts = [prompt, prompt_2]
...@@ -389,14 +391,18 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -389,14 +391,18 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
negative_prompt = negative_prompt or "" negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt_2 or negative_prompt negative_prompt_2 = negative_prompt_2 or negative_prompt
# normalize str to list
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
negative_prompt_2 = (
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
)
uncond_tokens: List[str] uncond_tokens: List[str]
if prompt is not None and type(prompt) is not type(negative_prompt): if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError( raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}." f" {type(prompt)}."
) )
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt, negative_prompt_2]
elif batch_size != len(negative_prompt): elif batch_size != len(negative_prompt):
raise ValueError( raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
......
...@@ -263,9 +263,9 @@ class StableDiffusionXLPipeline( ...@@ -263,9 +263,9 @@ class StableDiffusionXLPipeline(
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
if prompt is not None and isinstance(prompt, str): prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = 1
elif prompt is not None and isinstance(prompt, list): if prompt is not None:
batch_size = len(prompt) batch_size = len(prompt)
else: else:
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
...@@ -278,6 +278,8 @@ class StableDiffusionXLPipeline( ...@@ -278,6 +278,8 @@ class StableDiffusionXLPipeline(
if prompt_embeds is None: if prompt_embeds is None:
prompt_2 = prompt_2 or prompt prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
# textual inversion: procecss multi-vector tokens if necessary # textual inversion: procecss multi-vector tokens if necessary
prompt_embeds_list = [] prompt_embeds_list = []
prompts = [prompt, prompt_2] prompts = [prompt, prompt_2]
...@@ -327,14 +329,18 @@ class StableDiffusionXLPipeline( ...@@ -327,14 +329,18 @@ class StableDiffusionXLPipeline(
negative_prompt = negative_prompt or "" negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt_2 or negative_prompt negative_prompt_2 = negative_prompt_2 or negative_prompt
# normalize str to list
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
negative_prompt_2 = (
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
)
uncond_tokens: List[str] uncond_tokens: List[str]
if prompt is not None and type(prompt) is not type(negative_prompt): if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError( raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}." f" {type(prompt)}."
) )
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt, negative_prompt_2]
elif batch_size != len(negative_prompt): elif batch_size != len(negative_prompt):
raise ValueError( raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
......
...@@ -270,9 +270,9 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -270,9 +270,9 @@ class StableDiffusionXLImg2ImgPipeline(
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
if prompt is not None and isinstance(prompt, str): prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = 1
elif prompt is not None and isinstance(prompt, list): if prompt is not None:
batch_size = len(prompt) batch_size = len(prompt)
else: else:
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
...@@ -285,6 +285,8 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -285,6 +285,8 @@ class StableDiffusionXLImg2ImgPipeline(
if prompt_embeds is None: if prompt_embeds is None:
prompt_2 = prompt_2 or prompt prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
# textual inversion: procecss multi-vector tokens if necessary # textual inversion: procecss multi-vector tokens if necessary
prompt_embeds_list = [] prompt_embeds_list = []
prompts = [prompt, prompt_2] prompts = [prompt, prompt_2]
...@@ -334,14 +336,18 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -334,14 +336,18 @@ class StableDiffusionXLImg2ImgPipeline(
negative_prompt = negative_prompt or "" negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt_2 or negative_prompt negative_prompt_2 = negative_prompt_2 or negative_prompt
# normalize str to list
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
negative_prompt_2 = (
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
)
uncond_tokens: List[str] uncond_tokens: List[str]
if prompt is not None and type(prompt) is not type(negative_prompt): if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError( raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}." f" {type(prompt)}."
) )
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt, negative_prompt_2]
elif batch_size != len(negative_prompt): elif batch_size != len(negative_prompt):
raise ValueError( raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
......
...@@ -419,9 +419,9 @@ class StableDiffusionXLInpaintPipeline( ...@@ -419,9 +419,9 @@ class StableDiffusionXLInpaintPipeline(
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
if prompt is not None and isinstance(prompt, str): prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = 1
elif prompt is not None and isinstance(prompt, list): if prompt is not None:
batch_size = len(prompt) batch_size = len(prompt)
else: else:
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
...@@ -434,6 +434,8 @@ class StableDiffusionXLInpaintPipeline( ...@@ -434,6 +434,8 @@ class StableDiffusionXLInpaintPipeline(
if prompt_embeds is None: if prompt_embeds is None:
prompt_2 = prompt_2 or prompt prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
# textual inversion: procecss multi-vector tokens if necessary # textual inversion: procecss multi-vector tokens if necessary
prompt_embeds_list = [] prompt_embeds_list = []
prompts = [prompt, prompt_2] prompts = [prompt, prompt_2]
...@@ -483,14 +485,18 @@ class StableDiffusionXLInpaintPipeline( ...@@ -483,14 +485,18 @@ class StableDiffusionXLInpaintPipeline(
negative_prompt = negative_prompt or "" negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt_2 or negative_prompt negative_prompt_2 = negative_prompt_2 or negative_prompt
# normalize str to list
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
negative_prompt_2 = (
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
)
uncond_tokens: List[str] uncond_tokens: List[str]
if prompt is not None and type(prompt) is not type(negative_prompt): if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError( raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}." f" {type(prompt)}."
) )
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt, negative_prompt_2]
elif batch_size != len(negative_prompt): elif batch_size != len(negative_prompt):
raise ValueError( raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
......
...@@ -287,9 +287,9 @@ class StableDiffusionXLAdapterPipeline( ...@@ -287,9 +287,9 @@ class StableDiffusionXLAdapterPipeline(
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
if prompt is not None and isinstance(prompt, str): prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = 1
elif prompt is not None and isinstance(prompt, list): if prompt is not None:
batch_size = len(prompt) batch_size = len(prompt)
else: else:
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
...@@ -302,6 +302,8 @@ class StableDiffusionXLAdapterPipeline( ...@@ -302,6 +302,8 @@ class StableDiffusionXLAdapterPipeline(
if prompt_embeds is None: if prompt_embeds is None:
prompt_2 = prompt_2 or prompt prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
# textual inversion: procecss multi-vector tokens if necessary # textual inversion: procecss multi-vector tokens if necessary
prompt_embeds_list = [] prompt_embeds_list = []
prompts = [prompt, prompt_2] prompts = [prompt, prompt_2]
...@@ -351,14 +353,18 @@ class StableDiffusionXLAdapterPipeline( ...@@ -351,14 +353,18 @@ class StableDiffusionXLAdapterPipeline(
negative_prompt = negative_prompt or "" negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt_2 or negative_prompt negative_prompt_2 = negative_prompt_2 or negative_prompt
# normalize str to list
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
negative_prompt_2 = (
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
)
uncond_tokens: List[str] uncond_tokens: List[str]
if prompt is not None and type(prompt) is not type(negative_prompt): if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError( raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}." f" {type(prompt)}."
) )
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt, negative_prompt_2]
elif batch_size != len(negative_prompt): elif batch_size != len(negative_prompt):
raise ValueError( raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
......
...@@ -261,6 +261,42 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest ...@@ -261,6 +261,42 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3 assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3 assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3
def test_stable_diffusion_xl_img2img_prompt_embeds_only(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
# forward without prompt embeds
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
inputs["prompt"] = 3 * [inputs["prompt"]]
output = sd_pipe(**inputs)
image_slice_1 = output.images[0, -3:, -3:, -1]
# forward with prompt embeds
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
prompt = 3 * [inputs.pop("prompt")]
(
prompt_embeds,
_,
pooled_prompt_embeds,
_,
) = sd_pipe.encode_prompt(prompt)
output = sd_pipe(
**inputs,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
)
image_slice_2 = output.images[0, -3:, -3:, -1]
# make sure that it's equal
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
def test_stable_diffusion_two_xl_mixture_of_denoiser(self): def test_stable_diffusion_two_xl_mixture_of_denoiser(self):
components = self.get_dummy_components() components = self.get_dummy_components()
pipe_1 = StableDiffusionXLPipeline(**components).to(torch_device) pipe_1 = StableDiffusionXLPipeline(**components).to(torch_device)
......
...@@ -559,6 +559,42 @@ class StableDiffusionXLImg2ImgRefinerOnlyPipelineFastTests( ...@@ -559,6 +559,42 @@ class StableDiffusionXLImg2ImgRefinerOnlyPipelineFastTests(
# make sure that it's equal # make sure that it's equal
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
def test_stable_diffusion_xl_img2img_prompt_embeds_only(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionXLImg2ImgPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
# forward without prompt embeds
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
inputs["prompt"] = 3 * [inputs["prompt"]]
output = sd_pipe(**inputs)
image_slice_1 = output.images[0, -3:, -3:, -1]
# forward with prompt embeds
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
prompt = 3 * [inputs.pop("prompt")]
(
prompt_embeds,
_,
pooled_prompt_embeds,
_,
) = sd_pipe.encode_prompt(prompt)
output = sd_pipe(
**inputs,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
)
image_slice_2 = output.images[0, -3:, -3:, -1]
# make sure that it's equal
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
def test_attention_slicing_forward_pass(self): def test_attention_slicing_forward_pass(self):
super().test_attention_slicing_forward_pass(expected_max_diff=3e-3) super().test_attention_slicing_forward_pass(expected_max_diff=3e-3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment