Unverified Commit 79329715 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Upscaling] Fix batch size (#1525)

parent 720dbfc9
...@@ -459,8 +459,10 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline): ...@@ -459,8 +459,10 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
else: else:
noise = torch.randn(image.shape, generator=generator, device=device, dtype=text_embeddings.dtype) noise = torch.randn(image.shape, generator=generator, device=device, dtype=text_embeddings.dtype)
image = self.low_res_scheduler.add_noise(image, noise, noise_level) image = self.low_res_scheduler.add_noise(image, noise, noise_level)
image = torch.cat([image] * 2) if do_classifier_free_guidance else image
noise_level = torch.cat([noise_level] * 2) if do_classifier_free_guidance else noise_level batch_multiplier = 2 if do_classifier_free_guidance else 1
image = torch.cat([image] * batch_multiplier * num_images_per_prompt)
noise_level = torch.cat([noise_level] * image.shape[0])
# 6. Prepare latent variables # 6. Prepare latent variables
height, width = image.shape[2:] height, width = image.shape[2:]
......
...@@ -161,6 +161,57 @@ class StableDiffusionUpscalePipelineFastTests(PipelineTesterMixin, unittest.Test ...@@ -161,6 +161,57 @@ class StableDiffusionUpscalePipelineFastTests(PipelineTesterMixin, unittest.Test
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_upscale_batch(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet_upscale
low_res_scheduler = DDPMScheduler()
scheduler = DDIMScheduler(prediction_type="v_prediction")
vae = self.dummy_vae
text_encoder = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
low_res_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionUpscalePipeline(
unet=unet,
low_res_scheduler=low_res_scheduler,
scheduler=scheduler,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
max_noise_level=350,
)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
output = sd_pipe(
2 * [prompt],
image=2 * [low_res_image],
guidance_scale=6.0,
noise_level=20,
num_inference_steps=2,
output_type="np",
)
image = output.images
assert image.shape[0] == 2
generator = torch.Generator(device=device).manual_seed(0)
output = sd_pipe(
[prompt],
image=low_res_image,
generator=generator,
num_images_per_prompt=2,
guidance_scale=6.0,
noise_level=20,
num_inference_steps=2,
output_type="np",
)
image = output.images
assert image.shape[0] == 2
@unittest.skipIf(torch_device != "cuda", "This test requires a GPU") @unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
def test_stable_diffusion_upscale_fp16(self): def test_stable_diffusion_upscale_fp16(self):
"""Test that stable diffusion upscale works with fp16""" """Test that stable diffusion upscale works with fp16"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment