Commit 79fa6ad9 authored by Thorsten Kurth's avatar Thorsten Kurth
Browse files

adjusted perf test shapes

parent ec413b4d
...@@ -58,6 +58,7 @@ except ImportError as err: ...@@ -58,6 +58,7 @@ except ImportError as err:
attention_cuda_extension = None attention_cuda_extension = None
_cuda_extension_available = False _cuda_extension_available = False
_perf_test_thresholds = {"fwd_ms": 50, "bwd_ms": 150}
class TestNeighborhoodAttentionS2(unittest.TestCase): class TestNeighborhoodAttentionS2(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -213,15 +214,16 @@ class TestNeighborhoodAttentionS2(unittest.TestCase): ...@@ -213,15 +214,16 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
self.assertTrue(torch.allclose(grad, grad_ref, atol=atol, rtol=rtol), f"Parameter gradient mismatch") self.assertTrue(torch.allclose(grad, grad_ref, atol=atol, rtol=rtol), f"Parameter gradient mismatch")
@unittest.skipIf((not torch.cuda.is_available()) or (not _cuda_extension_available), "skipping performance test because CUDA is not available") @unittest.skipUnless((torch.cuda.is_available() and _cuda_extension_available), "skipping performance test because CUDA is not available")
@parameterized.expand( @parameterized.expand(
[ [
# self attention # self attention
[1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5], #[1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
[1, 256, 1, (361, 720), (361, 720), "equiangular", "equiangular", 1e-5, 1e-5],
], ],
skip_on_empty=True, skip_on_empty=True,
) )
def test_perf(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=False): def test_perf(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True):
# extract some parameters # extract some parameters
nlat_in, nlon_in = in_shape nlat_in, nlon_in = in_shape
...@@ -267,7 +269,7 @@ class TestNeighborhoodAttentionS2(unittest.TestCase): ...@@ -267,7 +269,7 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
elapsed_time = time_forward_start.elapsed_time(time_forward_end) elapsed_time = time_forward_start.elapsed_time(time_forward_end)
if verbose: if verbose:
print(f"Forward execution time: {elapsed_time} ms") print(f"Forward execution time: {elapsed_time} ms")
self.assertTrue(elapsed_time < 150) self.assertTrue(elapsed_time < _perf_test_thresholds["fwd_ms"])
# sync weights: # sync weights:
with torch.no_grad(): with torch.no_grad():
...@@ -286,7 +288,7 @@ class TestNeighborhoodAttentionS2(unittest.TestCase): ...@@ -286,7 +288,7 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
v_gpu.requires_grad = True v_gpu.requires_grad = True
out_gpu = att_gpu(q_gpu, k_gpu, v_gpu) out_gpu = att_gpu(q_gpu, k_gpu, v_gpu)
out_grad = torch.randn(out_gpu.shape, dtype=torch.float32, device="cuda:0") out_grad = torch.randn(out_gpu.shape, dtype=torch.float32, device=self.device)
time_backward_start = torch.cuda.Event(enable_timing=True) time_backward_start = torch.cuda.Event(enable_timing=True)
time_backward_end = torch.cuda.Event(enable_timing=True) time_backward_end = torch.cuda.Event(enable_timing=True)
...@@ -294,7 +296,6 @@ class TestNeighborhoodAttentionS2(unittest.TestCase): ...@@ -294,7 +296,6 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
# warmup # warmup
out_gpu.backward(out_grad, retain_graph=True) out_gpu.backward(out_grad, retain_graph=True)
# print("out_grad_stride=",out_grad.stride())
time_backward_start.record() time_backward_start.record()
out_gpu.backward(out_grad) out_gpu.backward(out_grad)
time_backward_end.record() time_backward_end.record()
...@@ -302,7 +303,7 @@ class TestNeighborhoodAttentionS2(unittest.TestCase): ...@@ -302,7 +303,7 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
elapsed_time = time_backward_start.elapsed_time(time_backward_end) elapsed_time = time_backward_start.elapsed_time(time_backward_end)
if verbose: if verbose:
print(f"Backward execution time: {elapsed_time} ms") print(f"Backward execution time: {elapsed_time} ms")
self.assertTrue(elapsed_time < 400) self.assertTrue(elapsed_time < _perf_test_thresholds["bwd_ms"])
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment