Commit 47beb41a authored by Thorsten Kurth's avatar Thorsten Kurth
Browse files

fixing attention perf test attempt 1

parent 26ce5cb5
...@@ -214,7 +214,6 @@ class TestNeighborhoodAttentionS2(unittest.TestCase): ...@@ -214,7 +214,6 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
self.assertTrue(torch.allclose(grad, grad_ref, atol=atol, rtol=rtol), f"Parameter gradient mismatch") self.assertTrue(torch.allclose(grad, grad_ref, atol=atol, rtol=rtol), f"Parameter gradient mismatch")
@unittest.skipUnless((torch.cuda.is_available() and _cuda_extension_available), "skipping performance test because CUDA is not available")
@parameterized.expand( @parameterized.expand(
[ [
# self attention # self attention
...@@ -223,6 +222,7 @@ class TestNeighborhoodAttentionS2(unittest.TestCase): ...@@ -223,6 +222,7 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
], ],
skip_on_empty=True, skip_on_empty=True,
) )
@unittest.skipUnless((torch.cuda.is_available() and _cuda_extension_available), "skipping performance test because CUDA is not available")
def test_perf(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True): def test_perf(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True):
# extract some parameters # extract some parameters
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment