"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "95295061d0dbe09d0258516cdd780d0d64edb951"
Unverified Commit 69f49195 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Fix passing pooled prompt embeds to Cascade Decoder and Combined Pipeline (#7287)

* update

* update

* update

* update
parent ed224f94
...@@ -289,7 +289,9 @@ class StableCascadeDecoderPipeline(DiffusionPipeline): ...@@ -289,7 +289,9 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
guidance_scale: float = 0.0, guidance_scale: float = 0.0,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
...@@ -321,10 +323,17 @@ class StableCascadeDecoderPipeline(DiffusionPipeline): ...@@ -321,10 +323,17 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
prompt_embeds (`torch.FloatTensor`, *optional*): prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument. provided, text embeddings will be generated from `prompt` input argument.
prompt_embeds_pooled (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*): negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument. argument.
negative_prompt_embeds_pooled (`torch.FloatTensor`, *optional*):
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds_pooled will be generated from `negative_prompt` input
argument.
num_images_per_prompt (`int`, *optional*, defaults to 1): num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt. The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*): generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
...@@ -378,7 +387,7 @@ class StableCascadeDecoderPipeline(DiffusionPipeline): ...@@ -378,7 +387,7 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
# 2. Encode caption # 2. Encode caption
if prompt_embeds is None and negative_prompt_embeds is None: if prompt_embeds is None and negative_prompt_embeds is None:
prompt_embeds, _, negative_prompt_embeds, _ = self.encode_prompt( _, prompt_embeds_pooled, _, negative_prompt_embeds_pooled = self.encode_prompt(
prompt=prompt, prompt=prompt,
device=device, device=device,
batch_size=batch_size, batch_size=batch_size,
...@@ -386,10 +395,16 @@ class StableCascadeDecoderPipeline(DiffusionPipeline): ...@@ -386,10 +395,16 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
do_classifier_free_guidance=self.do_classifier_free_guidance, do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
prompt_embeds_pooled=prompt_embeds_pooled,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
) )
# The pooled embeds from the prior are pooled again before being passed to the decoder
prompt_embeds_pooled = ( prompt_embeds_pooled = (
torch.cat([prompt_embeds, negative_prompt_embeds]) if self.do_classifier_free_guidance else prompt_embeds torch.cat([prompt_embeds_pooled, negative_prompt_embeds_pooled])
if self.do_classifier_free_guidance
else prompt_embeds_pooled
) )
effnet = ( effnet = (
torch.cat([image_embeddings, torch.zeros_like(image_embeddings)]) torch.cat([image_embeddings, torch.zeros_like(image_embeddings)])
......
...@@ -155,14 +155,14 @@ class StableCascadeCombinedPipeline(DiffusionPipeline): ...@@ -155,14 +155,14 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
height: int = 512, height: int = 512,
width: int = 512, width: int = 512,
prior_num_inference_steps: int = 60, prior_num_inference_steps: int = 60,
prior_timesteps: Optional[List[float]] = None,
prior_guidance_scale: float = 4.0, prior_guidance_scale: float = 4.0,
num_inference_steps: int = 12, num_inference_steps: int = 12,
decoder_timesteps: Optional[List[float]] = None,
decoder_guidance_scale: float = 0.0, decoder_guidance_scale: float = 0.0,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
...@@ -187,10 +187,17 @@ class StableCascadeCombinedPipeline(DiffusionPipeline): ...@@ -187,10 +187,17 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
prompt_embeds (`torch.FloatTensor`, *optional*): prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, text embeddings will be generated from `prompt` input argument. weighting. If not provided, text embeddings will be generated from `prompt` input argument.
prompt_embeds_pooled (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*): negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.*
prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt`
input argument. input argument.
negative_prompt_embeds_pooled (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.*
prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt`
input argument.
num_images_per_prompt (`int`, *optional*, defaults to 1): num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt. The number of images to generate per prompt.
height (`int`, *optional*, defaults to 512): height (`int`, *optional*, defaults to 512):
...@@ -253,7 +260,6 @@ class StableCascadeCombinedPipeline(DiffusionPipeline): ...@@ -253,7 +260,6 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
[`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True, [`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True,
otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images.
""" """
prior_outputs = self.prior_pipe( prior_outputs = self.prior_pipe(
prompt=prompt if prompt_embeds is None else None, prompt=prompt if prompt_embeds is None else None,
images=images, images=images,
...@@ -263,7 +269,9 @@ class StableCascadeCombinedPipeline(DiffusionPipeline): ...@@ -263,7 +269,9 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
guidance_scale=prior_guidance_scale, guidance_scale=prior_guidance_scale,
negative_prompt=negative_prompt if negative_prompt_embeds is None else None, negative_prompt=negative_prompt if negative_prompt_embeds is None else None,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
prompt_embeds_pooled=prompt_embeds_pooled,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
generator=generator, generator=generator,
latents=latents, latents=latents,
...@@ -274,7 +282,9 @@ class StableCascadeCombinedPipeline(DiffusionPipeline): ...@@ -274,7 +282,9 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
) )
image_embeddings = prior_outputs.image_embeddings image_embeddings = prior_outputs.image_embeddings
prompt_embeds = prior_outputs.get("prompt_embeds", None) prompt_embeds = prior_outputs.get("prompt_embeds", None)
prompt_embeds_pooled = prior_outputs.get("prompt_embeds_pooled", None)
negative_prompt_embeds = prior_outputs.get("negative_prompt_embeds", None) negative_prompt_embeds = prior_outputs.get("negative_prompt_embeds", None)
negative_prompt_embeds_pooled = prior_outputs.get("negative_prompt_embeds_pooled", None)
outputs = self.decoder_pipe( outputs = self.decoder_pipe(
image_embeddings=image_embeddings, image_embeddings=image_embeddings,
...@@ -283,7 +293,9 @@ class StableCascadeCombinedPipeline(DiffusionPipeline): ...@@ -283,7 +293,9 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
guidance_scale=decoder_guidance_scale, guidance_scale=decoder_guidance_scale,
negative_prompt=negative_prompt if negative_prompt_embeds is None else None, negative_prompt=negative_prompt if negative_prompt_embeds is None else None,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
prompt_embeds_pooled=prompt_embeds_pooled,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
generator=generator, generator=generator,
output_type=output_type, output_type=output_type,
return_dict=return_dict, return_dict=return_dict,
......
...@@ -64,7 +64,9 @@ class StableCascadePriorPipelineOutput(BaseOutput): ...@@ -64,7 +64,9 @@ class StableCascadePriorPipelineOutput(BaseOutput):
image_embeddings: Union[torch.FloatTensor, np.ndarray] image_embeddings: Union[torch.FloatTensor, np.ndarray]
prompt_embeds: Union[torch.FloatTensor, np.ndarray] prompt_embeds: Union[torch.FloatTensor, np.ndarray]
prompt_embeds_pooled: Union[torch.FloatTensor, np.ndarray]
negative_prompt_embeds: Union[torch.FloatTensor, np.ndarray] negative_prompt_embeds: Union[torch.FloatTensor, np.ndarray]
negative_prompt_embeds_pooled: Union[torch.FloatTensor, np.ndarray]
class StableCascadePriorPipeline(DiffusionPipeline): class StableCascadePriorPipeline(DiffusionPipeline):
...@@ -305,6 +307,16 @@ class StableCascadePriorPipeline(DiffusionPipeline): ...@@ -305,6 +307,16 @@ class StableCascadePriorPipeline(DiffusionPipeline):
f" {negative_prompt_embeds.shape}." f" {negative_prompt_embeds.shape}."
) )
if prompt_embeds is not None and prompt_embeds_pooled is None:
raise ValueError(
"If `prompt_embeds` are provided, `prompt_embeds_pooled` must also be provided. Make sure to generate `prompt_embeds_pooled` from the same text encoder that was used to generate `prompt_embeds`"
)
if negative_prompt_embeds is not None and negative_prompt_embeds_pooled is None:
raise ValueError(
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_pooled` must also be provided. Make sure to generate `prompt_embeds_pooled` from the same text encoder that was used to generate `prompt_embeds`"
)
if prompt_embeds_pooled is not None and negative_prompt_embeds_pooled is not None: if prompt_embeds_pooled is not None and negative_prompt_embeds_pooled is not None:
if prompt_embeds_pooled.shape != negative_prompt_embeds_pooled.shape: if prompt_embeds_pooled.shape != negative_prompt_embeds_pooled.shape:
raise ValueError( raise ValueError(
...@@ -339,7 +351,7 @@ class StableCascadePriorPipeline(DiffusionPipeline): ...@@ -339,7 +351,7 @@ class StableCascadePriorPipeline(DiffusionPipeline):
def num_timesteps(self): def num_timesteps(self):
return self._num_timesteps return self._num_timesteps
def get_t_condioning(self, t, alphas_cumprod): def get_timestep_ratio_conditioning(self, t, alphas_cumprod):
s = torch.tensor([0.003]) s = torch.tensor([0.003])
clamp_range = [0, 1] clamp_range = [0, 1]
min_var = torch.cos(s / (1 + s) * torch.pi * 0.5) ** 2 min_var = torch.cos(s / (1 + s) * torch.pi * 0.5) ** 2
...@@ -558,7 +570,7 @@ class StableCascadePriorPipeline(DiffusionPipeline): ...@@ -558,7 +570,7 @@ class StableCascadePriorPipeline(DiffusionPipeline):
for i, t in enumerate(self.progress_bar(timesteps)): for i, t in enumerate(self.progress_bar(timesteps)):
if not isinstance(self.scheduler, DDPMWuerstchenScheduler): if not isinstance(self.scheduler, DDPMWuerstchenScheduler):
if len(alphas_cumprod) > 0: if len(alphas_cumprod) > 0:
timestep_ratio = self.get_t_condioning(t.long().cpu(), alphas_cumprod) timestep_ratio = self.get_timestep_ratio_conditioning(t.long().cpu(), alphas_cumprod)
timestep_ratio = timestep_ratio.expand(latents.size(0)).to(dtype).to(device) timestep_ratio = timestep_ratio.expand(latents.size(0)).to(dtype).to(device)
else: else:
timestep_ratio = t.float().div(self.scheduler.timesteps[-1]).expand(latents.size(0)).to(dtype) timestep_ratio = t.float().div(self.scheduler.timesteps[-1]).expand(latents.size(0)).to(dtype)
...@@ -609,6 +621,18 @@ class StableCascadePriorPipeline(DiffusionPipeline): ...@@ -609,6 +621,18 @@ class StableCascadePriorPipeline(DiffusionPipeline):
) # float() as bfloat16-> numpy doesnt work ) # float() as bfloat16-> numpy doesnt work
if not return_dict: if not return_dict:
return (latents, prompt_embeds, negative_prompt_embeds) return (
latents,
prompt_embeds,
prompt_embeds_pooled,
negative_prompt_embeds,
negative_prompt_embeds_pooled,
)
return StableCascadePriorPipelineOutput(latents, prompt_embeds, negative_prompt_embeds) return StableCascadePriorPipelineOutput(
image_embeddings=latents,
prompt_embeds=prompt_embeds,
prompt_embeds_pooled=prompt_embeds_pooled,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
)
...@@ -241,6 +241,39 @@ class StableCascadeCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestC ...@@ -241,6 +241,39 @@ class StableCascadeCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestC
def test_callback_inputs(self): def test_callback_inputs(self):
super().test_callback_inputs() super().test_callback_inputs()
# def test_callback_cfg(self): def test_stable_cascade_combined_prompt_embeds(self):
# pass device = "cpu"
# pass components = self.get_dummy_components()
pipe = StableCascadeCombinedPipeline(**components)
pipe.set_progress_bar_config(disable=None)
prompt = "A photograph of a shiba inu, wearing a hat"
(
prompt_embeds,
prompt_embeds_pooled,
negative_prompt_embeds,
negative_prompt_embeds_pooled,
) = pipe.prior_pipe.encode_prompt(device, 1, 1, False, prompt=prompt)
generator = torch.Generator(device=device)
output_prompt = pipe(
prompt=prompt,
num_inference_steps=1,
prior_num_inference_steps=1,
output_type="np",
generator=generator.manual_seed(0),
)
output_prompt_embeds = pipe(
prompt=None,
prompt_embeds=prompt_embeds,
prompt_embeds_pooled=prompt_embeds_pooled,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
num_inference_steps=1,
prior_num_inference_steps=1,
output_type="np",
generator=generator.manual_seed(0),
)
assert np.abs(output_prompt.images - output_prompt_embeds.images).max() < 1e-5
...@@ -207,6 +207,45 @@ class StableCascadeDecoderPipelineFastTests(PipelineTesterMixin, unittest.TestCa ...@@ -207,6 +207,45 @@ class StableCascadeDecoderPipelineFastTests(PipelineTesterMixin, unittest.TestCa
def test_float16_inference(self): def test_float16_inference(self):
super().test_float16_inference() super().test_float16_inference()
def test_stable_cascade_decoder_prompt_embeds(self):
device = "cpu"
components = self.get_dummy_components()
pipe = StableCascadeDecoderPipeline(**components)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image_embeddings = inputs["image_embeddings"]
prompt = "A photograph of a shiba inu, wearing a hat"
(
prompt_embeds,
prompt_embeds_pooled,
negative_prompt_embeds,
negative_prompt_embeds_pooled,
) = pipe.encode_prompt(device, 1, 1, False, prompt=prompt)
generator = torch.Generator(device=device)
decoder_output_prompt = pipe(
image_embeddings=image_embeddings,
prompt=prompt,
num_inference_steps=1,
output_type="np",
generator=generator.manual_seed(0),
)
decoder_output_prompt_embeds = pipe(
image_embeddings=image_embeddings,
prompt=None,
prompt_embeds=prompt_embeds,
prompt_embeds_pooled=prompt_embeds_pooled,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
num_inference_steps=1,
output_type="np",
generator=generator.manual_seed(0),
)
assert np.abs(decoder_output_prompt.images - decoder_output_prompt_embeds.images).max() < 1e-5
@slow @slow
@require_torch_gpu @require_torch_gpu
......
...@@ -273,6 +273,41 @@ class StableCascadePriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase ...@@ -273,6 +273,41 @@ class StableCascadePriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase
self.assertTrue(image_embed.shape == lora_image_embed.shape) self.assertTrue(image_embed.shape == lora_image_embed.shape)
def test_stable_cascade_decoder_prompt_embeds(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
prompt = "A photograph of a shiba inu, wearing a hat"
(
prompt_embeds,
prompt_embeds_pooled,
negative_prompt_embeds,
negative_prompt_embeds_pooled,
) = pipe.encode_prompt(device, 1, 1, False, prompt=prompt)
generator = torch.Generator(device=device)
output_prompt = pipe(
prompt=prompt,
num_inference_steps=1,
output_type="np",
generator=generator.manual_seed(0),
)
output_prompt_embeds = pipe(
prompt=None,
prompt_embeds=prompt_embeds,
prompt_embeds_pooled=prompt_embeds_pooled,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
num_inference_steps=1,
output_type="np",
generator=generator.manual_seed(0),
)
assert np.abs(output_prompt.image_embeddings - output_prompt_embeds.image_embeddings).max() < 1e-5
@slow @slow
@require_torch_gpu @require_torch_gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment