Unverified Commit 2b4f849d authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[PipelineTesterMixin] Handle non-image outputs for attn slicing test (#2504)



* [PipelineTesterMixin] Handle non-image outputs for batch/sinle inference test

* style

---------
Co-authored-by: default avatarWilliam Berman <WLBberman@gmail.com>
parent e4c356d3
......@@ -450,7 +450,9 @@ class PipelineTesterMixin:
def test_attention_slicing_forward_pass(self):
self._test_attention_slicing_forward_pass()
def _test_attention_slicing_forward_pass(self, test_max_difference=True, expected_max_diff=1e-3):
def _test_attention_slicing_forward_pass(
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
):
if not self.test_attention_slicing:
return
......@@ -474,6 +476,7 @@ class PipelineTesterMixin:
max_diff = np.abs(output_with_slicing - output_without_slicing).max()
self.assertLess(max_diff, expected_max_diff, "Attention slicing should not affect the inference results")
if test_mean_pixel_difference:
assert_mean_pixel_difference(output_with_slicing[0], output_without_slicing[0])
@unittest.skipIf(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment