test_sht.py 7.71 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
Boris Bonev's avatar
Boris Bonev committed
5
#
Boris Bonev's avatar
Boris Bonev committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import unittest
Boris Bonev's avatar
Boris Bonev committed
33
34
from parameterized import parameterized
import math
Boris Bonev's avatar
Boris Bonev committed
35
import torch
Boris Bonev's avatar
Boris Bonev committed
36
from torch.autograd import gradcheck
Boris Bonev's avatar
Boris Bonev committed
37
38
from torch_harmonics import *

Boris Bonev's avatar
Boris Bonev committed
39

Boris Bonev's avatar
Boris Bonev committed
40
41
42
class TestLegendrePolynomials(unittest.TestCase):

    def setUp(self):
Thorsten Kurth's avatar
Thorsten Kurth committed
43
        self.cml = lambda m, l: math.sqrt((2 * l + 1) / 4 / math.pi) * math.sqrt(math.factorial(l - m) / math.factorial(l + m))
Boris Bonev's avatar
Boris Bonev committed
44
45
46
47
        self.pml = dict()

        # preparing associated Legendre Polynomials (These include the Condon-Shortley phase)
        # for reference see e.g. https://en.wikipedia.org/wiki/Associated_Legendre_polynomials
Thorsten Kurth's avatar
Thorsten Kurth committed
48
        self.pml[(0, 0)] = lambda x: torch.ones_like(x)
Boris Bonev's avatar
Boris Bonev committed
49
        self.pml[(0, 1)] = lambda x: x
Thorsten Kurth's avatar
Thorsten Kurth committed
50
        self.pml[(1, 1)] = lambda x: -torch.sqrt(1.0 - x**2)
Boris Bonev's avatar
Boris Bonev committed
51
        self.pml[(0, 2)] = lambda x: 0.5 * (3 * x**2 - 1)
Thorsten Kurth's avatar
Thorsten Kurth committed
52
        self.pml[(1, 2)] = lambda x: -3 * x * torch.sqrt(1.0 - x**2)
Boris Bonev's avatar
Boris Bonev committed
53
54
        self.pml[(2, 2)] = lambda x: 3 * (1 - x**2)
        self.pml[(0, 3)] = lambda x: 0.5 * (5 * x**3 - 3 * x)
Thorsten Kurth's avatar
Thorsten Kurth committed
55
        self.pml[(1, 3)] = lambda x: 1.5 * (1 - 5 * x**2) * torch.sqrt(1.0 - x**2)
Boris Bonev's avatar
Boris Bonev committed
56
        self.pml[(2, 3)] = lambda x: 15 * x * (1 - x**2)
Thorsten Kurth's avatar
Thorsten Kurth committed
57
        self.pml[(3, 3)] = lambda x: -15 * torch.sqrt(1.0 - x**2) ** 3
Boris Bonev's avatar
Boris Bonev committed
58
59
60

        self.lmax = self.mmax = 4

Boris Bonev's avatar
Boris Bonev committed
61
62
        self.tol = 1e-9

Thorsten Kurth's avatar
Thorsten Kurth committed
63
64
65
    def test_legendre(self, verbose=False):
        if verbose:
            print("Testing computation of associated Legendre polynomials")
66
        from torch_harmonics.legendre import legpoly
Boris Bonev's avatar
Boris Bonev committed
67

Thorsten Kurth's avatar
Thorsten Kurth committed
68
        t = torch.linspace(0, 1, 100, dtype=torch.float64)
69
        vdm = legpoly(self.mmax, self.lmax, t)
Boris Bonev's avatar
Boris Bonev committed
70
71

        for l in range(self.lmax):
Boris Bonev's avatar
Boris Bonev committed
72
73
            for m in range(l + 1):
                diff = vdm[m, l] / self.cml(m, l) - self.pml[(m, l)](t)
Boris Bonev's avatar
Boris Bonev committed
74
                self.assertTrue(diff.max() <= self.tol)
Boris Bonev's avatar
Boris Bonev committed
75
76
77
78
79
80
81


class TestSphericalHarmonicTransform(unittest.TestCase):

    def setUp(self):

        if torch.cuda.is_available():
Boris Bonev's avatar
Boris Bonev committed
82
            self.device = torch.device("cuda")
Boris Bonev's avatar
Boris Bonev committed
83
        else:
Boris Bonev's avatar
Boris Bonev committed
84
85
86
87
            self.device = torch.device("cpu")

    @parameterized.expand(
        [
Thorsten Kurth's avatar
Thorsten Kurth committed
88
89
90
91
92
93
            [256, 512, 32, "ortho", "equiangular", 1e-9, False],
            [256, 512, 32, "ortho", "legendre-gauss", 1e-9, False],
            [256, 512, 32, "four-pi", "equiangular", 1e-9, False],
            [256, 512, 32, "four-pi", "legendre-gauss", 1e-9, False],
            [256, 512, 32, "schmidt", "equiangular", 1e-9, False],
            [256, 512, 32, "schmidt", "legendre-gauss", 1e-9, False],
Boris Bonev's avatar
Boris Bonev committed
94
95
        ]
    )
Thorsten Kurth's avatar
Thorsten Kurth committed
96
97
98
    def test_sht(self, nlat, nlon, batch_size, norm, grid, tol, verbose):
        if verbose:
            print(f"Testing real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization on {self.device.type} device")
Boris Bonev's avatar
Boris Bonev committed
99
100

        testiters = [1, 2, 4, 8, 16]
Boris Bonev's avatar
Boris Bonev committed
101
102
103
104
        if grid == "equiangular":
            mmax = nlat // 2
        else:
            mmax = nlat
Boris Bonev's avatar
Boris Bonev committed
105
106
        lmax = mmax

Boris Bonev's avatar
Boris Bonev committed
107
108
        sht = RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
        isht = InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
Boris Bonev's avatar
Boris Bonev committed
109

Boris Bonev's avatar
Boris Bonev committed
110
111
112
113
        with torch.no_grad():
            coeffs = torch.zeros(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
            coeffs[:, :lmax, :mmax] = torch.randn(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
            signal = isht(coeffs)
Boris Bonev's avatar
Boris Bonev committed
114

Boris Bonev's avatar
Boris Bonev committed
115
        # testing error accumulation
Boris Bonev's avatar
Boris Bonev committed
116
        for iter in testiters:
Boris Bonev's avatar
Boris Bonev committed
117
            with self.subTest(i=iter):
Thorsten Kurth's avatar
Thorsten Kurth committed
118
119
                if verbose:
                    print(f"{iter} iterations of batchsize {batch_size}:")
Boris Bonev's avatar
Boris Bonev committed
120
121
122

                base = signal

123
                for _ in range(iter):
Boris Bonev's avatar
Boris Bonev committed
124
                    base = isht(sht(base))
Boris Bonev's avatar
Boris Bonev committed
125
126

                err = torch.mean(torch.norm(base - signal, p="fro", dim=(-1, -2)) / torch.norm(signal, p="fro", dim=(-1, -2)))
Thorsten Kurth's avatar
Thorsten Kurth committed
127
128
                if verbose:
                    print(f"final relative error: {err.item()}")
Boris Bonev's avatar
Boris Bonev committed
129
130
                self.assertTrue(err.item() <= tol)

Boris Bonev's avatar
Boris Bonev committed
131
132
    @parameterized.expand(
        [
Thorsten Kurth's avatar
Thorsten Kurth committed
133
134
135
136
137
138
139
140
141
142
143
144
            [12, 24, 2, "ortho", "equiangular", 1e-5, False],
            [12, 24, 2, "ortho", "legendre-gauss", 1e-5, False],
            [12, 24, 2, "four-pi", "equiangular", 1e-5, False],
            [12, 24, 2, "four-pi", "legendre-gauss", 1e-5, False],
            [12, 24, 2, "schmidt", "equiangular", 1e-5, False],
            [12, 24, 2, "schmidt", "legendre-gauss", 1e-5, False],
            [15, 30, 2, "ortho", "equiangular", 1e-5, False],
            [15, 30, 2, "ortho", "legendre-gauss", 1e-5, False],
            [15, 30, 2, "four-pi", "equiangular", 1e-5, False],
            [15, 30, 2, "four-pi", "legendre-gauss", 1e-5, False],
            [15, 30, 2, "schmidt", "equiangular", 1e-5, False],
            [15, 30, 2, "schmidt", "legendre-gauss", 1e-5, False],
Boris Bonev's avatar
Boris Bonev committed
145
146
        ]
    )
Thorsten Kurth's avatar
Thorsten Kurth committed
147
148
149
    def test_sht_grads(self, nlat, nlon, batch_size, norm, grid, tol, verbose):
        if verbose:
            print(f"Testing gradients of real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization")
Boris Bonev's avatar
Boris Bonev committed
150
151
152
153
154

        if grid == "equiangular":
            mmax = nlat // 2
        else:
            mmax = nlat
Boris Bonev's avatar
Boris Bonev committed
155
156
        lmax = mmax

Boris Bonev's avatar
Boris Bonev committed
157
158
        sht = RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
        isht = InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
Boris Bonev's avatar
Boris Bonev committed
159

Boris Bonev's avatar
Boris Bonev committed
160
161
162
163
        with torch.no_grad():
            coeffs = torch.zeros(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
            coeffs[:, :lmax, :mmax] = torch.randn(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
            signal = isht(coeffs)
Boris Bonev's avatar
Boris Bonev committed
164
165

        # test the sht
Boris Bonev's avatar
Boris Bonev committed
166
        grad_input = torch.randn_like(signal, requires_grad=True)
Boris Bonev's avatar
Boris Bonev committed
167
168
169
170
171
172
173
        err_handle = lambda x: torch.mean(torch.norm(sht(x) - coeffs, p="fro", dim=(-1, -2)) / torch.norm(coeffs, p="fro", dim=(-1, -2)))
        test_result = gradcheck(err_handle, grad_input, eps=1e-6, atol=tol)
        self.assertTrue(test_result)

        # test the isht
        grad_input = torch.randn_like(coeffs, requires_grad=True)
        err_handle = lambda x: torch.mean(torch.norm(isht(x) - signal, p="fro", dim=(-1, -2)) / torch.norm(signal, p="fro", dim=(-1, -2)))
Boris Bonev's avatar
Boris Bonev committed
174
        test_result = gradcheck(err_handle, grad_input, eps=1e-6, atol=tol)
Boris Bonev's avatar
Boris Bonev committed
175
        self.assertTrue(test_result)
Boris Bonev's avatar
Boris Bonev committed
176
177


Boris Bonev's avatar
Boris Bonev committed
178
179
if __name__ == "__main__":
    unittest.main()