Unverified Commit 49a61eee authored by Thorsten Kurth's avatar Thorsten Kurth Committed by GitHub
Browse files

Merge pull request #83 from NVIDIA/maurob/devel

Optimized forward kernel for attention
parents c485a1fb c90b421a
......@@ -81,7 +81,7 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
],
skip_on_empty=True,
)
def test_custom_implementation(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True):
def test_custom_implementation(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=False):
"""Tests numerical equivalence between the custom (CUDA) implementation and the reference torch implementation"""
nlat_in, nlon_in = in_shape
......@@ -161,7 +161,7 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
],
skip_on_empty=True,
)
def test_neighborhood_global_equivalence(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True):
def test_neighborhood_global_equivalence(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=False):
"""Tests numerical equivalence between the global spherical attention module and the neighborhood spherical attention module with the neighborhood set ot the whole sphere"""
nlat_in, nlon_in = in_shape
......@@ -223,7 +223,7 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
skip_on_empty=True,
)
@unittest.skipUnless((torch.cuda.is_available() and _cuda_extension_available), "skipping performance test because CUDA is not available")
def test_perf(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=True):
def test_perf(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=False):
# extract some parameters
nlat_in, nlon_in = in_shape
......
......@@ -479,9 +479,10 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
qw = qw.reshape(B*nh, -1, H, W)
# convert to float32
kw = kw.to(torch.float32)
vw = vw.to(torch.float32)
qw = qw.to(torch.float32)
inp_dtype = kw.dtype
kw = kw.to(torch.float32).contiguous()
vw = vw.to(torch.float32).contiguous()
qw = qw.to(torch.float32).contiguous()
output = attention_cuda_extension.forward(kw, vw, qw, quad_weights,
col_idx, row_off,
......@@ -490,6 +491,9 @@ class _NeighborhoodAttentionS2Cuda(torch.autograd.Function):
_, C, H, W = output.shape
output = output.reshape(B, -1, H, W)
# convert back precision
output = output.to(dtype=inp_dtype)
return output
@staticmethod
......
......@@ -291,7 +291,7 @@ class NeighborhoodAttentionS2(nn.Module):
# set the last value
row_offset[row + 1] = idz + 1
row_offset = torch.from_numpy(row_offset)
row_offset = torch.from_numpy(row_offset).contiguous()
self.max_psi_nnz = col_idx.max().item() + 1
self.register_buffer("psi_row_idx", row_idx, persistent=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment