Unverified Commit a0227d4c authored by Boris Bonev's avatar Boris Bonev Committed by GitHub
Browse files

Bbonev/discrete continuous convolutions (#25)



* Fixing the precomputation of the Psi matrix

* moving tests to a parent folder

* removing tqdm

* deleting deprecated distributed tests

* Moving distributed test

* adapting workfolow

* Added some more comments to make the code more understandable

* more detailed explanation for the derivation of the rotation angles

* added another comment

---------
Co-authored-by: default avatarThorsten Kurth <tkurth@nvidia.com>
parent ad927429
...@@ -23,4 +23,4 @@ jobs: ...@@ -23,4 +23,4 @@ jobs:
- name: Test with pytest - name: Test with pytest
run: | run: |
python -m pip install pytest pytest-cov parameterized python -m pip install pytest pytest-cov parameterized
python -m pytest ./torch_harmonics/tests.py python -m pytest ./tests/test_sht.py
\ No newline at end of file \ No newline at end of file
...@@ -37,13 +37,6 @@ import torch ...@@ -37,13 +37,6 @@ import torch
from torch.autograd import gradcheck from torch.autograd import gradcheck
from torch_harmonics import * from torch_harmonics import *
# try:
# from tqdm import tqdm
# except:
# tqdm = lambda x : x
tqdm = lambda x : x
class TestLegendrePolynomials(unittest.TestCase): class TestLegendrePolynomials(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -124,7 +117,7 @@ class TestSphericalHarmonicTransform(unittest.TestCase): ...@@ -124,7 +117,7 @@ class TestSphericalHarmonicTransform(unittest.TestCase):
base = signal base = signal
for _ in tqdm(range(iter)): for _ in range(iter):
base = isht(sht(base)) base = isht(sht(base))
err = torch.mean(torch.norm(base-signal, p='fro', dim=(-1,-2)) / torch.norm(signal, p='fro', dim=(-1,-2)) ) err = torch.mean(torch.norm(base-signal, p='fro', dim=(-1,-2)) / torch.norm(signal, p='fro', dim=(-1,-2)) )
......
...@@ -380,7 +380,7 @@ def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: in ...@@ -380,7 +380,7 @@ def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: in
pscale = nlon_in // nlon_out pscale = nlon_in // nlon_out
# add a dummy dimension for nkernel # add a dummy dimension for nkernel and move the batch and channel dims to the end
x = x.reshape(1, batch_size * n_chans, nlat_in, nlon_in).permute(0, 2, 3, 1) x = x.reshape(1, batch_size * n_chans, nlat_in, nlon_in).permute(0, 2, 3, 1)
x = x.expand(kernel_size, -1, -1, -1) x = x.expand(kernel_size, -1, -1, -1)
......
...@@ -71,7 +71,17 @@ def _precompute_convolution_tensor( ...@@ -71,7 +71,17 @@ def _precompute_convolution_tensor(
""" """
Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$. Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al. Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al.
The output tensor has shape kernel_shape x nlat_out x (nlat_in * nlon_in) The output tensor has shape kernel_shape x nlat_out x (nlat_in * nlon_in).
The rotation of the Euler angles uses the YZY convention, which applied to the northpole $(0,0,1)^T$ yields
$$
Y(\alpha) Z(\beta) Y(\gamma) n =
{\begin{bmatrix}
\cos(\gamma)\sin(\alpha) + \cos(\alpha)\cos(\beta)\sin(\gamma) \\
\sin(\beta)\sin(\gamma) \\
\cos(\alpha)\cos(\gamma)-\cos(\beta)\sin(\alpha)\sin(\gamma)
\end{bmatrix}}
$$
""" """
assert len(in_shape) == 2 assert len(in_shape) == 2
...@@ -95,21 +105,31 @@ def _precompute_convolution_tensor( ...@@ -95,21 +105,31 @@ def _precompute_convolution_tensor(
out_vals = torch.empty([0], dtype=torch.long) out_vals = torch.empty([0], dtype=torch.long)
# compute the phi differences # compute the phi differences
phis = torch.linspace(0, 2 * math.pi, nlon_in) # It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
lons_in = torch.linspace(0, 2*math.pi, nlon_in+1)[:-1]
for t in range(nlat_out): for t in range(nlat_out):
alpha = -lats_in.reshape(-1, 1) # the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis
beta = phis alpha = - lats_out[t]
gamma = lats_out[t] beta = lons_in
gamma = lats_in.reshape(-1, 1)
# compute latitude of the rotated position
z = torch.cos(alpha) * torch.cos(gamma) - torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma)
z = torch.clamp(z, min=-1.0, max=1.0)
theta = torch.arccos(z)
# compute cartesian coordinates of the rotated position # compute cartesian coordinates of the rotated position
x = torch.cos(beta) * torch.sin(alpha) + torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) # This uses the YZY convention of Euler angles, where the last angle (alpha) is a passive rotation,
# and therefore applied with a negative sign
z = - torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(alpha) * torch.cos(gamma)
x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(gamma) * torch.sin(alpha)
y = torch.sin(beta) * torch.sin(gamma) y = torch.sin(beta) * torch.sin(gamma)
# normalization is emportant to avoid NaNs when arccos and atan are applied
# this can otherwise lead to spurious artifacts in the solution
norm = torch.sqrt(x*x + y*y + z*z)
x = x / norm
y = y / norm
z = z / norm
# compute spherical coordinates
theta = torch.arccos(z)
phi = torch.arctan2(y, x) phi = torch.arctan2(y, x)
# find the indices where the rotated position falls into the support of the kernel # find the indices where the rotated position falls into the support of the kernel
...@@ -159,7 +179,7 @@ class DiscreteContinuousConvS2(nn.Module): ...@@ -159,7 +179,7 @@ class DiscreteContinuousConvS2(nn.Module):
for kdim in kernel_shape: for kdim in kernel_shape:
self.kernel_size *= kdim self.kernel_size *= kdim
# bandlimit # compute theta cutoff based on the bandlimit of the input field
if theta_cutoff is None: if theta_cutoff is None:
theta_cutoff = (kernel_shape[0]+1) * torch.pi / float(self.nlat_in - 1) theta_cutoff = (kernel_shape[0]+1) * torch.pi / float(self.nlat_in - 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment