Commit c90b421a authored by Mauro Bisson's avatar Mauro Bisson
Browse files

Added Thorsten's fix for bug regarding contiguous storage taken from:

https://github.com/NVIDIA/torch-harmonics/compare/main...azrael417:torch-harmonics:tkurth/mauro-rebase
parent 7aa95ce5
...@@ -81,7 +81,7 @@ class TestNeighborhoodAttentionS2(unittest.TestCase): ...@@ -81,7 +81,7 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
], ],
skip_on_empty=True, skip_on_empty=True,
) )
def test_custom_implementation(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True): def test_custom_implementation(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=False):
"""Tests numerical equivalence between the custom (CUDA) implementation and the reference torch implementation""" """Tests numerical equivalence between the custom (CUDA) implementation and the reference torch implementation"""
nlat_in, nlon_in = in_shape nlat_in, nlon_in = in_shape
...@@ -161,7 +161,7 @@ class TestNeighborhoodAttentionS2(unittest.TestCase): ...@@ -161,7 +161,7 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
], ],
skip_on_empty=True, skip_on_empty=True,
) )
def test_neighborhood_global_equivalence(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True): def test_neighborhood_global_equivalence(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=False):
"""Tests numerical equivalence between the global spherical attention module and the neighborhood spherical attention module with the neighborhood set ot the whole sphere""" """Tests numerical equivalence between the global spherical attention module and the neighborhood spherical attention module with the neighborhood set ot the whole sphere"""
nlat_in, nlon_in = in_shape nlat_in, nlon_in = in_shape
...@@ -223,7 +223,7 @@ class TestNeighborhoodAttentionS2(unittest.TestCase): ...@@ -223,7 +223,7 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
skip_on_empty=True, skip_on_empty=True,
) )
@unittest.skipUnless((torch.cuda.is_available() and _cuda_extension_available), "skipping performance test because CUDA is not available") @unittest.skipUnless((torch.cuda.is_available() and _cuda_extension_available), "skipping performance test because CUDA is not available")
def test_perf(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True): def test_perf(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=False):
# extract some parameters # extract some parameters
nlat_in, nlon_in = in_shape nlat_in, nlon_in = in_shape
......
...@@ -479,9 +479,10 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function): ...@@ -479,9 +479,10 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
qw = qw.reshape(B*nh, -1, H, W) qw = qw.reshape(B*nh, -1, H, W)
# convert to float32 # convert to float32
kw = kw.to(torch.float32) inp_dtype = kw.dtype
vw = vw.to(torch.float32) kw = kw.to(torch.float32).contiguous()
qw = qw.to(torch.float32) vw = vw.to(torch.float32).contiguous()
qw = qw.to(torch.float32).contiguous()
output = attention_cuda_extension.forward(kw, vw, qw, quad_weights, output = attention_cuda_extension.forward(kw, vw, qw, quad_weights,
col_idx, row_off, col_idx, row_off,
...@@ -490,6 +491,9 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function): ...@@ -490,6 +491,9 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
_, C, H, W = output.shape _, C, H, W = output.shape
output = output.reshape(B, -1, H, W) output = output.reshape(B, -1, H, W)
# convert back precision
output = output.to(dtype=inp_dtype)
return output return output
@staticmethod @staticmethod
......
...@@ -291,7 +291,7 @@ class NeighborhoodAttentionS2(nn.Module): ...@@ -291,7 +291,7 @@ class NeighborhoodAttentionS2(nn.Module):
# set the last value # set the last value
row_offset[row + 1] = idz + 1 row_offset[row + 1] = idz + 1
row_offset = torch.from_numpy(row_offset) row_offset = torch.from_numpy(row_offset).contiguous()
self.max_psi_nnz = col_idx.max().item() + 1 self.max_psi_nnz = col_idx.max().item() + 1
self.register_buffer("psi_row_idx", row_idx, persistent=False) self.register_buffer("psi_row_idx", row_idx, persistent=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment