Unverified Commit 6730e5c1 authored by Thorsten Kurth's avatar Thorsten Kurth Committed by GitHub
Browse files

Tkurth/torchification (#66)

* adding caching

* replacing many numpy calls with torch calls

* bumping up version number to 0.7.6
parent 780fd143
......@@ -2,7 +2,16 @@
## Versioning
### v0.7.6
* Adding cache for precomoputed tensors such as weight tensors for DISCO and SHT
* Cache is returning copies of tensors and not references. Users are still encouraged to re-use
those tensors manually in their models because this will also save memory. However,
the cache will help with model setup speed.
* Adding test which ensures that cache is working correctly
### v0.7.5
* New normalization mode `support` for DISCO convolutions
* More efficient computation of Morlet filter basis
* Changed default for Morlet filter basis to a Hann window function
......
This source diff could not be displayed because it is too large. You can view the blob instead.
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import unittest
from parameterized import parameterized
import math
import torch
class TestCacheConsistency(unittest.TestCase):
def test_consistency(self, verbose=False):
if verbose:
print("Testing that cache values does not get modified externally")
from torch_harmonics.legendre import _precompute_legpoly
with torch.no_grad():
cost = torch.cos(torch.linspace(0.0, 2.0 * math.pi, 10, dtype=torch.float64))
leg1 = _precompute_legpoly(10, 10, cost)
# perform in-place modification of leg1
leg1 *= -1.0
leg2 = _precompute_legpoly(10, 10, cost)
self.assertFalse(torch.allclose(leg1, leg2))
if __name__ == "__main__":
unittest.main()
......@@ -38,7 +38,7 @@ import torch
from torch.autograd import gradcheck
from torch_harmonics import *
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes
def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalization=False, basis_norm_mode="none", merge_quadrature=False, eps=1e-9):
......@@ -106,22 +106,20 @@ def _precompute_convolution_tensor_dense(
nlat_out, nlon_out = out_shape
lats_in, win = quadrature._precompute_latitudes(nlat_in, grid=grid_in)
lats_in = torch.from_numpy(lats_in)
lats_out, wout = quadrature._precompute_latitudes(nlat_out, grid=grid_out)
lats_out = torch.from_numpy(lats_out)
# compute the phi differences. We need to make the linspace exclusive to not double the last point
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1, dtype=torch.float64)[:-1]
lons_out = torch.linspace(0, 2 * math.pi, nlon_out + 1, dtype=torch.float64)[:-1]
# compute the phi differences.
lons_in = _precompute_longitudes(nlon_in)
lons_out = _precompute_longitudes(nlon_out)
# effective theta cutoff if multiplied with a fudge factor to avoid aliasing with grid width (especially near poles)
theta_cutoff_eff = (1.0 + theta_eps) * theta_cutoff
# compute quadrature weights that will be merged into the Psi tensor
if transpose_normalization:
quad_weights = torch.from_numpy(wout).reshape(-1, 1) / nlon_in / 2.0
quad_weights = wout.reshape(-1, 1) / nlon_in / 2.0
else:
quad_weights = torch.from_numpy(win).reshape(-1, 1) / nlon_in / 2.0
quad_weights = win.reshape(-1, 1) / nlon_in / 2.0
# array for accumulating non-zero indices
out = torch.zeros(kernel_size, nlat_out, nlon_out, nlat_in, nlon_in, dtype=torch.float64)
......@@ -172,34 +170,32 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
else:
self.device = torch.device("cpu")
self.device = torch.device("cpu")
@parameterized.expand(
[
# regular convolution
[8, 4, 2, (16, 32), (16, 32), [3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), [3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (24, 48), (12, 24), [3, 3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (24, 48), (12, 24), [4, 3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (24, 48), (12, 24), [2, 2], "morlet", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (24, 48), (12, 24), [3], "zernike", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (16, 24), (8, 8), [3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (18, 36), (6, 12), [7], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), [5], "piecewise linear", "mean", "equiangular", "legendre-gauss", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), [5], "piecewise linear", "mean", "legendre-gauss", "equiangular", False, 1e-4],
[8, 4, 2, (16, 32), (8, 16), [5], "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", False, 1e-4],
[8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
[8, 4, 2, (16, 32), (8, 16), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
[8, 4, 2, (24, 48), (12, 24), (3, 3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
[8, 4, 2, (24, 48), (12, 24), (4, 3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
[8, 4, 2, (24, 48), (12, 24), (2, 2), "morlet", "mean", "equiangular", "equiangular", False, 1e-4, False],
[8, 4, 2, (24, 48), (12, 24), (3), "zernike", "mean", "equiangular", "equiangular", False, 1e-4, False],
[8, 4, 2, (16, 24), (8, 8), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
[8, 4, 2, (18, 36), (6, 12), (7), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
[8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "equiangular", "legendre-gauss", False, 1e-4, False],
[8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "legendre-gauss", "equiangular", False, 1e-4, False],
[8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", False, 1e-4, False],
# transpose convolution
[8, 4, 2, (16, 32), (16, 32), [3], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), [5], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (12, 24), (24, 48), [3, 3], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (12, 24), (24, 48), [4, 3], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (12, 24), (24, 48), [2, 2], "morlet", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (12, 24), (24, 48), [3], "zernike", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (8, 8), (16, 24), [3], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (6, 12), (18, 36), [7], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), [5], "piecewise linear", "mean", "equiangular", "legendre-gauss", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), [5], "piecewise linear", "mean", "legendre-gauss", "equiangular", True, 1e-4],
[8, 4, 2, (8, 16), (16, 32), [5], "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", True, 1e-4],
[8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
[8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
[8, 4, 2, (12, 24), (24, 48), (3, 3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
[8, 4, 2, (12, 24), (24, 48), (4, 3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
[8, 4, 2, (12, 24), (24, 48), (2, 2), "morlet", "mean", "equiangular", "equiangular", True, 1e-4, False],
[8, 4, 2, (12, 24), (24, 48), (3), "zernike", "mean", "equiangular", "equiangular", True, 1e-4, False],
[8, 4, 2, (8, 8), (16, 24), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
[8, 4, 2, (6, 12), (18, 36), (7), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
[8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "equiangular", "legendre-gauss", True, 1e-4, False],
[8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "equiangular", True, 1e-4, False],
[8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", True, 1e-4, False],
]
)
def test_disco_convolution(
......@@ -216,11 +212,19 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
grid_out,
transpose,
tol,
verbose,
):
if verbose:
print(f"Testing DISCO convolution on {in_shape[0]}x{in_shape[1]} {grid_in} grid to {out_shape[0]}x{out_shape[1]} {grid_out} grid on {self.device.type} device")
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
theta_cutoff = (kernel_shape[0] + 1) * torch.pi / float(nlat_in - 1)
if isinstance(kernel_shape, int):
theta_cutoff = (kernel_shape + 1) * torch.pi / float(nlat_in - 1)
else:
theta_cutoff = (kernel_shape[0] + 1) * torch.pi / float(nlat_in - 1)
Conv = DiscreteContinuousConvTransposeS2 if transpose else DiscreteContinuousConvS2
conv = Conv(
......
......@@ -183,18 +183,18 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
@parameterized.expand(
[
[128, 256, 128, 256, 32, 8, [3], "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[129, 256, 128, 256, 32, 8, [3], "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, [3, 2], "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 64, 128, 32, 8, [3], "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, [3], "piecewise linear", "mean", 2, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 6, [3], "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, [3], "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
[129, 256, 129, 256, 32, 8, [3], "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 8, [3, 2], "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
[64, 128, 128, 256, 32, 8, [3], "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 8, [3], "piecewise linear", "mean", 2, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 6, [3], "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[129, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, (3, 2), "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 64, 128, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 2, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 6, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
[129, 256, 129, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 8, (3, 2), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
[64, 128, 128, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 2, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 6, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
]
)
def test_distributed_disco_conv(
......
......@@ -183,14 +183,14 @@ class TestDistributedResampling(unittest.TestCase):
@parameterized.expand(
[
[64, 128, 128, 256, 32, 8, "equiangular", "equiangular", "bilinear", 1e-7],
[128, 256, 64, 128, 32, 8, "equiangular", "equiangular", "bilinear", 1e-7],
[64, 128, 128, 256, 32, 8, "equiangular", "equiangular", "bilinear-spherical", 1e-7],
[128, 256, 64, 128, 32, 8, "equiangular", "equiangular", "bilinear-spherical", 1e-7],
[64, 128, 128, 256, 32, 8, "equiangular", "equiangular", "bilinear", 1e-7, False],
[128, 256, 64, 128, 32, 8, "equiangular", "equiangular", "bilinear", 1e-7, False],
[64, 128, 128, 256, 32, 8, "equiangular", "equiangular", "bilinear-spherical", 1e-7, False],
[128, 256, 64, 128, 32, 8, "equiangular", "equiangular", "bilinear-spherical", 1e-7, False],
]
)
def test_distributed_resampling(
self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, grid_in, grid_out, mode, tol
self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, grid_in, grid_out, mode, tol, verbose
):
B, C, H, W = batch_size, num_chan, nlat_in, nlon_in
......@@ -248,7 +248,7 @@ class TestDistributedResampling(unittest.TestCase):
with torch.no_grad():
out_gather_full = self._gather_helper_fwd(out_local, B, C, res_dist)
err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2)))
if self.world_rank == 0:
if verbose and (self.world_rank == )0:
print(f"final relative error of output: {err.item()}")
self.assertTrue(err.item() <= tol)
......@@ -259,7 +259,7 @@ class TestDistributedResampling(unittest.TestCase):
igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, res_dist)
err = torch.mean(torch.norm(igrad_full - igrad_gather_full, p="fro", dim=(-1, -2)) / torch.norm(igrad_full, p="fro", dim=(-1, -2)))
if self.world_rank == 0:
if verbose and (self.world_rank == 0):
print(f"final relative error of gradients: {err.item()}")
self.assertTrue(err.item() <= tol)
......
......@@ -32,7 +32,6 @@
import unittest
from parameterized import parameterized
import math
import numpy as np
import torch
from torch.autograd import gradcheck
from torch_harmonics import *
......@@ -41,31 +40,32 @@ from torch_harmonics import *
class TestLegendrePolynomials(unittest.TestCase):
def setUp(self):
self.cml = lambda m, l: np.sqrt((2 * l + 1) / 4 / np.pi) * np.sqrt(math.factorial(l - m) / math.factorial(l + m))
self.cml = lambda m, l: math.sqrt((2 * l + 1) / 4 / math.pi) * math.sqrt(math.factorial(l - m) / math.factorial(l + m))
self.pml = dict()
# preparing associated Legendre Polynomials (These include the Condon-Shortley phase)
# for reference see e.g. https://en.wikipedia.org/wiki/Associated_Legendre_polynomials
self.pml[(0, 0)] = lambda x: np.ones_like(x)
self.pml[(0, 0)] = lambda x: torch.ones_like(x)
self.pml[(0, 1)] = lambda x: x
self.pml[(1, 1)] = lambda x: -np.sqrt(1.0 - x**2)
self.pml[(1, 1)] = lambda x: -torch.sqrt(1.0 - x**2)
self.pml[(0, 2)] = lambda x: 0.5 * (3 * x**2 - 1)
self.pml[(1, 2)] = lambda x: -3 * x * np.sqrt(1.0 - x**2)
self.pml[(1, 2)] = lambda x: -3 * x * torch.sqrt(1.0 - x**2)
self.pml[(2, 2)] = lambda x: 3 * (1 - x**2)
self.pml[(0, 3)] = lambda x: 0.5 * (5 * x**3 - 3 * x)
self.pml[(1, 3)] = lambda x: 1.5 * (1 - 5 * x**2) * np.sqrt(1.0 - x**2)
self.pml[(1, 3)] = lambda x: 1.5 * (1 - 5 * x**2) * torch.sqrt(1.0 - x**2)
self.pml[(2, 3)] = lambda x: 15 * x * (1 - x**2)
self.pml[(3, 3)] = lambda x: -15 * np.sqrt(1.0 - x**2) ** 3
self.pml[(3, 3)] = lambda x: -15 * torch.sqrt(1.0 - x**2) ** 3
self.lmax = self.mmax = 4
self.tol = 1e-9
def test_legendre(self):
print("Testing computation of associated Legendre polynomials")
def test_legendre(self, verbose=False):
if verbose:
print("Testing computation of associated Legendre polynomials")
from torch_harmonics.legendre import legpoly
t = np.linspace(0, 1, 100)
t = torch.linspace(0, 1, 100, dtype=torch.float64)
vdm = legpoly(self.mmax, self.lmax, t)
for l in range(self.lmax):
......@@ -79,24 +79,23 @@ class TestSphericalHarmonicTransform(unittest.TestCase):
def setUp(self):
if torch.cuda.is_available():
print("Running test on GPU")
self.device = torch.device("cuda")
else:
print("Running test on CPU")
self.device = torch.device("cpu")
@parameterized.expand(
[
[256, 512, 32, "ortho", "equiangular", 1e-9],
[256, 512, 32, "ortho", "legendre-gauss", 1e-9],
[256, 512, 32, "four-pi", "equiangular", 1e-9],
[256, 512, 32, "four-pi", "legendre-gauss", 1e-9],
[256, 512, 32, "schmidt", "equiangular", 1e-9],
[256, 512, 32, "schmidt", "legendre-gauss", 1e-9],
[256, 512, 32, "ortho", "equiangular", 1e-9, False],
[256, 512, 32, "ortho", "legendre-gauss", 1e-9, False],
[256, 512, 32, "four-pi", "equiangular", 1e-9, False],
[256, 512, 32, "four-pi", "legendre-gauss", 1e-9, False],
[256, 512, 32, "schmidt", "equiangular", 1e-9, False],
[256, 512, 32, "schmidt", "legendre-gauss", 1e-9, False],
]
)
def test_sht(self, nlat, nlon, batch_size, norm, grid, tol):
print(f"Testing real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization")
def test_sht(self, nlat, nlon, batch_size, norm, grid, tol, verbose):
if verbose:
print(f"Testing real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization on {self.device.type} device")
testiters = [1, 2, 4, 8, 16]
if grid == "equiangular":
......@@ -116,7 +115,8 @@ class TestSphericalHarmonicTransform(unittest.TestCase):
# testing error accumulation
for iter in testiters:
with self.subTest(i=iter):
print(f"{iter} iterations of batchsize {batch_size}:")
if verbose:
print(f"{iter} iterations of batchsize {batch_size}:")
base = signal
......@@ -124,27 +124,29 @@ class TestSphericalHarmonicTransform(unittest.TestCase):
base = isht(sht(base))
err = torch.mean(torch.norm(base - signal, p="fro", dim=(-1, -2)) / torch.norm(signal, p="fro", dim=(-1, -2)))
print(f"final relative error: {err.item()}")
if verbose:
print(f"final relative error: {err.item()}")
self.assertTrue(err.item() <= tol)
@parameterized.expand(
[
[12, 24, 2, "ortho", "equiangular", 1e-5],
[12, 24, 2, "ortho", "legendre-gauss", 1e-5],
[12, 24, 2, "four-pi", "equiangular", 1e-5],
[12, 24, 2, "four-pi", "legendre-gauss", 1e-5],
[12, 24, 2, "schmidt", "equiangular", 1e-5],
[12, 24, 2, "schmidt", "legendre-gauss", 1e-5],
[15, 30, 2, "ortho", "equiangular", 1e-5],
[15, 30, 2, "ortho", "legendre-gauss", 1e-5],
[15, 30, 2, "four-pi", "equiangular", 1e-5],
[15, 30, 2, "four-pi", "legendre-gauss", 1e-5],
[15, 30, 2, "schmidt", "equiangular", 1e-5],
[15, 30, 2, "schmidt", "legendre-gauss", 1e-5],
[12, 24, 2, "ortho", "equiangular", 1e-5, False],
[12, 24, 2, "ortho", "legendre-gauss", 1e-5, False],
[12, 24, 2, "four-pi", "equiangular", 1e-5, False],
[12, 24, 2, "four-pi", "legendre-gauss", 1e-5, False],
[12, 24, 2, "schmidt", "equiangular", 1e-5, False],
[12, 24, 2, "schmidt", "legendre-gauss", 1e-5, False],
[15, 30, 2, "ortho", "equiangular", 1e-5, False],
[15, 30, 2, "ortho", "legendre-gauss", 1e-5, False],
[15, 30, 2, "four-pi", "equiangular", 1e-5, False],
[15, 30, 2, "four-pi", "legendre-gauss", 1e-5, False],
[15, 30, 2, "schmidt", "equiangular", 1e-5, False],
[15, 30, 2, "schmidt", "legendre-gauss", 1e-5, False],
]
)
def test_sht_grads(self, nlat, nlon, batch_size, norm, grid, tol):
print(f"Testing gradients of real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization")
def test_sht_grads(self, nlat, nlon, batch_size, norm, grid, tol, verbose):
if verbose:
print(f"Testing gradients of real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization")
if grid == "equiangular":
mmax = nlat // 2
......
......@@ -29,7 +29,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
__version__ = "0.7.5a"
__version__ = "0.7.6"
from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT
from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
......
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import functools
from copy import deepcopy
# copying LRU cache decorator a la:
# https://stackoverflow.com/questions/54909357/how-to-get-functools-lru-cache-to-return-new-instances
def lru_cache(maxsize=20, typed=False, copy=False):
def decorator(f):
cached_func = functools.lru_cache(maxsize=maxsize, typed=typed)(f)
def wrapper(*args, **kwargs):
res = cached_func(*args, **kwargs)
if copy:
return deepcopy(res)
else:
return res
return wrapper
return decorator
......@@ -40,10 +40,11 @@ import torch.nn as nn
from functools import partial
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes
from torch_harmonics.cache import lru_cache
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes
from torch_harmonics._disco_convolution import _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch
from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda
from torch_harmonics.filter_basis import get_filter_basis
from torch_harmonics.filter_basis import FilterBasis, get_filter_basis
# import custom C++/CUDA extensions if available
try:
......@@ -134,17 +135,18 @@ def _normalize_convolution_tensor_s2(
return psi_vals
@lru_cache(typed=True, copy=True)
def _precompute_convolution_tensor_s2(
in_shape,
out_shape,
filter_basis,
grid_in="equiangular",
grid_out="equiangular",
theta_cutoff=0.01 * math.pi,
theta_eps = 1e-3,
transpose_normalization=False,
basis_norm_mode="mean",
merge_quadrature=False,
in_shape: Tuple[int],
out_shape: Tuple[int],
filter_basis: FilterBasis,
grid_in: Optional[str]="equiangular",
grid_out: Optional[str]="equiangular",
theta_cutoff: Optional[float]=0.01 * math.pi,
theta_eps: Optional[float]=1e-3,
transpose_normalization: Optional[bool]=False,
basis_norm_mode: Optional[str]="mean",
merge_quadrature: Optional[bool]=False,
):
"""
Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
......@@ -172,20 +174,18 @@ def _precompute_convolution_tensor_s2(
# precompute input and output grids
lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in)
lats_in = torch.from_numpy(lats_in)
lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out)
lats_out = torch.from_numpy(lats_out)
# compute the phi differences
# It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1, dtype=torch.float64)[:-1]
lons_in = _precompute_longitudes(nlon_in)
# compute quadrature weights and merge them into the convolution tensor.
# These quadrature integrate to 1 over the sphere.
if transpose_normalization:
quad_weights = torch.from_numpy(wout).reshape(-1, 1) / nlon_in / 2.0
quad_weights = wout.reshape(-1, 1) / nlon_in / 2.0
else:
quad_weights = torch.from_numpy(win).reshape(-1, 1) / nlon_in / 2.0
quad_weights = win.reshape(-1, 1) / nlon_in / 2.0
# effective theta cutoff if multiplied with a fudge factor to avoid aliasing with grid width (especially near poles)
theta_cutoff_eff = (1.0 + theta_eps) * theta_cutoff
......@@ -258,7 +258,7 @@ class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta):
self,
in_channels: int,
out_channels: int,
kernel_shape: Union[int, List[int]],
kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
basis_type: Optional[str] = "piecewise linear",
groups: Optional[int] = 1,
bias: Optional[bool] = True,
......@@ -309,7 +309,7 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
out_channels: int,
in_shape: Tuple[int],
out_shape: Tuple[int],
kernel_shape: Union[int, List[int]],
kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
basis_type: Optional[str] = "piecewise linear",
basis_norm_mode: Optional[str] = "mean",
groups: Optional[int] = 1,
......@@ -415,7 +415,7 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
out_channels: int,
in_shape: Tuple[int],
out_shape: Tuple[int],
kernel_shape: Union[int, List[int]],
kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
basis_type: Optional[str] = "piecewise linear",
basis_norm_mode: Optional[str] = "mean",
groups: Optional[int] = 1,
......
......@@ -41,7 +41,7 @@ import torch.nn as nn
from functools import partial
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes
from torch_harmonics._disco_convolution import _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch
from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda
from torch_harmonics.filter_basis import get_filter_basis
......@@ -106,20 +106,18 @@ def _precompute_distributed_convolution_tensor_s2(
# precompute input and output grids
lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in)
lats_in = torch.from_numpy(lats_in)
lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out)
lats_out = torch.from_numpy(lats_out)
# compute the phi differences
# It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1, dtype=torch.float64)[:-1]
lons_in = _precompute_longitudes(nlon_in)
# compute quadrature weights and merge them into the convolution tensor.
# These quadrature integrate to 1 over the sphere.
if transpose_normalization:
quad_weights = torch.from_numpy(wout).reshape(-1, 1) / nlon_in / 2.0
quad_weights = wout.reshape(-1, 1) / nlon_in / 2.0
else:
quad_weights = torch.from_numpy(win).reshape(-1, 1) / nlon_in / 2.0
quad_weights = win.reshape(-1, 1) / nlon_in / 2.0
# effective theta cutoff if multiplied with a fudge factor to avoid aliasing with grid width (especially near poles)
theta_cutoff_eff = (1.0 + theta_eps) * theta_cutoff
......@@ -215,7 +213,7 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
out_channels: int,
in_shape: Tuple[int],
out_shape: Tuple[int],
kernel_shape: Union[int, List[int]],
kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
basis_type: Optional[str] = "piecewise linear",
basis_norm_mode: Optional[str] = "mean",
groups: Optional[int] = 1,
......@@ -356,7 +354,7 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
out_channels: int,
in_shape: Tuple[int],
out_shape: Tuple[int],
kernel_shape: Union[int, List[int]],
kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
basis_type: Optional[str] = "piecewise linear",
basis_norm_mode: Optional[str] = "mean",
groups: Optional[int] = 1,
......
......@@ -31,12 +31,11 @@
from typing import List, Tuple, Union, Optional
import math
import numpy as np
import torch
import torch.nn as nn
from torch_harmonics.quadrature import _precompute_latitudes
from torch_harmonics.quadrature import _precompute_latitudes, _precompute_longitudes
from torch_harmonics.distributed import polar_group_size, azimuth_group_size, distributed_transpose_azimuth, distributed_transpose_polar
from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank
from torch_harmonics.distributed import compute_split_shapes
......@@ -82,54 +81,52 @@ class DistributedResampleS2(nn.Module):
# for upscaling the latitudes we will use interpolation
self.lats_in, _ = _precompute_latitudes(nlat_in, grid=grid_in)
self.lons_in = np.linspace(0, 2 * math.pi, nlon_in, endpoint=False)
self.lons_in = _precompute_longitudes(nlon_in)
self.lats_out, _ = _precompute_latitudes(nlat_out, grid=grid_out)
self.lons_out = np.linspace(0, 2 * math.pi, nlon_out, endpoint=False)
self.lons_out = _precompute_longitudes(nlon_out)
# in the case where some points lie outside of the range spanned by lats_in,
# we need to expand the solution to the poles before interpolating
self.expand_poles = (self.lats_out > self.lats_in[-1]).any() or (self.lats_out < self.lats_in[0]).any()
if self.expand_poles:
self.lats_in = np.insert(self.lats_in, 0, 0.0)
self.lats_in = np.append(self.lats_in, np.pi)
self.lats_in = torch.cat([torch.tensor([0.], dtype=torch.float64),
self.lats_in,
torch.tensor([math.pi], dtype=torch.float64)]).contiguous()
#self.lats_in = np.insert(self.lats_in, 0, 0.0)
#self.lats_in = np.append(self.lats_in, np.pi)
# prepare the interpolation by computing indices to the left and right of each output latitude
lat_idx = np.searchsorted(self.lats_in, self.lats_out, side="right") - 1
lat_idx = torch.searchsorted(self.lats_in, self.lats_out, side="right") - 1
# make sure that we properly treat the last point if they coincide with the pole
lat_idx = np.where(self.lats_out == self.lats_in[-1], lat_idx - 1, lat_idx)
lat_idx = torch.where(self.lats_out == self.lats_in[-1], lat_idx - 1, lat_idx)
# lat_idx = np.where(self.lats_out > self.lats_in[-1], lat_idx - 1, lat_idx)
# lat_idx = np.where(self.lats_out < self.lats_in[0], 0, lat_idx)
# compute the interpolation weights along the latitude
lat_weights = torch.from_numpy((self.lats_out - self.lats_in[lat_idx]) / np.diff(self.lats_in)[lat_idx]).float()
lat_weights = ((self.lats_out - self.lats_in[lat_idx]) / torch.diff(self.lats_in)[lat_idx]).to(torch.float32)
lat_weights = lat_weights.unsqueeze(-1)
# convert to tensor
lat_idx = torch.LongTensor(lat_idx)
# register buffers
self.register_buffer("lat_idx", lat_idx, persistent=False)
self.register_buffer("lat_weights", lat_weights, persistent=False)
# get left and right indices but this time make sure periodicity in the longitude is handled
lon_idx_left = np.searchsorted(self.lons_in, self.lons_out, side="right") - 1
lon_idx_right = np.where(self.lons_out >= self.lons_in[-1], np.zeros_like(lon_idx_left), lon_idx_left + 1)
lon_idx_left = torch.searchsorted(self.lons_in, self.lons_out, side="right") - 1
lon_idx_right = torch.where(self.lons_out >= self.lons_in[-1], torch.zeros_like(lon_idx_left), lon_idx_left + 1)
# get the difference
diff = self.lons_in[lon_idx_right] - self.lons_in[lon_idx_left]
diff = np.where(diff < 0.0, diff + 2 * math.pi, diff)
lon_weights = torch.from_numpy((self.lons_out - self.lons_in[lon_idx_left]) / diff).float()
# convert to tensor
lon_idx_left = torch.LongTensor(lon_idx_left)
lon_idx_right = torch.LongTensor(lon_idx_right)
diff = torch.where(diff < 0.0, diff + 2 * math.pi, diff)
lon_weights = ((self.lons_out - self.lons_in[lon_idx_left]) / diff).to(torch.float32)
# register buffers
self.register_buffer("lon_idx_left", lon_idx_left, persistent=False)
self.register_buffer("lon_idx_right", lon_idx_right, persistent=False)
self.register_buffer("lon_weights", lon_weights, persistent=False)
self.skip_resampling = (nlon_in == nlon_out) and (nlat_in == nlat_out) and (grid_in == grid_out)
def extra_repr(self):
r"""
Pretty print module
......@@ -172,6 +169,9 @@ class DistributedResampleS2(nn.Module):
def forward(self, x: torch.Tensor):
if self.skip_resampling:
return x
# transpose data so that h is local, and channels are split
num_chans = x.shape[-3]
......
......@@ -30,7 +30,6 @@
#
import os
import numpy as np
import torch
import torch.nn as nn
import torch.fft
......@@ -75,13 +74,13 @@ class DistributedRealSHT(nn.Module):
# compute quadrature points
if self.grid == "legendre-gauss":
cost, w = legendre_gauss_weights(nlat, -1, 1)
cost, weights = legendre_gauss_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat
elif self.grid == "lobatto":
cost, w = lobatto_weights(nlat, -1, 1)
cost, weights = lobatto_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat-1
elif self.grid == "equiangular":
cost, w = clenshaw_curtiss_weights(nlat, -1, 1)
cost, weights = clenshaw_curtiss_weights(nlat, -1, 1)
# cost, w = fejer2_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat
else:
......@@ -94,7 +93,7 @@ class DistributedRealSHT(nn.Module):
self.comm_rank_azimuth = azimuth_group_rank()
# apply cosine transform and flip them
tq = np.flip(np.arccos(cost))
tq = torch.flip(torch.arccos(cost), dims=(0,))
# determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1
......@@ -106,13 +105,11 @@ class DistributedRealSHT(nn.Module):
self.m_shapes = compute_split_shapes(self.mmax, self.comm_size_azimuth)
# combine quadrature weights with the legendre weights
weights = torch.from_numpy(w)
pct = _precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
pct = torch.from_numpy(pct)
weights = torch.einsum('mlk,k->mlk', pct, weights)
# split weights
weights = split_tensor_along_dim(weights, dim=0, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth]
weights = split_tensor_along_dim(weights, dim=0, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth].contiguous()
# remember quadrature weights
self.register_buffer('weights', weights, persistent=False)
......@@ -208,7 +205,7 @@ class DistributedInverseRealSHT(nn.Module):
self.comm_rank_azimuth = azimuth_group_rank()
# apply cosine transform and flip them
t = np.flip(np.arccos(cost))
t = torch.flip(torch.arccos(cost), dims=(0,))
# determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1
......@@ -221,10 +218,9 @@ class DistributedInverseRealSHT(nn.Module):
# compute legende polynomials
pct = _precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
pct = torch.from_numpy(pct)
# split in m
pct = split_tensor_along_dim(pct, dim=0, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth]
pct = split_tensor_along_dim(pct, dim=0, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth].contiguous()
# register
self.register_buffer('pct', pct, persistent=False)
......@@ -308,13 +304,13 @@ class DistributedRealVectorSHT(nn.Module):
# compute quadrature points
if self.grid == "legendre-gauss":
cost, w = legendre_gauss_weights(nlat, -1, 1)
cost, weights = legendre_gauss_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat
elif self.grid == "lobatto":
cost, w = lobatto_weights(nlat, -1, 1)
cost, weights = lobatto_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat-1
elif self.grid == "equiangular":
cost, w = clenshaw_curtiss_weights(nlat, -1, 1)
cost, weights = clenshaw_curtiss_weights(nlat, -1, 1)
# cost, w = fejer2_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat
else:
......@@ -327,7 +323,7 @@ class DistributedRealVectorSHT(nn.Module):
self.comm_rank_azimuth = azimuth_group_rank()
# apply cosine transform and flip them
tq = np.flip(np.arccos(cost))
tq = torch.flip(torch.arccos(cost), dims=(0,))
# determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1
......@@ -339,9 +335,7 @@ class DistributedRealVectorSHT(nn.Module):
self.m_shapes = compute_split_shapes(self.mmax, self.comm_size_azimuth)
# compute weights
weights = torch.from_numpy(w)
dpct = _precompute_dlegpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
dpct = torch.from_numpy(dpct)
# combine integration weights, normalization factor in to one:
l = torch.arange(0, self.lmax)
......@@ -352,7 +346,7 @@ class DistributedRealVectorSHT(nn.Module):
weights[1] = -1 * weights[1]
# we need to split in m, pad before:
weights = split_tensor_along_dim(weights, dim=1, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth]
weights = split_tensor_along_dim(weights, dim=1, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth].contiguous()
# remember quadrature weights
self.register_buffer('weights', weights, persistent=False)
......@@ -461,7 +455,7 @@ class DistributedInverseRealVectorSHT(nn.Module):
self.comm_rank_azimuth = azimuth_group_rank()
# apply cosine transform and flip them
t = np.flip(np.arccos(cost))
t = torch.flip(torch.arccos(cost), dims=(0,))
# determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1
......@@ -474,10 +468,9 @@ class DistributedInverseRealVectorSHT(nn.Module):
# compute legende polynomials
dpct = _precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
dpct = torch.from_numpy(dpct)
# split in m
dpct = split_tensor_along_dim(dpct, dim=1, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth]
dpct = split_tensor_along_dim(dpct, dim=1, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth].contiguous()
# register buffer
self.register_buffer('dpct', dpct, persistent=False)
......
......@@ -33,7 +33,9 @@
import torch
import torch.nn as nn
import torch_harmonics as harmonics
from torch_harmonics.quadrature import _precompute_longitudes
import math
import numpy as np
......@@ -74,8 +76,8 @@ class SphereSolver(nn.Module):
cost, _ = harmonics.quadrature.clenshaw_curtiss_weights(self.nlat, -1, 1)
# apply cosine transform and flip them
lats = -torch.as_tensor(np.arcsin(cost))
lons = torch.linspace(0, 2*np.pi, self.nlon+1, dtype=torch.float64)[:nlon]
lats = -torch.arcsin(cost)
lons = _precompute_longitudes(self.nlon)
self.lmax = self.sht.lmax
self.mmax = self.sht.mmax
......@@ -162,8 +164,8 @@ class SphereSolver(nn.Module):
#ax = plt.gca(projection=proj, frameon=True)
ax = fig.add_subplot(projection=proj)
Lons = Lons*180/np.pi
Lats = Lats*180/np.pi
Lons = Lons*180/math.pi
Lats = Lats*180/math.pi
# contour data over the map.
im = ax.pcolormesh(Lons, Lats, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=antialiased, vmax=vmax, vmin=vmin)
......@@ -175,4 +177,4 @@ class SphereSolver(nn.Module):
return im
def plot_specdata(self, data, fig, **kwargs):
return self.plot_griddata(self.isht(data), fig, **kwargs)
\ No newline at end of file
return self.plot_griddata(self.isht(data), fig, **kwargs)
......@@ -35,6 +35,7 @@ import torch.nn as nn
import torch_harmonics as harmonics
from torch_harmonics.quadrature import *
import math
import numpy as np
......@@ -79,11 +80,11 @@ class ShallowWaterSolver(nn.Module):
elif self.grid == "equiangular":
cost, quad_weights = harmonics.quadrature.clenshaw_curtiss_weights(self.nlat, -1, 1)
quad_weights = torch.as_tensor(quad_weights).reshape(-1, 1)
quad_weights = quad_weights.reshape(-1, 1)
# apply cosine transform and flip them
lats = -torch.as_tensor(np.arcsin(cost))
lons = torch.linspace(0, 2*np.pi, self.nlon+1, dtype=torch.float64)[:nlon]
lats = -torch.arcsin(cost)
lons = _precompute_longitudes(self.nlon)
self.lmax = self.sht.lmax
self.mmax = self.sht.mmax
......@@ -360,8 +361,8 @@ class ShallowWaterSolver(nn.Module):
#ax = plt.gca(projection=proj, frameon=True)
ax = fig.add_subplot(projection=proj)
Lons = Lons*180/np.pi
Lats = Lats*180/np.pi
Lons = Lons*180/math.pi
Lats = Lats*180/math.pi
# contour data over the map.
im = ax.pcolormesh(Lons, Lats, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=antialiased, vmax=vmax, vmin=vmin)
......@@ -375,8 +376,8 @@ class ShallowWaterSolver(nn.Module):
#ax = plt.gca(projection=proj, frameon=True)
ax = fig.add_subplot(projection=proj)
Lons = Lons*180/np.pi
Lats = Lats*180/np.pi
Lons = Lons*180/math.pi
Lats = Lats*180/math.pi
# contour data over the map.
im = ax.pcolormesh(Lons, Lats, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=antialiased, vmax=vmax, vmin=vmin)
......
......@@ -30,23 +30,12 @@
#
import abc
from typing import List, Tuple, Union, Optional
from typing import Tuple, Union, Optional
import math
import torch
def get_filter_basis(kernel_shape: Union[int, List[int], Tuple[int, int]], basis_type: str):
"""Factory function to generate the appropriate filter basis"""
if basis_type == "piecewise linear":
return PiecewiseLinearFilterBasis(kernel_shape=kernel_shape)
elif basis_type == "morlet":
return MorletFilterBasis(kernel_shape=kernel_shape)
elif basis_type == "zernike":
return ZernikeFilterBasis(kernel_shape=kernel_shape)
else:
raise ValueError(f"Unknown basis_type {basis_type}")
from torch_harmonics.cache import lru_cache
def _circle_dist(x1: torch.Tensor, x2: torch.Tensor):
......@@ -71,7 +60,7 @@ class FilterBasis(metaclass=abc.ABCMeta):
def __init__(
self,
kernel_shape: Union[int, List[int], Tuple[int, int]],
kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
):
self.kernel_shape = kernel_shape
......@@ -96,6 +85,20 @@ class FilterBasis(metaclass=abc.ABCMeta):
raise NotImplementedError
@lru_cache(typed=True, copy=False)
def get_filter_basis(kernel_shape: Union[int, Tuple[int], Tuple[int, int]], basis_type: str) -> FilterBasis:
"""Factory function to generate the appropriate filter basis"""
if basis_type == "piecewise linear":
return PiecewiseLinearFilterBasis(kernel_shape=kernel_shape)
elif basis_type == "morlet":
return MorletFilterBasis(kernel_shape=kernel_shape)
elif basis_type == "zernike":
return ZernikeFilterBasis(kernel_shape=kernel_shape)
else:
raise ValueError(f"Unknown basis_type {basis_type}")
class PiecewiseLinearFilterBasis(FilterBasis):
"""
Tensor-product basis on a disk constructed from piecewise linear basis functions.
......@@ -103,7 +106,7 @@ class PiecewiseLinearFilterBasis(FilterBasis):
def __init__(
self,
kernel_shape: Union[int, List[int], Tuple[int, int]],
kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
):
if isinstance(kernel_shape, int):
......@@ -222,7 +225,7 @@ class MorletFilterBasis(FilterBasis):
def __init__(
self,
kernel_shape: Union[int, List[int], Tuple[int, int]],
kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
):
if isinstance(kernel_shape, int):
......@@ -280,7 +283,7 @@ class ZernikeFilterBasis(FilterBasis):
def __init__(
self,
kernel_shape: Union[int, Tuple[int], List[int]],
kernel_shape: Union[int, Tuple[int]],
):
if isinstance(kernel_shape, tuple) or isinstance(kernel_shape, list):
......
......@@ -29,15 +29,20 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import numpy as np
from typing import Optional
import math
import torch
def clm(l, m):
from torch_harmonics.cache import lru_cache
def clm(l: int, m: int) -> float:
"""
defines the normalization factor to orthonormalize the Spherical Harmonics
"""
return np.sqrt((2*l + 1) / 4 / np.pi) * np.sqrt(np.math.factorial(l-m) / np.math.factorial(l+m))
return math.sqrt((2*l + 1) / 4 / math.pi) * math.sqrt(math.factorial(l-m) / math.factorial(l+m))
def legpoly(mmax, lmax, x, norm="ortho", inverse=False, csphase=True):
def legpoly(mmax: int, lmax: int, x: torch.Tensor, norm: Optional[str]="ortho", inverse: Optional[bool]=False, csphase: Optional[bool]=True) -> torch.Tensor:
r"""
Computes the values of (-1)^m c^l_m P^l_m(x) at the positions specified by x.
The resulting tensor has shape (mmax, lmax, len(x)). The Condon-Shortley Phase (-1)^m
......@@ -52,31 +57,31 @@ def legpoly(mmax, lmax, x, norm="ortho", inverse=False, csphase=True):
# compute the tensor P^m_n:
nmax = max(mmax,lmax)
vdm = np.zeros((nmax, nmax, len(x)), dtype=np.float64)
vdm = torch.zeros((nmax, nmax, len(x)), dtype=torch.float64, requires_grad=False)
norm_factor = 1. if norm == "ortho" else np.sqrt(4 * np.pi)
norm_factor = 1. if norm == "ortho" else math.sqrt(4 * math.pi)
norm_factor = 1. / norm_factor if inverse else norm_factor
# initial values to start the recursion
vdm[0,0,:] = norm_factor / np.sqrt(4 * np.pi)
vdm[0,0,:] = norm_factor / math.sqrt(4 * math.pi)
# fill the diagonal and the lower diagonal
for l in range(1, nmax):
vdm[l-1, l, :] = np.sqrt(2*l + 1) * x * vdm[l-1, l-1, :]
vdm[l, l, :] = np.sqrt( (2*l + 1) * (1 + x) * (1 - x) / 2 / l ) * vdm[l-1, l-1, :]
vdm[l-1, l, :] = math.sqrt(2*l + 1) * x * vdm[l-1, l-1, :]
vdm[l, l, :] = torch.sqrt( (2*l + 1) * (1 + x) * (1 - x) / 2 / l ) * vdm[l-1, l-1, :]
# fill the remaining values on the upper triangle and multiply b
for l in range(2, nmax):
for m in range(0, l-1):
vdm[m, l, :] = x * np.sqrt((2*l - 1) / (l - m) * (2*l + 1) / (l + m)) * vdm[m, l-1, :] \
- np.sqrt((l + m - 1) / (l - m) * (2*l + 1) / (2*l - 3) * (l - m - 1) / (l + m)) * vdm[m, l-2, :]
vdm[m, l, :] = x * math.sqrt((2*l - 1) / (l - m) * (2*l + 1) / (l + m)) * vdm[m, l-1, :] \
- math.sqrt((l + m - 1) / (l - m) * (2*l + 1) / (2*l - 3) * (l - m - 1) / (l + m)) * vdm[m, l-2, :]
if norm == "schmidt":
for l in range(0, nmax):
if inverse:
vdm[:, l, : ] = vdm[:, l, : ] * np.sqrt(2*l + 1)
vdm[:, l, : ] = vdm[:, l, : ] * math.sqrt(2*l + 1)
else:
vdm[:, l, : ] = vdm[:, l, : ] / np.sqrt(2*l + 1)
vdm[:, l, : ] = vdm[:, l, : ] / math.sqrt(2*l + 1)
vdm = vdm[:mmax, :lmax]
......@@ -86,7 +91,9 @@ def legpoly(mmax, lmax, x, norm="ortho", inverse=False, csphase=True):
return vdm
def _precompute_legpoly(mmax, lmax, t, norm="ortho", inverse=False, csphase=True):
@lru_cache(typed=True, copy=True)
def _precompute_legpoly(mmax: int , lmax: int, t: torch.Tensor,
norm: Optional[str]="ortho", inverse: Optional[bool]=False, csphase: Optional[bool]=True) -> torch.Tensor:
r"""
Computes the values of (-1)^m c^l_m P^l_m(\cos \theta) at the positions specified by t (theta).
The resulting tensor has shape (mmax, lmax, len(x)). The Condon-Shortley Phase (-1)^m
......@@ -99,9 +106,11 @@ def _precompute_legpoly(mmax, lmax, t, norm="ortho", inverse=False, csphase=True
[3] Schrama, E.; Orbit integration based upon interpolated gravitational gradients
"""
return legpoly(mmax, lmax, np.cos(t), norm=norm, inverse=inverse, csphase=csphase)
return legpoly(mmax, lmax, torch.cos(t), norm=norm, inverse=inverse, csphase=csphase)
def _precompute_dlegpoly(mmax, lmax, t, norm="ortho", inverse=False, csphase=True):
@lru_cache(typed=True, copy=True)
def _precompute_dlegpoly(mmax: int, lmax: int, t: torch.Tensor,
norm: Optional[str]="ortho", inverse: Optional[bool]=False, csphase: Optional[bool]=True) -> torch.Tensor:
r"""
Computes the values of the derivatives $\frac{d}{d \theta} P^m_l(\cos \theta)$
at the positions specified by t (theta), as well as $\frac{1}{\sin \theta} P^m_l(\cos \theta)$,
......@@ -114,32 +123,32 @@ def _precompute_dlegpoly(mmax, lmax, t, norm="ortho", inverse=False, csphase=Tru
pct = _precompute_legpoly(mmax+1, lmax+1, t, norm=norm, inverse=inverse, csphase=False)
dpct = np.zeros((2, mmax, lmax, len(t)), dtype=np.float64)
dpct = torch.zeros((2, mmax, lmax, len(t)), dtype=torch.float64, requires_grad=False)
# fill the derivative terms wrt theta
for l in range(0, lmax):
# m = 0
dpct[0, 0, l] = - np.sqrt(l*(l+1)) * pct[1, l]
dpct[0, 0, l] = - math.sqrt(l*(l+1)) * pct[1, l]
# 0 < m < l
for m in range(1, min(l, mmax)):
dpct[0, m, l] = 0.5 * ( np.sqrt((l+m)*(l-m+1)) * pct[m-1, l] - np.sqrt((l-m)*(l+m+1)) * pct[m+1, l] )
dpct[0, m, l] = 0.5 * ( math.sqrt((l+m)*(l-m+1)) * pct[m-1, l] - math.sqrt((l-m)*(l+m+1)) * pct[m+1, l] )
# m == l
if mmax > l:
dpct[0, l, l] = np.sqrt(l/2) * pct[l-1, l]
dpct[0, l, l] = math.sqrt(l/2) * pct[l-1, l]
# fill the - 1j m P^m_l / sin(phi). as this component is purely imaginary,
# we won't store it explicitly in a complex array
for m in range(1, min(l+1, mmax)):
# this component is implicitly complex
# we do not divide by m here as this cancels with the derivative of the exponential
dpct[1, m, l] = 0.5 * np.sqrt((2*l+1)/(2*l+3)) * \
( np.sqrt((l-m+1)*(l-m+2)) * pct[m-1, l+1] + np.sqrt((l+m+1)*(l+m+2)) * pct[m+1, l+1] )
dpct[1, m, l] = 0.5 * math.sqrt((2*l+1)/(2*l+3)) * \
( math.sqrt((l-m+1)*(l-m+2)) * pct[m-1, l+1] + math.sqrt((l+m+1)*(l+m+2)) * pct[m+1, l+1] )
if csphase:
for m in range(1, mmax, 2):
dpct[:, m] *= -1
return dpct
\ No newline at end of file
return dpct
......@@ -29,10 +29,14 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
from typing import Tuple, Optional
from torch_harmonics.cache import lru_cache
import math
import numpy as np
import torch
def _precompute_grid(n, grid="equidistant", a=0.0, b=1.0, periodic=False):
def _precompute_grid(n: int, grid: Optional[str]="equidistant", a: Optional[float]=0.0, b: Optional[float]=1.0,
periodic: Optional[bool]=False) -> Tuple[torch.Tensor, torch.Tensor]:
if (grid != "equidistant") and periodic:
raise ValueError(f"Periodic grid is only supported on equidistant grids.")
......@@ -51,31 +55,41 @@ def _precompute_grid(n, grid="equidistant", a=0.0, b=1.0, periodic=False):
return xlg, wlg
@lru_cache(typed=True, copy=True)
def _precompute_longitudes(nlon: int):
r"""
Convenience routine to precompute longitudes
"""
lons = torch.linspace(0, 2 * math.pi, nlon+1, dtype=torch.float64, requires_grad=False)[:-1]
return lons
def _precompute_latitudes(nlat, grid="equiangular"):
@lru_cache(typed=True, copy=True)
def _precompute_latitudes(nlat: int, grid: Optional[str]="equiangular") -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Convenience routine to precompute latitudes
"""
# compute coordinates in the cosine theta domain
xlg, wlg = _precompute_grid(nlat, grid=grid, a=-1.0, b=1.0, periodic=False)
# to perform the quadrature and account for the jacobian of the sphere, the quadrature rule
# is formulated in the cosine theta domain, which is designed to integrate functions of cos theta
lats = np.flip(np.arccos(xlg)).copy()
wlg = np.flip(wlg).copy()
lats = torch.flip(torch.arccos(xlg), dims=(0,)).clone()
wlg = torch.flip(wlg, dims=(0,)).clone()
return lats, wlg
def trapezoidal_weights(n, a=-1.0, b=1.0, periodic=False):
def trapezoidal_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0, periodic: Optional[bool]=False) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Helper routine which returns equidistant nodes with trapezoidal weights
on the interval [a, b]
"""
xlg = np.linspace(a, b, n, endpoint=periodic)
wlg = (b - a) / (n - periodic * 1) * np.ones(n)
xlg = torch.from_numpy(np.linspace(a, b, n, endpoint=periodic))
wlg = (b - a) / (n - periodic * 1) * torch.ones(n, requires_grad=False)
if not periodic:
wlg[0] *= 0.5
......@@ -84,35 +98,38 @@ def trapezoidal_weights(n, a=-1.0, b=1.0, periodic=False):
return xlg, wlg
def legendre_gauss_weights(n, a=-1.0, b=1.0):
def legendre_gauss_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Helper routine which returns the Legendre-Gauss nodes and weights
on the interval [a, b]
"""
xlg, wlg = np.polynomial.legendre.leggauss(n)
xlg = torch.from_numpy(xlg).clone()
wlg = torch.from_numpy(wlg).clone()
xlg = (b - a) * 0.5 * xlg + (b + a) * 0.5
wlg = wlg * (b - a) * 0.5
return xlg, wlg
def lobatto_weights(n, a=-1.0, b=1.0, tol=1e-16, maxiter=100):
def lobatto_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0,
tol: Optional[float]=1e-16, maxiter: Optional[int]=100) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Helper routine which returns the Legendre-Gauss-Lobatto nodes and weights
on the interval [a, b]
"""
wlg = np.zeros((n,))
tlg = np.zeros((n,))
tmp = np.zeros((n,))
wlg = torch.zeros((n,), dtype=torch.float64, requires_grad=False)
tlg = torch.zeros((n,), dtype=torch.float64, requires_grad=False)
tmp = torch.zeros((n,), dtype=torch.float64, requires_grad=False)
# Vandermonde Matrix
vdm = np.zeros((n, n))
vdm = torch.zeros((n, n), dtype=torch.float64, requires_grad=False)
# initialize Chebyshev nodes as first guess
for i in range(n):
tlg[i] = -np.cos(np.pi * i / (n - 1))
tlg[i] = -torch.cos(math.pi * i / (n - 1))
tmp = 2.0
......@@ -139,7 +156,7 @@ def lobatto_weights(n, a=-1.0, b=1.0, tol=1e-16, maxiter=100):
return tlg, wlg
def clenshaw_curtiss_weights(n, a=-1.0, b=1.0):
def clenshaw_curtiss_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Computation of the Clenshaw-Curtis quadrature nodes and weights.
This implementation follows
......@@ -149,26 +166,27 @@ def clenshaw_curtiss_weights(n, a=-1.0, b=1.0):
assert n > 1
tcc = np.cos(np.linspace(np.pi, 0, n))
tcc = torch.cos(torch.linspace(math.pi, 0, n, dtype=torch.float64, requires_grad=False))
if n == 2:
wcc = np.array([1.0, 1.0])
wcc = torch.tensor([1.0, 1.0], dtype=torch.float64)
else:
n1 = n - 1
N = np.arange(1, n1, 2)
N = torch.arange(1, n1, 2, dtype=torch.float64)
l = len(N)
m = n1 - l
v = np.concatenate([2 / N / (N - 2), 1 / N[-1:], np.zeros(m)])
v = 0 - v[:-1] - v[-1:0:-1]
v = torch.cat([2 / N / (N - 2), 1 / N[-1:], torch.zeros(m, dtype=torch.float64, requires_grad=False)])
#v = 0 - v[:-1] - v[-1:0:-1]
v = 0 - v[:-1] - torch.flip(v[1:], dims=(0,))
g0 = -np.ones(n1)
g0 = -torch.ones(n1, dtype=torch.float64, requires_grad=False)
g0[l] = g0[l] + n1
g0[m] = g0[m] + n1
g = g0 / (n1**2 - 1 + (n1 % 2))
wcc = np.fft.ifft(v + g).real
wcc = np.concatenate((wcc, wcc[:1]))
wcc = torch.fft.ifft(v + g).real
wcc = torch.cat((wcc, wcc[:1]))
# rescale
tcc = (b - a) * 0.5 * tcc + (b + a) * 0.5
......@@ -177,7 +195,7 @@ def clenshaw_curtiss_weights(n, a=-1.0, b=1.0):
return tcc, wcc
def fejer2_weights(n, a=-1.0, b=1.0):
def fejer2_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Computation of the Fejer quadrature nodes and weights.
This implementation follows
......@@ -187,18 +205,19 @@ def fejer2_weights(n, a=-1.0, b=1.0):
assert n > 2
tcc = np.cos(np.linspace(np.pi, 0, n))
tcc = torch.cos(torch.linspace(math.pi, 0, n, dtype=torch.float64, requires_grad=False))
n1 = n - 1
N = np.arange(1, n1, 2)
N = torch.arange(1, n1, 2, dtype=torch.float64)
l = len(N)
m = n1 - l
v = np.concatenate([2 / N / (N - 2), 1 / N[-1:], np.zeros(m)])
v = 0 - v[:-1] - v[-1:0:-1]
v = torch.cat([2 / N / (N - 2), 1 / N[-1:], torch.zeros(m, dtype=torch.float64, requires_grad=False)])
#v = 0 - v[:-1] - v[-1:0:-1]
v = 0 - v[:-1] - torch.flip(v[1:], dims=(0,))
wcc = np.fft.ifft(v).real
wcc = np.concatenate((wcc, wcc[:1]))
wcc = torch.fft.ifft(v).real
wcc = torch.cat((wcc, wcc[:1]))
# rescale
tcc = (b - a) * 0.5 * tcc + (b + a) * 0.5
......
......@@ -31,12 +31,12 @@
from typing import List, Tuple, Union, Optional
import math
import numpy as np
#import numpy as np
import torch
import torch.nn as nn
from torch_harmonics.quadrature import _precompute_latitudes
from torch_harmonics.quadrature import _precompute_latitudes, _precompute_longitudes
class ResampleS2(nn.Module):
......@@ -67,54 +67,53 @@ class ResampleS2(nn.Module):
# for upscaling the latitudes we will use interpolation
self.lats_in, _ = _precompute_latitudes(nlat_in, grid=grid_in)
self.lons_in = np.linspace(0, 2 * math.pi, nlon_in, endpoint=False)
self.lons_in = _precompute_longitudes(nlon_in)
self.lats_out, _ = _precompute_latitudes(nlat_out, grid=grid_out)
self.lons_out = np.linspace(0, 2 * math.pi, nlon_out, endpoint=False)
self.lons_out = _precompute_longitudes(nlon_out)
# in the case where some points lie outside of the range spanned by lats_in,
# we need to expand the solution to the poles before interpolating
self.expand_poles = (self.lats_out > self.lats_in[-1]).any() or (self.lats_out < self.lats_in[0]).any()
if self.expand_poles:
self.lats_in = np.insert(self.lats_in, 0, 0.0)
self.lats_in = np.append(self.lats_in, np.pi)
self.lats_in = torch.cat([torch.tensor([0.], dtype=torch.float64),
self.lats_in,
torch.tensor([math.pi], dtype=torch.float64)]).contiguous()
#self.lats_in = np.insert(self.lats_in, 0, 0.0)
#self.lats_in = np.append(self.lats_in, np.pi)
# prepare the interpolation by computing indices to the left and right of each output latitude
lat_idx = np.searchsorted(self.lats_in, self.lats_out, side="right") - 1
lat_idx = torch.searchsorted(self.lats_in, self.lats_out, side="right") - 1
# make sure that we properly treat the last point if they coincide with the pole
lat_idx = np.where(self.lats_out == self.lats_in[-1], lat_idx - 1, lat_idx)
lat_idx = torch.where(self.lats_out == self.lats_in[-1], lat_idx - 1, lat_idx)
# lat_idx = np.where(self.lats_out > self.lats_in[-1], lat_idx - 1, lat_idx)
# lat_idx = np.where(self.lats_out < self.lats_in[0], 0, lat_idx)
# compute the interpolation weights along the latitude
lat_weights = torch.from_numpy((self.lats_out - self.lats_in[lat_idx]) / np.diff(self.lats_in)[lat_idx]).float()
lat_weights = ((self.lats_out - self.lats_in[lat_idx]) / torch.diff(self.lats_in)[lat_idx]).to(torch.float32)
lat_weights = lat_weights.unsqueeze(-1)
# convert to tensor
lat_idx = torch.LongTensor(lat_idx)
# register buffers
self.register_buffer("lat_idx", lat_idx, persistent=False)
self.register_buffer("lat_weights", lat_weights, persistent=False)
# get left and right indices but this time make sure periodicity in the longitude is handled
lon_idx_left = np.searchsorted(self.lons_in, self.lons_out, side="right") - 1
lon_idx_right = np.where(self.lons_out >= self.lons_in[-1], np.zeros_like(lon_idx_left), lon_idx_left + 1)
lon_idx_left = torch.searchsorted(self.lons_in, self.lons_out, side="right") - 1
lon_idx_right = torch.where(self.lons_out >= self.lons_in[-1], torch.zeros_like(lon_idx_left), lon_idx_left + 1)
# get the difference
diff = self.lons_in[lon_idx_right] - self.lons_in[lon_idx_left]
diff = np.where(diff < 0.0, diff + 2 * math.pi, diff)
lon_weights = torch.from_numpy((self.lons_out - self.lons_in[lon_idx_left]) / diff).float()
# convert to tensor
lon_idx_left = torch.LongTensor(lon_idx_left)
lon_idx_right = torch.LongTensor(lon_idx_right)
diff = torch.where(diff < 0.0, diff + 2 * math.pi, diff)
lon_weights = ((self.lons_out - self.lons_in[lon_idx_left]) / diff).to(torch.float32)
# register buffers
self.register_buffer("lon_idx_left", lon_idx_left, persistent=False)
self.register_buffer("lon_idx_right", lon_idx_right, persistent=False)
self.register_buffer("lon_weights", lon_weights, persistent=False)
self.skip_resampling = (nlon_in == nlon_out) and (nlat_in == nlat_out) and (grid_in == grid_out)
def extra_repr(self):
r"""
Pretty print module
......@@ -139,7 +138,7 @@ class ResampleS2(nn.Module):
repeats[-1] = x.shape[-1]
x_north = x[..., 0:1, :].mean(dim=-1, keepdim=True).repeat(*repeats)
x_south = x[..., -1:, :].mean(dim=-1, keepdim=True).repeat(*repeats)
x = torch.concatenate((x_north, x, x_south), dim=-2)
x = torch.concatenate((x_north, x, x_south), dim=-2).contiguous()
return x
def _upscale_latitudes(self, x: torch.Tensor):
......@@ -156,6 +155,9 @@ class ResampleS2(nn.Module):
return x
def forward(self, x: torch.Tensor):
if self.skip_resampling:
return x
if self.expand_poles:
x = self._expand_poles(x)
x = self._upscale_latitudes(x)
......
......@@ -29,7 +29,6 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import numpy as np
import torch
import torch.nn as nn
import torch.fft
......@@ -70,17 +69,17 @@ class RealSHT(nn.Module):
# compute quadrature points and lmax based on the exactness of the quadrature
if self.grid == "legendre-gauss":
cost, w = legendre_gauss_weights(nlat, -1, 1)
cost, weights = legendre_gauss_weights(nlat, -1, 1)
# maximum polynomial degree for Gauss Legendre is 2 * nlat - 1 >= 2 * lmax
# and therefore lmax = nlat - 1 (inclusive)
self.lmax = lmax or self.nlat
elif self.grid == "lobatto":
cost, w = lobatto_weights(nlat, -1, 1)
cost, weights = lobatto_weights(nlat, -1, 1)
# maximum polynomial degree for Gauss Legendre is 2 * nlat - 3 >= 2 * lmax
# and therefore lmax = nlat - 2 (inclusive)
self.lmax = lmax or self.nlat - 1
elif self.grid == "equiangular":
cost, w = clenshaw_curtiss_weights(nlat, -1, 1)
cost, weights = clenshaw_curtiss_weights(nlat, -1, 1)
# in principle, Clenshaw-Curtiss quadrature is only exact up to polynomial degrees of nlat
# however, we observe that the quadrature is remarkably accurate for higher degress. This is why we do not
# choose a lower lmax for now.
......@@ -89,16 +88,14 @@ class RealSHT(nn.Module):
raise (ValueError("Unknown quadrature mode"))
# apply cosine transform and flip them
tq = np.flip(np.arccos(cost))
tq = torch.flip(torch.arccos(cost), dims=(0,))
# determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1
# combine quadrature weights with the legendre weights
weights = torch.from_numpy(w)
pct = _precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
pct = torch.from_numpy(pct)
weights = torch.einsum("mlk,k->mlk", pct, weights)
weights = torch.einsum("mlk,k->mlk", pct, weights).contiguous()
# remember quadrature weights
self.register_buffer("weights", weights, persistent=False)
......@@ -172,13 +169,12 @@ class InverseRealSHT(nn.Module):
raise (ValueError("Unknown quadrature mode"))
# apply cosine transform and flip them
t = np.flip(np.arccos(cost))
t = torch.flip(torch.arccos(cost), dims=(0,))
# determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1
pct = _precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
pct = torch.from_numpy(pct)
# register buffer
self.register_buffer("pct", pct, persistent=False)
......@@ -241,32 +237,29 @@ class RealVectorSHT(nn.Module):
# compute quadrature points
if self.grid == "legendre-gauss":
cost, w = legendre_gauss_weights(nlat, -1, 1)
cost, weights = legendre_gauss_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat
elif self.grid == "lobatto":
cost, w = lobatto_weights(nlat, -1, 1)
cost, weights = lobatto_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat - 1
elif self.grid == "equiangular":
cost, w = clenshaw_curtiss_weights(nlat, -1, 1)
cost, weights = clenshaw_curtiss_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat
else:
raise (ValueError("Unknown quadrature mode"))
# apply cosine transform and flip them
tq = np.flip(np.arccos(cost))
tq = torch.flip(torch.arccos(cost), dims=(0,))
# determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1
weights = torch.from_numpy(w)
dpct = _precompute_dlegpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
dpct = torch.from_numpy(dpct)
# combine integration weights, normalization factor in to one:
l = torch.arange(0, self.lmax)
norm_factor = 1.0 / l / (l + 1)
norm_factor[0] = 1.0
weights = torch.einsum("dmlk,k,l->dmlk", dpct, weights, norm_factor)
weights = torch.einsum("dmlk,k,l->dmlk", dpct, weights, norm_factor).contiguous()
# since the second component is imaginary, we need to take complex conjugation into account
weights[1] = -1 * weights[1]
......@@ -356,13 +349,12 @@ class InverseRealVectorSHT(nn.Module):
raise (ValueError("Unknown quadrature mode"))
# apply cosine transform and flip them
t = np.flip(np.arccos(cost))
t = torch.flip(torch.arccos(cost), dims=(0,))
# determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1
dpct = _precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
dpct = torch.from_numpy(dpct)
# register weights
self.register_buffer("dpct", dpct, persistent=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment