test_sht.py 8.35 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
Boris Bonev's avatar
Boris Bonev committed
5
#
Boris Bonev's avatar
Boris Bonev committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import unittest
Boris Bonev's avatar
Boris Bonev committed
33
34
from parameterized import parameterized
import math
Boris Bonev's avatar
Boris Bonev committed
35
import torch
Boris Bonev's avatar
Boris Bonev committed
36
from torch.autograd import gradcheck
Boris Bonev's avatar
Boris Bonev committed
37
38
from torch_harmonics import *

Boris Bonev's avatar
Boris Bonev committed
39

Boris Bonev's avatar
Boris Bonev committed
40
41
42
class TestLegendrePolynomials(unittest.TestCase):

    def setUp(self):
Thorsten Kurth's avatar
Thorsten Kurth committed
43
        self.cml = lambda m, l: math.sqrt((2 * l + 1) / 4 / math.pi) * math.sqrt(math.factorial(l - m) / math.factorial(l + m))
Boris Bonev's avatar
Boris Bonev committed
44
45
46
47
        self.pml = dict()

        # preparing associated Legendre Polynomials (These include the Condon-Shortley phase)
        # for reference see e.g. https://en.wikipedia.org/wiki/Associated_Legendre_polynomials
Thorsten Kurth's avatar
Thorsten Kurth committed
48
        self.pml[(0, 0)] = lambda x: torch.ones_like(x)
Boris Bonev's avatar
Boris Bonev committed
49
        self.pml[(0, 1)] = lambda x: x
Thorsten Kurth's avatar
Thorsten Kurth committed
50
        self.pml[(1, 1)] = lambda x: -torch.sqrt(1.0 - x**2)
Boris Bonev's avatar
Boris Bonev committed
51
        self.pml[(0, 2)] = lambda x: 0.5 * (3 * x**2 - 1)
Thorsten Kurth's avatar
Thorsten Kurth committed
52
        self.pml[(1, 2)] = lambda x: -3 * x * torch.sqrt(1.0 - x**2)
Boris Bonev's avatar
Boris Bonev committed
53
54
        self.pml[(2, 2)] = lambda x: 3 * (1 - x**2)
        self.pml[(0, 3)] = lambda x: 0.5 * (5 * x**3 - 3 * x)
Thorsten Kurth's avatar
Thorsten Kurth committed
55
        self.pml[(1, 3)] = lambda x: 1.5 * (1 - 5 * x**2) * torch.sqrt(1.0 - x**2)
Boris Bonev's avatar
Boris Bonev committed
56
        self.pml[(2, 3)] = lambda x: 15 * x * (1 - x**2)
Thorsten Kurth's avatar
Thorsten Kurth committed
57
        self.pml[(3, 3)] = lambda x: -15 * torch.sqrt(1.0 - x**2) ** 3
Boris Bonev's avatar
Boris Bonev committed
58
59
60

        self.lmax = self.mmax = 4

Boris Bonev's avatar
Boris Bonev committed
61
62
        self.tol = 1e-9

Thorsten Kurth's avatar
Thorsten Kurth committed
63
64
65
    def test_legendre(self, verbose=False):
        if verbose:
            print("Testing computation of associated Legendre polynomials")
66
        from torch_harmonics.legendre import legpoly
Boris Bonev's avatar
Boris Bonev committed
67

Thorsten Kurth's avatar
Thorsten Kurth committed
68
        t = torch.linspace(0, 1, 100, dtype=torch.float64)
69
        vdm = legpoly(self.mmax, self.lmax, t)
Boris Bonev's avatar
Boris Bonev committed
70
71

        for l in range(self.lmax):
Boris Bonev's avatar
Boris Bonev committed
72
73
            for m in range(l + 1):
                diff = vdm[m, l] / self.cml(m, l) - self.pml[(m, l)](t)
Boris Bonev's avatar
Boris Bonev committed
74
                self.assertTrue(diff.max() <= self.tol)
Boris Bonev's avatar
Boris Bonev committed
75
76
77
78
79
80
81


class TestSphericalHarmonicTransform(unittest.TestCase):

    def setUp(self):

        if torch.cuda.is_available():
Boris Bonev's avatar
Boris Bonev committed
82
            self.device = torch.device("cuda")
Boris Bonev's avatar
Boris Bonev committed
83
        else:
Boris Bonev's avatar
Boris Bonev committed
84
85
86
87
            self.device = torch.device("cpu")

    @parameterized.expand(
        [
Thorsten Kurth's avatar
Thorsten Kurth committed
88
89
            [256, 512, 32, "ortho", "equiangular", 1e-9, False],
            [256, 512, 32, "ortho", "legendre-gauss", 1e-9, False],
90
            [256, 512, 32, "ortho", "lobatto", 1e-9, False],
Thorsten Kurth's avatar
Thorsten Kurth committed
91
92
            [256, 512, 32, "four-pi", "equiangular", 1e-9, False],
            [256, 512, 32, "four-pi", "legendre-gauss", 1e-9, False],
93
            [256, 512, 32, "four-pi", "lobatto", 1e-9, False],
Thorsten Kurth's avatar
Thorsten Kurth committed
94
95
            [256, 512, 32, "schmidt", "equiangular", 1e-9, False],
            [256, 512, 32, "schmidt", "legendre-gauss", 1e-9, False],
96
            [256, 512, 32, "schmidt", "lobatto", 1e-9, False],
Boris Bonev's avatar
Boris Bonev committed
97
98
        ]
    )
Thorsten Kurth's avatar
Thorsten Kurth committed
99
100
101
    def test_sht(self, nlat, nlon, batch_size, norm, grid, tol, verbose):
        if verbose:
            print(f"Testing real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization on {self.device.type} device")
Boris Bonev's avatar
Boris Bonev committed
102
103

        testiters = [1, 2, 4, 8, 16]
Boris Bonev's avatar
Boris Bonev committed
104
105
        if grid == "equiangular":
            mmax = nlat // 2
106
107
        elif grid == "lobatto":
            mmax = nlat - 1
Boris Bonev's avatar
Boris Bonev committed
108
109
        else:
            mmax = nlat
Boris Bonev's avatar
Boris Bonev committed
110
111
        lmax = mmax

Boris Bonev's avatar
Boris Bonev committed
112
113
        sht = RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
        isht = InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
Boris Bonev's avatar
Boris Bonev committed
114

Boris Bonev's avatar
Boris Bonev committed
115
116
117
118
        with torch.no_grad():
            coeffs = torch.zeros(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
            coeffs[:, :lmax, :mmax] = torch.randn(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
            signal = isht(coeffs)
Boris Bonev's avatar
Boris Bonev committed
119

Boris Bonev's avatar
Boris Bonev committed
120
        # testing error accumulation
Boris Bonev's avatar
Boris Bonev committed
121
        for iter in testiters:
Boris Bonev's avatar
Boris Bonev committed
122
            with self.subTest(i=iter):
Thorsten Kurth's avatar
Thorsten Kurth committed
123
124
                if verbose:
                    print(f"{iter} iterations of batchsize {batch_size}:")
Boris Bonev's avatar
Boris Bonev committed
125
126
127

                base = signal

128
                for _ in range(iter):
Boris Bonev's avatar
Boris Bonev committed
129
                    base = isht(sht(base))
Boris Bonev's avatar
Boris Bonev committed
130
131

                err = torch.mean(torch.norm(base - signal, p="fro", dim=(-1, -2)) / torch.norm(signal, p="fro", dim=(-1, -2)))
Thorsten Kurth's avatar
Thorsten Kurth committed
132
133
                if verbose:
                    print(f"final relative error: {err.item()}")
Boris Bonev's avatar
Boris Bonev committed
134
135
                self.assertTrue(err.item() <= tol)

Boris Bonev's avatar
Boris Bonev committed
136
137
    @parameterized.expand(
        [
Thorsten Kurth's avatar
Thorsten Kurth committed
138
139
            [12, 24, 2, "ortho", "equiangular", 1e-5, False],
            [12, 24, 2, "ortho", "legendre-gauss", 1e-5, False],
140
            [12, 24, 2, "ortho", "lobatto", 1e-5, False],
Thorsten Kurth's avatar
Thorsten Kurth committed
141
142
            [12, 24, 2, "four-pi", "equiangular", 1e-5, False],
            [12, 24, 2, "four-pi", "legendre-gauss", 1e-5, False],
143
            [12, 24, 2, "four-pi", "lobatto", 1e-5, False],
Thorsten Kurth's avatar
Thorsten Kurth committed
144
145
            [12, 24, 2, "schmidt", "equiangular", 1e-5, False],
            [12, 24, 2, "schmidt", "legendre-gauss", 1e-5, False],
146
            [12, 24, 2, "schmidt", "lobatto", 1e-5, False],
Thorsten Kurth's avatar
Thorsten Kurth committed
147
148
            [15, 30, 2, "ortho", "equiangular", 1e-5, False],
            [15, 30, 2, "ortho", "legendre-gauss", 1e-5, False],
149
            [15, 30, 2, "ortho", "lobatto", 1e-5, False],
Thorsten Kurth's avatar
Thorsten Kurth committed
150
151
            [15, 30, 2, "four-pi", "equiangular", 1e-5, False],
            [15, 30, 2, "four-pi", "legendre-gauss", 1e-5, False],
152
            [15, 30, 2, "four-pi", "lobatto", 1e-5, False],
Thorsten Kurth's avatar
Thorsten Kurth committed
153
154
            [15, 30, 2, "schmidt", "equiangular", 1e-5, False],
            [15, 30, 2, "schmidt", "legendre-gauss", 1e-5, False],
155
            [15, 30, 2, "schmidt", "lobatto", 1e-5, False],
Boris Bonev's avatar
Boris Bonev committed
156
157
        ]
    )
Thorsten Kurth's avatar
Thorsten Kurth committed
158
159
160
    def test_sht_grads(self, nlat, nlon, batch_size, norm, grid, tol, verbose):
        if verbose:
            print(f"Testing gradients of real-valued SHT on {nlat}x{nlon} {grid} grid with {norm} normalization")
Boris Bonev's avatar
Boris Bonev committed
161
162
163

        if grid == "equiangular":
            mmax = nlat // 2
164
165
        elif grid == "lobatto":
            mmax = nlat - 1
Boris Bonev's avatar
Boris Bonev committed
166
167
        else:
            mmax = nlat
Boris Bonev's avatar
Boris Bonev committed
168
169
        lmax = mmax

Boris Bonev's avatar
Boris Bonev committed
170
171
        sht = RealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
        isht = InverseRealSHT(nlat, nlon, mmax=mmax, lmax=lmax, grid=grid, norm=norm).to(self.device)
Boris Bonev's avatar
Boris Bonev committed
172

Boris Bonev's avatar
Boris Bonev committed
173
174
175
176
        with torch.no_grad():
            coeffs = torch.zeros(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
            coeffs[:, :lmax, :mmax] = torch.randn(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128)
            signal = isht(coeffs)
Boris Bonev's avatar
Boris Bonev committed
177
178

        # test the sht
Boris Bonev's avatar
Boris Bonev committed
179
        grad_input = torch.randn_like(signal, requires_grad=True)
Boris Bonev's avatar
Boris Bonev committed
180
181
182
183
184
185
186
        err_handle = lambda x: torch.mean(torch.norm(sht(x) - coeffs, p="fro", dim=(-1, -2)) / torch.norm(coeffs, p="fro", dim=(-1, -2)))
        test_result = gradcheck(err_handle, grad_input, eps=1e-6, atol=tol)
        self.assertTrue(test_result)

        # test the isht
        grad_input = torch.randn_like(coeffs, requires_grad=True)
        err_handle = lambda x: torch.mean(torch.norm(isht(x) - signal, p="fro", dim=(-1, -2)) / torch.norm(signal, p="fro", dim=(-1, -2)))
Boris Bonev's avatar
Boris Bonev committed
187
        test_result = gradcheck(err_handle, grad_input, eps=1e-6, atol=tol)
Boris Bonev's avatar
Boris Bonev committed
188
        self.assertTrue(test_result)
Boris Bonev's avatar
Boris Bonev committed
189
190


Boris Bonev's avatar
Boris Bonev committed
191
192
if __name__ == "__main__":
    unittest.main()