Unverified Commit 3ba36f97 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[SD-XL] Fix sdxl controlnet inference (#4238)

* Fix controlnet xl inference

* correct some sd xl control inference
parent b288684d
...@@ -180,11 +180,19 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa ...@@ -180,11 +180,19 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
device = torch.device(f"cuda:{gpu_id}") device = torch.device(f"cuda:{gpu_id}")
if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
model_sequence = (
[self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
)
model_sequence.extend([self.unet, self.vae])
hook = None hook = None
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: for cpu_offloaded_model in model_sequence:
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
# control net hook has be manually offloaded as it alternates with unet
cpu_offload_with_hook(self.controlnet, device) cpu_offload_with_hook(self.controlnet, device)
# We'll offload the last model manually. # We'll offload the last model manually.
...@@ -639,7 +647,7 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa ...@@ -639,7 +647,7 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
num_inference_steps: int = 50, num_inference_steps: int = 50,
guidance_scale: float = 7.5, guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1, num_images_per_prompt: Optional[int] = 1,
...@@ -657,9 +665,9 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa ...@@ -657,9 +665,9 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
guess_mode: bool = False, guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0, control_guidance_end: Union[float, List[float]] = 1.0,
original_size: Tuple[int, int] = (1024, 1024), original_size: Tuple[int, int] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0), crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Tuple[int, int] = (1024, 1024), target_size: Tuple[int, int] = None,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -875,6 +883,9 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa ...@@ -875,6 +883,9 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
] ]
controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps) controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps)
original_size = original_size or image.shape[-2:]
target_size = target_size or (height, width)
# 7.2 Prepare added time ids & embeddings # 7.2 Prepare added time ids & embeddings
add_text_embeds = pooled_prompt_embeds add_text_embeds = pooled_prompt_embeds
add_time_ids = self._get_add_time_ids( add_time_ids = self._get_add_time_ids(
......
...@@ -28,9 +28,7 @@ from diffusers import ( ...@@ -28,9 +28,7 @@ from diffusers import (
) )
from diffusers.utils import randn_tensor, torch_device from diffusers.utils import randn_tensor, torch_device
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu
enable_full_determinism,
)
from ..pipeline_params import ( from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS, IMAGE_TO_IMAGE_IMAGE_PARAMS,
...@@ -125,10 +123,10 @@ class ControlNetPipelineSDXLFastTests( ...@@ -125,10 +123,10 @@ class ControlNetPipelineSDXLFastTests(
projection_dim=32, projection_dim=32,
) )
text_encoder = CLIPTextModel(text_encoder_config) text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True) tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config) text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True) tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
components = { components = {
"unet": unet, "unet": unet,
...@@ -179,6 +177,35 @@ class ControlNetPipelineSDXLFastTests( ...@@ -179,6 +177,35 @@ class ControlNetPipelineSDXLFastTests(
def test_inference_batch_single_identical(self): def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3) self._test_inference_batch_single_identical(expected_max_diff=2e-3)
@require_torch_gpu
def test_stable_diffusion_xl_offloads(self):
pipes = []
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components).to(torch_device)
pipes.append(sd_pipe)
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
sd_pipe.enable_model_cpu_offload()
pipes.append(sd_pipe)
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
sd_pipe.enable_sequential_cpu_offload()
pipes.append(sd_pipe)
image_slices = []
for pipe in pipes:
pipe.unet.set_default_attn_processor()
inputs = self.get_dummy_inputs(torch_device)
image = pipe(**inputs).images
image_slices.append(image[0, -3:, -3:, -1].flatten())
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3
def test_stable_diffusion_xl_multi_prompts(self): def test_stable_diffusion_xl_multi_prompts(self):
components = self.get_dummy_components() components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components).to(torch_device) sd_pipe = self.pipeline_class(**components).to(torch_device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment