Commit ea8d1a2e authored by Thorsten Kurth's avatar Thorsten Kurth
Browse files

adding device args to more functions

parent c877cda6
......@@ -127,7 +127,7 @@ def _precompute_convolution_tensor_dense(
quad_weights = win.reshape(-1, 1) / nlon_in / 2.0
# array for accumulating non-zero indices
out = torch.zeros(kernel_size, nlat_out, nlon_out, nlat_in, nlon_in, dtype=torch.float64)
out = torch.zeros(kernel_size, nlat_out, nlon_out, nlat_in, nlon_in, dtype=torch.float64, device=lons_in.device)
for t in range(nlat_out):
for p in range(nlon_out):
......@@ -315,6 +315,66 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
self.assertTrue(torch.allclose(x_grad, x_ref_grad, rtol=tol, atol=tol))
self.assertTrue(torch.allclose(conv.weight.grad, w_ref.grad, rtol=tol, atol=tol))
@parameterized.expand(
[
[8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
[8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", False, 1e-4, False],
[8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
[8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", True, 1e-4, False],
]
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is not available")
def test_device_instantiation(self, batch_size, in_channels, out_channels, in_shape, out_shape, kernel_shape, basis_type, basis_norm_mode, grid_in, grid_out, transpose, tol, verbose):
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
if isinstance(kernel_shape, int):
theta_cutoff = (kernel_shape + 1) * torch.pi / float(nlat_in - 1)
else:
theta_cutoff = (kernel_shape[0] + 1) * torch.pi / float(nlat_in - 1)
Conv = DiscreteContinuousConvTransposeS2 if transpose else DiscreteContinuousConvS2
conv_host = Conv(
in_channels,
out_channels,
in_shape,
out_shape,
kernel_shape,
basis_type=basis_type,
basis_norm_mode=basis_norm_mode,
groups=1,
grid_in=grid_in,
grid_out=grid_out,
bias=False,
theta_cutoff=theta_cutoff,
)
torch.set_default_device(self.device)
#with(self.device):
conv_device = Conv(
in_channels,
out_channels,
in_shape,
out_shape,
kernel_shape,
basis_type=basis_type,
basis_norm_mode=basis_norm_mode,
groups=1,
grid_in=grid_in,
grid_out=grid_out,
bias=False,
theta_cutoff=theta_cutoff,
)
print(conv_host.psi_col_idx.device, conv_device.psi_col_idx.device)
self.assertTrue(torch.allclose(conv_host.psi_col_idx.cpu(), conv_device.psi_col_idx.cpu()))
self.assertTrue(torch.allclose(conv_host.psi_row_idx.cpu(), conv_device.psi_row_idx.cpu()))
self.assertTrue(torch.allclose(conv_host.psi_roff_idx.cpu(), conv_device.psi_roff_idx.cpu()))
self.assertTrue(torch.allclose(conv_host.psi_vals.cpu(), conv_device.psi_vals.cpu()))
self.assertTrue(torch.allclose(conv_host.psi_idx.cpu(), conv_device.psi_idx.cpu()))
if __name__ == "__main__":
unittest.main()
......@@ -254,7 +254,7 @@ class MorletFilterBasis(FilterBasis):
mkernel = ikernel // self.kernel_shape[1]
# get relevant indices
iidx = torch.argwhere((r <= r_cutoff) & torch.full_like(ikernel, True, dtype=torch.bool))
iidx = torch.argwhere((r <= r_cutoff) & torch.full_like(ikernel, True, dtype=torch.bool, device=r.device))
# get corresponding r, phi, x and y coordinates
r = r[iidx[:, 1], iidx[:, 2]] / r_cutoff
......@@ -316,10 +316,10 @@ class ZernikeFilterBasis(FilterBasis):
"""
# enumerator for basis function
ikernel = torch.arange(self.kernel_size).reshape(-1, 1, 1)
ikernel = torch.arange(self.kernel_size, device=r.device).reshape(-1, 1, 1)
# get relevant indices
iidx = torch.argwhere((r <= r_cutoff) & torch.full_like(ikernel, True, dtype=torch.bool))
iidx = torch.argwhere((r <= r_cutoff) & torch.full_like(ikernel, True, dtype=torch.bool, device=r.device))
# indexing logic for zernike polynomials
# the total index is given by (n * (n + 2) + l ) // 2 which needs to be reversed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment