Unverified Commit 07c0fe4b authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

Use pipeline tests mixin for UnCLIP pipeline tests + unCLIP MPS fixes (#1908)

re: https://github.com/huggingface/diffusers/issues/1857

We relax some of the checks to deal with unclip reproducibility issues. Mainly by checking the average pixel difference (measured w/in 0-255) instead of the max pixel difference (measured w/in 0-1).

- [x] add mixin to UnCLIPPipelineFastTests
- [x] add mixin to UnCLIPImageVariationPipelineFastTests
- [x] Move UnCLIPPipeline flags in mixin to base class
- [x] Small MPS fixes for F.pad and F.interpolate
- [x] Made test unCLIP model's dimensions smaller to run tests faster
parent 1e651ca2
...@@ -208,6 +208,13 @@ class CrossAttention(nn.Module): ...@@ -208,6 +208,13 @@ class CrossAttention(nn.Module):
return attention_mask return attention_mask
if attention_mask.shape[-1] != target_length: if attention_mask.shape[-1] != target_length:
if attention_mask.device.type == "mps":
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
# Instead, we can manually construct the padding tensor.
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
padding = torch.zeros(padding_shape, device=attention_mask.device)
attention_mask = torch.concat([attention_mask, padding], dim=2)
else:
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
attention_mask = attention_mask.repeat_interleave(head_size, dim=0) attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
return attention_mask return attention_mask
......
...@@ -452,7 +452,7 @@ class AltDiffusionPipeline(DiffusionPipeline): ...@@ -452,7 +452,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
eta (`float`, *optional*, defaults to 0.0): eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others. [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*): generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic. to make generation deterministic.
latents (`torch.FloatTensor`, *optional*): latents (`torch.FloatTensor`, *optional*):
......
...@@ -449,7 +449,7 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -449,7 +449,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
eta (`float`, *optional*, defaults to 0.0): eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others. [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*): generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic. to make generation deterministic.
latents (`torch.FloatTensor`, *optional*): latents (`torch.FloatTensor`, *optional*):
......
...@@ -22,7 +22,8 @@ from transformers import CLIPTextModelWithProjection, CLIPTokenizer ...@@ -22,7 +22,8 @@ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from transformers.models.clip.modeling_clip import CLIPTextModelOutput from transformers.models.clip.modeling_clip import CLIPTextModelOutput
from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel
from ...pipelines import DiffusionPipeline, ImagePipelineOutput from ...pipelines import DiffusionPipeline
from ...pipelines.pipeline_utils import ImagePipelineOutput
from ...schedulers import UnCLIPScheduler from ...schedulers import UnCLIPScheduler
from ...utils import is_accelerate_available, logging, randn_tensor from ...utils import is_accelerate_available, logging, randn_tensor
from .text_proj import UnCLIPTextProjModel from .text_proj import UnCLIPTextProjModel
...@@ -130,13 +131,20 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -130,13 +131,20 @@ class UnCLIPPipeline(DiffusionPipeline):
prompt, prompt,
padding="max_length", padding="max_length",
max_length=self.tokenizer.model_max_length, max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt", return_tensors="pt",
) )
text_input_ids = text_inputs.input_ids text_input_ids = text_inputs.input_ids
text_mask = text_inputs.attention_mask.bool().to(device) text_mask = text_inputs.attention_mask.bool().to(device)
if text_input_ids.shape[-1] > self.tokenizer.model_max_length: untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning( logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to" "The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}" f" {self.tokenizer.model_max_length} tokens: {removed_text}"
...@@ -249,7 +257,7 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -249,7 +257,7 @@ class UnCLIPPipeline(DiffusionPipeline):
prior_num_inference_steps: int = 25, prior_num_inference_steps: int = 25,
decoder_num_inference_steps: int = 25, decoder_num_inference_steps: int = 25,
super_res_num_inference_steps: int = 7, super_res_num_inference_steps: int = 7,
generator: Optional[torch.Generator] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
prior_latents: Optional[torch.FloatTensor] = None, prior_latents: Optional[torch.FloatTensor] = None,
decoder_latents: Optional[torch.FloatTensor] = None, decoder_latents: Optional[torch.FloatTensor] = None,
super_res_latents: Optional[torch.FloatTensor] = None, super_res_latents: Optional[torch.FloatTensor] = None,
...@@ -278,7 +286,7 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -278,7 +286,7 @@ class UnCLIPPipeline(DiffusionPipeline):
super_res_num_inference_steps (`int`, *optional*, defaults to 7): super_res_num_inference_steps (`int`, *optional*, defaults to 7):
The number of denoising steps for super resolution. More denoising steps usually lead to a higher The number of denoising steps for super resolution. More denoising steps usually lead to a higher
quality image at the expense of slower inference. quality image at the expense of slower inference.
generator (`torch.Generator`, *optional*): generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic. to make generation deterministic.
prior_latents (`torch.FloatTensor` of shape (batch size, embeddings dimension), *optional*): prior_latents (`torch.FloatTensor` of shape (batch size, embeddings dimension), *optional*):
...@@ -394,7 +402,14 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -394,7 +402,14 @@ class UnCLIPPipeline(DiffusionPipeline):
do_classifier_free_guidance=do_classifier_free_guidance, do_classifier_free_guidance=do_classifier_free_guidance,
) )
if device.type == "mps":
# HACK: MPS: There is a panic when padding bool tensors,
# so cast to int tensor for the pad and back to bool afterwards
text_mask = text_mask.type(torch.int)
decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1) decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1)
decoder_text_mask = decoder_text_mask.type(torch.bool)
else:
decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True)
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
decoder_timesteps_tensor = self.decoder_scheduler.timesteps decoder_timesteps_tensor = self.decoder_scheduler.timesteps
...@@ -465,6 +480,10 @@ class UnCLIPPipeline(DiffusionPipeline): ...@@ -465,6 +480,10 @@ class UnCLIPPipeline(DiffusionPipeline):
self.super_res_scheduler, self.super_res_scheduler,
) )
if device.type == "mps":
# MPS does not support many interpolations
image_upscaled = F.interpolate(image_small, size=[height, width])
else:
interpolate_antialias = {} interpolate_antialias = {}
if "antialias" in inspect.signature(F.interpolate).parameters: if "antialias" in inspect.signature(F.interpolate).parameters:
interpolate_antialias["antialias"] = True interpolate_antialias["antialias"] = True
......
...@@ -328,7 +328,14 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline): ...@@ -328,7 +328,14 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
do_classifier_free_guidance=do_classifier_free_guidance, do_classifier_free_guidance=do_classifier_free_guidance,
) )
if device.type == "mps":
# HACK: MPS: There is a panic when padding bool tensors,
# so cast to int tensor for the pad and back to bool afterwards
text_mask = text_mask.type(torch.int)
decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1) decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1)
decoder_text_mask = decoder_text_mask.type(torch.bool)
else:
decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True)
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
decoder_timesteps_tensor = self.decoder_scheduler.timesteps decoder_timesteps_tensor = self.decoder_scheduler.timesteps
...@@ -401,6 +408,10 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline): ...@@ -401,6 +408,10 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
self.super_res_scheduler, self.super_res_scheduler,
) )
if device.type == "mps":
# MPS does not support many interpolations
image_upscaled = F.interpolate(image_small, size=[height, width])
else:
interpolate_antialias = {} interpolate_antialias = {}
if "antialias" in inspect.signature(F.interpolate).parameters: if "antialias" in inspect.signature(F.interpolate).parameters:
interpolate_antialias["antialias"] = True interpolate_antialias["antialias"] = True
......
...@@ -219,7 +219,6 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin): ...@@ -219,7 +219,6 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin):
returning a tuple, the first element is the sample tensor. returning a tuple, the first element is the sample tensor.
""" """
t = timestep t = timestep
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type == "learned_range": if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type == "learned_range":
......
...@@ -25,16 +25,24 @@ from diffusers.utils import load_numpy, nightly, slow, torch_device ...@@ -25,16 +25,24 @@ from diffusers.utils import load_numpy, nightly, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu from diffusers.utils.testing_utils import require_torch_gpu
from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
from ...test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
torch.backends.cuda.matmul.allow_tf32 = False
class UnCLIPPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = UnCLIPPipeline
class UnCLIPPipelineFastTests(unittest.TestCase): required_optional_params = [
def tearDown(self): "generator",
# clean up the VRAM after each test "return_dict",
super().tearDown() "prior_num_inference_steps",
gc.collect() "decoder_num_inference_steps",
torch.cuda.empty_cache() "super_res_num_inference_steps",
]
num_inference_steps_args = [
"prior_num_inference_steps",
"decoder_num_inference_steps",
"super_res_num_inference_steps",
]
@property @property
def text_embedder_hidden_size(self): def text_embedder_hidden_size(self):
...@@ -110,7 +118,7 @@ class UnCLIPPipelineFastTests(unittest.TestCase): ...@@ -110,7 +118,7 @@ class UnCLIPPipelineFastTests(unittest.TestCase):
torch.manual_seed(0) torch.manual_seed(0)
model_kwargs = { model_kwargs = {
"sample_size": 64, "sample_size": 32,
# RGB in channels # RGB in channels
"in_channels": 3, "in_channels": 3,
# Out channels is double in channels because predicts mean and variance # Out channels is double in channels because predicts mean and variance
...@@ -132,7 +140,7 @@ class UnCLIPPipelineFastTests(unittest.TestCase): ...@@ -132,7 +140,7 @@ class UnCLIPPipelineFastTests(unittest.TestCase):
@property @property
def dummy_super_res_kwargs(self): def dummy_super_res_kwargs(self):
return { return {
"sample_size": 128, "sample_size": 64,
"layers_per_block": 1, "layers_per_block": 1,
"down_block_types": ("ResnetDownsampleBlock2D", "ResnetDownsampleBlock2D"), "down_block_types": ("ResnetDownsampleBlock2D", "ResnetDownsampleBlock2D"),
"up_block_types": ("ResnetUpsampleBlock2D", "ResnetUpsampleBlock2D"), "up_block_types": ("ResnetUpsampleBlock2D", "ResnetUpsampleBlock2D"),
...@@ -156,9 +164,7 @@ class UnCLIPPipelineFastTests(unittest.TestCase): ...@@ -156,9 +164,7 @@ class UnCLIPPipelineFastTests(unittest.TestCase):
model = UNet2DModel(**self.dummy_super_res_kwargs) model = UNet2DModel(**self.dummy_super_res_kwargs)
return model return model
def test_unclip(self): def get_dummy_components(self):
device = "cpu"
prior = self.dummy_prior prior = self.dummy_prior
decoder = self.dummy_decoder decoder = self.dummy_decoder
text_proj = self.dummy_text_proj text_proj = self.dummy_text_proj
...@@ -186,62 +192,70 @@ class UnCLIPPipelineFastTests(unittest.TestCase): ...@@ -186,62 +192,70 @@ class UnCLIPPipelineFastTests(unittest.TestCase):
num_train_timesteps=1000, num_train_timesteps=1000,
) )
pipe = UnCLIPPipeline( components = {
prior=prior, "prior": prior,
decoder=decoder, "decoder": decoder,
text_proj=text_proj, "text_proj": text_proj,
text_encoder=text_encoder, "text_encoder": text_encoder,
tokenizer=tokenizer, "tokenizer": tokenizer,
super_res_first=super_res_first, "super_res_first": super_res_first,
super_res_last=super_res_last, "super_res_last": super_res_last,
prior_scheduler=prior_scheduler, "prior_scheduler": prior_scheduler,
decoder_scheduler=decoder_scheduler, "decoder_scheduler": decoder_scheduler,
super_res_scheduler=super_res_scheduler, "super_res_scheduler": super_res_scheduler,
) }
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "horse",
"generator": generator,
"prior_num_inference_steps": 2,
"decoder_num_inference_steps": 2,
"super_res_num_inference_steps": 2,
"output_type": "numpy",
}
return inputs
def test_unclip(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device) pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
prompt = "horse" output = pipe(**self.get_dummy_inputs(device))
generator = torch.Generator(device=device).manual_seed(0)
output = pipe(
[prompt],
generator=generator,
prior_num_inference_steps=2,
decoder_num_inference_steps=2,
super_res_num_inference_steps=2,
output_type="np",
)
image = output.images image = output.images
generator = torch.Generator(device=device).manual_seed(0)
image_from_tuple = pipe( image_from_tuple = pipe(
[prompt], **self.get_dummy_inputs(device),
generator=generator,
prior_num_inference_steps=2,
decoder_num_inference_steps=2,
super_res_num_inference_steps=2,
output_type="np",
return_dict=False, return_dict=False,
)[0] )[0]
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 128, 128, 3) assert image.shape == (1, 64, 64, 3)
expected_slice = np.array( expected_slice = np.array(
[ [
0.0011,
0.0002,
0.9962,
0.9940,
0.0002,
0.9997, 0.9997,
0.0003, 0.9988,
0.9987, 0.0028,
0.9989, 0.9997,
0.9984,
0.9965,
0.0029,
0.9986,
0.0025,
] ]
) )
...@@ -254,47 +268,17 @@ class UnCLIPPipelineFastTests(unittest.TestCase): ...@@ -254,47 +268,17 @@ class UnCLIPPipelineFastTests(unittest.TestCase):
class DummyScheduler: class DummyScheduler:
init_noise_sigma = 1 init_noise_sigma = 1
prior = self.dummy_prior components = self.get_dummy_components()
decoder = self.dummy_decoder
text_proj = self.dummy_text_proj
text_encoder = self.dummy_text_encoder
tokenizer = self.dummy_tokenizer
super_res_first = self.dummy_super_res_first
super_res_last = self.dummy_super_res_last
prior_scheduler = UnCLIPScheduler(
variance_type="fixed_small_log",
prediction_type="sample",
num_train_timesteps=1000,
clip_sample_range=5.0,
)
decoder_scheduler = UnCLIPScheduler(
variance_type="learned_range",
prediction_type="epsilon",
num_train_timesteps=1000,
)
super_res_scheduler = UnCLIPScheduler( pipe = self.pipeline_class(**components)
variance_type="fixed_small_log",
prediction_type="epsilon",
num_train_timesteps=1000,
)
pipe = UnCLIPPipeline(
prior=prior,
decoder=decoder,
text_proj=text_proj,
text_encoder=text_encoder,
tokenizer=tokenizer,
super_res_first=super_res_first,
super_res_last=super_res_last,
prior_scheduler=prior_scheduler,
decoder_scheduler=decoder_scheduler,
super_res_scheduler=super_res_scheduler,
)
pipe = pipe.to(device) pipe = pipe.to(device)
prior = components["prior"]
decoder = components["decoder"]
super_res_first = components["super_res_first"]
tokenizer = components["tokenizer"]
text_encoder = components["text_encoder"]
generator = torch.Generator(device=device).manual_seed(0) generator = torch.Generator(device=device).manual_seed(0)
dtype = prior.dtype dtype = prior.dtype
batch_size = 1 batch_size = 1
...@@ -362,6 +346,45 @@ class UnCLIPPipelineFastTests(unittest.TestCase): ...@@ -362,6 +346,45 @@ class UnCLIPPipelineFastTests(unittest.TestCase):
# make sure passing text embeddings manually is identical # make sure passing text embeddings manually is identical
assert np.abs(image - image_from_text).max() < 1e-4 assert np.abs(image - image_from_text).max() < 1e-4
# Overriding PipelineTesterMixin::test_attention_slicing_forward_pass
# because UnCLIP GPU undeterminism requires a looser check.
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent")
def test_attention_slicing_forward_pass(self):
test_max_difference = torch_device == "cpu"
self._test_attention_slicing_forward_pass(test_max_difference=test_max_difference)
# Overriding PipelineTesterMixin::test_inference_batch_single_identical
# because UnCLIP undeterminism requires a looser check.
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent")
def test_inference_batch_single_identical(self):
test_max_difference = torch_device == "cpu"
relax_max_difference = True
self._test_inference_batch_single_identical(
test_max_difference=test_max_difference, relax_max_difference=relax_max_difference
)
def test_inference_batch_consistent(self):
if torch_device == "mps":
# TODO: MPS errors with larger batch sizes
batch_sizes = [2, 3]
self._test_inference_batch_consistent(batch_sizes=batch_sizes)
else:
self._test_inference_batch_consistent()
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent")
def test_dict_tuple_outputs_equivalent(self):
return super().test_dict_tuple_outputs_equivalent()
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent")
def test_save_load_local(self):
return super().test_save_load_local()
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent")
def test_save_load_optional_components(self):
return super().test_save_load_optional_components()
@nightly @nightly
class UnCLIPPipelineCPUIntegrationTests(unittest.TestCase): class UnCLIPPipelineCPUIntegrationTests(unittest.TestCase):
...@@ -420,16 +443,12 @@ class UnCLIPPipelineIntegrationTests(unittest.TestCase): ...@@ -420,16 +443,12 @@ class UnCLIPPipelineIntegrationTests(unittest.TestCase):
output_type="np", output_type="np",
) )
image = np.asarray(pipeline.numpy_to_pil(output.images)[0], dtype=np.float32) image = output.images[0]
expected_image = np.asarray(pipeline.numpy_to_pil(expected_image)[0], dtype=np.float32)
# Karlo is extremely likely to strongly deviate depending on which hardware is used
# Here we just check that the image doesn't deviate more than 10 pixels from the reference image on average
avg_diff = np.abs(image - expected_image).mean()
assert avg_diff < 10, f"Error image deviates {avg_diff} pixels on average"
assert image.shape == (256, 256, 3) assert image.shape == (256, 256, 3)
assert_mean_pixel_difference(image, expected_image)
def test_unclip_pipeline_with_sequential_cpu_offloading(self): def test_unclip_pipeline_with_sequential_cpu_offloading(self):
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated() torch.cuda.reset_max_memory_allocated()
......
...@@ -39,16 +39,22 @@ from transformers import ( ...@@ -39,16 +39,22 @@ from transformers import (
CLIPVisionModelWithProjection, CLIPVisionModelWithProjection,
) )
from ...test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
torch.backends.cuda.matmul.allow_tf32 = False
class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = UnCLIPImageVariationPipeline
class UnCLIPImageVariationPipelineFastTests(unittest.TestCase): required_optional_params = [
def tearDown(self): "generator",
# clean up the VRAM after each test "return_dict",
super().tearDown() "decoder_num_inference_steps",
gc.collect() "super_res_num_inference_steps",
torch.cuda.empty_cache() ]
num_inference_steps_args = [
"decoder_num_inference_steps",
"super_res_num_inference_steps",
]
@property @property
def text_embedder_hidden_size(self): def text_embedder_hidden_size(self):
...@@ -124,7 +130,7 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase): ...@@ -124,7 +130,7 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase):
torch.manual_seed(0) torch.manual_seed(0)
model_kwargs = { model_kwargs = {
"sample_size": 64, "sample_size": 32,
# RGB in channels # RGB in channels
"in_channels": 3, "in_channels": 3,
# Out channels is double in channels because predicts mean and variance # Out channels is double in channels because predicts mean and variance
...@@ -146,7 +152,7 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase): ...@@ -146,7 +152,7 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase):
@property @property
def dummy_super_res_kwargs(self): def dummy_super_res_kwargs(self):
return { return {
"sample_size": 128, "sample_size": 64,
"layers_per_block": 1, "layers_per_block": 1,
"down_block_types": ("ResnetDownsampleBlock2D", "ResnetDownsampleBlock2D"), "down_block_types": ("ResnetDownsampleBlock2D", "ResnetDownsampleBlock2D"),
"up_block_types": ("ResnetUpsampleBlock2D", "ResnetUpsampleBlock2D"), "up_block_types": ("ResnetUpsampleBlock2D", "ResnetUpsampleBlock2D"),
...@@ -170,7 +176,7 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase): ...@@ -170,7 +176,7 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase):
model = UNet2DModel(**self.dummy_super_res_kwargs) model = UNet2DModel(**self.dummy_super_res_kwargs)
return model return model
def get_pipeline(self, device): def get_dummy_components(self):
decoder = self.dummy_decoder decoder = self.dummy_decoder
text_proj = self.dummy_text_proj text_proj = self.dummy_text_proj
text_encoder = self.dummy_text_encoder text_encoder = self.dummy_text_encoder
...@@ -194,26 +200,24 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase): ...@@ -194,26 +200,24 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase):
image_encoder = self.dummy_image_encoder image_encoder = self.dummy_image_encoder
pipe = UnCLIPImageVariationPipeline( return {
decoder=decoder, "decoder": decoder,
text_encoder=text_encoder, "text_encoder": text_encoder,
tokenizer=tokenizer, "tokenizer": tokenizer,
text_proj=text_proj, "text_proj": text_proj,
feature_extractor=feature_extractor, "feature_extractor": feature_extractor,
image_encoder=image_encoder, "image_encoder": image_encoder,
super_res_first=super_res_first, "super_res_first": super_res_first,
super_res_last=super_res_last, "super_res_last": super_res_last,
decoder_scheduler=decoder_scheduler, "decoder_scheduler": decoder_scheduler,
super_res_scheduler=super_res_scheduler, "super_res_scheduler": super_res_scheduler,
) }
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
return pipe
def get_pipeline_inputs(self, device, seed, pil_image=False): def get_dummy_inputs(self, device, seed=0, pil_image=True):
input_image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) input_image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed) generator = torch.Generator(device=device).manual_seed(seed)
if pil_image: if pil_image:
...@@ -232,16 +236,20 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase): ...@@ -232,16 +236,20 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase):
def test_unclip_image_variation_input_tensor(self): def test_unclip_image_variation_input_tensor(self):
device = "cpu" device = "cpu"
seed = 0
pipe = self.get_pipeline(device) components = self.get_dummy_components()
pipeline_inputs = self.get_pipeline_inputs(device, seed) pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
pipeline_inputs = self.get_dummy_inputs(device, pil_image=False)
output = pipe(**pipeline_inputs) output = pipe(**pipeline_inputs)
image = output.images image = output.images
tuple_pipeline_inputs = self.get_pipeline_inputs(device, seed) tuple_pipeline_inputs = self.get_dummy_inputs(device, pil_image=False)
image_from_tuple = pipe( image_from_tuple = pipe(
**tuple_pipeline_inputs, **tuple_pipeline_inputs,
...@@ -251,19 +259,19 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase): ...@@ -251,19 +259,19 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase):
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 128, 128, 3) assert image.shape == (1, 64, 64, 3)
expected_slice = np.array( expected_slice = np.array(
[ [
0.9988,
0.9997, 0.9997,
0.9944, 0.0002,
0.0003, 0.9997,
0.0003, 0.9997,
0.9974, 0.9969,
0.0003, 0.0023,
0.0004, 0.9997,
0.9931, 0.9969,
0.9970,
] ]
) )
...@@ -272,16 +280,20 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase): ...@@ -272,16 +280,20 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase):
def test_unclip_image_variation_input_image(self): def test_unclip_image_variation_input_image(self):
device = "cpu" device = "cpu"
seed = 0
pipe = self.get_pipeline(device) components = self.get_dummy_components()
pipeline_inputs = self.get_pipeline_inputs(device, seed, pil_image=True) pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
pipeline_inputs = self.get_dummy_inputs(device, pil_image=True)
output = pipe(**pipeline_inputs) output = pipe(**pipeline_inputs)
image = output.images image = output.images
tuple_pipeline_inputs = self.get_pipeline_inputs(device, seed, pil_image=True) tuple_pipeline_inputs = self.get_dummy_inputs(device, pil_image=True)
image_from_tuple = pipe( image_from_tuple = pipe(
**tuple_pipeline_inputs, **tuple_pipeline_inputs,
...@@ -291,32 +303,24 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase): ...@@ -291,32 +303,24 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase):
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 128, 128, 3) assert image.shape == (1, 64, 64, 3)
expected_slice = np.array( expected_slice = np.array([0.9997, 0.0003, 0.9997, 0.9997, 0.9970, 0.0024, 0.9997, 0.9971, 0.9971])
[
0.9988,
0.9997,
0.9944,
0.0003,
0.0003,
0.9974,
0.0003,
0.0004,
0.9931,
]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_unclip_image_variation_input_list_images(self): def test_unclip_image_variation_input_list_images(self):
device = "cpu" device = "cpu"
seed = 0
pipe = self.get_pipeline(device) components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
pipeline_inputs = self.get_pipeline_inputs(device, seed, pil_image=True) pipeline_inputs = self.get_dummy_inputs(device, pil_image=True)
pipeline_inputs["image"] = [ pipeline_inputs["image"] = [
pipeline_inputs["image"], pipeline_inputs["image"],
pipeline_inputs["image"], pipeline_inputs["image"],
...@@ -325,7 +329,7 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase): ...@@ -325,7 +329,7 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase):
output = pipe(**pipeline_inputs) output = pipe(**pipeline_inputs)
image = output.images image = output.images
tuple_pipeline_inputs = self.get_pipeline_inputs(device, seed, pil_image=True) tuple_pipeline_inputs = self.get_dummy_inputs(device, pil_image=True)
tuple_pipeline_inputs["image"] = [ tuple_pipeline_inputs["image"] = [
tuple_pipeline_inputs["image"], tuple_pipeline_inputs["image"],
tuple_pipeline_inputs["image"], tuple_pipeline_inputs["image"],
...@@ -339,19 +343,19 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase): ...@@ -339,19 +343,19 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase):
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (2, 128, 128, 3) assert image.shape == (2, 64, 64, 3)
expected_slice = np.array( expected_slice = np.array(
[ [
0.9997, 0.9997,
0.9997, 0.9989,
0.0003, 0.0008,
0.0003, 0.0021,
0.9950, 0.9960,
0.0003, 0.0018,
0.9993, 0.0014,
0.9957, 0.0002,
0.0004, 0.9933,
] ]
) )
...@@ -360,11 +364,15 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase): ...@@ -360,11 +364,15 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase):
def test_unclip_image_variation_input_num_images_per_prompt(self): def test_unclip_image_variation_input_num_images_per_prompt(self):
device = "cpu" device = "cpu"
seed = 0
pipe = self.get_pipeline(device) components = self.get_dummy_components()
pipeline_inputs = self.get_pipeline_inputs(device, seed, pil_image=True) pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
pipeline_inputs = self.get_dummy_inputs(device, pil_image=True)
pipeline_inputs["image"] = [ pipeline_inputs["image"] = [
pipeline_inputs["image"], pipeline_inputs["image"],
pipeline_inputs["image"], pipeline_inputs["image"],
...@@ -373,7 +381,7 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase): ...@@ -373,7 +381,7 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase):
output = pipe(**pipeline_inputs, num_images_per_prompt=2) output = pipe(**pipeline_inputs, num_images_per_prompt=2)
image = output.images image = output.images
tuple_pipeline_inputs = self.get_pipeline_inputs(device, seed, pil_image=True) tuple_pipeline_inputs = self.get_dummy_inputs(device, pil_image=True)
tuple_pipeline_inputs["image"] = [ tuple_pipeline_inputs["image"] = [
tuple_pipeline_inputs["image"], tuple_pipeline_inputs["image"],
tuple_pipeline_inputs["image"], tuple_pipeline_inputs["image"],
...@@ -388,18 +396,18 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase): ...@@ -388,18 +396,18 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase):
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (4, 128, 128, 3) assert image.shape == (4, 64, 64, 3)
expected_slice = np.array( expected_slice = np.array(
[ [
0.9997,
0.9997,
0.0008,
0.9952,
0.9980, 0.9980,
0.9997, 0.9997,
0.9961, 0.0023,
0.0029,
0.9997,
0.9985,
0.9997, 0.9997,
0.0010,
0.9995, 0.9995,
] ]
) )
...@@ -409,12 +417,16 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase): ...@@ -409,12 +417,16 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase):
def test_unclip_passed_image_embed(self): def test_unclip_passed_image_embed(self):
device = torch.device("cpu") device = torch.device("cpu")
seed = 0
class DummyScheduler: class DummyScheduler:
init_noise_sigma = 1 init_noise_sigma = 1
pipe = self.get_pipeline(device) components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device=device).manual_seed(0) generator = torch.Generator(device=device).manual_seed(0)
dtype = pipe.decoder.dtype dtype = pipe.decoder.dtype
...@@ -435,13 +447,13 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase): ...@@ -435,13 +447,13 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase):
shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler() shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
) )
pipeline_inputs = self.get_pipeline_inputs(device, seed) pipeline_inputs = self.get_dummy_inputs(device, pil_image=False)
img_out_1 = pipe( img_out_1 = pipe(
**pipeline_inputs, decoder_latents=decoder_latents, super_res_latents=super_res_latents **pipeline_inputs, decoder_latents=decoder_latents, super_res_latents=super_res_latents
).images ).images
pipeline_inputs = self.get_pipeline_inputs(device, seed) pipeline_inputs = self.get_dummy_inputs(device, pil_image=False)
# Don't pass image, instead pass embedding # Don't pass image, instead pass embedding
image = pipeline_inputs.pop("image") image = pipeline_inputs.pop("image")
image_embeddings = pipe.image_encoder(image).image_embeds image_embeddings = pipe.image_encoder(image).image_embeds
...@@ -456,6 +468,45 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase): ...@@ -456,6 +468,45 @@ class UnCLIPImageVariationPipelineFastTests(unittest.TestCase):
# make sure passing text embeddings manually is identical # make sure passing text embeddings manually is identical
assert np.abs(img_out_1 - img_out_2).max() < 1e-4 assert np.abs(img_out_1 - img_out_2).max() < 1e-4
# Overriding PipelineTesterMixin::test_attention_slicing_forward_pass
# because UnCLIP GPU undeterminism requires a looser check.
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent")
def test_attention_slicing_forward_pass(self):
test_max_difference = torch_device == "cpu"
self._test_attention_slicing_forward_pass(test_max_difference=test_max_difference)
# Overriding PipelineTesterMixin::test_inference_batch_single_identical
# because UnCLIP undeterminism requires a looser check.
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent")
def test_inference_batch_single_identical(self):
test_max_difference = torch_device == "cpu"
relax_max_difference = True
self._test_inference_batch_single_identical(
test_max_difference=test_max_difference, relax_max_difference=relax_max_difference
)
def test_inference_batch_consistent(self):
if torch_device == "mps":
# TODO: MPS errors with larger batch sizes
batch_sizes = [2, 3]
self._test_inference_batch_consistent(batch_sizes=batch_sizes)
else:
self._test_inference_batch_consistent()
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent")
def test_dict_tuple_outputs_equivalent(self):
return super().test_dict_tuple_outputs_equivalent()
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent")
def test_save_load_local(self):
return super().test_save_load_local()
@unittest.skipIf(torch_device == "mps", reason="MPS inconsistent")
def test_save_load_optional_components(self):
return super().test_save_load_optional_components()
@slow @slow
@require_torch_gpu @require_torch_gpu
...@@ -488,12 +539,8 @@ class UnCLIPImageVariationPipelineIntegrationTests(unittest.TestCase): ...@@ -488,12 +539,8 @@ class UnCLIPImageVariationPipelineIntegrationTests(unittest.TestCase):
output_type="np", output_type="np",
) )
image = np.asarray(pipeline.numpy_to_pil(output.images)[0], dtype=np.float32) image = output.images[0]
expected_image = np.asarray(pipeline.numpy_to_pil(expected_image)[0], dtype=np.float32)
# Karlo is extremely likely to strongly deviate depending on which hardware is used
# Here we just check that the image doesn't deviate more than 10 pixels from the reference image on average
avg_diff = np.abs(image - expected_image).mean()
assert avg_diff < 10, f"Error image deviates {avg_diff} pixels on average"
assert image.shape == (256, 256, 3) assert image.shape == (256, 256, 3)
assert_mean_pixel_difference(image, expected_image)
...@@ -28,9 +28,6 @@ from diffusers.utils.testing_utils import require_torch, torch_device ...@@ -28,9 +28,6 @@ from diffusers.utils.testing_utils import require_torch, torch_device
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
ALLOWED_REQUIRED_ARGS = ["source_prompt", "prompt", "image", "mask_image", "example_image"]
@require_torch @require_torch
class PipelineTesterMixin: class PipelineTesterMixin:
""" """
...@@ -39,6 +36,10 @@ class PipelineTesterMixin: ...@@ -39,6 +36,10 @@ class PipelineTesterMixin:
equivalence of dict and tuple outputs, etc. equivalence of dict and tuple outputs, etc.
""" """
allowed_required_args = ["source_prompt", "prompt", "image", "mask_image", "example_image"]
required_optional_params = ["generator", "num_inference_steps", "return_dict"]
num_inference_steps_args = ["num_inference_steps"]
# set these parameters to False in the child class if the pipeline does not support the corresponding functionality # set these parameters to False in the child class if the pipeline does not support the corresponding functionality
test_attention_slicing = True test_attention_slicing = True
test_cpu_offload = True test_cpu_offload = True
...@@ -120,15 +121,17 @@ class PipelineTesterMixin: ...@@ -120,15 +121,17 @@ class PipelineTesterMixin:
if param == "kwargs": if param == "kwargs":
# kwargs can be added if arguments of pipeline call function are deprecated # kwargs can be added if arguments of pipeline call function are deprecated
continue continue
assert param in ALLOWED_REQUIRED_ARGS assert param in self.allowed_required_args
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
required_optional_params = ["generator", "num_inference_steps", "return_dict"] for param in self.required_optional_params:
for param in required_optional_params:
assert param in optional_parameters assert param in optional_parameters
def test_inference_batch_consistent(self): def test_inference_batch_consistent(self):
self._test_inference_batch_consistent()
def _test_inference_batch_consistent(self, batch_sizes=[2, 4, 13]):
components = self.get_dummy_components() components = self.get_dummy_components()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe.to(torch_device) pipe.to(torch_device)
...@@ -140,10 +143,10 @@ class PipelineTesterMixin: ...@@ -140,10 +143,10 @@ class PipelineTesterMixin:
logger.setLevel(level=diffusers.logging.FATAL) logger.setLevel(level=diffusers.logging.FATAL)
# batchify inputs # batchify inputs
for batch_size in [2, 4, 13]: for batch_size in batch_sizes:
batched_inputs = {} batched_inputs = {}
for name, value in inputs.items(): for name, value in inputs.items():
if name in ALLOWED_REQUIRED_ARGS: if name in self.allowed_required_args:
# prompt is string # prompt is string
if name == "prompt": if name == "prompt":
len_prompt = len(value) len_prompt = len(value)
...@@ -160,7 +163,9 @@ class PipelineTesterMixin: ...@@ -160,7 +163,9 @@ class PipelineTesterMixin:
else: else:
batched_inputs[name] = value batched_inputs[name] = value
batched_inputs["num_inference_steps"] = inputs["num_inference_steps"] for arg in self.num_inference_steps_args:
batched_inputs[arg] = inputs[arg]
batched_inputs["output_type"] = None batched_inputs["output_type"] = None
if self.pipeline_class.__name__ == "DanceDiffusionPipeline": if self.pipeline_class.__name__ == "DanceDiffusionPipeline":
...@@ -182,12 +187,26 @@ class PipelineTesterMixin: ...@@ -182,12 +187,26 @@ class PipelineTesterMixin:
logger.setLevel(level=diffusers.logging.WARNING) logger.setLevel(level=diffusers.logging.WARNING)
def test_inference_batch_single_identical(self): def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical()
def _test_inference_batch_single_identical(
self, test_max_difference=None, test_mean_pixel_difference=None, relax_max_difference=False
):
if self.pipeline_class.__name__ in ["CycleDiffusionPipeline", "RePaintPipeline"]: if self.pipeline_class.__name__ in ["CycleDiffusionPipeline", "RePaintPipeline"]:
# RePaint can hardly be made deterministic since the scheduler is currently always # RePaint can hardly be made deterministic since the scheduler is currently always
# indeterministic # indeterministic
# CycleDiffusion is also slighly undeterministic # CycleDiffusion is also slighly undeterministic
return return
if test_max_difference is None:
# TODO(Pedro) - not sure why, but not at all reproducible at the moment it seems
# make sure that batched and non-batched is identical
test_max_difference = torch_device != "mps"
if test_mean_pixel_difference is None:
# TODO same as above
test_mean_pixel_difference = torch_device != "mps"
components = self.get_dummy_components() components = self.get_dummy_components()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe.to(torch_device) pipe.to(torch_device)
...@@ -202,7 +221,7 @@ class PipelineTesterMixin: ...@@ -202,7 +221,7 @@ class PipelineTesterMixin:
batched_inputs = {} batched_inputs = {}
batch_size = 3 batch_size = 3
for name, value in inputs.items(): for name, value in inputs.items():
if name in ALLOWED_REQUIRED_ARGS: if name in self.allowed_required_args:
# prompt is string # prompt is string
if name == "prompt": if name == "prompt":
len_prompt = len(value) len_prompt = len(value)
...@@ -221,7 +240,8 @@ class PipelineTesterMixin: ...@@ -221,7 +240,8 @@ class PipelineTesterMixin:
else: else:
batched_inputs[name] = value batched_inputs[name] = value
batched_inputs["num_inference_steps"] = inputs["num_inference_steps"] for arg in self.num_inference_steps_args:
batched_inputs[arg] = inputs[arg]
if self.pipeline_class.__name__ != "DanceDiffusionPipeline": if self.pipeline_class.__name__ != "DanceDiffusionPipeline":
batched_inputs["output_type"] = "np" batched_inputs["output_type"] = "np"
...@@ -234,10 +254,19 @@ class PipelineTesterMixin: ...@@ -234,10 +254,19 @@ class PipelineTesterMixin:
output = pipe(**inputs) output = pipe(**inputs)
logger.setLevel(level=diffusers.logging.WARNING) logger.setLevel(level=diffusers.logging.WARNING)
if torch_device != "mps": if test_max_difference:
# TODO(Pedro) - not sure why, but not at all reproducible at the moment it seems if relax_max_difference:
# make sure that batched and non-batched is identical # Taking the median of the largest <n> differences
assert np.abs(output_batch[0][0] - output[0][0]).max() < 1e-4 # is resilient to outliers
diff = np.abs(output_batch[0][0] - output[0][0])
diff.sort()
max_diff = np.median(diff[-5:])
else:
max_diff = np.abs(output_batch[0][0] - output[0][0]).max()
assert max_diff < 1e-4
if test_mean_pixel_difference:
assert_mean_pixel_difference(output_batch[0][0], output[0][0])
def test_dict_tuple_outputs_equivalent(self): def test_dict_tuple_outputs_equivalent(self):
if torch_device == "mps" and self.pipeline_class in ( if torch_device == "mps" and self.pipeline_class in (
...@@ -278,7 +307,9 @@ class PipelineTesterMixin: ...@@ -278,7 +307,9 @@ class PipelineTesterMixin:
times = [] times = []
for num_steps in [9, 6, 3]: for num_steps in [9, 6, 3]:
inputs = self.get_dummy_inputs(torch_device) inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = num_steps
for arg in self.num_inference_steps_args:
inputs[arg] = num_steps
start_time = time.time() start_time = time.time()
output = pipe(**inputs)[0] output = pipe(**inputs)[0]
...@@ -419,6 +450,9 @@ class PipelineTesterMixin: ...@@ -419,6 +450,9 @@ class PipelineTesterMixin:
self.assertTrue(np.isnan(output_cuda).sum() == 0) self.assertTrue(np.isnan(output_cuda).sum() == 0)
def test_attention_slicing_forward_pass(self): def test_attention_slicing_forward_pass(self):
self._test_attention_slicing_forward_pass()
def _test_attention_slicing_forward_pass(self, test_max_difference=True):
if not self.test_attention_slicing: if not self.test_attention_slicing:
return return
...@@ -448,9 +482,12 @@ class PipelineTesterMixin: ...@@ -448,9 +482,12 @@ class PipelineTesterMixin:
inputs = self.get_dummy_inputs(torch_device) inputs = self.get_dummy_inputs(torch_device)
output_with_slicing = pipe(**inputs)[0] output_with_slicing = pipe(**inputs)[0]
if test_max_difference:
max_diff = np.abs(output_with_slicing - output_without_slicing).max() max_diff = np.abs(output_with_slicing - output_without_slicing).max()
self.assertLess(max_diff, 1e-3, "Attention slicing should not affect the inference results") self.assertLess(max_diff, 1e-3, "Attention slicing should not affect the inference results")
assert_mean_pixel_difference(output_with_slicing[0], output_without_slicing[0])
@unittest.skipIf( @unittest.skipIf(
torch_device != "cuda" or not is_accelerate_available(), torch_device != "cuda" or not is_accelerate_available(),
reason="CPU offload is only available with CUDA and `accelerate` installed", reason="CPU offload is only available with CUDA and `accelerate` installed",
...@@ -518,3 +555,13 @@ class PipelineTesterMixin: ...@@ -518,3 +555,13 @@ class PipelineTesterMixin:
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr): with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
_ = pipe(**inputs) _ = pipe(**inputs)
self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled") self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
# reference image.
def assert_mean_pixel_difference(image, expected_image):
image = np.asarray(DiffusionPipeline.numpy_to_pil(image)[0], dtype=np.float32)
expected_image = np.asarray(DiffusionPipeline.numpy_to_pil(expected_image)[0], dtype=np.float32)
avg_diff = np.abs(image - expected_image).mean()
assert avg_diff < 10, f"Error image deviates {avg_diff} pixels on average"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment