Commit 87d9bfdc authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

bugfix in distributed convolution

parent e5a9c4af
...@@ -70,6 +70,7 @@ if [ "$run_distributed" = "true" ]; then ...@@ -70,6 +70,7 @@ if [ "$run_distributed" = "true" ]; then
export GRID_W=${grid_size_lon}; export GRID_W=${grid_size_lon};
python3 -m pytest tests/test_distributed_sht.py python3 -m pytest tests/test_distributed_sht.py
python3 -m pytest tests/test_distributed_convolution.py python3 -m pytest tests/test_distributed_convolution.py
python3 -m pytest tests/test_distributed_resample.py
" "
else else
echo "Skipping distributed tests." echo "Skipping distributed tests."
......
...@@ -177,6 +177,7 @@ class TestDiscreteContinuousConvolution(unittest.TestCase): ...@@ -177,6 +177,7 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
[8, 4, 2, (24, 48), (12, 24), [3, 3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4], [8, 4, 2, (24, 48), (12, 24), [3, 3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (24, 48), (12, 24), [4, 3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4], [8, 4, 2, (24, 48), (12, 24), [4, 3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (24, 48), (12, 24), [2, 2], "morlet", "mean", "equiangular", "equiangular", False, 1e-4], [8, 4, 2, (24, 48), (12, 24), [2, 2], "morlet", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (24, 48), (12, 24), [3], "zernike", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (16, 24), (8, 8), [3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4], [8, 4, 2, (16, 24), (8, 8), [3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (18, 36), (6, 12), [7], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4], [8, 4, 2, (18, 36), (6, 12), [7], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), [5], "piecewise linear", "mean", "equiangular", "legendre-gauss", False, 1e-4], [8, 4, 2, (16, 32), (8, 16), [5], "piecewise linear", "mean", "equiangular", "legendre-gauss", False, 1e-4],
...@@ -188,6 +189,7 @@ class TestDiscreteContinuousConvolution(unittest.TestCase): ...@@ -188,6 +189,7 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
[8, 4, 2, (12, 24), (24, 48), [3, 3], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4], [8, 4, 2, (12, 24), (24, 48), [3, 3], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (12, 24), (24, 48), [4, 3], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4], [8, 4, 2, (12, 24), (24, 48), [4, 3], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (12, 24), (24, 48), [2, 2], "morlet", "mean", "equiangular", "equiangular", True, 1e-4], [8, 4, 2, (12, 24), (24, 48), [2, 2], "morlet", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (12, 24), (24, 48), [3], "zernike", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (8, 8), (16, 24), [3], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4], [8, 4, 2, (8, 8), (16, 24), [3], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (6, 12), (18, 36), [7], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4], [8, 4, 2, (6, 12), (18, 36), [7], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), [5], "piecewise linear", "mean", "equiangular", "legendre-gauss", True, 1e-4], [8, 4, 2, (8, 16), (16, 32), [5], "piecewise linear", "mean", "equiangular", "legendre-gauss", True, 1e-4],
......
...@@ -183,18 +183,18 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): ...@@ -183,18 +183,18 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
@parameterized.expand( @parameterized.expand(
[ [
[128, 256, 128, 256, 32, 8, [3], "piecewise linear", "individual", 1, "equiangular", "equiangular", False, 1e-5], [128, 256, 128, 256, 32, 8, [3], "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[129, 256, 128, 256, 32, 8, [3], "piecewise linear", "individual", 1, "equiangular", "equiangular", False, 1e-5], [129, 256, 128, 256, 32, 8, [3], "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, [3, 2], "piecewise linear", "individual", 1, "equiangular", "equiangular", False, 1e-5], [128, 256, 128, 256, 32, 8, [3, 2], "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 64, 128, 32, 8, [3], "piecewise linear", "individual", 1, "equiangular", "equiangular", False, 1e-5], [128, 256, 64, 128, 32, 8, [3], "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, [3], "piecewise linear", "individual", 2, "equiangular", "equiangular", False, 1e-5], [128, 256, 128, 256, 32, 8, [3], "piecewise linear", "mean", 2, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 6, [3], "piecewise linear", "individual", 1, "equiangular", "equiangular", False, 1e-5], [128, 256, 128, 256, 32, 6, [3], "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, [3], "piecewise linear", "individual", 1, "equiangular", "equiangular", True, 1e-5], [128, 256, 128, 256, 32, 8, [3], "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
[129, 256, 129, 256, 32, 8, [3], "piecewise linear", "individual", 1, "equiangular", "equiangular", True, 1e-5], [129, 256, 129, 256, 32, 8, [3], "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 8, [3, 2], "piecewise linear", "individual", 1, "equiangular", "equiangular", True, 1e-5], [128, 256, 128, 256, 32, 8, [3, 2], "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
[64, 128, 128, 256, 32, 8, [3], "piecewise linear", "individual", 1, "equiangular", "equiangular", True, 1e-5], [64, 128, 128, 256, 32, 8, [3], "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 8, [3], "piecewise linear", "individual", 2, "equiangular", "equiangular", True, 1e-5], [128, 256, 128, 256, 32, 8, [3], "piecewise linear", "mean", 2, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 6, [3], "piecewise linear", "individual", 1, "equiangular", "equiangular", True, 1e-5], [128, 256, 128, 256, 32, 6, [3], "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
] ]
) )
def test_distributed_disco_conv( def test_distributed_disco_conv(
......
...@@ -112,11 +112,12 @@ def _precompute_distributed_convolution_tensor_s2( ...@@ -112,11 +112,12 @@ def _precompute_distributed_convolution_tensor_s2(
# It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0 # It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1] lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1]
# compute quadrature weights that will be merged into the Psi tensor # compute quadrature weights and merge them into the convolution tensor.
# These quadrature integrate to 1 over the sphere.
if transpose_normalization: if transpose_normalization:
quad_weights = 2.0 * torch.pi * torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in quad_weights = torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in / 2.0
else: else:
quad_weights = 2.0 * torch.pi * torch.from_numpy(win).float().reshape(-1, 1) / nlon_in quad_weights = torch.from_numpy(win).float().reshape(-1, 1) / nlon_in / 2.0
out_idx = [] out_idx = []
out_vals = [] out_vals = []
...@@ -129,11 +130,11 @@ def _precompute_distributed_convolution_tensor_s2( ...@@ -129,11 +130,11 @@ def _precompute_distributed_convolution_tensor_s2(
# compute cartesian coordinates of the rotated position # compute cartesian coordinates of the rotated position
# This uses the YZY convention of Euler angles, where the last angle (alpha) is a passive rotation, # This uses the YZY convention of Euler angles, where the last angle (alpha) is a passive rotation,
# and therefore applied with a negative sign # and therefore applied with a negative sign
z = -torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(alpha) * torch.cos(gamma)
x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(gamma) * torch.sin(alpha) x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(gamma) * torch.sin(alpha)
y = torch.sin(beta) * torch.sin(gamma) y = torch.sin(beta) * torch.sin(gamma)
z = -torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(alpha) * torch.cos(gamma)
# normalization is emportant to avoid NaNs when arccos and atan are applied # normalization is important to avoid NaNs when arccos and atan are applied
# this can otherwise lead to spurious artifacts in the solution # this can otherwise lead to spurious artifacts in the solution
norm = torch.sqrt(x * x + y * y + z * z) norm = torch.sqrt(x * x + y * y + z * z)
x = x / norm x = x / norm
...@@ -270,6 +271,7 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv): ...@@ -270,6 +271,7 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
roff_idx = preprocess_psi(self.kernel_size, self.nlat_out_local, ker_idx, row_idx, col_idx, vals).contiguous() roff_idx = preprocess_psi(self.kernel_size, self.nlat_out_local, ker_idx, row_idx, col_idx, vals).contiguous()
self.register_buffer("psi_roff_idx", roff_idx, persistent=False) self.register_buffer("psi_roff_idx", roff_idx, persistent=False)
# save all datastructures
self.register_buffer("psi_ker_idx", ker_idx, persistent=False) self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
self.register_buffer("psi_row_idx", row_idx, persistent=False) self.register_buffer("psi_row_idx", row_idx, persistent=False)
self.register_buffer("psi_col_idx", col_idx, persistent=False) self.register_buffer("psi_col_idx", col_idx, persistent=False)
...@@ -412,6 +414,7 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -412,6 +414,7 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
roff_idx = preprocess_psi(self.kernel_size, self.nlat_in_local, ker_idx, row_idx, col_idx, vals).contiguous() roff_idx = preprocess_psi(self.kernel_size, self.nlat_in_local, ker_idx, row_idx, col_idx, vals).contiguous()
self.register_buffer("psi_roff_idx", roff_idx, persistent=False) self.register_buffer("psi_roff_idx", roff_idx, persistent=False)
# save all datastructures
self.register_buffer("psi_ker_idx", ker_idx, persistent=False) self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
self.register_buffer("psi_row_idx", row_idx, persistent=False) self.register_buffer("psi_row_idx", row_idx, persistent=False)
self.register_buffer("psi_col_idx", col_idx, persistent=False) self.register_buffer("psi_col_idx", col_idx, persistent=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment