Unverified Commit 6c60e430 authored by Álvaro Somoza's avatar Álvaro Somoza Committed by GitHub
Browse files

Consistent SDXL Controlnet callback tensor inputs (#7958)

* make _callback_tensor_inputs consistent between sdxl pipelines

* forgot this one

* fix failing test

* fix test_components_function

* fix controlnet inpaint tests
parent 1221b28e
......@@ -198,8 +198,26 @@ class StableDiffusionXLControlNetInpaintPipeline(
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
_optional_components = [
"tokenizer",
"tokenizer_2",
"text_encoder",
"text_encoder_2",
"image_encoder",
"feature_extractor",
]
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
"negative_prompt_embeds",
"add_text_embeds",
"add_time_ids",
"negative_pooled_prompt_embeds",
"add_neg_time_ids",
"mask",
"masked_image_latents",
]
def __init__(
self,
......
......@@ -236,7 +236,15 @@ class StableDiffusionXLControlNetPipeline(
"feature_extractor",
"image_encoder",
]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
"negative_prompt_embeds",
"add_text_embeds",
"add_time_ids",
"negative_pooled_prompt_embeds",
"negative_add_time_ids",
]
def __init__(
self,
......@@ -1528,6 +1536,12 @@ class StableDiffusionXLControlNetPipeline(
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
negative_pooled_prompt_embeds = callback_outputs.pop(
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
)
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
......
......@@ -228,7 +228,15 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
"feature_extractor",
"image_encoder",
]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
"negative_prompt_embeds",
"add_text_embeds",
"add_time_ids",
"negative_pooled_prompt_embeds",
"add_neg_time_ids",
]
def __init__(
self,
......@@ -1584,6 +1592,12 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
negative_pooled_prompt_embeds = callback_outputs.pop(
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
)
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
......
......@@ -158,7 +158,15 @@ class StableDiffusionXLControlNetXSPipeline(
"text_encoder_2",
"feature_extractor",
]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
"negative_prompt_embeds",
"add_text_embeds",
"add_time_ids",
"negative_pooled_prompt_embeds",
"negative_add_time_ids",
]
def __init__(
self,
......
......@@ -19,7 +19,15 @@ import unittest
import numpy as np
import torch
from PIL import Image
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from transformers import (
CLIPImageProcessor,
CLIPTextConfig,
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionConfig,
CLIPVisionModelWithProjection,
)
from diffusers import (
AutoencoderKL,
......@@ -34,6 +42,7 @@ from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
TEXT_TO_IMAGE_IMAGE_PARAMS,
TEXT_TO_IMAGE_PARAMS,
)
......@@ -55,6 +64,14 @@ class ControlNetPipelineSDXLFastTests(
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = frozenset(IMAGE_TO_IMAGE_IMAGE_PARAMS.union({"mask_image", "control_image"}))
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union(
{
"add_text_embeds",
"add_time_ids",
"mask",
"masked_image_latents",
}
)
def get_dummy_components(self):
torch.manual_seed(0)
......@@ -129,6 +146,30 @@ class ControlNetPipelineSDXLFastTests(
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
image_encoder_config = CLIPVisionConfig(
hidden_size=32,
image_size=224,
projection_dim=32,
intermediate_size=37,
num_attention_heads=4,
num_channels=3,
num_hidden_layers=5,
patch_size=14,
)
image_encoder = CLIPVisionModelWithProjection(image_encoder_config)
feature_extractor = CLIPImageProcessor(
crop_size=224,
do_center_crop=True,
do_normalize=True,
do_resize=True,
image_mean=[0.48145466, 0.4578275, 0.40821073],
image_std=[0.26862954, 0.26130258, 0.27577711],
resample=3,
size=224,
)
components = {
"unet": unet,
"controlnet": controlnet,
......@@ -138,6 +179,8 @@ class ControlNetPipelineSDXLFastTests(
"tokenizer": tokenizer,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
"image_encoder": image_encoder,
"feature_extractor": feature_extractor,
}
return components
......
......@@ -34,6 +34,7 @@ from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
)
from ..test_pipelines_common import (
IPAdapterTesterMixin,
......@@ -55,9 +56,13 @@ class ControlNetPipelineSDXLImg2ImgFastTests(
):
pipeline_class = StableDiffusionXLControlNetImg2ImgPipeline
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union(
{"add_text_embeds", "add_time_ids", "add_neg_time_ids"}
)
def get_dummy_components(self, skip_first_text_encoder=False):
torch.manual_seed(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment