Commit 3d604f85 authored by Thorsten Kurth's avatar Thorsten Kurth
Browse files

renamings

parent 0767c39c
...@@ -39,7 +39,7 @@ from torch.autograd import gradcheck ...@@ -39,7 +39,7 @@ from torch.autograd import gradcheck
from torch_harmonics import quadrature, DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2 from torch_harmonics import quadrature, DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes
from torch_harmonics.convolution import _precompute_convolution_tensor_s2
_devices = [(torch.device("cpu"),)] _devices = [(torch.device("cpu"),)]
if torch.cuda.is_available(): if torch.cuda.is_available():
...@@ -199,9 +199,10 @@ class TestDiscreteContinuousConvolution(unittest.TestCase): ...@@ -199,9 +199,10 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
[8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "equiangular", "legendre-gauss", True, 1e-4, False], [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "equiangular", "legendre-gauss", True, 1e-4, False],
[8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "equiangular", True, 1e-4, False], [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "equiangular", True, 1e-4, False],
[8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", True, 1e-4, False], [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", True, 1e-4, False],
] ],
skip_on_empty=True,
) )
def test_disco_convolution( def test_forward_backward(
self, self,
batch_size, batch_size,
in_channels, in_channels,
...@@ -321,11 +322,12 @@ class TestDiscreteContinuousConvolution(unittest.TestCase): ...@@ -321,11 +322,12 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
[8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", False, 1e-4, False], [8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", False, 1e-4, False],
[8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False], [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
[8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", True, 1e-4, False], [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", True, 1e-4, False],
] ],
skip_on_empty=True,
) )
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is not available") @unittest.skipIf(not torch.cuda.is_available(), "CUDA is not available")
def test_device_instantiation(self, batch_size, in_channels, out_channels, in_shape, out_shape, kernel_shape, basis_type, basis_norm_mode, grid_in, grid_out, transpose, tol, verbose): def test_device_instantiation(self, batch_size, in_channels, out_channels, in_shape, out_shape, kernel_shape, basis_type, basis_norm_mode, grid_in, grid_out, transpose, tol, verbose):
nlat_in, nlon_in = in_shape nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape nlat_out, nlon_out = out_shape
...@@ -334,7 +336,10 @@ class TestDiscreteContinuousConvolution(unittest.TestCase): ...@@ -334,7 +336,10 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
else: else:
theta_cutoff = (kernel_shape[0] + 1) * torch.pi / float(nlat_in - 1) theta_cutoff = (kernel_shape[0] + 1) * torch.pi / float(nlat_in - 1)
# get handle
Conv = DiscreteContinuousConvTransposeS2 if transpose else DiscreteContinuousConvS2 Conv = DiscreteContinuousConvTransposeS2 if transpose else DiscreteContinuousConvS2
# init on cpu
conv_host = Conv( conv_host = Conv(
in_channels, in_channels,
out_channels, out_channels,
...@@ -350,9 +355,9 @@ class TestDiscreteContinuousConvolution(unittest.TestCase): ...@@ -350,9 +355,9 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
theta_cutoff=theta_cutoff, theta_cutoff=theta_cutoff,
) )
torch.set_default_device(self.device) #torch.set_default_device(self.device)
#with(self.device): with torch.device(self.device):
conv_device = Conv( conv_device = Conv(
in_channels, in_channels,
out_channels, out_channels,
in_shape, in_shape,
...@@ -367,8 +372,8 @@ class TestDiscreteContinuousConvolution(unittest.TestCase): ...@@ -367,8 +372,8 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
theta_cutoff=theta_cutoff, theta_cutoff=theta_cutoff,
) )
print(conv_host.psi_col_idx.device, conv_device.psi_col_idx.device) # since we specified the device specifier everywhere, it should always
# use the cpu and it should be the same everywhere
self.assertTrue(torch.allclose(conv_host.psi_col_idx.cpu(), conv_device.psi_col_idx.cpu())) self.assertTrue(torch.allclose(conv_host.psi_col_idx.cpu(), conv_device.psi_col_idx.cpu()))
self.assertTrue(torch.allclose(conv_host.psi_row_idx.cpu(), conv_device.psi_row_idx.cpu())) self.assertTrue(torch.allclose(conv_host.psi_row_idx.cpu(), conv_device.psi_row_idx.cpu()))
self.assertTrue(torch.allclose(conv_host.psi_roff_idx.cpu(), conv_device.psi_roff_idx.cpu())) self.assertTrue(torch.allclose(conv_host.psi_roff_idx.cpu(), conv_device.psi_roff_idx.cpu()))
......
...@@ -101,15 +101,16 @@ class TestSphericalHarmonicTransform(unittest.TestCase): ...@@ -101,15 +101,16 @@ class TestSphericalHarmonicTransform(unittest.TestCase):
[33, 64, 32, "ortho", "equiangular", 1e-9, False], [33, 64, 32, "ortho", "equiangular", 1e-9, False],
[33, 64, 32, "ortho", "legendre-gauss", 1e-9, False], [33, 64, 32, "ortho", "legendre-gauss", 1e-9, False],
[33, 64, 32, "ortho", "lobatto", 1e-9, False], [33, 64, 32, "ortho", "lobatto", 1e-9, False],
[33, 64, 32, "four-pi", "equiangular", 1e-9, False], [33, 64, 32, "four-pi", "equiangular", 1e-9, False],
[33, 64, 32, "four-pi", "legendre-gauss", 1e-9, False], [33, 64, 32, "four-pi", "legendre-gauss", 1e-9, False],
[33, 64, 32, "four-pi", "lobatto", 1e-9, False], [33, 64, 32, "four-pi", "lobatto", 1e-9, False],
[33, 64, 32, "schmidt", "equiangular", 1e-9, False], [33, 64, 32, "schmidt", "equiangular", 1e-9, False],
[33, 64, 32, "schmidt", "legendre-gauss", 1e-9, False], [33, 64, 32, "schmidt", "legendre-gauss", 1e-9, False],
[33, 64, 32, "schmidt", "lobatto", 1e-9, False], [33, 64, 32, "schmidt", "lobatto", 1e-9, False],
] ],
skip_on_empty=True,
) )
def test_sht(self, nlat, nlon, batch_size, norm, grid, tol, verbose): def test_forward_inverse(self, nlat, nlon, batch_size, norm, grid, tol, verbose):
if verbose: if verbose:
print(f"Testing real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization on {self.device.type} device") print(f"Testing real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization on {self.device.type} device")
...@@ -168,9 +169,10 @@ class TestSphericalHarmonicTransform(unittest.TestCase): ...@@ -168,9 +169,10 @@ class TestSphericalHarmonicTransform(unittest.TestCase):
[15, 30, 2, "schmidt", "equiangular", 1e-5, False], [15, 30, 2, "schmidt", "equiangular", 1e-5, False],
[15, 30, 2, "schmidt", "legendre-gauss", 1e-5, False], [15, 30, 2, "schmidt", "legendre-gauss", 1e-5, False],
[15, 30, 2, "schmidt", "lobatto", 1e-5, False], [15, 30, 2, "schmidt", "lobatto", 1e-5, False],
] ],
skip_on_empty=True,
) )
def test_sht_grads(self, nlat, nlon, batch_size, norm, grid, tol, verbose): def test_grads(self, nlat, nlon, batch_size, norm, grid, tol, verbose):
if verbose: if verbose:
print(f"Testing gradients of real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization") print(f"Testing gradients of real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization")
...@@ -202,6 +204,40 @@ class TestSphericalHarmonicTransform(unittest.TestCase): ...@@ -202,6 +204,40 @@ class TestSphericalHarmonicTransform(unittest.TestCase):
test_result = gradcheck(err_handle, grad_input, eps=1e-6, atol=tol) test_result = gradcheck(err_handle, grad_input, eps=1e-6, atol=tol)
self.assertTrue(test_result) self.assertTrue(test_result)
@parameterized.expand(
[
# even-even
[12, 24, 2, "ortho", "equiangular", 1e-5, False],
[12, 24, 2, "ortho", "legendre-gauss", 1e-5, False],
[12, 24, 2, "ortho", "lobatto", 1e-5, False],
],
skip_on_empty=True,
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is not available")
def test_device_instantiation(self, nlat, nlon, batch_size, norm, grid, tol, verbose):
if verbose:
print(f"Testing device instantiation of real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization")
if grid == "equiangular":
mmax = nlat // 2
elif grid == "lobatto":
mmax = nlat - 1
else:
mmax = nlat
lmax = mmax
# init on cpu
sht_host = th.RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm)
isht_host = th.InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm)
# init on device
with torch.device(self.device):
sht_device = th.RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm)
isht_device = th.InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm)
self.assertTrue(torch.allclose(sht_host.weights.cpu(), sht_device.weights.cpu()))
self.assertTrue(torch.allclose(isht_host.pct.cpu(), isht_device.pct.cpu()))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment