Commit 1a47fa08 authored by Thorsten Kurth's avatar Thorsten Kurth
Browse files

streamlining perf test

parent a07c5b2b
...@@ -65,10 +65,9 @@ class TestNeighborhoodAttentionS2(unittest.TestCase): ...@@ -65,10 +65,9 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
self.device = torch.device("cuda:0") self.device = torch.device("cuda:0")
torch.cuda.set_device(self.device.index) torch.cuda.set_device(self.device.index)
torch.cuda.manual_seed(333) torch.cuda.manual_seed(333)
torch.manual_seed(333)
else: else:
self.device = torch.device("cpu") self.device = torch.device("cpu")
torch.manual_seed(333) torch.manual_seed(333)
@parameterized.expand( @parameterized.expand(
[ [
...@@ -78,7 +77,8 @@ class TestNeighborhoodAttentionS2(unittest.TestCase): ...@@ -78,7 +77,8 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
[4, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3], [4, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 4, 1, (6, 12), (6, 12), "legendre-gauss", "legendre-gauss", 1e-5, 1e-3], [4, 4, 1, (6, 12), (6, 12), "legendre-gauss", "legendre-gauss", 1e-5, 1e-3],
[4, 4, 1, (6, 12), (6, 12), "lobatto", "lobatto", 1e-5, 1e-3], [4, 4, 1, (6, 12), (6, 12), "lobatto", "lobatto", 1e-5, 1e-3],
] ],
skip_on_empty=True,
) )
def test_custom_implementation(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True): def test_custom_implementation(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True):
"""Tests numerical equivalence between the custom (CUDA) implementation and the reference torch implementation""" """Tests numerical equivalence between the custom (CUDA) implementation and the reference torch implementation"""
...@@ -157,7 +157,8 @@ class TestNeighborhoodAttentionS2(unittest.TestCase): ...@@ -157,7 +157,8 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
# [4, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3], # [4, 4, 4, (6, 12), (6, 12), "equiangular", "equiangular", 1e-5, 1e-3],
[4, 4, 1, (6, 12), (6, 12), "legendre-gauss", "legendre-gauss", 1e-2, 0], [4, 4, 1, (6, 12), (6, 12), "legendre-gauss", "legendre-gauss", 1e-2, 0],
[4, 4, 1, (6, 12), (6, 12), "lobatto", "lobatto", 1e-2, 0], [4, 4, 1, (6, 12), (6, 12), "lobatto", "lobatto", 1e-2, 0],
] ],
skip_on_empty=True,
) )
def test_neighborhood_global_equivalence(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True): def test_neighborhood_global_equivalence(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True):
"""Tests numerical equivalence between the global spherical attention module and the neighborhood spherical attention module with the neighborhood set ot the whole sphere""" """Tests numerical equivalence between the global spherical attention module and the neighborhood spherical attention module with the neighborhood set ot the whole sphere"""
...@@ -212,30 +213,26 @@ class TestNeighborhoodAttentionS2(unittest.TestCase): ...@@ -212,30 +213,26 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
self.assertTrue(torch.allclose(grad, grad_ref, atol=atol, rtol=rtol), f"Parameter gradient mismatch") self.assertTrue(torch.allclose(grad, grad_ref, atol=atol, rtol=rtol), f"Parameter gradient mismatch")
@unittest.skipIf((not torch.cuda.is_available()) or (not _cuda_extension_available), "skipping performance test because CUDA is not available")
@parameterized.expand( @parameterized.expand(
[ [
# self attention # self attention
[1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5], [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
] ],
skip_on_empty=True,
) )
def test_perf(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol): def test_perf(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=False):
# this test only makes sense when CUDA version is available
if torch.cuda.is_available():
if not _cuda_extension_available:
print("WARNING: Problem loading CUDA attention module")
return
# extract some parameters # extract some parameters
nlat_in, nlon_in = in_shape nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape nlat_out, nlon_out = out_shape
# TODO: this test seems hardcoded for GPU. Is this necessary? # TODO: this test seems hardcoded for GPU. Is this necessary?
k_gpu = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device="cuda:0") k_gpu = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device=self.device)
k_gpu.requires_grad = False k_gpu.requires_grad = False
v_gpu = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device="cuda:0") v_gpu = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device=self.device)
v_gpu.requires_grad = False v_gpu.requires_grad = False
q_gpu = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device="cuda:0") q_gpu = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device=self.device)
q_gpu.requires_grad = False q_gpu.requires_grad = False
# set up layers # set up layers
...@@ -244,10 +241,9 @@ class TestNeighborhoodAttentionS2(unittest.TestCase): ...@@ -244,10 +241,9 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
time_layer_setup_start.record() time_layer_setup_start.record()
att_gpu = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads, att_gpu = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads,
in_shape=in_shape, out_shape=out_shape, in_shape=in_shape, out_shape=out_shape,
grid_in=grid_in, grid_out=grid_out, bias=True).to("cuda:0") grid_in=grid_in, grid_out=grid_out, bias=True).to(self.device)
time_layer_setup_end.record() time_layer_setup_end.record()
torch.cuda.synchronize() torch.cuda.synchronize()
# print(f"Layer setup: {time_layer_setup_start.elapsed_time(time_layer_setup_end)} ms")
# random weights # random weights
with torch.no_grad(): with torch.no_grad():
...@@ -268,9 +264,10 @@ class TestNeighborhoodAttentionS2(unittest.TestCase): ...@@ -268,9 +264,10 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
out_gpu = att_gpu(q_gpu, k_gpu, v_gpu) out_gpu = att_gpu(q_gpu, k_gpu, v_gpu)
time_forward_end.record() time_forward_end.record()
torch.cuda.synchronize() torch.cuda.synchronize()
elapsed_time = time_forward_start.elapsed_time(time_forward_end) elapsed_time = time_forward_start.elapsed_time(time_forward_end)
assert elapsed_time < 150, "Forward pass took much too long, there must be a performance regression!" if verbose:
print(f"Forward execution time: {elapsed_time} ms")
self.assertTrue(elapsed_time < 150)
# sync weights: # sync weights:
with torch.no_grad(): with torch.no_grad():
...@@ -302,9 +299,10 @@ class TestNeighborhoodAttentionS2(unittest.TestCase): ...@@ -302,9 +299,10 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
out_gpu.backward(out_grad) out_gpu.backward(out_grad)
time_backward_end.record() time_backward_end.record()
torch.cuda.synchronize() torch.cuda.synchronize()
# print(f"Backward execution: {time_backward_start.elapsed_time(time_backward_end)} ms")
elapsed_time = time_backward_start.elapsed_time(time_backward_end) elapsed_time = time_backward_start.elapsed_time(time_backward_end)
assert elapsed_time < 400, "Backward pass took much too long, there must be a performance regression!" if verbose:
print(f"Backward execution time: {elapsed_time} ms")
self.assertTrue(elapsed_time < 400)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment