Unverified Commit 584e1bd6 authored by Thorsten Kurth's avatar Thorsten Kurth Committed by GitHub
Browse files

Merge pull request #78 from NVIDIA/tkurth/attention-perf-test-fix

fixing attention perf test attempt 1
parents 26ce5cb5 47beb41a
...@@ -214,7 +214,6 @@ class TestNeighborhoodAttentionS2(unittest.TestCase): ...@@ -214,7 +214,6 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
self.assertTrue(torch.allclose(grad, grad_ref, atol=atol, rtol=rtol), f"Parameter gradient mismatch") self.assertTrue(torch.allclose(grad, grad_ref, atol=atol, rtol=rtol), f"Parameter gradient mismatch")
@unittest.skipUnless((torch.cuda.is_available() and _cuda_extension_available), "skipping performance test because CUDA is not available")
@parameterized.expand( @parameterized.expand(
[ [
# self attention # self attention
...@@ -223,6 +222,7 @@ class TestNeighborhoodAttentionS2(unittest.TestCase): ...@@ -223,6 +222,7 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
], ],
skip_on_empty=True, skip_on_empty=True,
) )
@unittest.skipUnless((torch.cuda.is_available() and _cuda_extension_available), "skipping performance test because CUDA is not available")
def test_perf(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True): def test_perf(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True):
# extract some parameters # extract some parameters
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment