test_sht.py 10.7 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
Boris Bonev's avatar
Boris Bonev committed
5
#
Boris Bonev's avatar
Boris Bonev committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import unittest
33
from parameterized import parameterized, parameterized_class
Boris Bonev's avatar
Boris Bonev committed
34
import math
Boris Bonev's avatar
Boris Bonev committed
35
import torch
Boris Bonev's avatar
Boris Bonev committed
36
from torch.autograd import gradcheck
Boris Bonev's avatar
Boris Bonev committed
37
import torch_harmonics as th
Boris Bonev's avatar
Boris Bonev committed
38

39
40
41
42
43
_devices = [(torch.device("cpu"),)]
if torch.cuda.is_available():
    _devices.append((torch.device("cuda"),))


Boris Bonev's avatar
Boris Bonev committed
44
class TestLegendrePolynomials(unittest.TestCase):
45
    """Test the associated Legendre polynomials (CPU/CUDA if available)."""
Boris Bonev's avatar
Boris Bonev committed
46
    def setUp(self):
Thorsten Kurth's avatar
Thorsten Kurth committed
47
        self.cml = lambda m, l: math.sqrt((2 * l + 1) / 4 / math.pi) * math.sqrt(math.factorial(l - m) / math.factorial(l + m))
Boris Bonev's avatar
Boris Bonev committed
48
49
50
51
        self.pml = dict()

        # preparing associated Legendre Polynomials (These include the Condon-Shortley phase)
        # for reference see e.g. https://en.wikipedia.org/wiki/Associated_Legendre_polynomials
Thorsten Kurth's avatar
Thorsten Kurth committed
52
        self.pml[(0, 0)] = lambda x: torch.ones_like(x)
Boris Bonev's avatar
Boris Bonev committed
53
        self.pml[(0, 1)] = lambda x: x
Thorsten Kurth's avatar
Thorsten Kurth committed
54
        self.pml[(1, 1)] = lambda x: -torch.sqrt(1.0 - x**2)
Boris Bonev's avatar
Boris Bonev committed
55
        self.pml[(0, 2)] = lambda x: 0.5 * (3 * x**2 - 1)
Thorsten Kurth's avatar
Thorsten Kurth committed
56
        self.pml[(1, 2)] = lambda x: -3 * x * torch.sqrt(1.0 - x**2)
Boris Bonev's avatar
Boris Bonev committed
57
58
        self.pml[(2, 2)] = lambda x: 3 * (1 - x**2)
        self.pml[(0, 3)] = lambda x: 0.5 * (5 * x**3 - 3 * x)
Thorsten Kurth's avatar
Thorsten Kurth committed
59
        self.pml[(1, 3)] = lambda x: 1.5 * (1 - 5 * x**2) * torch.sqrt(1.0 - x**2)
Boris Bonev's avatar
Boris Bonev committed
60
        self.pml[(2, 3)] = lambda x: 15 * x * (1 - x**2)
Thorsten Kurth's avatar
Thorsten Kurth committed
61
        self.pml[(3, 3)] = lambda x: -15 * torch.sqrt(1.0 - x**2) ** 3
Boris Bonev's avatar
Boris Bonev committed
62
63
64

        self.lmax = self.mmax = 4

Boris Bonev's avatar
Boris Bonev committed
65
66
        self.tol = 1e-9

Thorsten Kurth's avatar
Thorsten Kurth committed
67
68
69
    def test_legendre(self, verbose=False):
        if verbose:
            print("Testing computation of associated Legendre polynomials")
Boris Bonev's avatar
Boris Bonev committed
70

Thorsten Kurth's avatar
Thorsten Kurth committed
71
        t = torch.linspace(0, 1, 100, dtype=torch.float64)
Boris Bonev's avatar
Boris Bonev committed
72
        vdm = th.legendre.legpoly(self.mmax, self.lmax, t)
Boris Bonev's avatar
Boris Bonev committed
73
74

        for l in range(self.lmax):
Boris Bonev's avatar
Boris Bonev committed
75
76
            for m in range(l + 1):
                diff = vdm[m, l] / self.cml(m, l) - self.pml[(m, l)](t)
Boris Bonev's avatar
Boris Bonev committed
77
                self.assertTrue(diff.max() <= self.tol)
Boris Bonev's avatar
Boris Bonev committed
78
79


80
@parameterized_class(("device"), _devices)
Boris Bonev's avatar
Boris Bonev committed
81
class TestSphericalHarmonicTransform(unittest.TestCase):
82
    """Test the spherical harmonic transform (CPU/CUDA if available)."""
Boris Bonev's avatar
Boris Bonev committed
83
    def setUp(self):
Andrea Paris's avatar
Andrea Paris committed
84
85
86
        torch.manual_seed(333)
        if self.device.type == "cuda":
            torch.cuda.manual_seed(333)
Boris Bonev's avatar
Boris Bonev committed
87
88
89

    @parameterized.expand(
        [
90
91
92
93
94
95
96
97
98
99
100
101
102
103
            # even-even
            [32, 64, 32, "ortho", "equiangular", 1e-9, False],
            [32, 64, 32, "ortho", "legendre-gauss", 1e-9, False],
            [32, 64, 32, "ortho", "lobatto", 1e-9, False],
            [32, 64, 32, "four-pi", "equiangular", 1e-9, False],
            [32, 64, 32, "four-pi", "legendre-gauss", 1e-9, False],
            [32, 64, 32, "four-pi", "lobatto", 1e-9, False],
            [32, 64, 32, "schmidt", "equiangular", 1e-9, False],
            [32, 64, 32, "schmidt", "legendre-gauss", 1e-9, False],
            [32, 64, 32, "schmidt", "lobatto", 1e-9, False],
            # odd-even
            [33, 64, 32, "ortho", "equiangular", 1e-9, False],
            [33, 64, 32, "ortho", "legendre-gauss", 1e-9, False],
            [33, 64, 32, "ortho", "lobatto", 1e-9, False],
Thorsten Kurth's avatar
Thorsten Kurth committed
104
            [33, 64, 32, "four-pi", "equiangular", 1e-9, False],
105
106
107
108
109
            [33, 64, 32, "four-pi", "legendre-gauss", 1e-9, False],
            [33, 64, 32, "four-pi", "lobatto", 1e-9, False],
            [33, 64, 32, "schmidt", "equiangular", 1e-9, False],
            [33, 64, 32, "schmidt", "legendre-gauss", 1e-9, False],
            [33, 64, 32, "schmidt", "lobatto", 1e-9, False],
Thorsten Kurth's avatar
Thorsten Kurth committed
110
111
        ],
        skip_on_empty=True,
Boris Bonev's avatar
Boris Bonev committed
112
    )
Thorsten Kurth's avatar
Thorsten Kurth committed
113
    def test_forward_inverse(self, nlat, nlon, batch_size, norm, grid, tol, verbose):
Thorsten Kurth's avatar
Thorsten Kurth committed
114
115
        if verbose:
            print(f"Testing real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization on {self.device.type} device")
Boris Bonev's avatar
Boris Bonev committed
116
117

        testiters = [1, 2, 4, 8, 16]
Boris Bonev's avatar
Boris Bonev committed
118
119
        if grid == "equiangular":
            mmax = nlat // 2
120
121
        elif grid == "lobatto":
            mmax = nlat - 1
Boris Bonev's avatar
Boris Bonev committed
122
123
        else:
            mmax = nlat
Boris Bonev's avatar
Boris Bonev committed
124
125
        lmax = mmax

Boris Bonev's avatar
Boris Bonev committed
126
127
        sht = th.RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
        isht = th.InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
Boris Bonev's avatar
Boris Bonev committed
128

Boris Bonev's avatar
Boris Bonev committed
129
130
131
132
        with torch.no_grad():
            coeffs = torch.zeros(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
            coeffs[:, :lmax, :mmax] = torch.randn(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
            signal = isht(coeffs)
Boris Bonev's avatar
Boris Bonev committed
133

Boris Bonev's avatar
Boris Bonev committed
134
        # testing error accumulation
Boris Bonev's avatar
Boris Bonev committed
135
        for iter in testiters:
Boris Bonev's avatar
Boris Bonev committed
136
            with self.subTest(i=iter):
Thorsten Kurth's avatar
Thorsten Kurth committed
137
138
                if verbose:
                    print(f"{iter} iterations of batchsize {batch_size}:")
Boris Bonev's avatar
Boris Bonev committed
139
140
141

                base = signal

142
                for _ in range(iter):
Boris Bonev's avatar
Boris Bonev committed
143
                    base = isht(sht(base))
Boris Bonev's avatar
Boris Bonev committed
144
145

                err = torch.mean(torch.norm(base - signal, p="fro", dim=(-1, -2)) / torch.norm(signal, p="fro", dim=(-1, -2)))
Thorsten Kurth's avatar
Thorsten Kurth committed
146
147
                if verbose:
                    print(f"final relative error: {err.item()}")
Boris Bonev's avatar
Boris Bonev committed
148
149
                self.assertTrue(err.item() <= tol)

Boris Bonev's avatar
Boris Bonev committed
150
151
    @parameterized.expand(
        [
152
            # even-even
Thorsten Kurth's avatar
Thorsten Kurth committed
153
154
            [12, 24, 2, "ortho", "equiangular", 1e-5, False],
            [12, 24, 2, "ortho", "legendre-gauss", 1e-5, False],
155
            [12, 24, 2, "ortho", "lobatto", 1e-5, False],
Thorsten Kurth's avatar
Thorsten Kurth committed
156
157
            [12, 24, 2, "four-pi", "equiangular", 1e-5, False],
            [12, 24, 2, "four-pi", "legendre-gauss", 1e-5, False],
158
            [12, 24, 2, "four-pi", "lobatto", 1e-5, False],
Thorsten Kurth's avatar
Thorsten Kurth committed
159
160
            [12, 24, 2, "schmidt", "equiangular", 1e-5, False],
            [12, 24, 2, "schmidt", "legendre-gauss", 1e-5, False],
161
            [12, 24, 2, "schmidt", "lobatto", 1e-5, False],
162
            # odd-even
Thorsten Kurth's avatar
Thorsten Kurth committed
163
164
            [15, 30, 2, "ortho", "equiangular", 1e-5, False],
            [15, 30, 2, "ortho", "legendre-gauss", 1e-5, False],
165
            [15, 30, 2, "ortho", "lobatto", 1e-5, False],
Thorsten Kurth's avatar
Thorsten Kurth committed
166
167
            [15, 30, 2, "four-pi", "equiangular", 1e-5, False],
            [15, 30, 2, "four-pi", "legendre-gauss", 1e-5, False],
168
            [15, 30, 2, "four-pi", "lobatto", 1e-5, False],
Thorsten Kurth's avatar
Thorsten Kurth committed
169
170
            [15, 30, 2, "schmidt", "equiangular", 1e-5, False],
            [15, 30, 2, "schmidt", "legendre-gauss", 1e-5, False],
171
            [15, 30, 2, "schmidt", "lobatto", 1e-5, False],
Thorsten Kurth's avatar
Thorsten Kurth committed
172
173
        ],
        skip_on_empty=True,
Boris Bonev's avatar
Boris Bonev committed
174
    )
Thorsten Kurth's avatar
Thorsten Kurth committed
175
    def test_grads(self, nlat, nlon, batch_size, norm, grid, tol, verbose):
Thorsten Kurth's avatar
Thorsten Kurth committed
176
177
        if verbose:
            print(f"Testing gradients of real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization")
Boris Bonev's avatar
Boris Bonev committed
178
179
180

        if grid == "equiangular":
            mmax = nlat // 2
181
182
        elif grid == "lobatto":
            mmax = nlat - 1
Boris Bonev's avatar
Boris Bonev committed
183
184
        else:
            mmax = nlat
Boris Bonev's avatar
Boris Bonev committed
185
186
        lmax = mmax

Boris Bonev's avatar
Boris Bonev committed
187
188
        sht = th.RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
        isht = th.InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
Boris Bonev's avatar
Boris Bonev committed
189

Boris Bonev's avatar
Boris Bonev committed
190
191
192
193
        with torch.no_grad():
            coeffs = torch.zeros(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
            coeffs[:, :lmax, :mmax] = torch.randn(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
            signal = isht(coeffs)
Boris Bonev's avatar
Boris Bonev committed
194
195

        # test the sht
Boris Bonev's avatar
Boris Bonev committed
196
        grad_input = torch.randn_like(signal, requires_grad=True)
Boris Bonev's avatar
Boris Bonev committed
197
198
199
200
201
202
203
        err_handle = lambda x: torch.mean(torch.norm(sht(x) - coeffs, p="fro", dim=(-1, -2)) / torch.norm(coeffs, p="fro", dim=(-1, -2)))
        test_result = gradcheck(err_handle, grad_input, eps=1e-6, atol=tol)
        self.assertTrue(test_result)

        # test the isht
        grad_input = torch.randn_like(coeffs, requires_grad=True)
        err_handle = lambda x: torch.mean(torch.norm(isht(x) - signal, p="fro", dim=(-1, -2)) / torch.norm(signal, p="fro", dim=(-1, -2)))
Boris Bonev's avatar
Boris Bonev committed
204
        test_result = gradcheck(err_handle, grad_input, eps=1e-6, atol=tol)
Boris Bonev's avatar
Boris Bonev committed
205
        self.assertTrue(test_result)
Boris Bonev's avatar
Boris Bonev committed
206

Thorsten Kurth's avatar
Thorsten Kurth committed
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
    @parameterized.expand(
        [
            # even-even
            [12, 24, 2, "ortho", "equiangular", 1e-5, False],
            [12, 24, 2, "ortho", "legendre-gauss", 1e-5, False],
            [12, 24, 2, "ortho", "lobatto", 1e-5, False],
        ],
        skip_on_empty=True,
    )
    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is not available")
    def test_device_instantiation(self, nlat, nlon, batch_size, norm, grid, tol, verbose):
        if verbose:
            print(f"Testing device instantiation of real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization")

        if grid == "equiangular":
            mmax = nlat // 2
        elif grid == "lobatto":
            mmax = nlat - 1
        else:
            mmax = nlat
        lmax = mmax

        # init on cpu
        sht_host = th.RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm)
        isht_host = th.InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm)

        # init on device
        with torch.device(self.device):
            sht_device = th.RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm)
            isht_device = th.InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm)

        self.assertTrue(torch.allclose(sht_host.weights.cpu(), sht_device.weights.cpu()))
        self.assertTrue(torch.allclose(isht_host.pct.cpu(), isht_device.pct.cpu()))

Boris Bonev's avatar
Boris Bonev committed
241

Boris Bonev's avatar
Boris Bonev committed
242
243
if __name__ == "__main__":
    unittest.main()