Unverified Commit c2a38ef9 authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

Fix/update the LDM pipeline and tests (#1743)

* Fix/update LDM tests

* batched generators
parent 08cc36dd
...@@ -128,25 +128,38 @@ class LDMTextToImagePipeline(DiffusionPipeline): ...@@ -128,25 +128,38 @@ class LDMTextToImagePipeline(DiffusionPipeline):
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if guidance_scale != 1.0: if guidance_scale != 1.0:
uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt") uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=77, truncation=True, return_tensors="pt"
)
uncond_embeddings = self.bert(uncond_input.input_ids.to(self.device))[0] uncond_embeddings = self.bert(uncond_input.input_ids.to(self.device))[0]
# get prompt text embeddings # get prompt text embeddings
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt") text_input = self.tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt")
text_embeddings = self.bert(text_input.input_ids.to(self.device))[0] text_embeddings = self.bert(text_input.input_ids.to(self.device))[0]
# get the initial random noise unless the user supplied it # get the initial random noise unless the user supplied it
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None: if latents is None:
if self.device.type == "mps": rand_device = "cpu" if self.device.type == "mps" else self.device
# randn does not work reproducibly on mps
latents = torch.randn(latents_shape, generator=generator, device="cpu").to(self.device) if isinstance(generator, list):
latents_shape = (1,) + latents_shape[1:]
latents = [
torch.randn(latents_shape, generator=generator[i], device=rand_device, dtype=text_embeddings.dtype)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0)
else: else:
latents = torch.randn( latents = torch.randn(
latents_shape, latents_shape, generator=generator, device=rand_device, dtype=text_embeddings.dtype
generator=generator,
device=self.device,
) )
latents = latents.to(self.device)
else: else:
if latents.shape != latents_shape: if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
......
...@@ -13,24 +13,29 @@ ...@@ -13,24 +13,29 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import gc
import unittest import unittest
import numpy as np import numpy as np
import torch import torch
from diffusers import AutoencoderKL, DDIMScheduler, LDMTextToImagePipeline, UNet2DConditionModel from diffusers import AutoencoderKL, DDIMScheduler, LDMTextToImagePipeline, UNet2DConditionModel
from diffusers.utils.testing_utils import require_torch, slow, torch_device from diffusers.utils.testing_utils import load_numpy, nightly, require_torch_gpu, slow, torch_device
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from ...test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
class LDMTextToImagePipelineFastTests(unittest.TestCase): class LDMTextToImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
@property pipeline_class = LDMTextToImagePipeline
def dummy_cond_unet(self): test_cpu_offload = False
def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
model = UNet2DConditionModel( unet = UNet2DConditionModel(
block_out_channels=(32, 64), block_out_channels=(32, 64),
layers_per_block=2, layers_per_block=2,
sample_size=32, sample_size=32,
...@@ -40,25 +45,24 @@ class LDMTextToImagePipelineFastTests(unittest.TestCase): ...@@ -40,25 +45,24 @@ class LDMTextToImagePipelineFastTests(unittest.TestCase):
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32, cross_attention_dim=32,
) )
return model scheduler = DDIMScheduler(
beta_start=0.00085,
@property beta_end=0.012,
def dummy_vae(self): beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
torch.manual_seed(0) torch.manual_seed(0)
model = AutoencoderKL( vae = AutoencoderKL(
block_out_channels=[32, 64], block_out_channels=(32, 64),
in_channels=3, in_channels=3,
out_channels=3, out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], down_block_types=("DownEncoderBlock2D", "DownEncoderBlock2D"),
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], up_block_types=("UpDecoderBlock2D", "UpDecoderBlock2D"),
latent_channels=4, latent_channels=4,
) )
return model
@property
def dummy_text_encoder(self):
torch.manual_seed(0) torch.manual_seed(0)
config = CLIPTextConfig( text_encoder_config = CLIPTextConfig(
bos_token_id=0, bos_token_id=0,
eos_token_id=2, eos_token_id=2,
hidden_size=32, hidden_size=32,
...@@ -69,96 +73,117 @@ class LDMTextToImagePipelineFastTests(unittest.TestCase): ...@@ -69,96 +73,117 @@ class LDMTextToImagePipelineFastTests(unittest.TestCase):
pad_token_id=1, pad_token_id=1,
vocab_size=1000, vocab_size=1000,
) )
return CLIPTextModel(config) text_encoder = CLIPTextModel(text_encoder_config)
def test_inference_text2img(self):
if torch_device != "cpu":
return
unet = self.dummy_cond_unet
scheduler = DDIMScheduler()
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
ldm = LDMTextToImagePipeline(vqvae=vae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler) components = {
ldm.to(torch_device) "unet": unet,
ldm.set_progress_bar_config(disable=None) "scheduler": scheduler,
"vqvae": vae,
prompt = "A painting of a squirrel eating a burger" "bert": text_encoder,
"tokenizer": tokenizer,
# Warmup pass when using mps (see #372) }
if torch_device == "mps": return components
generator = torch.manual_seed(0)
_ = ldm( def get_dummy_inputs(self, device, seed=0):
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=1, output_type="numpy" if str(device).startswith("mps"):
).images generator = torch.manual_seed(seed)
else:
device = torch_device if torch_device != "mps" else "cpu" generator = torch.Generator(device=device).manual_seed(seed)
generator = torch.Generator(device=device).manual_seed(0) inputs = {
"prompt": "A painting of a squirrel eating a burger",
image = ldm( "generator": generator,
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="numpy" "num_inference_steps": 2,
).images "guidance_scale": 6.0,
"output_type": "numpy",
device = torch_device if torch_device != "mps" else "cpu" }
generator = torch.Generator(device=device).manual_seed(0) return inputs
image_from_tuple = ldm(
[prompt],
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="numpy",
return_dict=False,
)[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 16, 16, 3)
expected_slice = np.array([0.6806, 0.5454, 0.5638, 0.4893, 0.4656, 0.4257, 0.6248, 0.5217, 0.5498])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
@slow
@require_torch
class LDMTextToImagePipelineIntegrationTests(unittest.TestCase):
def test_inference_text2img(self): def test_inference_text2img(self):
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256") device = "cpu" # ensure determinism for the device-dependent torch.Generator
ldm.to(torch_device)
ldm.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
device = torch_device if torch_device != "mps" else "cpu"
generator = torch.Generator(device=device).manual_seed(0)
image = ldm( components = self.get_dummy_components()
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="numpy" pipe = LDMTextToImagePipeline(**components)
).images pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 256, 256, 3) assert image.shape == (1, 16, 16, 3)
expected_slice = np.array([0.9256, 0.9340, 0.8933, 0.9361, 0.9113, 0.8727, 0.9122, 0.8745, 0.8099]) expected_slice = np.array([0.59450, 0.64078, 0.55509, 0.51229, 0.69640, 0.36960, 0.59296, 0.60801, 0.49332])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_inference_text2img_fast(self):
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
ldm.to(torch_device)
ldm.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
device = torch_device if torch_device != "mps" else "cpu" assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
generator = torch.Generator(device=device).manual_seed(0)
image = ldm(prompt, generator=generator, num_inference_steps=1, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1] @slow
@require_torch_gpu
class LDMTextToImagePipelineSlowTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, device, dtype=torch.float32, seed=0):
generator = torch.Generator(device=device).manual_seed(seed)
latents = np.random.RandomState(seed).standard_normal((1, 4, 32, 32))
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"latents": latents,
"generator": generator,
"num_inference_steps": 3,
"guidance_scale": 6.0,
"output_type": "numpy",
}
return inputs
def test_ldm_default_ddim(self):
pipe = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256").to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1].flatten()
assert image.shape == (1, 256, 256, 3) assert image.shape == (1, 256, 256, 3)
expected_slice = np.array([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344]) expected_slice = np.array([0.51825, 0.52850, 0.52543, 0.54258, 0.52304, 0.52569, 0.54363, 0.55276, 0.56878])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 max_diff = np.abs(expected_slice - image_slice).max()
assert max_diff < 1e-3
@nightly
@require_torch_gpu
class LDMTextToImagePipelineNightlyTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, device, dtype=torch.float32, seed=0):
generator = torch.Generator(device=device).manual_seed(seed)
latents = np.random.RandomState(seed).standard_normal((1, 4, 32, 32))
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"latents": latents,
"generator": generator,
"num_inference_steps": 50,
"guidance_scale": 6.0,
"output_type": "numpy",
}
return inputs
def test_ldm_default_ddim(self):
pipe = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256").to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images[0]
expected_image = load_numpy(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/ldm_text2img/ldm_large_256_ddim.npy"
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 1e-3
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment