Unverified Commit 6c60e430 authored by Álvaro Somoza's avatar Álvaro Somoza Committed by GitHub
Browse files

Consistent SDXL Controlnet callback tensor inputs (#7958)

* make _callback_tensor_inputs consistent between sdxl pipelines

* forgot this one

* fix failing test

* fix test_components_function

* fix controlnet inpaint tests
parent 1221b28e
...@@ -198,8 +198,26 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -198,8 +198,26 @@ class StableDiffusionXLControlNetInpaintPipeline(
""" """
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] _optional_components = [
"tokenizer",
"tokenizer_2",
"text_encoder",
"text_encoder_2",
"image_encoder",
"feature_extractor",
]
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
"negative_prompt_embeds",
"add_text_embeds",
"add_time_ids",
"negative_pooled_prompt_embeds",
"add_neg_time_ids",
"mask",
"masked_image_latents",
]
def __init__( def __init__(
self, self,
......
...@@ -236,7 +236,15 @@ class StableDiffusionXLControlNetPipeline( ...@@ -236,7 +236,15 @@ class StableDiffusionXLControlNetPipeline(
"feature_extractor", "feature_extractor",
"image_encoder", "image_encoder",
] ]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] _callback_tensor_inputs = [
"latents",
"prompt_embeds",
"negative_prompt_embeds",
"add_text_embeds",
"add_time_ids",
"negative_pooled_prompt_embeds",
"negative_add_time_ids",
]
def __init__( def __init__(
self, self,
...@@ -1528,6 +1536,12 @@ class StableDiffusionXLControlNetPipeline( ...@@ -1528,6 +1536,12 @@ class StableDiffusionXLControlNetPipeline(
latents = callback_outputs.pop("latents", latents) latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
negative_pooled_prompt_embeds = callback_outputs.pop(
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
)
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
......
...@@ -228,7 +228,15 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -228,7 +228,15 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
"feature_extractor", "feature_extractor",
"image_encoder", "image_encoder",
] ]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] _callback_tensor_inputs = [
"latents",
"prompt_embeds",
"negative_prompt_embeds",
"add_text_embeds",
"add_time_ids",
"negative_pooled_prompt_embeds",
"add_neg_time_ids",
]
def __init__( def __init__(
self, self,
...@@ -1584,6 +1592,12 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -1584,6 +1592,12 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
latents = callback_outputs.pop("latents", latents) latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
negative_pooled_prompt_embeds = callback_outputs.pop(
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
)
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
......
...@@ -158,7 +158,15 @@ class StableDiffusionXLControlNetXSPipeline( ...@@ -158,7 +158,15 @@ class StableDiffusionXLControlNetXSPipeline(
"text_encoder_2", "text_encoder_2",
"feature_extractor", "feature_extractor",
] ]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] _callback_tensor_inputs = [
"latents",
"prompt_embeds",
"negative_prompt_embeds",
"add_text_embeds",
"add_time_ids",
"negative_pooled_prompt_embeds",
"negative_add_time_ids",
]
def __init__( def __init__(
self, self,
......
...@@ -19,7 +19,15 @@ import unittest ...@@ -19,7 +19,15 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers import (
CLIPImageProcessor,
CLIPTextConfig,
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionConfig,
CLIPVisionModelWithProjection,
)
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
...@@ -34,6 +42,7 @@ from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor ...@@ -34,6 +42,7 @@ from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor
from ..pipeline_params import ( from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS, IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS,
TEXT_TO_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS,
) )
...@@ -55,6 +64,14 @@ class ControlNetPipelineSDXLFastTests( ...@@ -55,6 +64,14 @@ class ControlNetPipelineSDXLFastTests(
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = frozenset(IMAGE_TO_IMAGE_IMAGE_PARAMS.union({"mask_image", "control_image"})) image_params = frozenset(IMAGE_TO_IMAGE_IMAGE_PARAMS.union({"mask_image", "control_image"}))
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union(
{
"add_text_embeds",
"add_time_ids",
"mask",
"masked_image_latents",
}
)
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
...@@ -129,6 +146,30 @@ class ControlNetPipelineSDXLFastTests( ...@@ -129,6 +146,30 @@ class ControlNetPipelineSDXLFastTests(
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config) text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
image_encoder_config = CLIPVisionConfig(
hidden_size=32,
image_size=224,
projection_dim=32,
intermediate_size=37,
num_attention_heads=4,
num_channels=3,
num_hidden_layers=5,
patch_size=14,
)
image_encoder = CLIPVisionModelWithProjection(image_encoder_config)
feature_extractor = CLIPImageProcessor(
crop_size=224,
do_center_crop=True,
do_normalize=True,
do_resize=True,
image_mean=[0.48145466, 0.4578275, 0.40821073],
image_std=[0.26862954, 0.26130258, 0.27577711],
resample=3,
size=224,
)
components = { components = {
"unet": unet, "unet": unet,
"controlnet": controlnet, "controlnet": controlnet,
...@@ -138,6 +179,8 @@ class ControlNetPipelineSDXLFastTests( ...@@ -138,6 +179,8 @@ class ControlNetPipelineSDXLFastTests(
"tokenizer": tokenizer, "tokenizer": tokenizer,
"text_encoder_2": text_encoder_2, "text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2, "tokenizer_2": tokenizer_2,
"image_encoder": image_encoder,
"feature_extractor": feature_extractor,
} }
return components return components
......
...@@ -34,6 +34,7 @@ from ..pipeline_params import ( ...@@ -34,6 +34,7 @@ from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS, IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
) )
from ..test_pipelines_common import ( from ..test_pipelines_common import (
IPAdapterTesterMixin, IPAdapterTesterMixin,
...@@ -55,9 +56,13 @@ class ControlNetPipelineSDXLImg2ImgFastTests( ...@@ -55,9 +56,13 @@ class ControlNetPipelineSDXLImg2ImgFastTests(
): ):
pipeline_class = StableDiffusionXLControlNetImg2ImgPipeline pipeline_class = StableDiffusionXLControlNetImg2ImgPipeline
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union(
{"add_text_embeds", "add_time_ids", "add_neg_time_ids"}
)
def get_dummy_components(self, skip_first_text_encoder=False): def get_dummy_components(self, skip_first_text_encoder=False):
torch.manual_seed(0) torch.manual_seed(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment