Unverified Commit 8e8954bd authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

fix no CFG for kandinsky pipelines (#4193)



* fix bug when no cfg

* style

* fix no cfg for shap-e and cycle

* style

* fix no cfg for sdxl

* fix copies

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
parent 09fab561
......@@ -384,9 +384,10 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
if do_classifier_free_guidance:
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
......
......@@ -344,9 +344,9 @@ class KandinskyPipeline(DiffusionPipeline):
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(
dtype=prompt_embeds.dtype, device=device
)
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(
dtype=prompt_embeds.dtype, device=device
)
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps_tensor = self.scheduler.timesteps
......
......@@ -414,9 +414,9 @@ class KandinskyImg2ImgPipeline(DiffusionPipeline):
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(
dtype=prompt_embeds.dtype, device=device
)
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(
dtype=prompt_embeds.dtype, device=device
)
# 3. pre-processing initial image
if not isinstance(image, list):
......@@ -472,9 +472,17 @@ class KandinskyImg2ImgPipeline(DiffusionPipeline):
)[0]
if do_classifier_free_guidance:
noise_pred, _ = noise_pred.split(latents.shape[1], dim=1)
noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
_, variance_pred_text = variance_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1)
if not (
hasattr(self.scheduler.config, "variance_type")
and self.scheduler.config.variance_type in ["learned", "learned_range"]
):
noise_pred, _ = noise_pred.split(latents.shape[1], dim=1)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
......
......@@ -518,9 +518,9 @@ class KandinskyInpaintPipeline(DiffusionPipeline):
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(
dtype=prompt_embeds.dtype, device=device
)
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(
dtype=prompt_embeds.dtype, device=device
)
# preprocess image and mask
mask_image, image = prepare_mask_and_masked_image(image, mask_image, height, width)
......
......@@ -205,7 +205,9 @@ class KandinskyV22Pipeline(DiffusionPipeline):
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(dtype=self.unet.dtype, device=device)
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(
dtype=self.unet.dtype, device=device
)
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps_tensor = self.scheduler.timesteps
......
......@@ -259,8 +259,11 @@ class KandinskyV22ControlnetPipeline(DiffusionPipeline):
negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
hint = hint.repeat_interleave(num_images_per_prompt, dim=0)
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(dtype=self.unet.dtype, device=device)
hint = torch.cat([hint, hint], dim=0).to(dtype=self.unet.dtype, device=device)
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(
dtype=self.unet.dtype, device=device
)
hint = torch.cat([hint, hint], dim=0).to(dtype=self.unet.dtype, device=device)
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps_tensor = self.scheduler.timesteps
......
......@@ -316,8 +316,10 @@ class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline):
negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
hint = hint.repeat_interleave(num_images_per_prompt, dim=0)
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(dtype=self.unet.dtype, device=device)
hint = torch.cat([hint, hint], dim=0).to(dtype=self.unet.dtype, device=device)
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(
dtype=self.unet.dtype, device=device
)
hint = torch.cat([hint, hint], dim=0).to(dtype=self.unet.dtype, device=device)
if not isinstance(image, list):
image = [image]
......
......@@ -282,7 +282,9 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline):
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(dtype=self.unet.dtype, device=device)
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(
dtype=self.unet.dtype, device=device
)
if not isinstance(image, list):
image = [image]
......
......@@ -383,7 +383,9 @@ class KandinskyV22InpaintPipeline(DiffusionPipeline):
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(dtype=self.unet.dtype, device=device)
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to(
dtype=self.unet.dtype, device=device
)
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps_tensor = self.scheduler.timesteps
......
......@@ -309,7 +309,7 @@ class ShapEPipeline(DiffusionPipeline):
scaled_model_input.shape[2], dim=2
) # batch_size, num_embeddings, embedding_dim
if do_classifier_free_guidance is not None:
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
......
......@@ -265,7 +265,7 @@ class ShapEImg2ImgPipeline(DiffusionPipeline):
scaled_model_input.shape[2], dim=2
) # batch_size, num_embeddings, embedding_dim
if do_classifier_free_guidance is not None:
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
......
......@@ -773,30 +773,49 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2)
source_latent_model_input = torch.cat([source_latents] * 2)
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
source_latent_model_input = (
torch.cat([source_latents] * 2) if do_classifier_free_guidance else source_latents
)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
source_latent_model_input = self.scheduler.scale_model_input(source_latent_model_input, t)
# predict the noise residual
concat_latent_model_input = torch.stack(
[
source_latent_model_input[0],
latent_model_input[0],
source_latent_model_input[1],
latent_model_input[1],
],
dim=0,
)
concat_prompt_embeds = torch.stack(
[
source_prompt_embeds[0],
prompt_embeds[0],
source_prompt_embeds[1],
prompt_embeds[1],
],
dim=0,
)
if do_classifier_free_guidance:
concat_latent_model_input = torch.stack(
[
source_latent_model_input[0],
latent_model_input[0],
source_latent_model_input[1],
latent_model_input[1],
],
dim=0,
)
concat_prompt_embeds = torch.stack(
[
source_prompt_embeds[0],
prompt_embeds[0],
source_prompt_embeds[1],
prompt_embeds[1],
],
dim=0,
)
else:
concat_latent_model_input = torch.cat(
[
source_latent_model_input,
latent_model_input,
],
dim=0,
)
concat_prompt_embeds = torch.cat(
[
source_prompt_embeds,
prompt_embeds,
],
dim=0,
)
concat_noise_pred = self.unet(
concat_latent_model_input,
t,
......@@ -805,16 +824,21 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
).sample
# perform guidance
(
source_noise_pred_uncond,
noise_pred_uncond,
source_noise_pred_text,
noise_pred_text,
) = concat_noise_pred.chunk(4, dim=0)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
source_noise_pred = source_noise_pred_uncond + source_guidance_scale * (
source_noise_pred_text - source_noise_pred_uncond
)
if do_classifier_free_guidance:
(
source_noise_pred_uncond,
noise_pred_uncond,
source_noise_pred_text,
noise_pred_text,
) = concat_noise_pred.chunk(4, dim=0)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
source_noise_pred = source_noise_pred_uncond + source_guidance_scale * (
source_noise_pred_text - source_noise_pred_uncond
)
else:
(source_noise_pred, noise_pred) = concat_noise_pred.chunk(2, dim=0)
# Sample source_latents from the posterior distribution.
prev_source_latents = posterior_sample(
......
......@@ -399,9 +399,10 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
if do_classifier_free_guidance:
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
......
......@@ -407,9 +407,10 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
if do_classifier_free_guidance:
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
......
......@@ -513,9 +513,10 @@ class StableDiffusionXLInpaintPipeline(
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
if do_classifier_free_guidance:
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
......
......@@ -22,7 +22,14 @@ import torch
from PIL import Image
from transformers import XLMRobertaTokenizerFast
from diffusers import DDIMScheduler, KandinskyImg2ImgPipeline, KandinskyPriorPipeline, UNet2DConditionModel, VQModel
from diffusers import (
DDIMScheduler,
DDPMScheduler,
KandinskyImg2ImgPipeline,
KandinskyPriorPipeline,
UNet2DConditionModel,
VQModel,
)
from diffusers.pipelines.kandinsky.text_encoder import MCLIPConfig, MultilingualCLIP
from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu
......@@ -328,3 +335,54 @@ class KandinskyImg2ImgPipelineIntegrationTests(unittest.TestCase):
assert image.shape == (768, 768, 3)
assert_mean_pixel_difference(image, expected_image)
def test_kandinsky_img2img_ddpm(self):
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/kandinsky/kandinsky_img2img_ddpm_frog.npy"
)
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/frog.png"
)
prompt = "A red cartoon frog, 4k"
pipe_prior = KandinskyPriorPipeline.from_pretrained(
"kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16
)
pipe_prior.to(torch_device)
scheduler = DDPMScheduler.from_pretrained("kandinsky-community/kandinsky-2-1", subfolder="ddpm_scheduler")
pipeline = KandinskyImg2ImgPipeline.from_pretrained(
"kandinsky-community/kandinsky-2-1", scheduler=scheduler, torch_dtype=torch.float16
)
pipeline = pipeline.to(torch_device)
pipeline.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
image_emb, zero_image_emb = pipe_prior(
prompt,
generator=generator,
num_inference_steps=5,
negative_prompt="",
).to_tuple()
output = pipeline(
prompt,
image=init_image,
image_embeds=image_emb,
negative_image_embeds=zero_image_emb,
generator=generator,
num_inference_steps=100,
height=768,
width=768,
strength=0.2,
output_type="np",
)
image = output.images[0]
assert image.shape == (768, 768, 3)
assert_mean_pixel_difference(image, expected_image)
......@@ -772,6 +772,27 @@ class PipelineTesterMixin:
assert images.shape[0] == batch_size * num_images_per_prompt
def test_cfg(self):
sig = inspect.signature(self.pipeline_class.__call__)
if "guidance_scale" not in sig.parameters:
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
inputs["guidance_scale"] = 1.0
out_no_cfg = pipe(**inputs)[0]
inputs["guidance_scale"] = 7.5
out_cfg = pipe(**inputs)[0]
assert out_cfg.shape == out_no_cfg.shape
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment