test_convolution.py 13.7 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import unittest
from parameterized import parameterized
from functools import partial
import math
import numpy as np
import torch
from torch.autograd import gradcheck
from torch_harmonics import *

Thorsten Kurth's avatar
Thorsten Kurth committed
41
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes
Boris Bonev's avatar
Boris Bonev committed
42

43

44
def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalization=False, basis_norm_mode="none", merge_quadrature=False, eps=1e-9):
Boris Bonev's avatar
Boris Bonev committed
45
46
47
48
49
    """
    Discretely normalizes the convolution tensor.
    """

    kernel_size, nlat_out, nlon_out, nlat_in, nlon_in = psi.shape
50
51
52
53
54
55
    correction_factor = nlon_out / nlon_in

    if basis_norm_mode == "individual":
        if transpose_normalization:
            # the normalization is not quite symmetric due to the compressed way psi is stored in the main code
            # look at the normalization code in the actual implementation
Boris Bonev's avatar
Boris Bonev committed
56
            psi_norm = torch.sum(quad_weights.reshape(1, -1, 1, 1, 1) * psi[:, :, :1].abs(), dim=(1, 4), keepdim=True)
57
        else:
Boris Bonev's avatar
Boris Bonev committed
58
            psi_norm = torch.sum(quad_weights.reshape(1, 1, 1, -1, 1) * psi.abs(), dim=(3, 4), keepdim=True)
59
60
61
62
63

    elif basis_norm_mode == "mean":
        if transpose_normalization:
            # the normalization is not quite symmetric due to the compressed way psi is stored in the main code
            # look at the normalization code in the actual implementation
Boris Bonev's avatar
Boris Bonev committed
64
            psi_norm = torch.sum(quad_weights.reshape(1, -1, 1, 1, 1) * psi[:, :, :1].abs(), dim=(1, 4), keepdim=True)
65
66
            psi_norm = psi_norm.mean(dim=3, keepdim=True)
        else:
Boris Bonev's avatar
Boris Bonev committed
67
            psi_norm = torch.sum(quad_weights.reshape(1, 1, 1, -1, 1) * psi.abs(), dim=(3, 4), keepdim=True)
68
69
70
71
72
            psi_norm = psi_norm.mean(dim=1, keepdim=True)
    elif basis_norm_mode == "none":
        psi_norm = 1.0
    else:
        raise ValueError(f"Unknown basis normalization mode {basis_norm_mode}.")
Boris Bonev's avatar
Boris Bonev committed
73
74

    if transpose_normalization:
75
        if merge_quadrature:
76
            psi = quad_weights.reshape(1, -1, 1, 1, 1) * psi / correction_factor
Boris Bonev's avatar
Boris Bonev committed
77
    else:
78
79
        if merge_quadrature:
            psi = quad_weights.reshape(1, 1, 1, -1, 1) * psi
Boris Bonev's avatar
Boris Bonev committed
80
81
82
83

    return psi / (psi_norm + eps)


84
85
86
def _precompute_convolution_tensor_dense(
    in_shape,
    out_shape,
87
    filter_basis,
88
89
90
    grid_in="equiangular",
    grid_out="equiangular",
    theta_cutoff=0.01 * math.pi,
91
    theta_eps=1e-3,
92
    transpose_normalization=False,
93
    basis_norm_mode="none",
94
95
    merge_quadrature=False,
):
Boris Bonev's avatar
Boris Bonev committed
96
97
98
99
100
101
102
    """
    Helper routine to compute the convolution Tensor in a dense fashion
    """

    assert len(in_shape) == 2
    assert len(out_shape) == 2

103
    kernel_size = filter_basis.kernel_size
Boris Bonev's avatar
Boris Bonev committed
104
105
106
107

    nlat_in, nlon_in = in_shape
    nlat_out, nlon_out = out_shape

108
109
    lats_in, win = quadrature._precompute_latitudes(nlat_in, grid=grid_in)
    lats_out, wout = quadrature._precompute_latitudes(nlat_out, grid=grid_out)
Boris Bonev's avatar
Boris Bonev committed
110

Thorsten Kurth's avatar
Thorsten Kurth committed
111
112
113
    # compute the phi differences.
    lons_in = _precompute_longitudes(nlon_in)
    lons_out = _precompute_longitudes(nlon_out)
114
115
116

    # effective theta cutoff if multiplied with a fudge factor to avoid aliasing with grid width (especially near poles)
    theta_cutoff_eff = (1.0 + theta_eps) * theta_cutoff
Boris Bonev's avatar
Boris Bonev committed
117

118
119
    # compute quadrature weights that will be merged into the Psi tensor
    if transpose_normalization:
Thorsten Kurth's avatar
Thorsten Kurth committed
120
        quad_weights = wout.reshape(-1, 1) / nlon_in / 2.0
121
    else:
Thorsten Kurth's avatar
Thorsten Kurth committed
122
        quad_weights = win.reshape(-1, 1) / nlon_in / 2.0
123

124
125
    # array for accumulating non-zero indices
    out = torch.zeros(kernel_size, nlat_out, nlon_out, nlat_in, nlon_in, dtype=torch.float64)
Boris Bonev's avatar
Boris Bonev committed
126
127
128
129
130
131
132
133
134
135
136
137

    for t in range(nlat_out):
        for p in range(nlon_out):
            alpha = -lats_out[t]
            beta = lons_in - lons_out[p]
            gamma = lats_in.reshape(-1, 1)

            # compute latitude of the rotated position
            z = -torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(alpha) * torch.cos(gamma)

            # compute cartesian coordinates of the rotated position
            x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(gamma) * torch.sin(alpha)
138
            y = torch.sin(beta) * torch.sin(gamma) * torch.ones_like(alpha)
Boris Bonev's avatar
Boris Bonev committed
139
140
141
142
143
144
145
146
147

            # normalize instead of clipping to ensure correct range
            norm = torch.sqrt(x * x + y * y + z * z)
            x = x / norm
            y = y / norm
            z = z / norm

            # compute spherical coordinates
            theta = torch.arccos(z)
148
149
            phi = torch.arctan2(y, x)
            phi = torch.where(phi < 0.0, phi + 2 * torch.pi, phi)
Boris Bonev's avatar
Boris Bonev committed
150
151

            # find the indices where the rotated position falls into the support of the kernel
152
            iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff_eff)
153
            out[iidx[:, 0], t, p, iidx[:, 1], iidx[:, 2]] = vals
Boris Bonev's avatar
Boris Bonev committed
154

155
    # take care of normalization and cast to float
156
157
158
    out = _normalize_convolution_tensor_dense(
        out, quad_weights=quad_weights, transpose_normalization=transpose_normalization, basis_norm_mode=basis_norm_mode, merge_quadrature=merge_quadrature
    )
159
    out = out.to(dtype=torch.float32)
Boris Bonev's avatar
Boris Bonev committed
160

Boris Bonev's avatar
Boris Bonev committed
161
162
163
164
165
166
    return out


class TestDiscreteContinuousConvolution(unittest.TestCase):
    def setUp(self):
        if torch.cuda.is_available():
167
168
169
            self.device = torch.device("cuda:0")
            torch.cuda.set_device(self.device.index)
            torch.cuda.manual_seed(333)
Boris Bonev's avatar
Boris Bonev committed
170
171
        else:
            self.device = torch.device("cpu")
Boris Bonev's avatar
Boris Bonev committed
172

Boris Bonev's avatar
Boris Bonev committed
173
174
175
    @parameterized.expand(
        [
            # regular convolution
Thorsten Kurth's avatar
Thorsten Kurth committed
176
177
178
179
180
181
182
183
184
185
186
            [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
            [8, 4, 2, (16, 32), (8, 16), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
            [8, 4, 2, (24, 48), (12, 24), (3, 3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
            [8, 4, 2, (24, 48), (12, 24), (4, 3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
            [8, 4, 2, (24, 48), (12, 24), (2, 2), "morlet", "mean", "equiangular", "equiangular", False, 1e-4, False],
            [8, 4, 2, (24, 48), (12, 24), (3), "zernike", "mean", "equiangular", "equiangular", False, 1e-4, False],
            [8, 4, 2, (16, 24), (8, 8), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
            [8, 4, 2, (18, 36), (6, 12), (7), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
            [8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "equiangular", "legendre-gauss", False, 1e-4, False],
            [8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "legendre-gauss", "equiangular", False, 1e-4, False],
            [8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", False, 1e-4, False],
Boris Bonev's avatar
Boris Bonev committed
187
            # transpose convolution
Thorsten Kurth's avatar
Thorsten Kurth committed
188
189
190
191
192
193
194
195
196
197
198
            [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
            [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
            [8, 4, 2, (12, 24), (24, 48), (3, 3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
            [8, 4, 2, (12, 24), (24, 48), (4, 3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
            [8, 4, 2, (12, 24), (24, 48), (2, 2), "morlet", "mean", "equiangular", "equiangular", True, 1e-4, False],
            [8, 4, 2, (12, 24), (24, 48), (3), "zernike", "mean", "equiangular", "equiangular", True, 1e-4, False],
            [8, 4, 2, (8, 8), (16, 24), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
            [8, 4, 2, (6, 12), (18, 36), (7), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
            [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "equiangular", "legendre-gauss", True, 1e-4, False],
            [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "equiangular", True, 1e-4, False],
            [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", True, 1e-4, False],
Boris Bonev's avatar
Boris Bonev committed
199
200
201
202
203
204
205
206
207
208
        ]
    )
    def test_disco_convolution(
        self,
        batch_size,
        in_channels,
        out_channels,
        in_shape,
        out_shape,
        kernel_shape,
209
210
        basis_type,
        basis_norm_mode,
Boris Bonev's avatar
Boris Bonev committed
211
212
213
214
        grid_in,
        grid_out,
        transpose,
        tol,
Thorsten Kurth's avatar
Thorsten Kurth committed
215
        verbose,
Boris Bonev's avatar
Boris Bonev committed
216
    ):
Thorsten Kurth's avatar
Thorsten Kurth committed
217
218
219
220

        if verbose:
            print(f"Testing DISCO convolution on {in_shape[0]}x{in_shape[1]} {grid_in} grid to {out_shape[0]}x{out_shape[1]} {grid_out} grid on {self.device.type} device")
        
Boris Bonev's avatar
Boris Bonev committed
221
222
223
        nlat_in, nlon_in = in_shape
        nlat_out, nlon_out = out_shape

Thorsten Kurth's avatar
Thorsten Kurth committed
224
225
226
227
        if isinstance(kernel_shape, int):
            theta_cutoff = (kernel_shape + 1) * torch.pi / float(nlat_in - 1)
        else:
            theta_cutoff = (kernel_shape[0] + 1) * torch.pi / float(nlat_in - 1)
Boris Bonev's avatar
Boris Bonev committed
228

Boris Bonev's avatar
Boris Bonev committed
229
        Conv = DiscreteContinuousConvTransposeS2 if transpose else DiscreteContinuousConvS2
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
        conv = Conv(
            in_channels,
            out_channels,
            in_shape,
            out_shape,
            kernel_shape,
            basis_type=basis_type,
            basis_norm_mode=basis_norm_mode,
            groups=1,
            grid_in=grid_in,
            grid_out=grid_out,
            bias=False,
            theta_cutoff=theta_cutoff,
        ).to(self.device)

        filter_basis = conv.filter_basis
Boris Bonev's avatar
Boris Bonev committed
246
247

        if transpose:
248
            psi_dense = _precompute_convolution_tensor_dense(
249
250
251
252
253
254
255
256
257
                out_shape,
                in_shape,
                filter_basis,
                grid_in=grid_out,
                grid_out=grid_in,
                theta_cutoff=theta_cutoff,
                transpose_normalization=transpose,
                basis_norm_mode=basis_norm_mode,
                merge_quadrature=True,
258
            ).to(self.device)
Boris Bonev's avatar
Boris Bonev committed
259
260
261
262

            psi = torch.sparse_coo_tensor(conv.psi_idx, conv.psi_vals, size=(conv.kernel_size, conv.nlat_in, conv.nlat_out * conv.nlon_out)).to_dense()

            self.assertTrue(torch.allclose(psi, psi_dense[:, :, 0].reshape(-1, nlat_in, nlat_out * nlon_out)))
Boris Bonev's avatar
Boris Bonev committed
263
        else:
264
            psi_dense = _precompute_convolution_tensor_dense(
265
266
267
268
269
270
271
272
273
                in_shape,
                out_shape,
                filter_basis,
                grid_in=grid_in,
                grid_out=grid_out,
                theta_cutoff=theta_cutoff,
                transpose_normalization=transpose,
                basis_norm_mode=basis_norm_mode,
                merge_quadrature=True,
274
            ).to(self.device)
Boris Bonev's avatar
Boris Bonev committed
275

Boris Bonev's avatar
Boris Bonev committed
276
277
            psi = torch.sparse_coo_tensor(conv.psi_idx, conv.psi_vals, size=(conv.kernel_size, conv.nlat_out, conv.nlat_in * conv.nlon_in)).to_dense()

Boris Bonev's avatar
Boris Bonev committed
278
            self.assertTrue(torch.allclose(psi, psi_dense[:, :, 0].reshape(-1, nlat_out, nlat_in * nlon_in)))
Boris Bonev's avatar
Boris Bonev committed
279
280

        # create a copy of the weight
281
282
283
284
        w_ref = torch.empty_like(conv.weight)
        with torch.no_grad():
            w_ref.copy_(conv.weight)
        w_ref.requires_grad = True
Boris Bonev's avatar
Boris Bonev committed
285
286

        # create an input signal
287
288
289
290
291
292
293
294
        x = torch.randn(batch_size, in_channels, *in_shape, device=self.device)

        # FWD and BWD pass
        x.requires_grad = True
        y = conv(x)
        grad_input = torch.randn_like(y)
        y.backward(grad_input)
        x_grad = x.grad.clone()
Boris Bonev's avatar
Boris Bonev committed
295
296
297

        # perform the reference computation
        x_ref = x.clone().detach()
298
        x_ref.requires_grad = True
Boris Bonev's avatar
Boris Bonev committed
299
300
        if transpose:
            y_ref = torch.einsum("oif,biqr->bofqr", w_ref, x_ref)
301
            y_ref = torch.einsum("fqrtp,bofqr->botp", psi_dense, y_ref)
Boris Bonev's avatar
Boris Bonev committed
302
        else:
303
            y_ref = torch.einsum("ftpqr,bcqr->bcftp", psi_dense, x_ref)
Boris Bonev's avatar
Boris Bonev committed
304
            y_ref = torch.einsum("oif,biftp->botp", w_ref, y_ref)
305
        y_ref.backward(grad_input)
Boris Bonev's avatar
Boris Bonev committed
306
307
        x_ref_grad = x_ref.grad.clone()

Boris Bonev's avatar
Boris Bonev committed
308
309
310
        # compare results
        self.assertTrue(torch.allclose(y, y_ref, rtol=tol, atol=tol))

Boris Bonev's avatar
Boris Bonev committed
311
        # compare
312
        self.assertTrue(torch.allclose(x_grad, x_ref_grad, rtol=tol, atol=tol))
Boris Bonev's avatar
Boris Bonev committed
313
314
        self.assertTrue(torch.allclose(conv.weight.grad, w_ref.grad, rtol=tol, atol=tol))

Boris Bonev's avatar
Boris Bonev committed
315

Boris Bonev's avatar
Boris Bonev committed
316
317
if __name__ == "__main__":
    unittest.main()