Unverified Commit 07c0fe4b authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

Use pipeline tests mixin for UnCLIP pipeline tests + unCLIP MPS fixes (#1908)

re: https://github.com/huggingface/diffusers/issues/1857

We relax some of the checks to deal with unclip reproducibility issues. Mainly by checking the average pixel difference (measured w/in 0-255) instead of the max pixel difference (measured w/in 0-1).

- [x] add mixin to UnCLIPPipelineFastTests
- [x] add mixin to UnCLIPImageVariationPipelineFastTests
- [x] Move UnCLIPPipeline flags in mixin to base class
- [x] Small MPS fixes for F.pad and F.interpolate
- [x] Made test unCLIP model's dimensions smaller to run tests faster
parent 1e651ca2
......@@ -208,7 +208,14 @@ class CrossAttention(nn.Module):
return attention_mask
if attention_mask.shape[-1] != target_length:
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
if attention_mask.device.type == "mps":
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
# Instead, we can manually construct the padding tensor.
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
padding = torch.zeros(padding_shape, device=attention_mask.device)
attention_mask = torch.concat([attention_mask, padding], dim=2)
else:
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
return attention_mask
......
......@@ -452,7 +452,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
......
......@@ -449,7 +449,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
......
......@@ -22,7 +22,8 @@ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from transformers.models.clip.modeling_clip import CLIPTextModelOutput
from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel
from ...pipelines import DiffusionPipeline, ImagePipelineOutput
from ...pipelines import DiffusionPipeline
from ...pipelines.pipeline_utils import ImagePipelineOutput
from ...schedulers import UnCLIPScheduler
from ...utils import is_accelerate_available, logging, randn_tensor
from .text_proj import UnCLIPTextProjModel
......@@ -130,13 +131,20 @@ class UnCLIPPipeline(DiffusionPipeline):
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
text_mask = text_inputs.attention_mask.bool().to(device)
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
......@@ -249,7 +257,7 @@ class UnCLIPPipeline(DiffusionPipeline):
prior_num_inference_steps: int = 25,
decoder_num_inference_steps: int = 25,
super_res_num_inference_steps: int = 7,
generator: Optional[torch.Generator] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
prior_latents: Optional[torch.FloatTensor] = None,
decoder_latents: Optional[torch.FloatTensor] = None,
super_res_latents: Optional[torch.FloatTensor] = None,
......@@ -278,7 +286,7 @@ class UnCLIPPipeline(DiffusionPipeline):
super_res_num_inference_steps (`int`, *optional*, defaults to 7):
The number of denoising steps for super resolution. More denoising steps usually lead to a higher
quality image at the expense of slower inference.
generator (`torch.Generator`, *optional*):
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
prior_latents (`torch.FloatTensor` of shape (batch size, embeddings dimension), *optional*):
......@@ -394,7 +402,14 @@ class UnCLIPPipeline(DiffusionPipeline):
do_classifier_free_guidance=do_classifier_free_guidance,
)
decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1)
if device.type == "mps":
# HACK: MPS: There is a panic when padding bool tensors,
# so cast to int tensor for the pad and back to bool afterwards
text_mask = text_mask.type(torch.int)
decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1)
decoder_text_mask = decoder_text_mask.type(torch.bool)
else:
decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True)
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
decoder_timesteps_tensor = self.decoder_scheduler.timesteps
......@@ -465,13 +480,17 @@ class UnCLIPPipeline(DiffusionPipeline):
self.super_res_scheduler,
)
interpolate_antialias = {}
if "antialias" in inspect.signature(F.interpolate).parameters:
interpolate_antialias["antialias"] = True
if device.type == "mps":
# MPS does not support many interpolations
image_upscaled = F.interpolate(image_small, size=[height, width])
else:
interpolate_antialias = {}
if "antialias" in inspect.signature(F.interpolate).parameters:
interpolate_antialias["antialias"] = True
image_upscaled = F.interpolate(
image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias
)
image_upscaled = F.interpolate(
image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias
)
for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)):
# no classifier free guidance
......
......@@ -328,7 +328,14 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
do_classifier_free_guidance=do_classifier_free_guidance,
)
decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1)
if device.type == "mps":
# HACK: MPS: There is a panic when padding bool tensors,
# so cast to int tensor for the pad and back to bool afterwards
text_mask = text_mask.type(torch.int)
decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1)
decoder_text_mask = decoder_text_mask.type(torch.bool)
else:
decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True)
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
decoder_timesteps_tensor = self.decoder_scheduler.timesteps
......@@ -401,13 +408,17 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
self.super_res_scheduler,
)
interpolate_antialias = {}
if "antialias" in inspect.signature(F.interpolate).parameters:
interpolate_antialias["antialias"] = True
if device.type == "mps":
# MPS does not support many interpolations
image_upscaled = F.interpolate(image_small, size=[height, width])
else:
interpolate_antialias = {}
if "antialias" in inspect.signature(F.interpolate).parameters:
interpolate_antialias["antialias"] = True
image_upscaled = F.interpolate(
image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias
)
image_upscaled = F.interpolate(
image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias
)
for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)):
# no classifier free guidance
......
......@@ -219,7 +219,6 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin):
returning a tuple, the first element is the sample tensor.
"""
t = timestep
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type == "learned_range":
......
......@@ -25,16 +25,24 @@ from diffusers.utils import load_numpy, nightly, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu
from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
from ...test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
torch.backends.cuda.matmul.allow_tf32 = False
class UnCLIPPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = UnCLIPPipeline
class UnCLIPPipelineFastTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
required_optional_params = [
"generator",
"return_dict",
"prior_num_inference_steps",
"decoder_num_inference_steps",
"super_res_num_inference_steps",
]
num_inference_steps_args = [
"prior_num_inference_steps",
"decoder_num_inference_steps",
"super_res_num_inference_steps",
]
@property
def text_embedder_hidden_size(self):
......@@ -110,7 +118,7 @@ class UnCLIPPipelineFastTests(unittest.TestCase):
torch.manual_seed(0)
model_kwargs = {
"sample_size": 64,
"sample_size": 32,
# RGB in channels
"in_channels": 3,
# Out channels is double in channels because predicts mean and variance
......@@ -132,7 +140,7 @@ class UnCLIPPipelineFastTests(unittest.TestCase):
@property
def dummy_super_res_kwargs(self):
return {
"sample_size": 128,
"sample_size": 64,
"layers_per_block": 1,
"down_block_types": ("ResnetDownsampleBlock2D", "ResnetDownsampleBlock2D"),
"up_block_types": ("ResnetUpsampleBlock2D", "ResnetUpsampleBlock2D"),
......@@ -156,9 +164,7 @@ class UnCLIPPipelineFastTests(unittest.TestCase):
model = UNet2DModel(**self.dummy_super_res_kwargs)
return model
def test_unclip(self):
device = "cpu"
def get_dummy_components(self):
prior = self.dummy_prior
decoder = self.dummy_decoder
text_proj = self.dummy_text_proj
......@@ -186,62 +192,70 @@ class UnCLIPPipelineFastTests(unittest.TestCase):
num_train_timesteps=1000,
)
pipe = UnCLIPPipeline(
prior=prior,
decoder=decoder,
text_proj=text_proj,
text_encoder=text_encoder,
tokenizer=tokenizer,
super_res_first=super_res_first,
super_res_last=super_res_last,
prior_scheduler=prior_scheduler,
decoder_scheduler=decoder_scheduler,
super_res_scheduler=super_res_scheduler,
)
components = {
"prior": prior,
"decoder": decoder,
"text_proj": text_proj,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"super_res_first": super_res_first,
"super_res_last": super_res_last,
"prior_scheduler": prior_scheduler,
"decoder_scheduler": decoder_scheduler,
"super_res_scheduler": super_res_scheduler,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "horse",
"generator": generator,
"prior_num_inference_steps": 2,
"decoder_num_inference_steps": 2,
"super_res_num_inference_steps": 2,
"output_type": "numpy",
}
return inputs
def test_unclip(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
prompt = "horse"
generator = torch.Generator(device=device).manual_seed(0)
output = pipe(
[prompt],
generator=generator,
prior_num_inference_steps=2,
decoder_num_inference_steps=2,
super_res_num_inference_steps=2,
output_type="np",
)
output = pipe(**self.get_dummy_inputs(device))
image = output.images
generator = torch.Generator(device=device).manual_seed(0)
image_from_tuple = pipe(
[prompt],
generator=generator,
prior_num_inference_steps=2,
decoder_num_inference_steps=2,
super_res_num_inference_steps=2,
output_type="np",
**self.get_dummy_inputs(device),
return_dict=False,
)[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 128, 128, 3)
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array(
[
0.0011,
0.0002,
0.9962,
0.9940,
0.0002,
0.9997,
0.0003,
0.9987,
0.9989,
0.9988,
0.0028,
0.9997,
0.9984,
0.9965,
0.0029,
0.9986,
0.0025,
]
)
......@@ -254,47 +268,17 @@ class UnCLIPPipelineFastTests(unittest.TestCase):
class DummyScheduler:
init_noise_sigma = 1
prior = self.dummy_prior
decoder = self.dummy_decoder
text_proj = self.dummy_text_proj
text_encoder = self.dummy_text_encoder
tokenizer = self.dummy_tokenizer
super_res_first = self.dummy_super_res_first
super_res_last = self.dummy_super_res_last
prior_scheduler = UnCLIPScheduler(
variance_type="fixed_small_log",
prediction_type="sample",
num_train_timesteps=1000,
clip_sample_range=5.0,
)
decoder_scheduler = UnCLIPScheduler(
variance_type="learned_range",
prediction_type="epsilon",
num_train_timesteps=1000,
)
super_res_scheduler = UnCLIPScheduler(
variance_type="fixed_small_log",
prediction_type="epsilon",
num_train_timesteps=1000,
)
components = self.get_dummy_components()
pipe = UnCLIPPipeline(
prior=prior,
decoder=decoder,
text_proj=text_proj,
text_encoder=text_encoder,
tokenizer=tokenizer,
super_res_first=super_res_first,
super_res_last=super_res_last,
prior_scheduler=prior_scheduler,
decoder_scheduler=decoder_scheduler,
super_res_scheduler=super_res_scheduler,
)
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
prior = components["prior"]
decoder = components["decoder"]
super_res_first = components["super_res_first"]
tokenizer = components["tokenizer"]
text_encoder = components["text_encoder"]
generator = torch.Generator(device=device).manual_seed(0)
dtype = prior.dtype
batch_size = 1
......@@ -362,6 +346,45 @@ class UnCLIPPipelineFastTests(unittest.TestCase):
# make sure passing text embeddings manually is identical
assert np.abs(image - image_from_text).max() < 1e-4
# Overriding PipelineTesterMixin::test_attention_slicing_forward_pass
# because UnCLIP GPU undeterminism requires a looser check.
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent")
def test_attention_slicing_forward_pass(self):
test_max_difference = torch_device == "cpu"
self._test_attention_slicing_forward_pass(test_max_difference=test_max_difference)
# Overriding PipelineTesterMixin::test_inference_batch_single_identical
# because UnCLIP undeterminism requires a looser check.
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent")
def test_inference_batch_single_identical(self):
test_max_difference = torch_device == "cpu"
relax_max_difference = True
self._test_inference_batch_single_identical(
test_max_difference=test_max_difference, relax_max_difference=relax_max_difference
)
def test_inference_batch_consistent(self):
if torch_device == "mps":
# TODO: MPS errors with larger batch sizes
batch_sizes = [2, 3]
self._test_inference_batch_consistent(batch_sizes=batch_sizes)
else:
self._test_inference_batch_consistent()
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent")
def test_dict_tuple_outputs_equivalent(self):
return super().test_dict_tuple_outputs_equivalent()
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent")
def test_save_load_local(self):
return super().test_save_load_local()
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent")
def test_save_load_optional_components(self):
return super().test_save_load_optional_components()
@nightly
class UnCLIPPipelineCPUIntegrationTests(unittest.TestCase):
......@@ -420,16 +443,12 @@ class UnCLIPPipelineIntegrationTests(unittest.TestCase):
output_type="np",
)
image = np.asarray(pipeline.numpy_to_pil(output.images)[0], dtype=np.float32)
expected_image = np.asarray(pipeline.numpy_to_pil(expected_image)[0], dtype=np.float32)
# Karlo is extremely likely to strongly deviate depending on which hardware is used
# Here we just check that the image doesn't deviate more than 10 pixels from the reference image on average
avg_diff = np.abs(image - expected_image).mean()
image = output.images[0]
assert avg_diff < 10, f"Error image deviates {avg_diff} pixels on average"
assert image.shape == (256, 256, 3)
assert_mean_pixel_difference(image, expected_image)
def test_unclip_pipeline_with_sequential_cpu_offloading(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
......
......@@ -39,16 +39,22 @@ from transformers import (
CLIPVisionModelWithProjection,
)
from ...test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
torch.backends.cuda.matmul.allow_tf32 = False
class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = UnCLIPImageVariationPipeline
class UnCLIPImageVariationPipelineFastTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
required_optional_params = [
"generator",
"return_dict",
"decoder_num_inference_steps",
"super_res_num_inference_steps",
]
num_inference_steps_args = [
"decoder_num_inference_steps",
"super_res_num_inference_steps",
]
@property
def text_embedder_hidden_size(self):
......@@ -124,7 +130,7 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase):
torch.manual_seed(0)
model_kwargs = {
"sample_size": 64,
"sample_size": 32,
# RGB in channels
"in_channels": 3,
# Out channels is double in channels because predicts mean and variance
......@@ -146,7 +152,7 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase):
@property
def dummy_super_res_kwargs(self):
return {
"sample_size": 128,
"sample_size": 64,
"layers_per_block": 1,
"down_block_types": ("ResnetDownsampleBlock2D", "ResnetDownsampleBlock2D"),
"up_block_types": ("ResnetUpsampleBlock2D", "ResnetUpsampleBlock2D"),
......@@ -170,7 +176,7 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase):
model = UNet2DModel(**self.dummy_super_res_kwargs)
return model
def get_pipeline(self, device):
def get_dummy_components(self):
decoder = self.dummy_decoder
text_proj = self.dummy_text_proj
text_encoder = self.dummy_text_encoder
......@@ -194,27 +200,25 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase):
image_encoder = self.dummy_image_encoder
pipe = UnCLIPImageVariationPipeline(
decoder=decoder,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_proj=text_proj,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
super_res_first=super_res_first,
super_res_last=super_res_last,
decoder_scheduler=decoder_scheduler,
super_res_scheduler=super_res_scheduler,
)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
return pipe
return {
"decoder": decoder,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"text_proj": text_proj,
"feature_extractor": feature_extractor,
"image_encoder": image_encoder,
"super_res_first": super_res_first,
"super_res_last": super_res_last,
"decoder_scheduler": decoder_scheduler,
"super_res_scheduler": super_res_scheduler,
}
def get_pipeline_inputs(self, device, seed, pil_image=False):
def get_dummy_inputs(self, device, seed=0, pil_image=True):
input_image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
generator = torch.Generator(device=device).manual_seed(seed)
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
if pil_image:
input_image = input_image * 0.5 + 0.5
......@@ -232,16 +236,20 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase):
def test_unclip_image_variation_input_tensor(self):
device = "cpu"
seed = 0
pipe = self.get_pipeline(device)
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
pipeline_inputs = self.get_pipeline_inputs(device, seed)
pipeline_inputs = self.get_dummy_inputs(device, pil_image=False)
output = pipe(**pipeline_inputs)
image = output.images
tuple_pipeline_inputs = self.get_pipeline_inputs(device, seed)
tuple_pipeline_inputs = self.get_dummy_inputs(device, pil_image=False)
image_from_tuple = pipe(
**tuple_pipeline_inputs,
......@@ -251,19 +259,19 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase):
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 128, 128, 3)
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array(
[
0.9988,
0.9997,
0.9944,
0.0003,
0.0003,
0.9974,
0.0003,
0.0004,
0.9931,
0.0002,
0.9997,
0.9997,
0.9969,
0.0023,
0.9997,
0.9969,
0.9970,
]
)
......@@ -272,16 +280,20 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase):
def test_unclip_image_variation_input_image(self):
device = "cpu"
seed = 0
pipe = self.get_pipeline(device)
components = self.get_dummy_components()
pipeline_inputs = self.get_pipeline_inputs(device, seed, pil_image=True)
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
pipeline_inputs = self.get_dummy_inputs(device, pil_image=True)
output = pipe(**pipeline_inputs)
image = output.images
tuple_pipeline_inputs = self.get_pipeline_inputs(device, seed, pil_image=True)
tuple_pipeline_inputs = self.get_dummy_inputs(device, pil_image=True)
image_from_tuple = pipe(
**tuple_pipeline_inputs,
......@@ -291,32 +303,24 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase):
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 128, 128, 3)
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array(
[
0.9988,
0.9997,
0.9944,
0.0003,
0.0003,
0.9974,
0.0003,
0.0004,
0.9931,
]
)
expected_slice = np.array([0.9997, 0.0003, 0.9997, 0.9997, 0.9970, 0.0024, 0.9997, 0.9971, 0.9971])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_unclip_image_variation_input_list_images(self):
device = "cpu"
seed = 0
pipe = self.get_pipeline(device)
components = self.get_dummy_components()
pipeline_inputs = self.get_pipeline_inputs(device, seed, pil_image=True)
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
pipeline_inputs = self.get_dummy_inputs(device, pil_image=True)
pipeline_inputs["image"] = [
pipeline_inputs["image"],
pipeline_inputs["image"],
......@@ -325,7 +329,7 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase):
output = pipe(**pipeline_inputs)
image = output.images
tuple_pipeline_inputs = self.get_pipeline_inputs(device, seed, pil_image=True)
tuple_pipeline_inputs = self.get_dummy_inputs(device, pil_image=True)
tuple_pipeline_inputs["image"] = [
tuple_pipeline_inputs["image"],
tuple_pipeline_inputs["image"],
......@@ -339,19 +343,19 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase):
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (2, 128, 128, 3)
assert image.shape == (2, 64, 64, 3)
expected_slice = np.array(
[
0.9997,
0.9997,
0.0003,
0.0003,
0.9950,
0.0003,
0.9993,
0.9957,
0.0004,
0.9989,
0.0008,
0.0021,
0.9960,
0.0018,
0.0014,
0.0002,
0.9933,
]
)
......@@ -360,11 +364,15 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase):
def test_unclip_image_variation_input_num_images_per_prompt(self):
device = "cpu"
seed = 0
pipe = self.get_pipeline(device)
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipeline_inputs = self.get_pipeline_inputs(device, seed, pil_image=True)
pipe.set_progress_bar_config(disable=None)
pipeline_inputs = self.get_dummy_inputs(device, pil_image=True)
pipeline_inputs["image"] = [
pipeline_inputs["image"],
pipeline_inputs["image"],
......@@ -373,7 +381,7 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase):
output = pipe(**pipeline_inputs, num_images_per_prompt=2)
image = output.images
tuple_pipeline_inputs = self.get_pipeline_inputs(device, seed, pil_image=True)
tuple_pipeline_inputs = self.get_dummy_inputs(device, pil_image=True)
tuple_pipeline_inputs["image"] = [
tuple_pipeline_inputs["image"],
tuple_pipeline_inputs["image"],
......@@ -388,18 +396,18 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase):
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (4, 128, 128, 3)
assert image.shape == (4, 64, 64, 3)
expected_slice = np.array(
[
0.9997,
0.9997,
0.0008,
0.9952,
0.9980,
0.9997,
0.9961,
0.0023,
0.0029,
0.9997,
0.9985,
0.9997,
0.0010,
0.9995,
]
)
......@@ -409,12 +417,16 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase):
def test_unclip_passed_image_embed(self):
device = torch.device("cpu")
seed = 0
class DummyScheduler:
init_noise_sigma = 1
pipe = self.get_pipeline(device)
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device=device).manual_seed(0)
dtype = pipe.decoder.dtype
......@@ -435,13 +447,13 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase):
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
)
pipeline_inputs = self.get_pipeline_inputs(device, seed)
pipeline_inputs = self.get_dummy_inputs(device, pil_image=False)
img_out_1 = pipe(
**pipeline_inputs, decoder_latents=decoder_latents, super_res_latents=super_res_latents
).images
pipeline_inputs = self.get_pipeline_inputs(device, seed)
pipeline_inputs = self.get_dummy_inputs(device, pil_image=False)
# Don't pass image, instead pass embedding
image = pipeline_inputs.pop("image")
image_embeddings = pipe.image_encoder(image).image_embeds
......@@ -456,6 +468,45 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase):
# make sure passing text embeddings manually is identical
assert np.abs(img_out_1 - img_out_2).max() < 1e-4
# Overriding PipelineTesterMixin::test_attention_slicing_forward_pass
# because UnCLIP GPU undeterminism requires a looser check.
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent")
def test_attention_slicing_forward_pass(self):
test_max_difference = torch_device == "cpu"
self._test_attention_slicing_forward_pass(test_max_difference=test_max_difference)
# Overriding PipelineTesterMixin::test_inference_batch_single_identical
# because UnCLIP undeterminism requires a looser check.
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent")
def test_inference_batch_single_identical(self):
test_max_difference = torch_device == "cpu"
relax_max_difference = True
self._test_inference_batch_single_identical(
test_max_difference=test_max_difference, relax_max_difference=relax_max_difference
)
def test_inference_batch_consistent(self):
if torch_device == "mps":
# TODO: MPS errors with larger batch sizes
batch_sizes = [2, 3]
self._test_inference_batch_consistent(batch_sizes=batch_sizes)
else:
self._test_inference_batch_consistent()
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent")
def test_dict_tuple_outputs_equivalent(self):
return super().test_dict_tuple_outputs_equivalent()
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent")
def test_save_load_local(self):
return super().test_save_load_local()
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent")
def test_save_load_optional_components(self):
return super().test_save_load_optional_components()
@slow
@require_torch_gpu
......@@ -488,12 +539,8 @@ class UnCLIPImageVariationPipelineIntegrationTests(unittest.TestCase):
output_type="np",
)
image = np.asarray(pipeline.numpy_to_pil(output.images)[0], dtype=np.float32)
expected_image = np.asarray(pipeline.numpy_to_pil(expected_image)[0], dtype=np.float32)
# Karlo is extremely likely to strongly deviate depending on which hardware is used
# Here we just check that the image doesn't deviate more than 10 pixels from the reference image on average
avg_diff = np.abs(image - expected_image).mean()
image = output.images[0]
assert avg_diff < 10, f"Error image deviates {avg_diff} pixels on average"
assert image.shape == (256, 256, 3)
assert_mean_pixel_difference(image, expected_image)
......@@ -28,9 +28,6 @@ from diffusers.utils.testing_utils import require_torch, torch_device
torch.backends.cuda.matmul.allow_tf32 = False
ALLOWED_REQUIRED_ARGS = ["source_prompt", "prompt", "image", "mask_image", "example_image"]
@require_torch
class PipelineTesterMixin:
"""
......@@ -39,6 +36,10 @@ class PipelineTesterMixin:
equivalence of dict and tuple outputs, etc.
"""
allowed_required_args = ["source_prompt", "prompt", "image", "mask_image", "example_image"]
required_optional_params = ["generator", "num_inference_steps", "return_dict"]
num_inference_steps_args = ["num_inference_steps"]
# set these parameters to False in the child class if the pipeline does not support the corresponding functionality
test_attention_slicing = True
test_cpu_offload = True
......@@ -120,15 +121,17 @@ class PipelineTesterMixin:
if param == "kwargs":
# kwargs can be added if arguments of pipeline call function are deprecated
continue
assert param in ALLOWED_REQUIRED_ARGS
assert param in self.allowed_required_args
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
required_optional_params = ["generator", "num_inference_steps", "return_dict"]
for param in required_optional_params:
for param in self.required_optional_params:
assert param in optional_parameters
def test_inference_batch_consistent(self):
self._test_inference_batch_consistent()
def _test_inference_batch_consistent(self, batch_sizes=[2, 4, 13]):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
......@@ -140,10 +143,10 @@ class PipelineTesterMixin:
logger.setLevel(level=diffusers.logging.FATAL)
# batchify inputs
for batch_size in [2, 4, 13]:
for batch_size in batch_sizes:
batched_inputs = {}
for name, value in inputs.items():
if name in ALLOWED_REQUIRED_ARGS:
if name in self.allowed_required_args:
# prompt is string
if name == "prompt":
len_prompt = len(value)
......@@ -160,7 +163,9 @@ class PipelineTesterMixin:
else:
batched_inputs[name] = value
batched_inputs["num_inference_steps"] = inputs["num_inference_steps"]
for arg in self.num_inference_steps_args:
batched_inputs[arg] = inputs[arg]
batched_inputs["output_type"] = None
if self.pipeline_class.__name__ == "DanceDiffusionPipeline":
......@@ -182,12 +187,26 @@ class PipelineTesterMixin:
logger.setLevel(level=diffusers.logging.WARNING)
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical()
def _test_inference_batch_single_identical(
self, test_max_difference=None, test_mean_pixel_difference=None, relax_max_difference=False
):
if self.pipeline_class.__name__ in ["CycleDiffusionPipeline", "RePaintPipeline"]:
# RePaint can hardly be made deterministic since the scheduler is currently always
# indeterministic
# CycleDiffusion is also slighly undeterministic
return
if test_max_difference is None:
# TODO(Pedro) - not sure why, but not at all reproducible at the moment it seems
# make sure that batched and non-batched is identical
test_max_difference = torch_device != "mps"
if test_mean_pixel_difference is None:
# TODO same as above
test_mean_pixel_difference = torch_device != "mps"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
......@@ -202,7 +221,7 @@ class PipelineTesterMixin:
batched_inputs = {}
batch_size = 3
for name, value in inputs.items():
if name in ALLOWED_REQUIRED_ARGS:
if name in self.allowed_required_args:
# prompt is string
if name == "prompt":
len_prompt = len(value)
......@@ -221,7 +240,8 @@ class PipelineTesterMixin:
else:
batched_inputs[name] = value
batched_inputs["num_inference_steps"] = inputs["num_inference_steps"]
for arg in self.num_inference_steps_args:
batched_inputs[arg] = inputs[arg]
if self.pipeline_class.__name__ != "DanceDiffusionPipeline":
batched_inputs["output_type"] = "np"
......@@ -234,10 +254,19 @@ class PipelineTesterMixin:
output = pipe(**inputs)
logger.setLevel(level=diffusers.logging.WARNING)
if torch_device != "mps":
# TODO(Pedro) - not sure why, but not at all reproducible at the moment it seems
# make sure that batched and non-batched is identical
assert np.abs(output_batch[0][0] - output[0][0]).max() < 1e-4
if test_max_difference:
if relax_max_difference:
# Taking the median of the largest <n> differences
# is resilient to outliers
diff = np.abs(output_batch[0][0] - output[0][0])
diff.sort()
max_diff = np.median(diff[-5:])
else:
max_diff = np.abs(output_batch[0][0] - output[0][0]).max()
assert max_diff < 1e-4
if test_mean_pixel_difference:
assert_mean_pixel_difference(output_batch[0][0], output[0][0])
def test_dict_tuple_outputs_equivalent(self):
if torch_device == "mps" and self.pipeline_class in (
......@@ -278,7 +307,9 @@ class PipelineTesterMixin:
times = []
for num_steps in [9, 6, 3]:
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = num_steps
for arg in self.num_inference_steps_args:
inputs[arg] = num_steps
start_time = time.time()
output = pipe(**inputs)[0]
......@@ -419,6 +450,9 @@ class PipelineTesterMixin:
self.assertTrue(np.isnan(output_cuda).sum() == 0)
def test_attention_slicing_forward_pass(self):
self._test_attention_slicing_forward_pass()
def _test_attention_slicing_forward_pass(self, test_max_difference=True):
if not self.test_attention_slicing:
return
......@@ -448,8 +482,11 @@ class PipelineTesterMixin:
inputs = self.get_dummy_inputs(torch_device)
output_with_slicing = pipe(**inputs)[0]
max_diff = np.abs(output_with_slicing - output_without_slicing).max()
self.assertLess(max_diff, 1e-3, "Attention slicing should not affect the inference results")
if test_max_difference:
max_diff = np.abs(output_with_slicing - output_without_slicing).max()
self.assertLess(max_diff, 1e-3, "Attention slicing should not affect the inference results")
assert_mean_pixel_difference(output_with_slicing[0], output_without_slicing[0])
@unittest.skipIf(
torch_device != "cuda" or not is_accelerate_available(),
......@@ -518,3 +555,13 @@ class PipelineTesterMixin:
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
_ = pipe(**inputs)
self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
# reference image.
def assert_mean_pixel_difference(image, expected_image):
image = np.asarray(DiffusionPipeline.numpy_to_pil(image)[0], dtype=np.float32)
expected_image = np.asarray(DiffusionPipeline.numpy_to_pil(expected_image)[0], dtype=np.float32)
avg_diff = np.abs(image - expected_image).mean()
assert avg_diff < 10, f"Error image deviates {avg_diff} pixels on average"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment