Unverified Commit c971d458 authored by Boris Bonev's avatar Boris Bonev Committed by GitHub
Browse files

Bbonev/disco refactor (#28)

Changed the code to only implicitly use sparse tensors in the modules, in order to enable compatibility with DDP
parent 942aa4ea
......@@ -221,8 +221,10 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
in_shape, out_shape, kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff
).to(self.device)
psi = torch.sparse_coo_tensor(conv.psi_idx, conv.psi_vals, size=(conv.kernel_size, conv.nlat_out, conv.nlat_in * conv.nlon_in)).to_dense()
self.assertTrue(
torch.allclose(conv.psi.to_dense(), psi_dense[:, :, 0].reshape(-1, nlat_out, nlat_in * nlon_in))
torch.allclose(psi, psi_dense[:, :, 0].reshape(-1, nlat_out, nlat_in * nlon_in))
)
# create a copy of the weight
......
......@@ -222,10 +222,12 @@ class DiscreteContinuousConvS2(nn.Module):
idx, vals = _precompute_convolution_tensor(
in_shape, out_shape, kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff
)
psi = torch.sparse_coo_tensor(
idx, vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in)
).coalesce()
self.register_buffer("psi", psi, persistent=False)
# psi = torch.sparse_coo_tensor(
# idx, vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in)
# ).coalesce()
self.register_buffer("psi_idx", idx, persistent=False)
self.register_buffer("psi_vals", vals, persistent=False)
# self.register_buffer("psi", psi, persistent=False)
# groups
self.groups = groups
......@@ -248,10 +250,12 @@ class DiscreteContinuousConvS2(nn.Module):
# pre-multiply x with the quadrature weights
x = self.quad_weights * x
psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in)).coalesce()
if x.is_cuda and use_triton_kernel:
x = _disco_s2_contraction_triton(x, self.psi, self.nlon_out)
x = _disco_s2_contraction_triton(x, psi, self.nlon_out)
else:
x = _disco_s2_contraction_torch(x, self.psi, self.nlon_out)
x = _disco_s2_contraction_torch(x, psi, self.nlon_out)
# extract shape
B, C, K, H, W = x.shape
......@@ -317,10 +321,12 @@ class DiscreteContinuousConvTransposeS2(nn.Module):
idx, vals = _precompute_convolution_tensor(
out_shape, in_shape, kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff
)
psi = torch.sparse_coo_tensor(
idx, vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out)
).coalesce()
self.register_buffer("psi", psi, persistent=False)
# psi = torch.sparse_coo_tensor(
# idx, vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out)
# ).coalesce()
self.register_buffer("psi_idx", idx, persistent=False)
self.register_buffer("psi_vals", vals, persistent=False)
# self.register_buffer("psi", psi, persistent=False)
# groups
self.groups = groups
......@@ -351,10 +357,12 @@ class DiscreteContinuousConvTransposeS2(nn.Module):
# pre-multiply x with the quadrature weights
x = self.quad_weights * x
psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out)).coalesce()
if x.is_cuda and use_triton_kernel:
out = _disco_s2_transpose_contraction_triton(x, self.psi, self.nlon_out)
out = _disco_s2_transpose_contraction_triton(x, psi, self.nlon_out)
else:
out = _disco_s2_transpose_contraction_torch(x, self.psi, self.nlon_out)
out = _disco_s2_transpose_contraction_torch(x, psi, self.nlon_out)
if self.bias is not None:
out = out + self.bias.reshape(1, -1, 1, 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment