Unverified Commit 18b018c8 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[SDXL Refiner] Fix refiner forward pass for batched input (#4327)

* fix_batch_xl

* Fix other pipelines as well

* up

* up

* Update tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py

* sort

* up

* Finish it all up Co-authored-by: Bagheera <bghira@users.github.com>

* Co-authored-by: Bagheera bghira@users.github.com

* Co-authored-by: Bagheera <bghira@users.github.com>

* Finish it all up Co-authored-by: Bagheera <bghira@users.github.com>
parent 54fab2cd
...@@ -906,15 +906,17 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L ...@@ -906,15 +906,17 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
negative_aesthetic_score, negative_aesthetic_score,
dtype=prompt_embeds.dtype, dtype=prompt_embeds.dtype,
) )
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
if do_classifier_free_guidance: if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
prompt_embeds = prompt_embeds.to(device) prompt_embeds = prompt_embeds.to(device)
add_text_embeds = add_text_embeds.to(device) add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) add_time_ids = add_time_ids.to(device)
# 9. Denoising loop # 9. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
......
...@@ -1168,15 +1168,17 @@ class StableDiffusionXLInpaintPipeline( ...@@ -1168,15 +1168,17 @@ class StableDiffusionXLInpaintPipeline(
negative_aesthetic_score, negative_aesthetic_score,
dtype=prompt_embeds.dtype, dtype=prompt_embeds.dtype,
) )
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
if do_classifier_free_guidance: if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
prompt_embeds = prompt_embeds.to(device) prompt_embeds = prompt_embeds.to(device)
add_text_embeds = add_text_embeds.to(device) add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) add_time_ids = add_time_ids.to(device)
# 11. Denoising loop # 11. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
......
...@@ -811,6 +811,7 @@ class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFile ...@@ -811,6 +811,7 @@ class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFile
negative_aesthetic_score, negative_aesthetic_score,
dtype=prompt_embeds.dtype, dtype=prompt_embeds.dtype,
) )
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
original_prompt_embeds_len = len(prompt_embeds) original_prompt_embeds_len = len(prompt_embeds)
original_add_text_embeds_len = len(add_text_embeds) original_add_text_embeds_len = len(add_text_embeds)
...@@ -819,6 +820,7 @@ class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFile ...@@ -819,6 +820,7 @@ class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFile
if do_classifier_free_guidance: if do_classifier_free_guidance:
prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds], dim=0) prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds], dim=0)
add_text_embeds = torch.cat([add_text_embeds, negative_pooled_prompt_embeds], dim=0) add_text_embeds = torch.cat([add_text_embeds, negative_pooled_prompt_embeds], dim=0)
add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
add_time_ids = torch.cat([add_time_ids, add_neg_time_ids], dim=0) add_time_ids = torch.cat([add_time_ids, add_neg_time_ids], dim=0)
# Make dimensions consistent # Make dimensions consistent
...@@ -828,7 +830,7 @@ class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFile ...@@ -828,7 +830,7 @@ class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFile
prompt_embeds = prompt_embeds.to(device).to(torch.float32) prompt_embeds = prompt_embeds.to(device).to(torch.float32)
add_text_embeds = add_text_embeds.to(device).to(torch.float32) add_text_embeds = add_text_embeds.to(device).to(torch.float32)
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) add_time_ids = add_time_ids.to(device)
# 11. Denoising loop # 11. Denoising loop
self.unet = self.unet.to(torch.float32) self.unet = self.unet.to(torch.float32)
......
...@@ -64,7 +64,7 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel ...@@ -64,7 +64,7 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
addition_embed_type="text_time", addition_embed_type="text_time",
addition_time_embed_dim=8, addition_time_embed_dim=8,
transformer_layers_per_block=(1, 2), transformer_layers_per_block=(1, 2),
projection_class_embeddings_input_dim=80, # 6 * 8 + 32 projection_class_embeddings_input_dim=72, # 5 * 8 + 32
cross_attention_dim=64 if not skip_first_text_encoder else 32, cross_attention_dim=64 if not skip_first_text_encoder else 32,
) )
scheduler = EulerDiscreteScheduler( scheduler = EulerDiscreteScheduler(
...@@ -113,9 +113,18 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel ...@@ -113,9 +113,18 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
"tokenizer": tokenizer if not skip_first_text_encoder else None, "tokenizer": tokenizer if not skip_first_text_encoder else None,
"text_encoder_2": text_encoder_2, "text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2, "tokenizer_2": tokenizer_2,
"requires_aesthetics_score": True,
} }
return components return components
def test_components_function(self):
init_components = self.get_dummy_components()
init_components.pop("requires_aesthetics_score")
pipe = self.pipeline_class(**init_components)
self.assertTrue(hasattr(pipe, "components"))
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
def get_dummy_inputs(self, device, seed=0): def get_dummy_inputs(self, device, seed=0):
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
image = image / 2 + 0.5 image = image / 2 + 0.5
...@@ -147,7 +156,7 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel ...@@ -147,7 +156,7 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
assert image.shape == (1, 32, 32, 3) assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.4656, 0.4840, 0.4439, 0.6698, 0.5574, 0.4524, 0.5799, 0.5943, 0.5165]) expected_slice = np.array([0.4664, 0.4886, 0.4403, 0.6902, 0.5592, 0.4534, 0.5931, 0.5951, 0.5224])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
...@@ -165,7 +174,7 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel ...@@ -165,7 +174,7 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
assert image.shape == (1, 32, 32, 3) assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.4676, 0.4865, 0.4335, 0.6715, 0.5578, 0.4497, 0.5847, 0.5967, 0.5198]) expected_slice = np.array([0.4578, 0.4981, 0.4301, 0.6454, 0.5588, 0.4442, 0.5678, 0.5940, 0.5176])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
......
...@@ -66,7 +66,7 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel ...@@ -66,7 +66,7 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
addition_embed_type="text_time", addition_embed_type="text_time",
addition_time_embed_dim=8, addition_time_embed_dim=8,
transformer_layers_per_block=(1, 2), transformer_layers_per_block=(1, 2),
projection_class_embeddings_input_dim=80, # 6 * 8 + 32 projection_class_embeddings_input_dim=72, # 5 * 8 + 32
cross_attention_dim=64 if not skip_first_text_encoder else 32, cross_attention_dim=64 if not skip_first_text_encoder else 32,
) )
scheduler = EulerDiscreteScheduler( scheduler = EulerDiscreteScheduler(
...@@ -115,6 +115,7 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel ...@@ -115,6 +115,7 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
"tokenizer": tokenizer if not skip_first_text_encoder else None, "tokenizer": tokenizer if not skip_first_text_encoder else None,
"text_encoder_2": text_encoder_2, "text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2, "tokenizer_2": tokenizer_2,
"requires_aesthetics_score": True,
} }
return components return components
...@@ -142,6 +143,14 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel ...@@ -142,6 +143,14 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
} }
return inputs return inputs
def test_components_function(self):
init_components = self.get_dummy_components()
init_components.pop("requires_aesthetics_score")
pipe = self.pipeline_class(**init_components)
self.assertTrue(hasattr(pipe, "components"))
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
def test_stable_diffusion_xl_inpaint_euler(self): def test_stable_diffusion_xl_inpaint_euler(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components() components = self.get_dummy_components()
...@@ -155,7 +164,7 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel ...@@ -155,7 +164,7 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
assert image.shape == (1, 64, 64, 3) assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.6965, 0.5584, 0.5693, 0.5739, 0.6092, 0.6620, 0.5902, 0.5612, 0.5319]) expected_slice = np.array([0.8029, 0.5523, 0.5825, 0.6003, 0.6702, 0.7018, 0.6369, 0.5955, 0.5123])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
...@@ -250,10 +259,9 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel ...@@ -250,10 +259,9 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
image = sd_pipe(**inputs).images image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
print(torch.from_numpy(image_slice).flatten())
assert image.shape == (1, 64, 64, 3) assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.9106, 0.6563, 0.6766, 0.6537, 0.6709, 0.7367, 0.6537, 0.5937, 0.5418]) expected_slice = np.array([0.7045, 0.4838, 0.5454, 0.6270, 0.6168, 0.6717, 0.6484, 0.5681, 0.4922])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
......
...@@ -68,7 +68,7 @@ class StableDiffusionXLInstructPix2PixPipelineFastTests( ...@@ -68,7 +68,7 @@ class StableDiffusionXLInstructPix2PixPipelineFastTests(
addition_embed_type="text_time", addition_embed_type="text_time",
addition_time_embed_dim=8, addition_time_embed_dim=8,
transformer_layers_per_block=(1, 2), transformer_layers_per_block=(1, 2),
projection_class_embeddings_input_dim=80, # 6 * 8 + 32 projection_class_embeddings_input_dim=72, # 5 * 8 + 32
cross_attention_dim=64, cross_attention_dim=64,
) )
...@@ -118,8 +118,7 @@ class StableDiffusionXLInstructPix2PixPipelineFastTests( ...@@ -118,8 +118,7 @@ class StableDiffusionXLInstructPix2PixPipelineFastTests(
"tokenizer": tokenizer, "tokenizer": tokenizer,
"text_encoder_2": text_encoder_2, "text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2, "tokenizer_2": tokenizer_2,
# "safety_checker": None, "requires_aesthetics_score": True,
# "feature_extractor": None,
} }
return components return components
...@@ -141,6 +140,14 @@ class StableDiffusionXLInstructPix2PixPipelineFastTests( ...@@ -141,6 +140,14 @@ class StableDiffusionXLInstructPix2PixPipelineFastTests(
} }
return inputs return inputs
def test_components_function(self):
init_components = self.get_dummy_components()
init_components.pop("requires_aesthetics_score")
pipe = self.pipeline_class(**init_components)
self.assertTrue(hasattr(pipe, "components"))
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
def test_inference_batch_single_identical(self): def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3) super().test_inference_batch_single_identical(expected_max_diff=3e-3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment