Unverified Commit cd8b7507 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

speed up attend-and-excite fast tests (#3079)

parent 3b641eab
...@@ -44,7 +44,7 @@ class StableDiffusionAttendAndExcitePipelineFastTests(PipelineTesterMixin, unitt ...@@ -44,7 +44,7 @@ class StableDiffusionAttendAndExcitePipelineFastTests(PipelineTesterMixin, unitt
torch.manual_seed(0) torch.manual_seed(0)
unet = UNet2DConditionModel( unet = UNet2DConditionModel(
block_out_channels=(32, 64), block_out_channels=(32, 64),
layers_per_block=2, layers_per_block=1,
sample_size=32, sample_size=32,
in_channels=4, in_channels=4,
out_channels=4, out_channels=4,
...@@ -111,7 +111,7 @@ class StableDiffusionAttendAndExcitePipelineFastTests(PipelineTesterMixin, unitt ...@@ -111,7 +111,7 @@ class StableDiffusionAttendAndExcitePipelineFastTests(PipelineTesterMixin, unitt
"prompt": "a cat and a frog", "prompt": "a cat and a frog",
"token_indices": [2, 5], "token_indices": [2, 5],
"generator": generator, "generator": generator,
"num_inference_steps": 2, "num_inference_steps": 1,
"guidance_scale": 6.0, "guidance_scale": 6.0,
"output_type": "numpy", "output_type": "numpy",
"max_iter_to_alter": 2, "max_iter_to_alter": 2,
...@@ -132,13 +132,18 @@ class StableDiffusionAttendAndExcitePipelineFastTests(PipelineTesterMixin, unitt ...@@ -132,13 +132,18 @@ class StableDiffusionAttendAndExcitePipelineFastTests(PipelineTesterMixin, unitt
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
self.assertEqual(image.shape, (1, 64, 64, 3)) self.assertEqual(image.shape, (1, 64, 64, 3))
expected_slice = np.array([0.5743, 0.6081, 0.4975, 0.5021, 0.5441, 0.4699, 0.4988, 0.4841, 0.4851]) expected_slice = np.array(
[0.63905364, 0.62897307, 0.48599017, 0.5133624, 0.5550048, 0.45769516, 0.50326973, 0.5023139, 0.45384496]
)
max_diff = np.abs(image_slice.flatten() - expected_slice).max() max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3) self.assertLessEqual(max_diff, 1e-3)
def test_inference_batch_consistent(self): def test_inference_batch_consistent(self):
# NOTE: Larger batch sizes cause this test to timeout, only test on smaller batches # NOTE: Larger batch sizes cause this test to timeout, only test on smaller batches
self._test_inference_batch_consistent(batch_sizes=[2, 4]) self._test_inference_batch_consistent(batch_sizes=[1, 2])
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(batch_size=2)
@require_torch_gpu @require_torch_gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment