Unverified Commit 69996938 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

speed up Shap-E fast test (#5686)



skip rendering
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
parent 9ae90593
...@@ -160,7 +160,7 @@ class ShapEPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -160,7 +160,7 @@ class ShapEPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"generator": generator, "generator": generator,
"num_inference_steps": 1, "num_inference_steps": 1,
"frame_size": 32, "frame_size": 32,
"output_type": "np", "output_type": "latent",
} }
return inputs return inputs
...@@ -176,24 +176,12 @@ class ShapEPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -176,24 +176,12 @@ class ShapEPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
output = pipe(**self.get_dummy_inputs(device)) output = pipe(**self.get_dummy_inputs(device))
image = output.images[0] image = output.images[0]
image_slice = image[0, -3:, -3:, -1] image = image.cpu().numpy()
image_slice = image[-3:, -3:]
assert image.shape == (20, 32, 32, 3)
assert image.shape == (32, 16)
expected_slice = np.array(
[
0.00039216,
0.00039216,
0.00039216,
0.00039216,
0.00039216,
0.00039216,
0.00039216,
0.00039216,
0.00039216,
]
)
expected_slice = np.array([-1.0000, -0.6241, 1.0000, -0.8978, -0.6866, 0.7876, -0.7473, -0.2874, 0.6103])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_inference_batch_consistent(self): def test_inference_batch_consistent(self):
......
...@@ -181,7 +181,7 @@ class ShapEImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -181,7 +181,7 @@ class ShapEImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"generator": generator, "generator": generator,
"num_inference_steps": 1, "num_inference_steps": 1,
"frame_size": 32, "frame_size": 32,
"output_type": "np", "output_type": "latent",
} }
return inputs return inputs
...@@ -197,22 +197,12 @@ class ShapEImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -197,22 +197,12 @@ class ShapEImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
output = pipe(**self.get_dummy_inputs(device)) output = pipe(**self.get_dummy_inputs(device))
image = output.images[0] image = output.images[0]
image_slice = image[0, -3:, -3:, -1] image_slice = image[-3:, -3:].cpu().numpy()
assert image.shape == (20, 32, 32, 3) assert image.shape == (32, 16)
expected_slice = np.array( expected_slice = np.array(
[ [-1.0, 0.40668195, 0.57322013, -0.9469888, 0.4283227, 0.30348337, -0.81094897, 0.74555075, 0.15342723]
0.00039216,
0.00039216,
0.00039216,
0.00039216,
0.00039216,
0.00039216,
0.00039216,
0.00039216,
0.00039216,
]
) )
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
......
...@@ -493,7 +493,7 @@ class PipelineTesterMixin: ...@@ -493,7 +493,7 @@ class PipelineTesterMixin:
assert output_batch[0].shape[0] == batch_size assert output_batch[0].shape[0] == batch_size
max_diff = np.abs(output_batch[0][0] - output[0][0]).max() max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max()
assert max_diff < expected_max_diff assert max_diff < expected_max_diff
def test_dict_tuple_outputs_equivalent(self, expected_max_difference=1e-4): def test_dict_tuple_outputs_equivalent(self, expected_max_difference=1e-4):
...@@ -702,7 +702,7 @@ class PipelineTesterMixin: ...@@ -702,7 +702,7 @@ class PipelineTesterMixin:
self.assertLess(max_diff, expected_max_diff, "Attention slicing should not affect the inference results") self.assertLess(max_diff, expected_max_diff, "Attention slicing should not affect the inference results")
if test_mean_pixel_difference: if test_mean_pixel_difference:
assert_mean_pixel_difference(output_with_slicing[0], output_without_slicing[0]) assert_mean_pixel_difference(to_np(output_with_slicing[0]), to_np(output_without_slicing[0]))
@unittest.skipIf( @unittest.skipIf(
torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"), torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment