Unverified Commit 6730e5c1 authored by Thorsten Kurth's avatar Thorsten Kurth Committed by GitHub
Browse files

Tkurth/torchification (#66)

* adding caching

* replacing many numpy calls with torch calls

* bumping up version number to 0.7.6
parent 780fd143
...@@ -2,7 +2,16 @@ ...@@ -2,7 +2,16 @@
## Versioning ## Versioning
### v0.7.6
* Adding cache for precomoputed tensors such as weight tensors for DISCO and SHT
* Cache is returning copies of tensors and not references. Users are still encouraged to re-use
those tensors manually in their models because this will also save memory. However,
the cache will help with model setup speed.
* Adding test which ensures that cache is working correctly
### v0.7.5 ### v0.7.5
* New normalization mode `support` for DISCO convolutions * New normalization mode `support` for DISCO convolutions
* More efficient computation of Morlet filter basis * More efficient computation of Morlet filter basis
* Changed default for Morlet filter basis to a Hann window function * Changed default for Morlet filter basis to a Hann window function
......
This diff is collapsed.
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import unittest
from parameterized import parameterized
import math
import torch
class TestCacheConsistency(unittest.TestCase):
def test_consistency(self, verbose=False):
if verbose:
print("Testing that cache values does not get modified externally")
from torch_harmonics.legendre import _precompute_legpoly
with torch.no_grad():
cost = torch.cos(torch.linspace(0.0, 2.0 * math.pi, 10, dtype=torch.float64))
leg1 = _precompute_legpoly(10, 10, cost)
# perform in-place modification of leg1
leg1 *= -1.0
leg2 = _precompute_legpoly(10, 10, cost)
self.assertFalse(torch.allclose(leg1, leg2))
if __name__ == "__main__":
unittest.main()
...@@ -38,7 +38,7 @@ import torch ...@@ -38,7 +38,7 @@ import torch
from torch.autograd import gradcheck from torch.autograd import gradcheck
from torch_harmonics import * from torch_harmonics import *
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes
def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalization=False, basis_norm_mode="none", merge_quadrature=False, eps=1e-9): def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalization=False, basis_norm_mode="none", merge_quadrature=False, eps=1e-9):
...@@ -106,22 +106,20 @@ def _precompute_convolution_tensor_dense( ...@@ -106,22 +106,20 @@ def _precompute_convolution_tensor_dense(
nlat_out, nlon_out = out_shape nlat_out, nlon_out = out_shape
lats_in, win = quadrature._precompute_latitudes(nlat_in, grid=grid_in) lats_in, win = quadrature._precompute_latitudes(nlat_in, grid=grid_in)
lats_in = torch.from_numpy(lats_in)
lats_out, wout = quadrature._precompute_latitudes(nlat_out, grid=grid_out) lats_out, wout = quadrature._precompute_latitudes(nlat_out, grid=grid_out)
lats_out = torch.from_numpy(lats_out)
# compute the phi differences. We need to make the linspace exclusive to not double the last point # compute the phi differences.
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1, dtype=torch.float64)[:-1] lons_in = _precompute_longitudes(nlon_in)
lons_out = torch.linspace(0, 2 * math.pi, nlon_out + 1, dtype=torch.float64)[:-1] lons_out = _precompute_longitudes(nlon_out)
# effective theta cutoff if multiplied with a fudge factor to avoid aliasing with grid width (especially near poles) # effective theta cutoff if multiplied with a fudge factor to avoid aliasing with grid width (especially near poles)
theta_cutoff_eff = (1.0 + theta_eps) * theta_cutoff theta_cutoff_eff = (1.0 + theta_eps) * theta_cutoff
# compute quadrature weights that will be merged into the Psi tensor # compute quadrature weights that will be merged into the Psi tensor
if transpose_normalization: if transpose_normalization:
quad_weights = torch.from_numpy(wout).reshape(-1, 1) / nlon_in / 2.0 quad_weights = wout.reshape(-1, 1) / nlon_in / 2.0
else: else:
quad_weights = torch.from_numpy(win).reshape(-1, 1) / nlon_in / 2.0 quad_weights = win.reshape(-1, 1) / nlon_in / 2.0
# array for accumulating non-zero indices # array for accumulating non-zero indices
out = torch.zeros(kernel_size, nlat_out, nlon_out, nlat_in, nlon_in, dtype=torch.float64) out = torch.zeros(kernel_size, nlat_out, nlon_out, nlat_in, nlon_in, dtype=torch.float64)
...@@ -172,34 +170,32 @@ class TestDiscreteContinuousConvolution(unittest.TestCase): ...@@ -172,34 +170,32 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
else: else:
self.device = torch.device("cpu") self.device = torch.device("cpu")
self.device = torch.device("cpu")
@parameterized.expand( @parameterized.expand(
[ [
# regular convolution # regular convolution
[8, 4, 2, (16, 32), (16, 32), [3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4], [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
[8, 4, 2, (16, 32), (8, 16), [3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4], [8, 4, 2, (16, 32), (8, 16), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
[8, 4, 2, (24, 48), (12, 24), [3, 3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4], [8, 4, 2, (24, 48), (12, 24), (3, 3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
[8, 4, 2, (24, 48), (12, 24), [4, 3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4], [8, 4, 2, (24, 48), (12, 24), (4, 3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
[8, 4, 2, (24, 48), (12, 24), [2, 2], "morlet", "mean", "equiangular", "equiangular", False, 1e-4], [8, 4, 2, (24, 48), (12, 24), (2, 2), "morlet", "mean", "equiangular", "equiangular", False, 1e-4, False],
[8, 4, 2, (24, 48), (12, 24), [3], "zernike", "mean", "equiangular", "equiangular", False, 1e-4], [8, 4, 2, (24, 48), (12, 24), (3), "zernike", "mean", "equiangular", "equiangular", False, 1e-4, False],
[8, 4, 2, (16, 24), (8, 8), [3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4], [8, 4, 2, (16, 24), (8, 8), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
[8, 4, 2, (18, 36), (6, 12), [7], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4], [8, 4, 2, (18, 36), (6, 12), (7), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
[8, 4, 2, (16, 32), (8, 16), [5], "piecewise linear", "mean", "equiangular", "legendre-gauss", False, 1e-4], [8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "equiangular", "legendre-gauss", False, 1e-4, False],
[8, 4, 2, (16, 32), (8, 16), [5], "piecewise linear", "mean", "legendre-gauss", "equiangular", False, 1e-4], [8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "legendre-gauss", "equiangular", False, 1e-4, False],
[8, 4, 2, (16, 32), (8, 16), [5], "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", False, 1e-4], [8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", False, 1e-4, False],
# transpose convolution # transpose convolution
[8, 4, 2, (16, 32), (16, 32), [3], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4], [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
[8, 4, 2, (8, 16), (16, 32), [5], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4], [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
[8, 4, 2, (12, 24), (24, 48), [3, 3], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4], [8, 4, 2, (12, 24), (24, 48), (3, 3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
[8, 4, 2, (12, 24), (24, 48), [4, 3], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4], [8, 4, 2, (12, 24), (24, 48), (4, 3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
[8, 4, 2, (12, 24), (24, 48), [2, 2], "morlet", "mean", "equiangular", "equiangular", True, 1e-4], [8, 4, 2, (12, 24), (24, 48), (2, 2), "morlet", "mean", "equiangular", "equiangular", True, 1e-4, False],
[8, 4, 2, (12, 24), (24, 48), [3], "zernike", "mean", "equiangular", "equiangular", True, 1e-4], [8, 4, 2, (12, 24), (24, 48), (3), "zernike", "mean", "equiangular", "equiangular", True, 1e-4, False],
[8, 4, 2, (8, 8), (16, 24), [3], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4], [8, 4, 2, (8, 8), (16, 24), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
[8, 4, 2, (6, 12), (18, 36), [7], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4], [8, 4, 2, (6, 12), (18, 36), (7), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
[8, 4, 2, (8, 16), (16, 32), [5], "piecewise linear", "mean", "equiangular", "legendre-gauss", True, 1e-4], [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "equiangular", "legendre-gauss", True, 1e-4, False],
[8, 4, 2, (8, 16), (16, 32), [5], "piecewise linear", "mean", "legendre-gauss", "equiangular", True, 1e-4], [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "equiangular", True, 1e-4, False],
[8, 4, 2, (8, 16), (16, 32), [5], "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", True, 1e-4], [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", True, 1e-4, False],
] ]
) )
def test_disco_convolution( def test_disco_convolution(
...@@ -216,11 +212,19 @@ class TestDiscreteContinuousConvolution(unittest.TestCase): ...@@ -216,11 +212,19 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
grid_out, grid_out,
transpose, transpose,
tol, tol,
verbose,
): ):
if verbose:
print(f"Testing DISCO convolution on {in_shape[0]}x{in_shape[1]} {grid_in} grid to {out_shape[0]}x{out_shape[1]} {grid_out} grid on {self.device.type} device")
nlat_in, nlon_in = in_shape nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape nlat_out, nlon_out = out_shape
theta_cutoff = (kernel_shape[0] + 1) * torch.pi / float(nlat_in - 1) if isinstance(kernel_shape, int):
theta_cutoff = (kernel_shape + 1) * torch.pi / float(nlat_in - 1)
else:
theta_cutoff = (kernel_shape[0] + 1) * torch.pi / float(nlat_in - 1)
Conv = DiscreteContinuousConvTransposeS2 if transpose else DiscreteContinuousConvS2 Conv = DiscreteContinuousConvTransposeS2 if transpose else DiscreteContinuousConvS2
conv = Conv( conv = Conv(
......
...@@ -183,18 +183,18 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase): ...@@ -183,18 +183,18 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
@parameterized.expand( @parameterized.expand(
[ [
[128, 256, 128, 256, 32, 8, [3], "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5], [128, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[129, 256, 128, 256, 32, 8, [3], "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5], [129, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, [3, 2], "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5], [128, 256, 128, 256, 32, 8, (3, 2), "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 64, 128, 32, 8, [3], "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5], [128, 256, 64, 128, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, [3], "piecewise linear", "mean", 2, "equiangular", "equiangular", False, 1e-5], [128, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 2, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 6, [3], "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5], [128, 256, 128, 256, 32, 6, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[128, 256, 128, 256, 32, 8, [3], "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5], [128, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
[129, 256, 129, 256, 32, 8, [3], "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5], [129, 256, 129, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 8, [3, 2], "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5], [128, 256, 128, 256, 32, 8, (3, 2), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
[64, 128, 128, 256, 32, 8, [3], "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5], [64, 128, 128, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 8, [3], "piecewise linear", "mean", 2, "equiangular", "equiangular", True, 1e-5], [128, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 2, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 6, [3], "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5], [128, 256, 128, 256, 32, 6, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
] ]
) )
def test_distributed_disco_conv( def test_distributed_disco_conv(
......
...@@ -183,14 +183,14 @@ class TestDistributedResampling(unittest.TestCase): ...@@ -183,14 +183,14 @@ class TestDistributedResampling(unittest.TestCase):
@parameterized.expand( @parameterized.expand(
[ [
[64, 128, 128, 256, 32, 8, "equiangular", "equiangular", "bilinear", 1e-7], [64, 128, 128, 256, 32, 8, "equiangular", "equiangular", "bilinear", 1e-7, False],
[128, 256, 64, 128, 32, 8, "equiangular", "equiangular", "bilinear", 1e-7], [128, 256, 64, 128, 32, 8, "equiangular", "equiangular", "bilinear", 1e-7, False],
[64, 128, 128, 256, 32, 8, "equiangular", "equiangular", "bilinear-spherical", 1e-7], [64, 128, 128, 256, 32, 8, "equiangular", "equiangular", "bilinear-spherical", 1e-7, False],
[128, 256, 64, 128, 32, 8, "equiangular", "equiangular", "bilinear-spherical", 1e-7], [128, 256, 64, 128, 32, 8, "equiangular", "equiangular", "bilinear-spherical", 1e-7, False],
] ]
) )
def test_distributed_resampling( def test_distributed_resampling(
self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, grid_in, grid_out, mode, tol self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, grid_in, grid_out, mode, tol, verbose
): ):
B, C, H, W = batch_size, num_chan, nlat_in, nlon_in B, C, H, W = batch_size, num_chan, nlat_in, nlon_in
...@@ -248,7 +248,7 @@ class TestDistributedResampling(unittest.TestCase): ...@@ -248,7 +248,7 @@ class TestDistributedResampling(unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
out_gather_full = self._gather_helper_fwd(out_local, B, C, res_dist) out_gather_full = self._gather_helper_fwd(out_local, B, C, res_dist)
err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2))) err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2)))
if self.world_rank == 0: if verbose and (self.world_rank == )0:
print(f"final relative error of output: {err.item()}") print(f"final relative error of output: {err.item()}")
self.assertTrue(err.item() <= tol) self.assertTrue(err.item() <= tol)
...@@ -259,7 +259,7 @@ class TestDistributedResampling(unittest.TestCase): ...@@ -259,7 +259,7 @@ class TestDistributedResampling(unittest.TestCase):
igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, res_dist) igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, res_dist)
err = torch.mean(torch.norm(igrad_full - igrad_gather_full, p="fro", dim=(-1, -2)) / torch.norm(igrad_full, p="fro", dim=(-1, -2))) err = torch.mean(torch.norm(igrad_full - igrad_gather_full, p="fro", dim=(-1, -2)) / torch.norm(igrad_full, p="fro", dim=(-1, -2)))
if self.world_rank == 0: if verbose and (self.world_rank == 0):
print(f"final relative error of gradients: {err.item()}") print(f"final relative error of gradients: {err.item()}")
self.assertTrue(err.item() <= tol) self.assertTrue(err.item() <= tol)
......
...@@ -32,7 +32,6 @@ ...@@ -32,7 +32,6 @@
import unittest import unittest
from parameterized import parameterized from parameterized import parameterized
import math import math
import numpy as np
import torch import torch
from torch.autograd import gradcheck from torch.autograd import gradcheck
from torch_harmonics import * from torch_harmonics import *
...@@ -41,31 +40,32 @@ from torch_harmonics import * ...@@ -41,31 +40,32 @@ from torch_harmonics import *
class TestLegendrePolynomials(unittest.TestCase): class TestLegendrePolynomials(unittest.TestCase):
def setUp(self): def setUp(self):
self.cml = lambda m, l: np.sqrt((2 * l + 1) / 4 / np.pi) * np.sqrt(math.factorial(l - m) / math.factorial(l + m)) self.cml = lambda m, l: math.sqrt((2 * l + 1) / 4 / math.pi) * math.sqrt(math.factorial(l - m) / math.factorial(l + m))
self.pml = dict() self.pml = dict()
# preparing associated Legendre Polynomials (These include the Condon-Shortley phase) # preparing associated Legendre Polynomials (These include the Condon-Shortley phase)
# for reference see e.g. https://en.wikipedia.org/wiki/Associated_Legendre_polynomials # for reference see e.g. https://en.wikipedia.org/wiki/Associated_Legendre_polynomials
self.pml[(0, 0)] = lambda x: np.ones_like(x) self.pml[(0, 0)] = lambda x: torch.ones_like(x)
self.pml[(0, 1)] = lambda x: x self.pml[(0, 1)] = lambda x: x
self.pml[(1, 1)] = lambda x: -np.sqrt(1.0 - x**2) self.pml[(1, 1)] = lambda x: -torch.sqrt(1.0 - x**2)
self.pml[(0, 2)] = lambda x: 0.5 * (3 * x**2 - 1) self.pml[(0, 2)] = lambda x: 0.5 * (3 * x**2 - 1)
self.pml[(1, 2)] = lambda x: -3 * x * np.sqrt(1.0 - x**2) self.pml[(1, 2)] = lambda x: -3 * x * torch.sqrt(1.0 - x**2)
self.pml[(2, 2)] = lambda x: 3 * (1 - x**2) self.pml[(2, 2)] = lambda x: 3 * (1 - x**2)
self.pml[(0, 3)] = lambda x: 0.5 * (5 * x**3 - 3 * x) self.pml[(0, 3)] = lambda x: 0.5 * (5 * x**3 - 3 * x)
self.pml[(1, 3)] = lambda x: 1.5 * (1 - 5 * x**2) * np.sqrt(1.0 - x**2) self.pml[(1, 3)] = lambda x: 1.5 * (1 - 5 * x**2) * torch.sqrt(1.0 - x**2)
self.pml[(2, 3)] = lambda x: 15 * x * (1 - x**2) self.pml[(2, 3)] = lambda x: 15 * x * (1 - x**2)
self.pml[(3, 3)] = lambda x: -15 * np.sqrt(1.0 - x**2) ** 3 self.pml[(3, 3)] = lambda x: -15 * torch.sqrt(1.0 - x**2) ** 3
self.lmax = self.mmax = 4 self.lmax = self.mmax = 4
self.tol = 1e-9 self.tol = 1e-9
def test_legendre(self): def test_legendre(self, verbose=False):
print("Testing computation of associated Legendre polynomials") if verbose:
print("Testing computation of associated Legendre polynomials")
from torch_harmonics.legendre import legpoly from torch_harmonics.legendre import legpoly
t = np.linspace(0, 1, 100) t = torch.linspace(0, 1, 100, dtype=torch.float64)
vdm = legpoly(self.mmax, self.lmax, t) vdm = legpoly(self.mmax, self.lmax, t)
for l in range(self.lmax): for l in range(self.lmax):
...@@ -79,24 +79,23 @@ class TestSphericalHarmonicTransform(unittest.TestCase): ...@@ -79,24 +79,23 @@ class TestSphericalHarmonicTransform(unittest.TestCase):
def setUp(self): def setUp(self):
if torch.cuda.is_available(): if torch.cuda.is_available():
print("Running test on GPU")
self.device = torch.device("cuda") self.device = torch.device("cuda")
else: else:
print("Running test on CPU")
self.device = torch.device("cpu") self.device = torch.device("cpu")
@parameterized.expand( @parameterized.expand(
[ [
[256, 512, 32, "ortho", "equiangular", 1e-9], [256, 512, 32, "ortho", "equiangular", 1e-9, False],
[256, 512, 32, "ortho", "legendre-gauss", 1e-9], [256, 512, 32, "ortho", "legendre-gauss", 1e-9, False],
[256, 512, 32, "four-pi", "equiangular", 1e-9], [256, 512, 32, "four-pi", "equiangular", 1e-9, False],
[256, 512, 32, "four-pi", "legendre-gauss", 1e-9], [256, 512, 32, "four-pi", "legendre-gauss", 1e-9, False],
[256, 512, 32, "schmidt", "equiangular", 1e-9], [256, 512, 32, "schmidt", "equiangular", 1e-9, False],
[256, 512, 32, "schmidt", "legendre-gauss", 1e-9], [256, 512, 32, "schmidt", "legendre-gauss", 1e-9, False],
] ]
) )
def test_sht(self, nlat, nlon, batch_size, norm, grid, tol): def test_sht(self, nlat, nlon, batch_size, norm, grid, tol, verbose):
print(f"Testing real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization") if verbose:
print(f"Testing real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization on {self.device.type} device")
testiters = [1, 2, 4, 8, 16] testiters = [1, 2, 4, 8, 16]
if grid == "equiangular": if grid == "equiangular":
...@@ -116,7 +115,8 @@ class TestSphericalHarmonicTransform(unittest.TestCase): ...@@ -116,7 +115,8 @@ class TestSphericalHarmonicTransform(unittest.TestCase):
# testing error accumulation # testing error accumulation
for iter in testiters: for iter in testiters:
with self.subTest(i=iter): with self.subTest(i=iter):
print(f"{iter} iterations of batchsize {batch_size}:") if verbose:
print(f"{iter} iterations of batchsize {batch_size}:")
base = signal base = signal
...@@ -124,27 +124,29 @@ class TestSphericalHarmonicTransform(unittest.TestCase): ...@@ -124,27 +124,29 @@ class TestSphericalHarmonicTransform(unittest.TestCase):
base = isht(sht(base)) base = isht(sht(base))
err = torch.mean(torch.norm(base - signal, p="fro", dim=(-1, -2)) / torch.norm(signal, p="fro", dim=(-1, -2))) err = torch.mean(torch.norm(base - signal, p="fro", dim=(-1, -2)) / torch.norm(signal, p="fro", dim=(-1, -2)))
print(f"final relative error: {err.item()}") if verbose:
print(f"final relative error: {err.item()}")
self.assertTrue(err.item() <= tol) self.assertTrue(err.item() <= tol)
@parameterized.expand( @parameterized.expand(
[ [
[12, 24, 2, "ortho", "equiangular", 1e-5], [12, 24, 2, "ortho", "equiangular", 1e-5, False],
[12, 24, 2, "ortho", "legendre-gauss", 1e-5], [12, 24, 2, "ortho", "legendre-gauss", 1e-5, False],
[12, 24, 2, "four-pi", "equiangular", 1e-5], [12, 24, 2, "four-pi", "equiangular", 1e-5, False],
[12, 24, 2, "four-pi", "legendre-gauss", 1e-5], [12, 24, 2, "four-pi", "legendre-gauss", 1e-5, False],
[12, 24, 2, "schmidt", "equiangular", 1e-5], [12, 24, 2, "schmidt", "equiangular", 1e-5, False],
[12, 24, 2, "schmidt", "legendre-gauss", 1e-5], [12, 24, 2, "schmidt", "legendre-gauss", 1e-5, False],
[15, 30, 2, "ortho", "equiangular", 1e-5], [15, 30, 2, "ortho", "equiangular", 1e-5, False],
[15, 30, 2, "ortho", "legendre-gauss", 1e-5], [15, 30, 2, "ortho", "legendre-gauss", 1e-5, False],
[15, 30, 2, "four-pi", "equiangular", 1e-5], [15, 30, 2, "four-pi", "equiangular", 1e-5, False],
[15, 30, 2, "four-pi", "legendre-gauss", 1e-5], [15, 30, 2, "four-pi", "legendre-gauss", 1e-5, False],
[15, 30, 2, "schmidt", "equiangular", 1e-5], [15, 30, 2, "schmidt", "equiangular", 1e-5, False],
[15, 30, 2, "schmidt", "legendre-gauss", 1e-5], [15, 30, 2, "schmidt", "legendre-gauss", 1e-5, False],
] ]
) )
def test_sht_grads(self, nlat, nlon, batch_size, norm, grid, tol): def test_sht_grads(self, nlat, nlon, batch_size, norm, grid, tol, verbose):
print(f"Testing gradients of real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization") if verbose:
print(f"Testing gradients of real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization")
if grid == "equiangular": if grid == "equiangular":
mmax = nlat // 2 mmax = nlat // 2
......
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# #
__version__ = "0.7.5a" __version__ = "0.7.6"
from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT
from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2 from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
......
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import functools
from copy import deepcopy
# copying LRU cache decorator a la:
# https://stackoverflow.com/questions/54909357/how-to-get-functools-lru-cache-to-return-new-instances
def lru_cache(maxsize=20, typed=False, copy=False):
def decorator(f):
cached_func = functools.lru_cache(maxsize=maxsize, typed=typed)(f)
def wrapper(*args, **kwargs):
res = cached_func(*args, **kwargs)
if copy:
return deepcopy(res)
else:
return res
return wrapper
return decorator
...@@ -40,10 +40,11 @@ import torch.nn as nn ...@@ -40,10 +40,11 @@ import torch.nn as nn
from functools import partial from functools import partial
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes from torch_harmonics.cache import lru_cache
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes
from torch_harmonics._disco_convolution import _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch from torch_harmonics._disco_convolution import _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch
from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda
from torch_harmonics.filter_basis import get_filter_basis from torch_harmonics.filter_basis import FilterBasis, get_filter_basis
# import custom C++/CUDA extensions if available # import custom C++/CUDA extensions if available
try: try:
...@@ -134,17 +135,18 @@ def _normalize_convolution_tensor_s2( ...@@ -134,17 +135,18 @@ def _normalize_convolution_tensor_s2(
return psi_vals return psi_vals
@lru_cache(typed=True, copy=True)
def _precompute_convolution_tensor_s2( def _precompute_convolution_tensor_s2(
in_shape, in_shape: Tuple[int],
out_shape, out_shape: Tuple[int],
filter_basis, filter_basis: FilterBasis,
grid_in="equiangular", grid_in: Optional[str]="equiangular",
grid_out="equiangular", grid_out: Optional[str]="equiangular",
theta_cutoff=0.01 * math.pi, theta_cutoff: Optional[float]=0.01 * math.pi,
theta_eps = 1e-3, theta_eps: Optional[float]=1e-3,
transpose_normalization=False, transpose_normalization: Optional[bool]=False,
basis_norm_mode="mean", basis_norm_mode: Optional[str]="mean",
merge_quadrature=False, merge_quadrature: Optional[bool]=False,
): ):
""" """
Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$. Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
...@@ -172,20 +174,18 @@ def _precompute_convolution_tensor_s2( ...@@ -172,20 +174,18 @@ def _precompute_convolution_tensor_s2(
# precompute input and output grids # precompute input and output grids
lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in) lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in)
lats_in = torch.from_numpy(lats_in)
lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out) lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out)
lats_out = torch.from_numpy(lats_out)
# compute the phi differences # compute the phi differences
# It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0 # It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1, dtype=torch.float64)[:-1] lons_in = _precompute_longitudes(nlon_in)
# compute quadrature weights and merge them into the convolution tensor. # compute quadrature weights and merge them into the convolution tensor.
# These quadrature integrate to 1 over the sphere. # These quadrature integrate to 1 over the sphere.
if transpose_normalization: if transpose_normalization:
quad_weights = torch.from_numpy(wout).reshape(-1, 1) / nlon_in / 2.0 quad_weights = wout.reshape(-1, 1) / nlon_in / 2.0
else: else:
quad_weights = torch.from_numpy(win).reshape(-1, 1) / nlon_in / 2.0 quad_weights = win.reshape(-1, 1) / nlon_in / 2.0
# effective theta cutoff if multiplied with a fudge factor to avoid aliasing with grid width (especially near poles) # effective theta cutoff if multiplied with a fudge factor to avoid aliasing with grid width (especially near poles)
theta_cutoff_eff = (1.0 + theta_eps) * theta_cutoff theta_cutoff_eff = (1.0 + theta_eps) * theta_cutoff
...@@ -258,7 +258,7 @@ class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta): ...@@ -258,7 +258,7 @@ class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta):
self, self,
in_channels: int, in_channels: int,
out_channels: int, out_channels: int,
kernel_shape: Union[int, List[int]], kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
basis_type: Optional[str] = "piecewise linear", basis_type: Optional[str] = "piecewise linear",
groups: Optional[int] = 1, groups: Optional[int] = 1,
bias: Optional[bool] = True, bias: Optional[bool] = True,
...@@ -309,7 +309,7 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv): ...@@ -309,7 +309,7 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
out_channels: int, out_channels: int,
in_shape: Tuple[int], in_shape: Tuple[int],
out_shape: Tuple[int], out_shape: Tuple[int],
kernel_shape: Union[int, List[int]], kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
basis_type: Optional[str] = "piecewise linear", basis_type: Optional[str] = "piecewise linear",
basis_norm_mode: Optional[str] = "mean", basis_norm_mode: Optional[str] = "mean",
groups: Optional[int] = 1, groups: Optional[int] = 1,
...@@ -415,7 +415,7 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -415,7 +415,7 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
out_channels: int, out_channels: int,
in_shape: Tuple[int], in_shape: Tuple[int],
out_shape: Tuple[int], out_shape: Tuple[int],
kernel_shape: Union[int, List[int]], kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
basis_type: Optional[str] = "piecewise linear", basis_type: Optional[str] = "piecewise linear",
basis_norm_mode: Optional[str] = "mean", basis_norm_mode: Optional[str] = "mean",
groups: Optional[int] = 1, groups: Optional[int] = 1,
......
...@@ -41,7 +41,7 @@ import torch.nn as nn ...@@ -41,7 +41,7 @@ import torch.nn as nn
from functools import partial from functools import partial
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes
from torch_harmonics._disco_convolution import _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch from torch_harmonics._disco_convolution import _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch
from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda
from torch_harmonics.filter_basis import get_filter_basis from torch_harmonics.filter_basis import get_filter_basis
...@@ -106,20 +106,18 @@ def _precompute_distributed_convolution_tensor_s2( ...@@ -106,20 +106,18 @@ def _precompute_distributed_convolution_tensor_s2(
# precompute input and output grids # precompute input and output grids
lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in) lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in)
lats_in = torch.from_numpy(lats_in)
lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out) lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out)
lats_out = torch.from_numpy(lats_out)
# compute the phi differences # compute the phi differences
# It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0 # It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1, dtype=torch.float64)[:-1] lons_in = _precompute_longitudes(nlon_in)
# compute quadrature weights and merge them into the convolution tensor. # compute quadrature weights and merge them into the convolution tensor.
# These quadrature integrate to 1 over the sphere. # These quadrature integrate to 1 over the sphere.
if transpose_normalization: if transpose_normalization:
quad_weights = torch.from_numpy(wout).reshape(-1, 1) / nlon_in / 2.0 quad_weights = wout.reshape(-1, 1) / nlon_in / 2.0
else: else:
quad_weights = torch.from_numpy(win).reshape(-1, 1) / nlon_in / 2.0 quad_weights = win.reshape(-1, 1) / nlon_in / 2.0
# effective theta cutoff if multiplied with a fudge factor to avoid aliasing with grid width (especially near poles) # effective theta cutoff if multiplied with a fudge factor to avoid aliasing with grid width (especially near poles)
theta_cutoff_eff = (1.0 + theta_eps) * theta_cutoff theta_cutoff_eff = (1.0 + theta_eps) * theta_cutoff
...@@ -215,7 +213,7 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv): ...@@ -215,7 +213,7 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
out_channels: int, out_channels: int,
in_shape: Tuple[int], in_shape: Tuple[int],
out_shape: Tuple[int], out_shape: Tuple[int],
kernel_shape: Union[int, List[int]], kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
basis_type: Optional[str] = "piecewise linear", basis_type: Optional[str] = "piecewise linear",
basis_norm_mode: Optional[str] = "mean", basis_norm_mode: Optional[str] = "mean",
groups: Optional[int] = 1, groups: Optional[int] = 1,
...@@ -356,7 +354,7 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -356,7 +354,7 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
out_channels: int, out_channels: int,
in_shape: Tuple[int], in_shape: Tuple[int],
out_shape: Tuple[int], out_shape: Tuple[int],
kernel_shape: Union[int, List[int]], kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
basis_type: Optional[str] = "piecewise linear", basis_type: Optional[str] = "piecewise linear",
basis_norm_mode: Optional[str] = "mean", basis_norm_mode: Optional[str] = "mean",
groups: Optional[int] = 1, groups: Optional[int] = 1,
......
...@@ -31,12 +31,11 @@ ...@@ -31,12 +31,11 @@
from typing import List, Tuple, Union, Optional from typing import List, Tuple, Union, Optional
import math import math
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch_harmonics.quadrature import _precompute_latitudes from torch_harmonics.quadrature import _precompute_latitudes, _precompute_longitudes
from torch_harmonics.distributed import polar_group_size, azimuth_group_size, distributed_transpose_azimuth, distributed_transpose_polar from torch_harmonics.distributed import polar_group_size, azimuth_group_size, distributed_transpose_azimuth, distributed_transpose_polar
from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank
from torch_harmonics.distributed import compute_split_shapes from torch_harmonics.distributed import compute_split_shapes
...@@ -82,54 +81,52 @@ class DistributedResampleS2(nn.Module): ...@@ -82,54 +81,52 @@ class DistributedResampleS2(nn.Module):
# for upscaling the latitudes we will use interpolation # for upscaling the latitudes we will use interpolation
self.lats_in, _ = _precompute_latitudes(nlat_in, grid=grid_in) self.lats_in, _ = _precompute_latitudes(nlat_in, grid=grid_in)
self.lons_in = np.linspace(0, 2 * math.pi, nlon_in, endpoint=False) self.lons_in = _precompute_longitudes(nlon_in)
self.lats_out, _ = _precompute_latitudes(nlat_out, grid=grid_out) self.lats_out, _ = _precompute_latitudes(nlat_out, grid=grid_out)
self.lons_out = np.linspace(0, 2 * math.pi, nlon_out, endpoint=False) self.lons_out = _precompute_longitudes(nlon_out)
# in the case where some points lie outside of the range spanned by lats_in, # in the case where some points lie outside of the range spanned by lats_in,
# we need to expand the solution to the poles before interpolating # we need to expand the solution to the poles before interpolating
self.expand_poles = (self.lats_out > self.lats_in[-1]).any() or (self.lats_out < self.lats_in[0]).any() self.expand_poles = (self.lats_out > self.lats_in[-1]).any() or (self.lats_out < self.lats_in[0]).any()
if self.expand_poles: if self.expand_poles:
self.lats_in = np.insert(self.lats_in, 0, 0.0) self.lats_in = torch.cat([torch.tensor([0.], dtype=torch.float64),
self.lats_in = np.append(self.lats_in, np.pi) self.lats_in,
torch.tensor([math.pi], dtype=torch.float64)]).contiguous()
#self.lats_in = np.insert(self.lats_in, 0, 0.0)
#self.lats_in = np.append(self.lats_in, np.pi)
# prepare the interpolation by computing indices to the left and right of each output latitude # prepare the interpolation by computing indices to the left and right of each output latitude
lat_idx = np.searchsorted(self.lats_in, self.lats_out, side="right") - 1 lat_idx = torch.searchsorted(self.lats_in, self.lats_out, side="right") - 1
# make sure that we properly treat the last point if they coincide with the pole # make sure that we properly treat the last point if they coincide with the pole
lat_idx = np.where(self.lats_out == self.lats_in[-1], lat_idx - 1, lat_idx) lat_idx = torch.where(self.lats_out == self.lats_in[-1], lat_idx - 1, lat_idx)
# lat_idx = np.where(self.lats_out > self.lats_in[-1], lat_idx - 1, lat_idx) # lat_idx = np.where(self.lats_out > self.lats_in[-1], lat_idx - 1, lat_idx)
# lat_idx = np.where(self.lats_out < self.lats_in[0], 0, lat_idx) # lat_idx = np.where(self.lats_out < self.lats_in[0], 0, lat_idx)
# compute the interpolation weights along the latitude # compute the interpolation weights along the latitude
lat_weights = torch.from_numpy((self.lats_out - self.lats_in[lat_idx]) / np.diff(self.lats_in)[lat_idx]).float() lat_weights = ((self.lats_out - self.lats_in[lat_idx]) / torch.diff(self.lats_in)[lat_idx]).to(torch.float32)
lat_weights = lat_weights.unsqueeze(-1) lat_weights = lat_weights.unsqueeze(-1)
# convert to tensor
lat_idx = torch.LongTensor(lat_idx)
# register buffers # register buffers
self.register_buffer("lat_idx", lat_idx, persistent=False) self.register_buffer("lat_idx", lat_idx, persistent=False)
self.register_buffer("lat_weights", lat_weights, persistent=False) self.register_buffer("lat_weights", lat_weights, persistent=False)
# get left and right indices but this time make sure periodicity in the longitude is handled # get left and right indices but this time make sure periodicity in the longitude is handled
lon_idx_left = np.searchsorted(self.lons_in, self.lons_out, side="right") - 1 lon_idx_left = torch.searchsorted(self.lons_in, self.lons_out, side="right") - 1
lon_idx_right = np.where(self.lons_out >= self.lons_in[-1], np.zeros_like(lon_idx_left), lon_idx_left + 1) lon_idx_right = torch.where(self.lons_out >= self.lons_in[-1], torch.zeros_like(lon_idx_left), lon_idx_left + 1)
# get the difference # get the difference
diff = self.lons_in[lon_idx_right] - self.lons_in[lon_idx_left] diff = self.lons_in[lon_idx_right] - self.lons_in[lon_idx_left]
diff = np.where(diff < 0.0, diff + 2 * math.pi, diff) diff = torch.where(diff < 0.0, diff + 2 * math.pi, diff)
lon_weights = torch.from_numpy((self.lons_out - self.lons_in[lon_idx_left]) / diff).float() lon_weights = ((self.lons_out - self.lons_in[lon_idx_left]) / diff).to(torch.float32)
# convert to tensor
lon_idx_left = torch.LongTensor(lon_idx_left)
lon_idx_right = torch.LongTensor(lon_idx_right)
# register buffers # register buffers
self.register_buffer("lon_idx_left", lon_idx_left, persistent=False) self.register_buffer("lon_idx_left", lon_idx_left, persistent=False)
self.register_buffer("lon_idx_right", lon_idx_right, persistent=False) self.register_buffer("lon_idx_right", lon_idx_right, persistent=False)
self.register_buffer("lon_weights", lon_weights, persistent=False) self.register_buffer("lon_weights", lon_weights, persistent=False)
self.skip_resampling = (nlon_in == nlon_out) and (nlat_in == nlat_out) and (grid_in == grid_out)
def extra_repr(self): def extra_repr(self):
r""" r"""
Pretty print module Pretty print module
...@@ -172,6 +169,9 @@ class DistributedResampleS2(nn.Module): ...@@ -172,6 +169,9 @@ class DistributedResampleS2(nn.Module):
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
if self.skip_resampling:
return x
# transpose data so that h is local, and channels are split # transpose data so that h is local, and channels are split
num_chans = x.shape[-3] num_chans = x.shape[-3]
......
...@@ -30,7 +30,6 @@ ...@@ -30,7 +30,6 @@
# #
import os import os
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.fft import torch.fft
...@@ -75,13 +74,13 @@ class DistributedRealSHT(nn.Module): ...@@ -75,13 +74,13 @@ class DistributedRealSHT(nn.Module):
# compute quadrature points # compute quadrature points
if self.grid == "legendre-gauss": if self.grid == "legendre-gauss":
cost, w = legendre_gauss_weights(nlat, -1, 1) cost, weights = legendre_gauss_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat self.lmax = lmax or self.nlat
elif self.grid == "lobatto": elif self.grid == "lobatto":
cost, w = lobatto_weights(nlat, -1, 1) cost, weights = lobatto_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat-1 self.lmax = lmax or self.nlat-1
elif self.grid == "equiangular": elif self.grid == "equiangular":
cost, w = clenshaw_curtiss_weights(nlat, -1, 1) cost, weights = clenshaw_curtiss_weights(nlat, -1, 1)
# cost, w = fejer2_weights(nlat, -1, 1) # cost, w = fejer2_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat self.lmax = lmax or self.nlat
else: else:
...@@ -94,7 +93,7 @@ class DistributedRealSHT(nn.Module): ...@@ -94,7 +93,7 @@ class DistributedRealSHT(nn.Module):
self.comm_rank_azimuth = azimuth_group_rank() self.comm_rank_azimuth = azimuth_group_rank()
# apply cosine transform and flip them # apply cosine transform and flip them
tq = np.flip(np.arccos(cost)) tq = torch.flip(torch.arccos(cost), dims=(0,))
# determine the dimensions # determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1 self.mmax = mmax or self.nlon // 2 + 1
...@@ -106,13 +105,11 @@ class DistributedRealSHT(nn.Module): ...@@ -106,13 +105,11 @@ class DistributedRealSHT(nn.Module):
self.m_shapes = compute_split_shapes(self.mmax, self.comm_size_azimuth) self.m_shapes = compute_split_shapes(self.mmax, self.comm_size_azimuth)
# combine quadrature weights with the legendre weights # combine quadrature weights with the legendre weights
weights = torch.from_numpy(w)
pct = _precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase) pct = _precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
pct = torch.from_numpy(pct)
weights = torch.einsum('mlk,k->mlk', pct, weights) weights = torch.einsum('mlk,k->mlk', pct, weights)
# split weights # split weights
weights = split_tensor_along_dim(weights, dim=0, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth] weights = split_tensor_along_dim(weights, dim=0, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth].contiguous()
# remember quadrature weights # remember quadrature weights
self.register_buffer('weights', weights, persistent=False) self.register_buffer('weights', weights, persistent=False)
...@@ -208,7 +205,7 @@ class DistributedInverseRealSHT(nn.Module): ...@@ -208,7 +205,7 @@ class DistributedInverseRealSHT(nn.Module):
self.comm_rank_azimuth = azimuth_group_rank() self.comm_rank_azimuth = azimuth_group_rank()
# apply cosine transform and flip them # apply cosine transform and flip them
t = np.flip(np.arccos(cost)) t = torch.flip(torch.arccos(cost), dims=(0,))
# determine the dimensions # determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1 self.mmax = mmax or self.nlon // 2 + 1
...@@ -221,10 +218,9 @@ class DistributedInverseRealSHT(nn.Module): ...@@ -221,10 +218,9 @@ class DistributedInverseRealSHT(nn.Module):
# compute legende polynomials # compute legende polynomials
pct = _precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase) pct = _precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
pct = torch.from_numpy(pct)
# split in m # split in m
pct = split_tensor_along_dim(pct, dim=0, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth] pct = split_tensor_along_dim(pct, dim=0, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth].contiguous()
# register # register
self.register_buffer('pct', pct, persistent=False) self.register_buffer('pct', pct, persistent=False)
...@@ -308,13 +304,13 @@ class DistributedRealVectorSHT(nn.Module): ...@@ -308,13 +304,13 @@ class DistributedRealVectorSHT(nn.Module):
# compute quadrature points # compute quadrature points
if self.grid == "legendre-gauss": if self.grid == "legendre-gauss":
cost, w = legendre_gauss_weights(nlat, -1, 1) cost, weights = legendre_gauss_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat self.lmax = lmax or self.nlat
elif self.grid == "lobatto": elif self.grid == "lobatto":
cost, w = lobatto_weights(nlat, -1, 1) cost, weights = lobatto_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat-1 self.lmax = lmax or self.nlat-1
elif self.grid == "equiangular": elif self.grid == "equiangular":
cost, w = clenshaw_curtiss_weights(nlat, -1, 1) cost, weights = clenshaw_curtiss_weights(nlat, -1, 1)
# cost, w = fejer2_weights(nlat, -1, 1) # cost, w = fejer2_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat self.lmax = lmax or self.nlat
else: else:
...@@ -327,7 +323,7 @@ class DistributedRealVectorSHT(nn.Module): ...@@ -327,7 +323,7 @@ class DistributedRealVectorSHT(nn.Module):
self.comm_rank_azimuth = azimuth_group_rank() self.comm_rank_azimuth = azimuth_group_rank()
# apply cosine transform and flip them # apply cosine transform and flip them
tq = np.flip(np.arccos(cost)) tq = torch.flip(torch.arccos(cost), dims=(0,))
# determine the dimensions # determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1 self.mmax = mmax or self.nlon // 2 + 1
...@@ -339,9 +335,7 @@ class DistributedRealVectorSHT(nn.Module): ...@@ -339,9 +335,7 @@ class DistributedRealVectorSHT(nn.Module):
self.m_shapes = compute_split_shapes(self.mmax, self.comm_size_azimuth) self.m_shapes = compute_split_shapes(self.mmax, self.comm_size_azimuth)
# compute weights # compute weights
weights = torch.from_numpy(w)
dpct = _precompute_dlegpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase) dpct = _precompute_dlegpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
dpct = torch.from_numpy(dpct)
# combine integration weights, normalization factor in to one: # combine integration weights, normalization factor in to one:
l = torch.arange(0, self.lmax) l = torch.arange(0, self.lmax)
...@@ -352,7 +346,7 @@ class DistributedRealVectorSHT(nn.Module): ...@@ -352,7 +346,7 @@ class DistributedRealVectorSHT(nn.Module):
weights[1] = -1 * weights[1] weights[1] = -1 * weights[1]
# we need to split in m, pad before: # we need to split in m, pad before:
weights = split_tensor_along_dim(weights, dim=1, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth] weights = split_tensor_along_dim(weights, dim=1, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth].contiguous()
# remember quadrature weights # remember quadrature weights
self.register_buffer('weights', weights, persistent=False) self.register_buffer('weights', weights, persistent=False)
...@@ -461,7 +455,7 @@ class DistributedInverseRealVectorSHT(nn.Module): ...@@ -461,7 +455,7 @@ class DistributedInverseRealVectorSHT(nn.Module):
self.comm_rank_azimuth = azimuth_group_rank() self.comm_rank_azimuth = azimuth_group_rank()
# apply cosine transform and flip them # apply cosine transform and flip them
t = np.flip(np.arccos(cost)) t = torch.flip(torch.arccos(cost), dims=(0,))
# determine the dimensions # determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1 self.mmax = mmax or self.nlon // 2 + 1
...@@ -474,10 +468,9 @@ class DistributedInverseRealVectorSHT(nn.Module): ...@@ -474,10 +468,9 @@ class DistributedInverseRealVectorSHT(nn.Module):
# compute legende polynomials # compute legende polynomials
dpct = _precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase) dpct = _precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
dpct = torch.from_numpy(dpct)
# split in m # split in m
dpct = split_tensor_along_dim(dpct, dim=1, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth] dpct = split_tensor_along_dim(dpct, dim=1, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth].contiguous()
# register buffer # register buffer
self.register_buffer('dpct', dpct, persistent=False) self.register_buffer('dpct', dpct, persistent=False)
......
...@@ -33,7 +33,9 @@ ...@@ -33,7 +33,9 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch_harmonics as harmonics import torch_harmonics as harmonics
from torch_harmonics.quadrature import _precompute_longitudes
import math
import numpy as np import numpy as np
...@@ -74,8 +76,8 @@ class SphereSolver(nn.Module): ...@@ -74,8 +76,8 @@ class SphereSolver(nn.Module):
cost, _ = harmonics.quadrature.clenshaw_curtiss_weights(self.nlat, -1, 1) cost, _ = harmonics.quadrature.clenshaw_curtiss_weights(self.nlat, -1, 1)
# apply cosine transform and flip them # apply cosine transform and flip them
lats = -torch.as_tensor(np.arcsin(cost)) lats = -torch.arcsin(cost)
lons = torch.linspace(0, 2*np.pi, self.nlon+1, dtype=torch.float64)[:nlon] lons = _precompute_longitudes(self.nlon)
self.lmax = self.sht.lmax self.lmax = self.sht.lmax
self.mmax = self.sht.mmax self.mmax = self.sht.mmax
...@@ -162,8 +164,8 @@ class SphereSolver(nn.Module): ...@@ -162,8 +164,8 @@ class SphereSolver(nn.Module):
#ax = plt.gca(projection=proj, frameon=True) #ax = plt.gca(projection=proj, frameon=True)
ax = fig.add_subplot(projection=proj) ax = fig.add_subplot(projection=proj)
Lons = Lons*180/np.pi Lons = Lons*180/math.pi
Lats = Lats*180/np.pi Lats = Lats*180/math.pi
# contour data over the map. # contour data over the map.
im = ax.pcolormesh(Lons, Lats, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=antialiased, vmax=vmax, vmin=vmin) im = ax.pcolormesh(Lons, Lats, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=antialiased, vmax=vmax, vmin=vmin)
...@@ -175,4 +177,4 @@ class SphereSolver(nn.Module): ...@@ -175,4 +177,4 @@ class SphereSolver(nn.Module):
return im return im
def plot_specdata(self, data, fig, **kwargs): def plot_specdata(self, data, fig, **kwargs):
return self.plot_griddata(self.isht(data), fig, **kwargs) return self.plot_griddata(self.isht(data), fig, **kwargs)
\ No newline at end of file
...@@ -35,6 +35,7 @@ import torch.nn as nn ...@@ -35,6 +35,7 @@ import torch.nn as nn
import torch_harmonics as harmonics import torch_harmonics as harmonics
from torch_harmonics.quadrature import * from torch_harmonics.quadrature import *
import math
import numpy as np import numpy as np
...@@ -79,11 +80,11 @@ class ShallowWaterSolver(nn.Module): ...@@ -79,11 +80,11 @@ class ShallowWaterSolver(nn.Module):
elif self.grid == "equiangular": elif self.grid == "equiangular":
cost, quad_weights = harmonics.quadrature.clenshaw_curtiss_weights(self.nlat, -1, 1) cost, quad_weights = harmonics.quadrature.clenshaw_curtiss_weights(self.nlat, -1, 1)
quad_weights = torch.as_tensor(quad_weights).reshape(-1, 1) quad_weights = quad_weights.reshape(-1, 1)
# apply cosine transform and flip them # apply cosine transform and flip them
lats = -torch.as_tensor(np.arcsin(cost)) lats = -torch.arcsin(cost)
lons = torch.linspace(0, 2*np.pi, self.nlon+1, dtype=torch.float64)[:nlon] lons = _precompute_longitudes(self.nlon)
self.lmax = self.sht.lmax self.lmax = self.sht.lmax
self.mmax = self.sht.mmax self.mmax = self.sht.mmax
...@@ -360,8 +361,8 @@ class ShallowWaterSolver(nn.Module): ...@@ -360,8 +361,8 @@ class ShallowWaterSolver(nn.Module):
#ax = plt.gca(projection=proj, frameon=True) #ax = plt.gca(projection=proj, frameon=True)
ax = fig.add_subplot(projection=proj) ax = fig.add_subplot(projection=proj)
Lons = Lons*180/np.pi Lons = Lons*180/math.pi
Lats = Lats*180/np.pi Lats = Lats*180/math.pi
# contour data over the map. # contour data over the map.
im = ax.pcolormesh(Lons, Lats, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=antialiased, vmax=vmax, vmin=vmin) im = ax.pcolormesh(Lons, Lats, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=antialiased, vmax=vmax, vmin=vmin)
...@@ -375,8 +376,8 @@ class ShallowWaterSolver(nn.Module): ...@@ -375,8 +376,8 @@ class ShallowWaterSolver(nn.Module):
#ax = plt.gca(projection=proj, frameon=True) #ax = plt.gca(projection=proj, frameon=True)
ax = fig.add_subplot(projection=proj) ax = fig.add_subplot(projection=proj)
Lons = Lons*180/np.pi Lons = Lons*180/math.pi
Lats = Lats*180/np.pi Lats = Lats*180/math.pi
# contour data over the map. # contour data over the map.
im = ax.pcolormesh(Lons, Lats, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=antialiased, vmax=vmax, vmin=vmin) im = ax.pcolormesh(Lons, Lats, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=antialiased, vmax=vmax, vmin=vmin)
......
...@@ -30,23 +30,12 @@ ...@@ -30,23 +30,12 @@
# #
import abc import abc
from typing import List, Tuple, Union, Optional from typing import Tuple, Union, Optional
import math import math
import torch import torch
from torch_harmonics.cache import lru_cache
def get_filter_basis(kernel_shape: Union[int, List[int], Tuple[int, int]], basis_type: str):
"""Factory function to generate the appropriate filter basis"""
if basis_type == "piecewise linear":
return PiecewiseLinearFilterBasis(kernel_shape=kernel_shape)
elif basis_type == "morlet":
return MorletFilterBasis(kernel_shape=kernel_shape)
elif basis_type == "zernike":
return ZernikeFilterBasis(kernel_shape=kernel_shape)
else:
raise ValueError(f"Unknown basis_type {basis_type}")
def _circle_dist(x1: torch.Tensor, x2: torch.Tensor): def _circle_dist(x1: torch.Tensor, x2: torch.Tensor):
...@@ -71,7 +60,7 @@ class FilterBasis(metaclass=abc.ABCMeta): ...@@ -71,7 +60,7 @@ class FilterBasis(metaclass=abc.ABCMeta):
def __init__( def __init__(
self, self,
kernel_shape: Union[int, List[int], Tuple[int, int]], kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
): ):
self.kernel_shape = kernel_shape self.kernel_shape = kernel_shape
...@@ -96,6 +85,20 @@ class FilterBasis(metaclass=abc.ABCMeta): ...@@ -96,6 +85,20 @@ class FilterBasis(metaclass=abc.ABCMeta):
raise NotImplementedError raise NotImplementedError
@lru_cache(typed=True, copy=False)
def get_filter_basis(kernel_shape: Union[int, Tuple[int], Tuple[int, int]], basis_type: str) -> FilterBasis:
"""Factory function to generate the appropriate filter basis"""
if basis_type == "piecewise linear":
return PiecewiseLinearFilterBasis(kernel_shape=kernel_shape)
elif basis_type == "morlet":
return MorletFilterBasis(kernel_shape=kernel_shape)
elif basis_type == "zernike":
return ZernikeFilterBasis(kernel_shape=kernel_shape)
else:
raise ValueError(f"Unknown basis_type {basis_type}")
class PiecewiseLinearFilterBasis(FilterBasis): class PiecewiseLinearFilterBasis(FilterBasis):
""" """
Tensor-product basis on a disk constructed from piecewise linear basis functions. Tensor-product basis on a disk constructed from piecewise linear basis functions.
...@@ -103,7 +106,7 @@ class PiecewiseLinearFilterBasis(FilterBasis): ...@@ -103,7 +106,7 @@ class PiecewiseLinearFilterBasis(FilterBasis):
def __init__( def __init__(
self, self,
kernel_shape: Union[int, List[int], Tuple[int, int]], kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
): ):
if isinstance(kernel_shape, int): if isinstance(kernel_shape, int):
...@@ -222,7 +225,7 @@ class MorletFilterBasis(FilterBasis): ...@@ -222,7 +225,7 @@ class MorletFilterBasis(FilterBasis):
def __init__( def __init__(
self, self,
kernel_shape: Union[int, List[int], Tuple[int, int]], kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
): ):
if isinstance(kernel_shape, int): if isinstance(kernel_shape, int):
...@@ -280,7 +283,7 @@ class ZernikeFilterBasis(FilterBasis): ...@@ -280,7 +283,7 @@ class ZernikeFilterBasis(FilterBasis):
def __init__( def __init__(
self, self,
kernel_shape: Union[int, Tuple[int], List[int]], kernel_shape: Union[int, Tuple[int]],
): ):
if isinstance(kernel_shape, tuple) or isinstance(kernel_shape, list): if isinstance(kernel_shape, tuple) or isinstance(kernel_shape, list):
......
...@@ -29,15 +29,20 @@ ...@@ -29,15 +29,20 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# #
import numpy as np from typing import Optional
import math
import torch
def clm(l, m): from torch_harmonics.cache import lru_cache
def clm(l: int, m: int) -> float:
""" """
defines the normalization factor to orthonormalize the Spherical Harmonics defines the normalization factor to orthonormalize the Spherical Harmonics
""" """
return np.sqrt((2*l + 1) / 4 / np.pi) * np.sqrt(np.math.factorial(l-m) / np.math.factorial(l+m)) return math.sqrt((2*l + 1) / 4 / math.pi) * math.sqrt(math.factorial(l-m) / math.factorial(l+m))
def legpoly(mmax, lmax, x, norm="ortho", inverse=False, csphase=True): def legpoly(mmax: int, lmax: int, x: torch.Tensor, norm: Optional[str]="ortho", inverse: Optional[bool]=False, csphase: Optional[bool]=True) -> torch.Tensor:
r""" r"""
Computes the values of (-1)^m c^l_m P^l_m(x) at the positions specified by x. Computes the values of (-1)^m c^l_m P^l_m(x) at the positions specified by x.
The resulting tensor has shape (mmax, lmax, len(x)). The Condon-Shortley Phase (-1)^m The resulting tensor has shape (mmax, lmax, len(x)). The Condon-Shortley Phase (-1)^m
...@@ -52,31 +57,31 @@ def legpoly(mmax, lmax, x, norm="ortho", inverse=False, csphase=True): ...@@ -52,31 +57,31 @@ def legpoly(mmax, lmax, x, norm="ortho", inverse=False, csphase=True):
# compute the tensor P^m_n: # compute the tensor P^m_n:
nmax = max(mmax,lmax) nmax = max(mmax,lmax)
vdm = np.zeros((nmax, nmax, len(x)), dtype=np.float64) vdm = torch.zeros((nmax, nmax, len(x)), dtype=torch.float64, requires_grad=False)
norm_factor = 1. if norm == "ortho" else np.sqrt(4 * np.pi) norm_factor = 1. if norm == "ortho" else math.sqrt(4 * math.pi)
norm_factor = 1. / norm_factor if inverse else norm_factor norm_factor = 1. / norm_factor if inverse else norm_factor
# initial values to start the recursion # initial values to start the recursion
vdm[0,0,:] = norm_factor / np.sqrt(4 * np.pi) vdm[0,0,:] = norm_factor / math.sqrt(4 * math.pi)
# fill the diagonal and the lower diagonal # fill the diagonal and the lower diagonal
for l in range(1, nmax): for l in range(1, nmax):
vdm[l-1, l, :] = np.sqrt(2*l + 1) * x * vdm[l-1, l-1, :] vdm[l-1, l, :] = math.sqrt(2*l + 1) * x * vdm[l-1, l-1, :]
vdm[l, l, :] = np.sqrt( (2*l + 1) * (1 + x) * (1 - x) / 2 / l ) * vdm[l-1, l-1, :] vdm[l, l, :] = torch.sqrt( (2*l + 1) * (1 + x) * (1 - x) / 2 / l ) * vdm[l-1, l-1, :]
# fill the remaining values on the upper triangle and multiply b # fill the remaining values on the upper triangle and multiply b
for l in range(2, nmax): for l in range(2, nmax):
for m in range(0, l-1): for m in range(0, l-1):
vdm[m, l, :] = x * np.sqrt((2*l - 1) / (l - m) * (2*l + 1) / (l + m)) * vdm[m, l-1, :] \ vdm[m, l, :] = x * math.sqrt((2*l - 1) / (l - m) * (2*l + 1) / (l + m)) * vdm[m, l-1, :] \
- np.sqrt((l + m - 1) / (l - m) * (2*l + 1) / (2*l - 3) * (l - m - 1) / (l + m)) * vdm[m, l-2, :] - math.sqrt((l + m - 1) / (l - m) * (2*l + 1) / (2*l - 3) * (l - m - 1) / (l + m)) * vdm[m, l-2, :]
if norm == "schmidt": if norm == "schmidt":
for l in range(0, nmax): for l in range(0, nmax):
if inverse: if inverse:
vdm[:, l, : ] = vdm[:, l, : ] * np.sqrt(2*l + 1) vdm[:, l, : ] = vdm[:, l, : ] * math.sqrt(2*l + 1)
else: else:
vdm[:, l, : ] = vdm[:, l, : ] / np.sqrt(2*l + 1) vdm[:, l, : ] = vdm[:, l, : ] / math.sqrt(2*l + 1)
vdm = vdm[:mmax, :lmax] vdm = vdm[:mmax, :lmax]
...@@ -86,7 +91,9 @@ def legpoly(mmax, lmax, x, norm="ortho", inverse=False, csphase=True): ...@@ -86,7 +91,9 @@ def legpoly(mmax, lmax, x, norm="ortho", inverse=False, csphase=True):
return vdm return vdm
def _precompute_legpoly(mmax, lmax, t, norm="ortho", inverse=False, csphase=True): @lru_cache(typed=True, copy=True)
def _precompute_legpoly(mmax: int , lmax: int, t: torch.Tensor,
norm: Optional[str]="ortho", inverse: Optional[bool]=False, csphase: Optional[bool]=True) -> torch.Tensor:
r""" r"""
Computes the values of (-1)^m c^l_m P^l_m(\cos \theta) at the positions specified by t (theta). Computes the values of (-1)^m c^l_m P^l_m(\cos \theta) at the positions specified by t (theta).
The resulting tensor has shape (mmax, lmax, len(x)). The Condon-Shortley Phase (-1)^m The resulting tensor has shape (mmax, lmax, len(x)). The Condon-Shortley Phase (-1)^m
...@@ -99,9 +106,11 @@ def _precompute_legpoly(mmax, lmax, t, norm="ortho", inverse=False, csphase=True ...@@ -99,9 +106,11 @@ def _precompute_legpoly(mmax, lmax, t, norm="ortho", inverse=False, csphase=True
[3] Schrama, E.; Orbit integration based upon interpolated gravitational gradients [3] Schrama, E.; Orbit integration based upon interpolated gravitational gradients
""" """
return legpoly(mmax, lmax, np.cos(t), norm=norm, inverse=inverse, csphase=csphase) return legpoly(mmax, lmax, torch.cos(t), norm=norm, inverse=inverse, csphase=csphase)
def _precompute_dlegpoly(mmax, lmax, t, norm="ortho", inverse=False, csphase=True): @lru_cache(typed=True, copy=True)
def _precompute_dlegpoly(mmax: int, lmax: int, t: torch.Tensor,
norm: Optional[str]="ortho", inverse: Optional[bool]=False, csphase: Optional[bool]=True) -> torch.Tensor:
r""" r"""
Computes the values of the derivatives $\frac{d}{d \theta} P^m_l(\cos \theta)$ Computes the values of the derivatives $\frac{d}{d \theta} P^m_l(\cos \theta)$
at the positions specified by t (theta), as well as $\frac{1}{\sin \theta} P^m_l(\cos \theta)$, at the positions specified by t (theta), as well as $\frac{1}{\sin \theta} P^m_l(\cos \theta)$,
...@@ -114,32 +123,32 @@ def _precompute_dlegpoly(mmax, lmax, t, norm="ortho", inverse=False, csphase=Tru ...@@ -114,32 +123,32 @@ def _precompute_dlegpoly(mmax, lmax, t, norm="ortho", inverse=False, csphase=Tru
pct = _precompute_legpoly(mmax+1, lmax+1, t, norm=norm, inverse=inverse, csphase=False) pct = _precompute_legpoly(mmax+1, lmax+1, t, norm=norm, inverse=inverse, csphase=False)
dpct = np.zeros((2, mmax, lmax, len(t)), dtype=np.float64) dpct = torch.zeros((2, mmax, lmax, len(t)), dtype=torch.float64, requires_grad=False)
# fill the derivative terms wrt theta # fill the derivative terms wrt theta
for l in range(0, lmax): for l in range(0, lmax):
# m = 0 # m = 0
dpct[0, 0, l] = - np.sqrt(l*(l+1)) * pct[1, l] dpct[0, 0, l] = - math.sqrt(l*(l+1)) * pct[1, l]
# 0 < m < l # 0 < m < l
for m in range(1, min(l, mmax)): for m in range(1, min(l, mmax)):
dpct[0, m, l] = 0.5 * ( np.sqrt((l+m)*(l-m+1)) * pct[m-1, l] - np.sqrt((l-m)*(l+m+1)) * pct[m+1, l] ) dpct[0, m, l] = 0.5 * ( math.sqrt((l+m)*(l-m+1)) * pct[m-1, l] - math.sqrt((l-m)*(l+m+1)) * pct[m+1, l] )
# m == l # m == l
if mmax > l: if mmax > l:
dpct[0, l, l] = np.sqrt(l/2) * pct[l-1, l] dpct[0, l, l] = math.sqrt(l/2) * pct[l-1, l]
# fill the - 1j m P^m_l / sin(phi). as this component is purely imaginary, # fill the - 1j m P^m_l / sin(phi). as this component is purely imaginary,
# we won't store it explicitly in a complex array # we won't store it explicitly in a complex array
for m in range(1, min(l+1, mmax)): for m in range(1, min(l+1, mmax)):
# this component is implicitly complex # this component is implicitly complex
# we do not divide by m here as this cancels with the derivative of the exponential # we do not divide by m here as this cancels with the derivative of the exponential
dpct[1, m, l] = 0.5 * np.sqrt((2*l+1)/(2*l+3)) * \ dpct[1, m, l] = 0.5 * math.sqrt((2*l+1)/(2*l+3)) * \
( np.sqrt((l-m+1)*(l-m+2)) * pct[m-1, l+1] + np.sqrt((l+m+1)*(l+m+2)) * pct[m+1, l+1] ) ( math.sqrt((l-m+1)*(l-m+2)) * pct[m-1, l+1] + math.sqrt((l+m+1)*(l+m+2)) * pct[m+1, l+1] )
if csphase: if csphase:
for m in range(1, mmax, 2): for m in range(1, mmax, 2):
dpct[:, m] *= -1 dpct[:, m] *= -1
return dpct return dpct
\ No newline at end of file
...@@ -29,10 +29,14 @@ ...@@ -29,10 +29,14 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# #
from typing import Tuple, Optional
from torch_harmonics.cache import lru_cache
import math
import numpy as np import numpy as np
import torch
def _precompute_grid(n: int, grid: Optional[str]="equidistant", a: Optional[float]=0.0, b: Optional[float]=1.0,
def _precompute_grid(n, grid="equidistant", a=0.0, b=1.0, periodic=False): periodic: Optional[bool]=False) -> Tuple[torch.Tensor, torch.Tensor]:
if (grid != "equidistant") and periodic: if (grid != "equidistant") and periodic:
raise ValueError(f"Periodic grid is only supported on equidistant grids.") raise ValueError(f"Periodic grid is only supported on equidistant grids.")
...@@ -51,31 +55,41 @@ def _precompute_grid(n, grid="equidistant", a=0.0, b=1.0, periodic=False): ...@@ -51,31 +55,41 @@ def _precompute_grid(n, grid="equidistant", a=0.0, b=1.0, periodic=False):
return xlg, wlg return xlg, wlg
@lru_cache(typed=True, copy=True)
def _precompute_longitudes(nlon: int):
r"""
Convenience routine to precompute longitudes
"""
lons = torch.linspace(0, 2 * math.pi, nlon+1, dtype=torch.float64, requires_grad=False)[:-1]
return lons
def _precompute_latitudes(nlat, grid="equiangular"): @lru_cache(typed=True, copy=True)
def _precompute_latitudes(nlat: int, grid: Optional[str]="equiangular") -> Tuple[torch.Tensor, torch.Tensor]:
r""" r"""
Convenience routine to precompute latitudes Convenience routine to precompute latitudes
""" """
# compute coordinates in the cosine theta domain # compute coordinates in the cosine theta domain
xlg, wlg = _precompute_grid(nlat, grid=grid, a=-1.0, b=1.0, periodic=False) xlg, wlg = _precompute_grid(nlat, grid=grid, a=-1.0, b=1.0, periodic=False)
# to perform the quadrature and account for the jacobian of the sphere, the quadrature rule # to perform the quadrature and account for the jacobian of the sphere, the quadrature rule
# is formulated in the cosine theta domain, which is designed to integrate functions of cos theta # is formulated in the cosine theta domain, which is designed to integrate functions of cos theta
lats = np.flip(np.arccos(xlg)).copy() lats = torch.flip(torch.arccos(xlg), dims=(0,)).clone()
wlg = np.flip(wlg).copy() wlg = torch.flip(wlg, dims=(0,)).clone()
return lats, wlg return lats, wlg
def trapezoidal_weights(n, a=-1.0, b=1.0, periodic=False): def trapezoidal_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0, periodic: Optional[bool]=False) -> Tuple[torch.Tensor, torch.Tensor]:
r""" r"""
Helper routine which returns equidistant nodes with trapezoidal weights Helper routine which returns equidistant nodes with trapezoidal weights
on the interval [a, b] on the interval [a, b]
""" """
xlg = np.linspace(a, b, n, endpoint=periodic) xlg = torch.from_numpy(np.linspace(a, b, n, endpoint=periodic))
wlg = (b - a) / (n - periodic * 1) * np.ones(n) wlg = (b - a) / (n - periodic * 1) * torch.ones(n, requires_grad=False)
if not periodic: if not periodic:
wlg[0] *= 0.5 wlg[0] *= 0.5
...@@ -84,35 +98,38 @@ def trapezoidal_weights(n, a=-1.0, b=1.0, periodic=False): ...@@ -84,35 +98,38 @@ def trapezoidal_weights(n, a=-1.0, b=1.0, periodic=False):
return xlg, wlg return xlg, wlg
def legendre_gauss_weights(n, a=-1.0, b=1.0): def legendre_gauss_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0) -> Tuple[torch.Tensor, torch.Tensor]:
r""" r"""
Helper routine which returns the Legendre-Gauss nodes and weights Helper routine which returns the Legendre-Gauss nodes and weights
on the interval [a, b] on the interval [a, b]
""" """
xlg, wlg = np.polynomial.legendre.leggauss(n) xlg, wlg = np.polynomial.legendre.leggauss(n)
xlg = torch.from_numpy(xlg).clone()
wlg = torch.from_numpy(wlg).clone()
xlg = (b - a) * 0.5 * xlg + (b + a) * 0.5 xlg = (b - a) * 0.5 * xlg + (b + a) * 0.5
wlg = wlg * (b - a) * 0.5 wlg = wlg * (b - a) * 0.5
return xlg, wlg return xlg, wlg
def lobatto_weights(n, a=-1.0, b=1.0, tol=1e-16, maxiter=100): def lobatto_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0,
tol: Optional[float]=1e-16, maxiter: Optional[int]=100) -> Tuple[torch.Tensor, torch.Tensor]:
r""" r"""
Helper routine which returns the Legendre-Gauss-Lobatto nodes and weights Helper routine which returns the Legendre-Gauss-Lobatto nodes and weights
on the interval [a, b] on the interval [a, b]
""" """
wlg = np.zeros((n,)) wlg = torch.zeros((n,), dtype=torch.float64, requires_grad=False)
tlg = np.zeros((n,)) tlg = torch.zeros((n,), dtype=torch.float64, requires_grad=False)
tmp = np.zeros((n,)) tmp = torch.zeros((n,), dtype=torch.float64, requires_grad=False)
# Vandermonde Matrix # Vandermonde Matrix
vdm = np.zeros((n, n)) vdm = torch.zeros((n, n), dtype=torch.float64, requires_grad=False)
# initialize Chebyshev nodes as first guess # initialize Chebyshev nodes as first guess
for i in range(n): for i in range(n):
tlg[i] = -np.cos(np.pi * i / (n - 1)) tlg[i] = -torch.cos(math.pi * i / (n - 1))
tmp = 2.0 tmp = 2.0
...@@ -139,7 +156,7 @@ def lobatto_weights(n, a=-1.0, b=1.0, tol=1e-16, maxiter=100): ...@@ -139,7 +156,7 @@ def lobatto_weights(n, a=-1.0, b=1.0, tol=1e-16, maxiter=100):
return tlg, wlg return tlg, wlg
def clenshaw_curtiss_weights(n, a=-1.0, b=1.0): def clenshaw_curtiss_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0) -> Tuple[torch.Tensor, torch.Tensor]:
r""" r"""
Computation of the Clenshaw-Curtis quadrature nodes and weights. Computation of the Clenshaw-Curtis quadrature nodes and weights.
This implementation follows This implementation follows
...@@ -149,26 +166,27 @@ def clenshaw_curtiss_weights(n, a=-1.0, b=1.0): ...@@ -149,26 +166,27 @@ def clenshaw_curtiss_weights(n, a=-1.0, b=1.0):
assert n > 1 assert n > 1
tcc = np.cos(np.linspace(np.pi, 0, n)) tcc = torch.cos(torch.linspace(math.pi, 0, n, dtype=torch.float64, requires_grad=False))
if n == 2: if n == 2:
wcc = np.array([1.0, 1.0]) wcc = torch.tensor([1.0, 1.0], dtype=torch.float64)
else: else:
n1 = n - 1 n1 = n - 1
N = np.arange(1, n1, 2) N = torch.arange(1, n1, 2, dtype=torch.float64)
l = len(N) l = len(N)
m = n1 - l m = n1 - l
v = np.concatenate([2 / N / (N - 2), 1 / N[-1:], np.zeros(m)]) v = torch.cat([2 / N / (N - 2), 1 / N[-1:], torch.zeros(m, dtype=torch.float64, requires_grad=False)])
v = 0 - v[:-1] - v[-1:0:-1] #v = 0 - v[:-1] - v[-1:0:-1]
v = 0 - v[:-1] - torch.flip(v[1:], dims=(0,))
g0 = -np.ones(n1) g0 = -torch.ones(n1, dtype=torch.float64, requires_grad=False)
g0[l] = g0[l] + n1 g0[l] = g0[l] + n1
g0[m] = g0[m] + n1 g0[m] = g0[m] + n1
g = g0 / (n1**2 - 1 + (n1 % 2)) g = g0 / (n1**2 - 1 + (n1 % 2))
wcc = np.fft.ifft(v + g).real wcc = torch.fft.ifft(v + g).real
wcc = np.concatenate((wcc, wcc[:1])) wcc = torch.cat((wcc, wcc[:1]))
# rescale # rescale
tcc = (b - a) * 0.5 * tcc + (b + a) * 0.5 tcc = (b - a) * 0.5 * tcc + (b + a) * 0.5
...@@ -177,7 +195,7 @@ def clenshaw_curtiss_weights(n, a=-1.0, b=1.0): ...@@ -177,7 +195,7 @@ def clenshaw_curtiss_weights(n, a=-1.0, b=1.0):
return tcc, wcc return tcc, wcc
def fejer2_weights(n, a=-1.0, b=1.0): def fejer2_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0) -> Tuple[torch.Tensor, torch.Tensor]:
r""" r"""
Computation of the Fejer quadrature nodes and weights. Computation of the Fejer quadrature nodes and weights.
This implementation follows This implementation follows
...@@ -187,18 +205,19 @@ def fejer2_weights(n, a=-1.0, b=1.0): ...@@ -187,18 +205,19 @@ def fejer2_weights(n, a=-1.0, b=1.0):
assert n > 2 assert n > 2
tcc = np.cos(np.linspace(np.pi, 0, n)) tcc = torch.cos(torch.linspace(math.pi, 0, n, dtype=torch.float64, requires_grad=False))
n1 = n - 1 n1 = n - 1
N = np.arange(1, n1, 2) N = torch.arange(1, n1, 2, dtype=torch.float64)
l = len(N) l = len(N)
m = n1 - l m = n1 - l
v = np.concatenate([2 / N / (N - 2), 1 / N[-1:], np.zeros(m)]) v = torch.cat([2 / N / (N - 2), 1 / N[-1:], torch.zeros(m, dtype=torch.float64, requires_grad=False)])
v = 0 - v[:-1] - v[-1:0:-1] #v = 0 - v[:-1] - v[-1:0:-1]
v = 0 - v[:-1] - torch.flip(v[1:], dims=(0,))
wcc = np.fft.ifft(v).real wcc = torch.fft.ifft(v).real
wcc = np.concatenate((wcc, wcc[:1])) wcc = torch.cat((wcc, wcc[:1]))
# rescale # rescale
tcc = (b - a) * 0.5 * tcc + (b + a) * 0.5 tcc = (b - a) * 0.5 * tcc + (b + a) * 0.5
......
...@@ -31,12 +31,12 @@ ...@@ -31,12 +31,12 @@
from typing import List, Tuple, Union, Optional from typing import List, Tuple, Union, Optional
import math import math
import numpy as np #import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch_harmonics.quadrature import _precompute_latitudes from torch_harmonics.quadrature import _precompute_latitudes, _precompute_longitudes
class ResampleS2(nn.Module): class ResampleS2(nn.Module):
...@@ -67,54 +67,53 @@ class ResampleS2(nn.Module): ...@@ -67,54 +67,53 @@ class ResampleS2(nn.Module):
# for upscaling the latitudes we will use interpolation # for upscaling the latitudes we will use interpolation
self.lats_in, _ = _precompute_latitudes(nlat_in, grid=grid_in) self.lats_in, _ = _precompute_latitudes(nlat_in, grid=grid_in)
self.lons_in = np.linspace(0, 2 * math.pi, nlon_in, endpoint=False) self.lons_in = _precompute_longitudes(nlon_in)
self.lats_out, _ = _precompute_latitudes(nlat_out, grid=grid_out) self.lats_out, _ = _precompute_latitudes(nlat_out, grid=grid_out)
self.lons_out = np.linspace(0, 2 * math.pi, nlon_out, endpoint=False) self.lons_out = _precompute_longitudes(nlon_out)
# in the case where some points lie outside of the range spanned by lats_in, # in the case where some points lie outside of the range spanned by lats_in,
# we need to expand the solution to the poles before interpolating # we need to expand the solution to the poles before interpolating
self.expand_poles = (self.lats_out > self.lats_in[-1]).any() or (self.lats_out < self.lats_in[0]).any() self.expand_poles = (self.lats_out > self.lats_in[-1]).any() or (self.lats_out < self.lats_in[0]).any()
if self.expand_poles: if self.expand_poles:
self.lats_in = np.insert(self.lats_in, 0, 0.0) self.lats_in = torch.cat([torch.tensor([0.], dtype=torch.float64),
self.lats_in = np.append(self.lats_in, np.pi) self.lats_in,
torch.tensor([math.pi], dtype=torch.float64)]).contiguous()
#self.lats_in = np.insert(self.lats_in, 0, 0.0)
#self.lats_in = np.append(self.lats_in, np.pi)
# prepare the interpolation by computing indices to the left and right of each output latitude # prepare the interpolation by computing indices to the left and right of each output latitude
lat_idx = np.searchsorted(self.lats_in, self.lats_out, side="right") - 1 lat_idx = torch.searchsorted(self.lats_in, self.lats_out, side="right") - 1
# make sure that we properly treat the last point if they coincide with the pole # make sure that we properly treat the last point if they coincide with the pole
lat_idx = np.where(self.lats_out == self.lats_in[-1], lat_idx - 1, lat_idx) lat_idx = torch.where(self.lats_out == self.lats_in[-1], lat_idx - 1, lat_idx)
# lat_idx = np.where(self.lats_out > self.lats_in[-1], lat_idx - 1, lat_idx) # lat_idx = np.where(self.lats_out > self.lats_in[-1], lat_idx - 1, lat_idx)
# lat_idx = np.where(self.lats_out < self.lats_in[0], 0, lat_idx) # lat_idx = np.where(self.lats_out < self.lats_in[0], 0, lat_idx)
# compute the interpolation weights along the latitude # compute the interpolation weights along the latitude
lat_weights = torch.from_numpy((self.lats_out - self.lats_in[lat_idx]) / np.diff(self.lats_in)[lat_idx]).float() lat_weights = ((self.lats_out - self.lats_in[lat_idx]) / torch.diff(self.lats_in)[lat_idx]).to(torch.float32)
lat_weights = lat_weights.unsqueeze(-1) lat_weights = lat_weights.unsqueeze(-1)
# convert to tensor
lat_idx = torch.LongTensor(lat_idx)
# register buffers # register buffers
self.register_buffer("lat_idx", lat_idx, persistent=False) self.register_buffer("lat_idx", lat_idx, persistent=False)
self.register_buffer("lat_weights", lat_weights, persistent=False) self.register_buffer("lat_weights", lat_weights, persistent=False)
# get left and right indices but this time make sure periodicity in the longitude is handled # get left and right indices but this time make sure periodicity in the longitude is handled
lon_idx_left = np.searchsorted(self.lons_in, self.lons_out, side="right") - 1 lon_idx_left = torch.searchsorted(self.lons_in, self.lons_out, side="right") - 1
lon_idx_right = np.where(self.lons_out >= self.lons_in[-1], np.zeros_like(lon_idx_left), lon_idx_left + 1) lon_idx_right = torch.where(self.lons_out >= self.lons_in[-1], torch.zeros_like(lon_idx_left), lon_idx_left + 1)
# get the difference # get the difference
diff = self.lons_in[lon_idx_right] - self.lons_in[lon_idx_left] diff = self.lons_in[lon_idx_right] - self.lons_in[lon_idx_left]
diff = np.where(diff < 0.0, diff + 2 * math.pi, diff) diff = torch.where(diff < 0.0, diff + 2 * math.pi, diff)
lon_weights = torch.from_numpy((self.lons_out - self.lons_in[lon_idx_left]) / diff).float() lon_weights = ((self.lons_out - self.lons_in[lon_idx_left]) / diff).to(torch.float32)
# convert to tensor
lon_idx_left = torch.LongTensor(lon_idx_left)
lon_idx_right = torch.LongTensor(lon_idx_right)
# register buffers # register buffers
self.register_buffer("lon_idx_left", lon_idx_left, persistent=False) self.register_buffer("lon_idx_left", lon_idx_left, persistent=False)
self.register_buffer("lon_idx_right", lon_idx_right, persistent=False) self.register_buffer("lon_idx_right", lon_idx_right, persistent=False)
self.register_buffer("lon_weights", lon_weights, persistent=False) self.register_buffer("lon_weights", lon_weights, persistent=False)
self.skip_resampling = (nlon_in == nlon_out) and (nlat_in == nlat_out) and (grid_in == grid_out)
def extra_repr(self): def extra_repr(self):
r""" r"""
Pretty print module Pretty print module
...@@ -139,7 +138,7 @@ class ResampleS2(nn.Module): ...@@ -139,7 +138,7 @@ class ResampleS2(nn.Module):
repeats[-1] = x.shape[-1] repeats[-1] = x.shape[-1]
x_north = x[..., 0:1, :].mean(dim=-1, keepdim=True).repeat(*repeats) x_north = x[..., 0:1, :].mean(dim=-1, keepdim=True).repeat(*repeats)
x_south = x[..., -1:, :].mean(dim=-1, keepdim=True).repeat(*repeats) x_south = x[..., -1:, :].mean(dim=-1, keepdim=True).repeat(*repeats)
x = torch.concatenate((x_north, x, x_south), dim=-2) x = torch.concatenate((x_north, x, x_south), dim=-2).contiguous()
return x return x
def _upscale_latitudes(self, x: torch.Tensor): def _upscale_latitudes(self, x: torch.Tensor):
...@@ -156,6 +155,9 @@ class ResampleS2(nn.Module): ...@@ -156,6 +155,9 @@ class ResampleS2(nn.Module):
return x return x
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
if self.skip_resampling:
return x
if self.expand_poles: if self.expand_poles:
x = self._expand_poles(x) x = self._expand_poles(x)
x = self._upscale_latitudes(x) x = self._upscale_latitudes(x)
......
...@@ -29,7 +29,6 @@ ...@@ -29,7 +29,6 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# #
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.fft import torch.fft
...@@ -70,17 +69,17 @@ class RealSHT(nn.Module): ...@@ -70,17 +69,17 @@ class RealSHT(nn.Module):
# compute quadrature points and lmax based on the exactness of the quadrature # compute quadrature points and lmax based on the exactness of the quadrature
if self.grid == "legendre-gauss": if self.grid == "legendre-gauss":
cost, w = legendre_gauss_weights(nlat, -1, 1) cost, weights = legendre_gauss_weights(nlat, -1, 1)
# maximum polynomial degree for Gauss Legendre is 2 * nlat - 1 >= 2 * lmax # maximum polynomial degree for Gauss Legendre is 2 * nlat - 1 >= 2 * lmax
# and therefore lmax = nlat - 1 (inclusive) # and therefore lmax = nlat - 1 (inclusive)
self.lmax = lmax or self.nlat self.lmax = lmax or self.nlat
elif self.grid == "lobatto": elif self.grid == "lobatto":
cost, w = lobatto_weights(nlat, -1, 1) cost, weights = lobatto_weights(nlat, -1, 1)
# maximum polynomial degree for Gauss Legendre is 2 * nlat - 3 >= 2 * lmax # maximum polynomial degree for Gauss Legendre is 2 * nlat - 3 >= 2 * lmax
# and therefore lmax = nlat - 2 (inclusive) # and therefore lmax = nlat - 2 (inclusive)
self.lmax = lmax or self.nlat - 1 self.lmax = lmax or self.nlat - 1
elif self.grid == "equiangular": elif self.grid == "equiangular":
cost, w = clenshaw_curtiss_weights(nlat, -1, 1) cost, weights = clenshaw_curtiss_weights(nlat, -1, 1)
# in principle, Clenshaw-Curtiss quadrature is only exact up to polynomial degrees of nlat # in principle, Clenshaw-Curtiss quadrature is only exact up to polynomial degrees of nlat
# however, we observe that the quadrature is remarkably accurate for higher degress. This is why we do not # however, we observe that the quadrature is remarkably accurate for higher degress. This is why we do not
# choose a lower lmax for now. # choose a lower lmax for now.
...@@ -89,16 +88,14 @@ class RealSHT(nn.Module): ...@@ -89,16 +88,14 @@ class RealSHT(nn.Module):
raise (ValueError("Unknown quadrature mode")) raise (ValueError("Unknown quadrature mode"))
# apply cosine transform and flip them # apply cosine transform and flip them
tq = np.flip(np.arccos(cost)) tq = torch.flip(torch.arccos(cost), dims=(0,))
# determine the dimensions # determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1 self.mmax = mmax or self.nlon // 2 + 1
# combine quadrature weights with the legendre weights # combine quadrature weights with the legendre weights
weights = torch.from_numpy(w)
pct = _precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase) pct = _precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
pct = torch.from_numpy(pct) weights = torch.einsum("mlk,k->mlk", pct, weights).contiguous()
weights = torch.einsum("mlk,k->mlk", pct, weights)
# remember quadrature weights # remember quadrature weights
self.register_buffer("weights", weights, persistent=False) self.register_buffer("weights", weights, persistent=False)
...@@ -172,13 +169,12 @@ class InverseRealSHT(nn.Module): ...@@ -172,13 +169,12 @@ class InverseRealSHT(nn.Module):
raise (ValueError("Unknown quadrature mode")) raise (ValueError("Unknown quadrature mode"))
# apply cosine transform and flip them # apply cosine transform and flip them
t = np.flip(np.arccos(cost)) t = torch.flip(torch.arccos(cost), dims=(0,))
# determine the dimensions # determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1 self.mmax = mmax or self.nlon // 2 + 1
pct = _precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase) pct = _precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
pct = torch.from_numpy(pct)
# register buffer # register buffer
self.register_buffer("pct", pct, persistent=False) self.register_buffer("pct", pct, persistent=False)
...@@ -241,32 +237,29 @@ class RealVectorSHT(nn.Module): ...@@ -241,32 +237,29 @@ class RealVectorSHT(nn.Module):
# compute quadrature points # compute quadrature points
if self.grid == "legendre-gauss": if self.grid == "legendre-gauss":
cost, w = legendre_gauss_weights(nlat, -1, 1) cost, weights = legendre_gauss_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat self.lmax = lmax or self.nlat
elif self.grid == "lobatto": elif self.grid == "lobatto":
cost, w = lobatto_weights(nlat, -1, 1) cost, weights = lobatto_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat - 1 self.lmax = lmax or self.nlat - 1
elif self.grid == "equiangular": elif self.grid == "equiangular":
cost, w = clenshaw_curtiss_weights(nlat, -1, 1) cost, weights = clenshaw_curtiss_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat self.lmax = lmax or self.nlat
else: else:
raise (ValueError("Unknown quadrature mode")) raise (ValueError("Unknown quadrature mode"))
# apply cosine transform and flip them # apply cosine transform and flip them
tq = np.flip(np.arccos(cost)) tq = torch.flip(torch.arccos(cost), dims=(0,))
# determine the dimensions # determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1 self.mmax = mmax or self.nlon // 2 + 1
weights = torch.from_numpy(w)
dpct = _precompute_dlegpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase) dpct = _precompute_dlegpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
dpct = torch.from_numpy(dpct)
# combine integration weights, normalization factor in to one: # combine integration weights, normalization factor in to one:
l = torch.arange(0, self.lmax) l = torch.arange(0, self.lmax)
norm_factor = 1.0 / l / (l + 1) norm_factor = 1.0 / l / (l + 1)
norm_factor[0] = 1.0 norm_factor[0] = 1.0
weights = torch.einsum("dmlk,k,l->dmlk", dpct, weights, norm_factor) weights = torch.einsum("dmlk,k,l->dmlk", dpct, weights, norm_factor).contiguous()
# since the second component is imaginary, we need to take complex conjugation into account # since the second component is imaginary, we need to take complex conjugation into account
weights[1] = -1 * weights[1] weights[1] = -1 * weights[1]
...@@ -356,13 +349,12 @@ class InverseRealVectorSHT(nn.Module): ...@@ -356,13 +349,12 @@ class InverseRealVectorSHT(nn.Module):
raise (ValueError("Unknown quadrature mode")) raise (ValueError("Unknown quadrature mode"))
# apply cosine transform and flip them # apply cosine transform and flip them
t = np.flip(np.arccos(cost)) t = torch.flip(torch.arccos(cost), dims=(0,))
# determine the dimensions # determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1 self.mmax = mmax or self.nlon // 2 + 1
dpct = _precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase) dpct = _precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
dpct = torch.from_numpy(dpct)
# register weights # register weights
self.register_buffer("dpct", dpct, persistent=False) self.register_buffer("dpct", dpct, persistent=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment