"git@developer.sourcefind.cn:OpenDAS/autoawq.git" did not exist on "0091f1e2ad2d96acde381d769df58f784f3dbad2"
Unverified Commit 69f49195 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Fix passing pooled prompt embeds to Cascade Decoder and Combined Pipeline (#7287)

* update

* update

* update

* update
parent ed224f94
......@@ -289,7 +289,9 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
guidance_scale: float = 0.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
......@@ -321,10 +323,17 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
prompt_embeds_pooled (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
negative_prompt_embeds_pooled (`torch.FloatTensor`, *optional*):
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds_pooled will be generated from `negative_prompt` input
argument.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
......@@ -378,7 +387,7 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
# 2. Encode caption
if prompt_embeds is None and negative_prompt_embeds is None:
prompt_embeds, _, negative_prompt_embeds, _ = self.encode_prompt(
_, prompt_embeds_pooled, _, negative_prompt_embeds_pooled = self.encode_prompt(
prompt=prompt,
device=device,
batch_size=batch_size,
......@@ -386,10 +395,16 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
prompt_embeds_pooled=prompt_embeds_pooled,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
)
# The pooled embeds from the prior are pooled again before being passed to the decoder
prompt_embeds_pooled = (
torch.cat([prompt_embeds, negative_prompt_embeds]) if self.do_classifier_free_guidance else prompt_embeds
torch.cat([prompt_embeds_pooled, negative_prompt_embeds_pooled])
if self.do_classifier_free_guidance
else prompt_embeds_pooled
)
effnet = (
torch.cat([image_embeddings, torch.zeros_like(image_embeddings)])
......
......@@ -155,14 +155,14 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
height: int = 512,
width: int = 512,
prior_num_inference_steps: int = 60,
prior_timesteps: Optional[List[float]] = None,
prior_guidance_scale: float = 4.0,
num_inference_steps: int = 12,
decoder_timesteps: Optional[List[float]] = None,
decoder_guidance_scale: float = 0.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
......@@ -187,10 +187,17 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, text embeddings will be generated from `prompt` input argument.
prompt_embeds_pooled (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.*
prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt`
input argument.
negative_prompt_embeds_pooled (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.*
prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt`
input argument.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
height (`int`, *optional*, defaults to 512):
......@@ -253,7 +260,6 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
[`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True,
otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images.
"""
prior_outputs = self.prior_pipe(
prompt=prompt if prompt_embeds is None else None,
images=images,
......@@ -263,7 +269,9 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
guidance_scale=prior_guidance_scale,
negative_prompt=negative_prompt if negative_prompt_embeds is None else None,
prompt_embeds=prompt_embeds,
prompt_embeds_pooled=prompt_embeds_pooled,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
num_images_per_prompt=num_images_per_prompt,
generator=generator,
latents=latents,
......@@ -274,7 +282,9 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
)
image_embeddings = prior_outputs.image_embeddings
prompt_embeds = prior_outputs.get("prompt_embeds", None)
prompt_embeds_pooled = prior_outputs.get("prompt_embeds_pooled", None)
negative_prompt_embeds = prior_outputs.get("negative_prompt_embeds", None)
negative_prompt_embeds_pooled = prior_outputs.get("negative_prompt_embeds_pooled", None)
outputs = self.decoder_pipe(
image_embeddings=image_embeddings,
......@@ -283,7 +293,9 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
guidance_scale=decoder_guidance_scale,
negative_prompt=negative_prompt if negative_prompt_embeds is None else None,
prompt_embeds=prompt_embeds,
prompt_embeds_pooled=prompt_embeds_pooled,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
generator=generator,
output_type=output_type,
return_dict=return_dict,
......
......@@ -64,7 +64,9 @@ class StableCascadePriorPipelineOutput(BaseOutput):
image_embeddings: Union[torch.FloatTensor, np.ndarray]
prompt_embeds: Union[torch.FloatTensor, np.ndarray]
prompt_embeds_pooled: Union[torch.FloatTensor, np.ndarray]
negative_prompt_embeds: Union[torch.FloatTensor, np.ndarray]
negative_prompt_embeds_pooled: Union[torch.FloatTensor, np.ndarray]
class StableCascadePriorPipeline(DiffusionPipeline):
......@@ -305,6 +307,16 @@ class StableCascadePriorPipeline(DiffusionPipeline):
f" {negative_prompt_embeds.shape}."
)
if prompt_embeds is not None and prompt_embeds_pooled is None:
raise ValueError(
"If `prompt_embeds` are provided, `prompt_embeds_pooled` must also be provided. Make sure to generate `prompt_embeds_pooled` from the same text encoder that was used to generate `prompt_embeds`"
)
if negative_prompt_embeds is not None and negative_prompt_embeds_pooled is None:
raise ValueError(
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_pooled` must also be provided. Make sure to generate `prompt_embeds_pooled` from the same text encoder that was used to generate `prompt_embeds`"
)
if prompt_embeds_pooled is not None and negative_prompt_embeds_pooled is not None:
if prompt_embeds_pooled.shape != negative_prompt_embeds_pooled.shape:
raise ValueError(
......@@ -339,7 +351,7 @@ class StableCascadePriorPipeline(DiffusionPipeline):
def num_timesteps(self):
return self._num_timesteps
def get_t_condioning(self, t, alphas_cumprod):
def get_timestep_ratio_conditioning(self, t, alphas_cumprod):
s = torch.tensor([0.003])
clamp_range = [0, 1]
min_var = torch.cos(s / (1 + s) * torch.pi * 0.5) ** 2
......@@ -558,7 +570,7 @@ class StableCascadePriorPipeline(DiffusionPipeline):
for i, t in enumerate(self.progress_bar(timesteps)):
if not isinstance(self.scheduler, DDPMWuerstchenScheduler):
if len(alphas_cumprod) > 0:
timestep_ratio = self.get_t_condioning(t.long().cpu(), alphas_cumprod)
timestep_ratio = self.get_timestep_ratio_conditioning(t.long().cpu(), alphas_cumprod)
timestep_ratio = timestep_ratio.expand(latents.size(0)).to(dtype).to(device)
else:
timestep_ratio = t.float().div(self.scheduler.timesteps[-1]).expand(latents.size(0)).to(dtype)
......@@ -609,6 +621,18 @@ class StableCascadePriorPipeline(DiffusionPipeline):
) # float() as bfloat16-> numpy doesnt work
if not return_dict:
return (latents, prompt_embeds, negative_prompt_embeds)
return (
latents,
prompt_embeds,
prompt_embeds_pooled,
negative_prompt_embeds,
negative_prompt_embeds_pooled,
)
return StableCascadePriorPipelineOutput(latents, prompt_embeds, negative_prompt_embeds)
return StableCascadePriorPipelineOutput(
image_embeddings=latents,
prompt_embeds=prompt_embeds,
prompt_embeds_pooled=prompt_embeds_pooled,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
)
......@@ -241,6 +241,39 @@ class StableCascadeCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestC
def test_callback_inputs(self):
super().test_callback_inputs()
# def test_callback_cfg(self):
# pass
# pass
def test_stable_cascade_combined_prompt_embeds(self):
device = "cpu"
components = self.get_dummy_components()
pipe = StableCascadeCombinedPipeline(**components)
pipe.set_progress_bar_config(disable=None)
prompt = "A photograph of a shiba inu, wearing a hat"
(
prompt_embeds,
prompt_embeds_pooled,
negative_prompt_embeds,
negative_prompt_embeds_pooled,
) = pipe.prior_pipe.encode_prompt(device, 1, 1, False, prompt=prompt)
generator = torch.Generator(device=device)
output_prompt = pipe(
prompt=prompt,
num_inference_steps=1,
prior_num_inference_steps=1,
output_type="np",
generator=generator.manual_seed(0),
)
output_prompt_embeds = pipe(
prompt=None,
prompt_embeds=prompt_embeds,
prompt_embeds_pooled=prompt_embeds_pooled,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
num_inference_steps=1,
prior_num_inference_steps=1,
output_type="np",
generator=generator.manual_seed(0),
)
assert np.abs(output_prompt.images - output_prompt_embeds.images).max() < 1e-5
......@@ -207,6 +207,45 @@ class StableCascadeDecoderPipelineFastTests(PipelineTesterMixin, unittest.TestCa
def test_float16_inference(self):
super().test_float16_inference()
def test_stable_cascade_decoder_prompt_embeds(self):
device = "cpu"
components = self.get_dummy_components()
pipe = StableCascadeDecoderPipeline(**components)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image_embeddings = inputs["image_embeddings"]
prompt = "A photograph of a shiba inu, wearing a hat"
(
prompt_embeds,
prompt_embeds_pooled,
negative_prompt_embeds,
negative_prompt_embeds_pooled,
) = pipe.encode_prompt(device, 1, 1, False, prompt=prompt)
generator = torch.Generator(device=device)
decoder_output_prompt = pipe(
image_embeddings=image_embeddings,
prompt=prompt,
num_inference_steps=1,
output_type="np",
generator=generator.manual_seed(0),
)
decoder_output_prompt_embeds = pipe(
image_embeddings=image_embeddings,
prompt=None,
prompt_embeds=prompt_embeds,
prompt_embeds_pooled=prompt_embeds_pooled,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
num_inference_steps=1,
output_type="np",
generator=generator.manual_seed(0),
)
assert np.abs(decoder_output_prompt.images - decoder_output_prompt_embeds.images).max() < 1e-5
@slow
@require_torch_gpu
......
......@@ -273,6 +273,41 @@ class StableCascadePriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase
self.assertTrue(image_embed.shape == lora_image_embed.shape)
def test_stable_cascade_decoder_prompt_embeds(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
prompt = "A photograph of a shiba inu, wearing a hat"
(
prompt_embeds,
prompt_embeds_pooled,
negative_prompt_embeds,
negative_prompt_embeds_pooled,
) = pipe.encode_prompt(device, 1, 1, False, prompt=prompt)
generator = torch.Generator(device=device)
output_prompt = pipe(
prompt=prompt,
num_inference_steps=1,
output_type="np",
generator=generator.manual_seed(0),
)
output_prompt_embeds = pipe(
prompt=None,
prompt_embeds=prompt_embeds,
prompt_embeds_pooled=prompt_embeds_pooled,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
num_inference_steps=1,
output_type="np",
generator=generator.manual_seed(0),
)
assert np.abs(output_prompt.image_embeddings - output_prompt_embeds.image_embeddings).max() < 1e-5
@slow
@require_torch_gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment