test_sht.py 11.2 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
Boris Bonev's avatar
Boris Bonev committed
5
#
Boris Bonev's avatar
Boris Bonev committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import unittest
33
from parameterized import parameterized, parameterized_class
Boris Bonev's avatar
Boris Bonev committed
34
import math
Boris Bonev's avatar
Boris Bonev committed
35
import torch
Boris Bonev's avatar
Boris Bonev committed
36
from torch.autograd import gradcheck
Boris Bonev's avatar
Boris Bonev committed
37
import torch_harmonics as th
Boris Bonev's avatar
Boris Bonev committed
38

39
40
41
42
43
_devices = [(torch.device("cpu"),)]
if torch.cuda.is_available():
    _devices.append((torch.device("cuda"),))


Boris Bonev's avatar
Boris Bonev committed
44
class TestLegendrePolynomials(unittest.TestCase):
apaaris's avatar
apaaris committed
45
46
47
48
49
50
51
52
    """
    Test the associated Legendre polynomials.

    Parameters
    ----------
    verbose : bool, optional
        Whether to print verbose output, by default False
    """
Boris Bonev's avatar
Boris Bonev committed
53
    def setUp(self):
Thorsten Kurth's avatar
Thorsten Kurth committed
54
        self.cml = lambda m, l: math.sqrt((2 * l + 1) / 4 / math.pi) * math.sqrt(math.factorial(l - m) / math.factorial(l + m))
Boris Bonev's avatar
Boris Bonev committed
55
56
57
58
        self.pml = dict()

        # preparing associated Legendre Polynomials (These include the Condon-Shortley phase)
        # for reference see e.g. https://en.wikipedia.org/wiki/Associated_Legendre_polynomials
Thorsten Kurth's avatar
Thorsten Kurth committed
59
        self.pml[(0, 0)] = lambda x: torch.ones_like(x)
Boris Bonev's avatar
Boris Bonev committed
60
        self.pml[(0, 1)] = lambda x: x
Thorsten Kurth's avatar
Thorsten Kurth committed
61
        self.pml[(1, 1)] = lambda x: -torch.sqrt(1.0 - x**2)
Boris Bonev's avatar
Boris Bonev committed
62
        self.pml[(0, 2)] = lambda x: 0.5 * (3 * x**2 - 1)
Thorsten Kurth's avatar
Thorsten Kurth committed
63
        self.pml[(1, 2)] = lambda x: -3 * x * torch.sqrt(1.0 - x**2)
Boris Bonev's avatar
Boris Bonev committed
64
65
        self.pml[(2, 2)] = lambda x: 3 * (1 - x**2)
        self.pml[(0, 3)] = lambda x: 0.5 * (5 * x**3 - 3 * x)
Thorsten Kurth's avatar
Thorsten Kurth committed
66
        self.pml[(1, 3)] = lambda x: 1.5 * (1 - 5 * x**2) * torch.sqrt(1.0 - x**2)
Boris Bonev's avatar
Boris Bonev committed
67
        self.pml[(2, 3)] = lambda x: 15 * x * (1 - x**2)
Thorsten Kurth's avatar
Thorsten Kurth committed
68
        self.pml[(3, 3)] = lambda x: -15 * torch.sqrt(1.0 - x**2) ** 3
Boris Bonev's avatar
Boris Bonev committed
69
70
71

        self.lmax = self.mmax = 4

Boris Bonev's avatar
Boris Bonev committed
72
73
        self.tol = 1e-9

Thorsten Kurth's avatar
Thorsten Kurth committed
74
    def test_legendre(self, verbose=False):
apaaris's avatar
apaaris committed
75
76
77
78
79
80
81
82
        """
        Test the computation of associated Legendre polynomials.

        Parameters
        ----------
        verbose : bool, optional
            Whether to print verbose output, by default False
        """
Thorsten Kurth's avatar
Thorsten Kurth committed
83
84
        if verbose:
            print("Testing computation of associated Legendre polynomials")
Boris Bonev's avatar
Boris Bonev committed
85

Thorsten Kurth's avatar
Thorsten Kurth committed
86
        t = torch.linspace(0, 1, 100, dtype=torch.float64)
Boris Bonev's avatar
Boris Bonev committed
87
        vdm = th.legendre.legpoly(self.mmax, self.lmax, t)
Boris Bonev's avatar
Boris Bonev committed
88
89

        for l in range(self.lmax):
Boris Bonev's avatar
Boris Bonev committed
90
91
            for m in range(l + 1):
                diff = vdm[m, l] / self.cml(m, l) - self.pml[(m, l)](t)
Boris Bonev's avatar
Boris Bonev committed
92
                self.assertTrue(diff.max() <= self.tol)
Boris Bonev's avatar
Boris Bonev committed
93
94


95
@parameterized_class(("device"), _devices)
Boris Bonev's avatar
Boris Bonev committed
96
class TestSphericalHarmonicTransform(unittest.TestCase):
apaaris's avatar
apaaris committed
97
98
99
100
101
102
103
104
    """
    Test the spherical harmonic transform.

    Parameters
    ----------
    verbose : bool, optional
        Whether to print verbose output, by default False
    """
Boris Bonev's avatar
Boris Bonev committed
105
    def setUp(self):
apaaris's avatar
apaaris committed
106
107
108
109
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")
Boris Bonev's avatar
Boris Bonev committed
110
111
112

    @parameterized.expand(
        [
113
114
115
116
117
118
119
120
121
122
123
124
125
126
            # even-even
            [32, 64, 32, "ortho", "equiangular", 1e-9, False],
            [32, 64, 32, "ortho", "legendre-gauss", 1e-9, False],
            [32, 64, 32, "ortho", "lobatto", 1e-9, False],
            [32, 64, 32, "four-pi", "equiangular", 1e-9, False],
            [32, 64, 32, "four-pi", "legendre-gauss", 1e-9, False],
            [32, 64, 32, "four-pi", "lobatto", 1e-9, False],
            [32, 64, 32, "schmidt", "equiangular", 1e-9, False],
            [32, 64, 32, "schmidt", "legendre-gauss", 1e-9, False],
            [32, 64, 32, "schmidt", "lobatto", 1e-9, False],
            # odd-even
            [33, 64, 32, "ortho", "equiangular", 1e-9, False],
            [33, 64, 32, "ortho", "legendre-gauss", 1e-9, False],
            [33, 64, 32, "ortho", "lobatto", 1e-9, False],
Thorsten Kurth's avatar
Thorsten Kurth committed
127
            [33, 64, 32, "four-pi", "equiangular", 1e-9, False],
128
129
130
131
132
            [33, 64, 32, "four-pi", "legendre-gauss", 1e-9, False],
            [33, 64, 32, "four-pi", "lobatto", 1e-9, False],
            [33, 64, 32, "schmidt", "equiangular", 1e-9, False],
            [33, 64, 32, "schmidt", "legendre-gauss", 1e-9, False],
            [33, 64, 32, "schmidt", "lobatto", 1e-9, False],
Thorsten Kurth's avatar
Thorsten Kurth committed
133
134
        ],
        skip_on_empty=True,
Boris Bonev's avatar
Boris Bonev committed
135
    )
Thorsten Kurth's avatar
Thorsten Kurth committed
136
    def test_forward_inverse(self, nlat, nlon, batch_size, norm, grid, tol, verbose):
Thorsten Kurth's avatar
Thorsten Kurth committed
137
138
        if verbose:
            print(f"Testing real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization on {self.device.type} device")
Boris Bonev's avatar
Boris Bonev committed
139
140

        testiters = [1, 2, 4, 8, 16]
Boris Bonev's avatar
Boris Bonev committed
141
142
        if grid == "equiangular":
            mmax = nlat // 2
143
144
        elif grid == "lobatto":
            mmax = nlat - 1
Boris Bonev's avatar
Boris Bonev committed
145
146
        else:
            mmax = nlat
Boris Bonev's avatar
Boris Bonev committed
147
148
        lmax = mmax

Boris Bonev's avatar
Boris Bonev committed
149
150
        sht = th.RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
        isht = th.InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
Boris Bonev's avatar
Boris Bonev committed
151

Boris Bonev's avatar
Boris Bonev committed
152
153
154
155
        with torch.no_grad():
            coeffs = torch.zeros(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
            coeffs[:, :lmax, :mmax] = torch.randn(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
            signal = isht(coeffs)
Boris Bonev's avatar
Boris Bonev committed
156

Boris Bonev's avatar
Boris Bonev committed
157
        # testing error accumulation
Boris Bonev's avatar
Boris Bonev committed
158
        for iter in testiters:
Boris Bonev's avatar
Boris Bonev committed
159
            with self.subTest(i=iter):
Thorsten Kurth's avatar
Thorsten Kurth committed
160
161
                if verbose:
                    print(f"{iter} iterations of batchsize {batch_size}:")
Boris Bonev's avatar
Boris Bonev committed
162
163
164

                base = signal

165
                for _ in range(iter):
Boris Bonev's avatar
Boris Bonev committed
166
                    base = isht(sht(base))
Boris Bonev's avatar
Boris Bonev committed
167
168

                err = torch.mean(torch.norm(base - signal, p="fro", dim=(-1, -2)) / torch.norm(signal, p="fro", dim=(-1, -2)))
Thorsten Kurth's avatar
Thorsten Kurth committed
169
170
                if verbose:
                    print(f"final relative error: {err.item()}")
Boris Bonev's avatar
Boris Bonev committed
171
172
                self.assertTrue(err.item() <= tol)

Boris Bonev's avatar
Boris Bonev committed
173
174
    @parameterized.expand(
        [
175
            # even-even
Thorsten Kurth's avatar
Thorsten Kurth committed
176
177
            [12, 24, 2, "ortho", "equiangular", 1e-5, False],
            [12, 24, 2, "ortho", "legendre-gauss", 1e-5, False],
178
            [12, 24, 2, "ortho", "lobatto", 1e-5, False],
Thorsten Kurth's avatar
Thorsten Kurth committed
179
180
            [12, 24, 2, "four-pi", "equiangular", 1e-5, False],
            [12, 24, 2, "four-pi", "legendre-gauss", 1e-5, False],
181
            [12, 24, 2, "four-pi", "lobatto", 1e-5, False],
Thorsten Kurth's avatar
Thorsten Kurth committed
182
183
            [12, 24, 2, "schmidt", "equiangular", 1e-5, False],
            [12, 24, 2, "schmidt", "legendre-gauss", 1e-5, False],
184
            [12, 24, 2, "schmidt", "lobatto", 1e-5, False],
185
            # odd-even
Thorsten Kurth's avatar
Thorsten Kurth committed
186
187
            [15, 30, 2, "ortho", "equiangular", 1e-5, False],
            [15, 30, 2, "ortho", "legendre-gauss", 1e-5, False],
188
            [15, 30, 2, "ortho", "lobatto", 1e-5, False],
Thorsten Kurth's avatar
Thorsten Kurth committed
189
190
            [15, 30, 2, "four-pi", "equiangular", 1e-5, False],
            [15, 30, 2, "four-pi", "legendre-gauss", 1e-5, False],
191
            [15, 30, 2, "four-pi", "lobatto", 1e-5, False],
Thorsten Kurth's avatar
Thorsten Kurth committed
192
193
            [15, 30, 2, "schmidt", "equiangular", 1e-5, False],
            [15, 30, 2, "schmidt", "legendre-gauss", 1e-5, False],
194
            [15, 30, 2, "schmidt", "lobatto", 1e-5, False],
Thorsten Kurth's avatar
Thorsten Kurth committed
195
196
        ],
        skip_on_empty=True,
Boris Bonev's avatar
Boris Bonev committed
197
    )
Thorsten Kurth's avatar
Thorsten Kurth committed
198
    def test_grads(self, nlat, nlon, batch_size, norm, grid, tol, verbose):
Thorsten Kurth's avatar
Thorsten Kurth committed
199
200
        if verbose:
            print(f"Testing gradients of real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization")
Boris Bonev's avatar
Boris Bonev committed
201
202
203

        if grid == "equiangular":
            mmax = nlat // 2
204
205
        elif grid == "lobatto":
            mmax = nlat - 1
Boris Bonev's avatar
Boris Bonev committed
206
207
        else:
            mmax = nlat
Boris Bonev's avatar
Boris Bonev committed
208
209
        lmax = mmax

Boris Bonev's avatar
Boris Bonev committed
210
211
        sht = th.RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
        isht = th.InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
Boris Bonev's avatar
Boris Bonev committed
212

Boris Bonev's avatar
Boris Bonev committed
213
214
215
216
        with torch.no_grad():
            coeffs = torch.zeros(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
            coeffs[:, :lmax, :mmax] = torch.randn(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
            signal = isht(coeffs)
Boris Bonev's avatar
Boris Bonev committed
217
218

        # test the sht
Boris Bonev's avatar
Boris Bonev committed
219
        grad_input = torch.randn_like(signal, requires_grad=True)
Boris Bonev's avatar
Boris Bonev committed
220
221
222
223
224
225
226
        err_handle = lambda x: torch.mean(torch.norm(sht(x) - coeffs, p="fro", dim=(-1, -2)) / torch.norm(coeffs, p="fro", dim=(-1, -2)))
        test_result = gradcheck(err_handle, grad_input, eps=1e-6, atol=tol)
        self.assertTrue(test_result)

        # test the isht
        grad_input = torch.randn_like(coeffs, requires_grad=True)
        err_handle = lambda x: torch.mean(torch.norm(isht(x) - signal, p="fro", dim=(-1, -2)) / torch.norm(signal, p="fro", dim=(-1, -2)))
Boris Bonev's avatar
Boris Bonev committed
227
        test_result = gradcheck(err_handle, grad_input, eps=1e-6, atol=tol)
Boris Bonev's avatar
Boris Bonev committed
228
        self.assertTrue(test_result)
Boris Bonev's avatar
Boris Bonev committed
229

Thorsten Kurth's avatar
Thorsten Kurth committed
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
    @parameterized.expand(
        [
            # even-even
            [12, 24, 2, "ortho", "equiangular", 1e-5, False],
            [12, 24, 2, "ortho", "legendre-gauss", 1e-5, False],
            [12, 24, 2, "ortho", "lobatto", 1e-5, False],
        ],
        skip_on_empty=True,
    )
    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is not available")
    def test_device_instantiation(self, nlat, nlon, batch_size, norm, grid, tol, verbose):
        if verbose:
            print(f"Testing device instantiation of real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization")

        if grid == "equiangular":
            mmax = nlat // 2
        elif grid == "lobatto":
            mmax = nlat - 1
        else:
            mmax = nlat
        lmax = mmax

        # init on cpu
        sht_host = th.RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm)
        isht_host = th.InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm)

        # init on device
        with torch.device(self.device):
            sht_device = th.RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm)
            isht_device = th.InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm)

        self.assertTrue(torch.allclose(sht_host.weights.cpu(), sht_device.weights.cpu()))
        self.assertTrue(torch.allclose(isht_host.pct.cpu(), isht_device.pct.cpu()))

Boris Bonev's avatar
Boris Bonev committed
264

Boris Bonev's avatar
Boris Bonev committed
265
266
if __name__ == "__main__":
    unittest.main()