test_convolution.py 16.1 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import unittest
from parameterized import parameterized
from functools import partial
import math
import numpy as np
import torch
from torch.autograd import gradcheck
from torch_harmonics import *

Boris Bonev's avatar
Boris Bonev committed
41
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes
Boris Bonev's avatar
Boris Bonev committed
42

43

44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# def _compute_vals_isotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, r_cutoff: float):
#     """
#     helper routine to compute the values of the isotropic kernel densely
#     """

#     kernel_size = (nr // 2) + nr % 2
#     ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
#     dr = 2 * r_cutoff / (nr + 1)

#     # compute the support
#     if nr % 2 == 1:
#         ir = ikernel * dr
#     else:
#         ir = (ikernel + 0.5) * dr

#     vals = torch.where(
#         ((r - ir).abs() <= dr) & (r <= r_cutoff),
#         (1 - (r - ir).abs() / dr),
#         0,
#     )

#     return vals


# def _compute_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, nphi: int, r_cutoff: float):
#     """
#     helper routine to compute the values of the anisotropic kernel densely
#     """

#     kernel_size = (nr // 2) * nphi + nr % 2
#     ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
#     dr = 2 * r_cutoff / (nr + 1)
#     dphi = 2.0 * math.pi / nphi

#     # disambiguate even and uneven cases and compute the support
#     if nr % 2 == 1:
#         ir = ((ikernel - 1) // nphi + 1) * dr
#         iphi = ((ikernel - 1) % nphi) * dphi
#     else:
#         ir = (ikernel // nphi + 0.5) * dr
#         iphi = (ikernel % nphi) * dphi

#     # compute the value of the filter
#     if nr % 2 == 1:
#         # find the indices where the rotated position falls into the support of the kernel
#         cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
#         cond_phi = ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi)
#         r_vals = torch.where(cond_r, (1 - (r - ir).abs() / dr), 0.0)
#         phi_vals = torch.where(cond_phi, (1 - torch.minimum((phi - iphi).abs(), (2 * math.pi - (phi - iphi).abs())) / dphi), 0.0)
#         vals = torch.where(ikernel > 0, r_vals * phi_vals, r_vals)
#     else:
#         # find the indices where the rotated position falls into the support of the kernel
#         cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
#         cond_phi = ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi)
#         r_vals = torch.where(cond_r, (1 - (r - ir).abs() / dr), 0.0)
#         phi_vals = torch.where(cond_phi, (1 - torch.minimum((phi - iphi).abs(), (2 * math.pi - (phi - iphi).abs())) / dphi), 0.0)
#         vals = r_vals * phi_vals

#         # in the even case, the inner casis functions overlap into areas with a negative areas
#         rn = -r
#         phin = torch.where(phi + math.pi >= 2 * math.pi, phi - math.pi, phi + math.pi)
#         cond_rn = ((rn - ir).abs() <= dr) & (rn <= r_cutoff)
#         cond_phin = ((phin - iphi).abs() <= dphi) | ((2 * math.pi - (phin - iphi).abs()) <= dphi)
#         rn_vals = torch.where(cond_rn, (1 - (rn - ir).abs() / dr), 0.0)
#         phin_vals = torch.where(cond_phin, (1 - torch.minimum((phin - iphi).abs(), (2 * math.pi - (phin - iphi).abs())) / dphi), 0.0)
#         vals += rn_vals * phin_vals

#     return vals


def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalization=False, basis_norm_mode="none", merge_quadrature=False, eps=1e-9):
Boris Bonev's avatar
Boris Bonev committed
115
116
117
118
119
    """
    Discretely normalizes the convolution tensor.
    """

    kernel_size, nlat_out, nlon_out, nlat_in, nlon_in = psi.shape
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
    correction_factor = nlon_out / nlon_in

    if basis_norm_mode == "individual":
        if transpose_normalization:
            # the normalization is not quite symmetric due to the compressed way psi is stored in the main code
            # look at the normalization code in the actual implementation
            psi_norm = torch.sqrt(torch.sum(quad_weights.reshape(1, -1, 1, 1, 1) * psi[:, :, :1].abs().pow(2), dim=(1, 4), keepdim=True) / 4 / math.pi)
        else:
            psi_norm = torch.sqrt(torch.sum(quad_weights.reshape(1, 1, 1, -1, 1) * psi.abs().pow(2), dim=(3, 4), keepdim=True) / 4 / math.pi)

    elif basis_norm_mode == "mean":
        if transpose_normalization:
            # the normalization is not quite symmetric due to the compressed way psi is stored in the main code
            # look at the normalization code in the actual implementation
            psi_norm = torch.sqrt(torch.sum(quad_weights.reshape(1, -1, 1, 1, 1) * psi[:, :, :1].abs().pow(2), dim=(1, 4), keepdim=True) / 4 / math.pi)
            psi_norm = psi_norm.mean(dim=3, keepdim=True)
        else:
            psi_norm = torch.sqrt(torch.sum(quad_weights.reshape(1, 1, 1, -1, 1) * psi.abs().pow(2), dim=(3, 4), keepdim=True) / 4 / math.pi)
            psi_norm = psi_norm.mean(dim=1, keepdim=True)
    elif basis_norm_mode == "none":
        psi_norm = 1.0
    else:
        raise ValueError(f"Unknown basis normalization mode {basis_norm_mode}.")
Boris Bonev's avatar
Boris Bonev committed
143
144

    if transpose_normalization:
145
        if merge_quadrature:
146
            psi = quad_weights.reshape(1, -1, 1, 1, 1) * psi / correction_factor
Boris Bonev's avatar
Boris Bonev committed
147
    else:
148
149
        if merge_quadrature:
            psi = quad_weights.reshape(1, 1, 1, -1, 1) * psi
Boris Bonev's avatar
Boris Bonev committed
150
151
152
153

    return psi / (psi_norm + eps)


154
155
156
157
def _precompute_convolution_tensor_dense(
    in_shape,
    out_shape,
    kernel_shape,
158
    filter_basis,
159
160
161
162
    grid_in="equiangular",
    grid_out="equiangular",
    theta_cutoff=0.01 * math.pi,
    transpose_normalization=False,
163
    basis_norm_mode="none",
164
165
    merge_quadrature=False,
):
Boris Bonev's avatar
Boris Bonev committed
166
167
168
169
170
171
172
    """
    Helper routine to compute the convolution Tensor in a dense fashion
    """

    assert len(in_shape) == 2
    assert len(out_shape) == 2

173
    kernel_size = filter_basis.kernel_size
Boris Bonev's avatar
Boris Bonev committed
174
175
176
177

    nlat_in, nlon_in = in_shape
    nlat_out, nlon_out = out_shape

178
    lats_in, win = quadrature._precompute_latitudes(nlat_in, grid=grid_in)
Boris Bonev's avatar
Boris Bonev committed
179
    lats_in = torch.from_numpy(lats_in).float()
180
    lats_out, wout = quadrature._precompute_latitudes(nlat_out, grid=grid_out)
Boris Bonev's avatar
Boris Bonev committed
181
182
183
184
185
186
    lats_out = torch.from_numpy(lats_out).float()  # array for accumulating non-zero indices

    # compute the phi differences. We need to make the linspace exclusive to not double the last point
    lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1]
    lons_out = torch.linspace(0, 2 * math.pi, nlon_out + 1)[:-1]

187
188
189
190
191
192
    # compute quadrature weights that will be merged into the Psi tensor
    if transpose_normalization:
        quad_weights = 2.0 * torch.pi * torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in
    else:
        quad_weights = 2.0 * torch.pi * torch.from_numpy(win).float().reshape(-1, 1) / nlon_in

Boris Bonev's avatar
Boris Bonev committed
193
194
195
196
197
198
199
200
201
202
203
204
205
    out = torch.zeros(kernel_size, nlat_out, nlon_out, nlat_in, nlon_in)

    for t in range(nlat_out):
        for p in range(nlon_out):
            alpha = -lats_out[t]
            beta = lons_in - lons_out[p]
            gamma = lats_in.reshape(-1, 1)

            # compute latitude of the rotated position
            z = -torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(alpha) * torch.cos(gamma)

            # compute cartesian coordinates of the rotated position
            x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(gamma) * torch.sin(alpha)
206
            y = torch.sin(beta) * torch.sin(gamma) * torch.ones_like(alpha)
Boris Bonev's avatar
Boris Bonev committed
207
208
209
210
211
212
213
214
215

            # normalize instead of clipping to ensure correct range
            norm = torch.sqrt(x * x + y * y + z * z)
            x = x / norm
            y = y / norm
            z = z / norm

            # compute spherical coordinates
            theta = torch.arccos(z)
216
217
            phi = torch.arctan2(y, x)
            phi = torch.where(phi < 0.0, phi + 2 * torch.pi, phi)
Boris Bonev's avatar
Boris Bonev committed
218
219

            # find the indices where the rotated position falls into the support of the kernel
220
221
            iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff)
            out[iidx[:, 0], t, p, iidx[:, 1], iidx[:, 2]] = vals
Boris Bonev's avatar
Boris Bonev committed
222

Boris Bonev's avatar
Boris Bonev committed
223
    # take care of normalization
224
225
226
    out = _normalize_convolution_tensor_dense(
        out, quad_weights=quad_weights, transpose_normalization=transpose_normalization, basis_norm_mode=basis_norm_mode, merge_quadrature=merge_quadrature
    )
Boris Bonev's avatar
Boris Bonev committed
227

Boris Bonev's avatar
Boris Bonev committed
228
229
230
231
232
233
    return out


class TestDiscreteContinuousConvolution(unittest.TestCase):
    def setUp(self):
        if torch.cuda.is_available():
234
235
236
            self.device = torch.device("cuda:0")
            torch.cuda.set_device(self.device.index)
            torch.cuda.manual_seed(333)
Boris Bonev's avatar
Boris Bonev committed
237
238
        else:
            self.device = torch.device("cpu")
Boris Bonev's avatar
Boris Bonev committed
239

240
        self.device = torch.device("cpu")
Boris Bonev's avatar
Boris Bonev committed
241

Boris Bonev's avatar
Boris Bonev committed
242
243
244
    @parameterized.expand(
        [
            # regular convolution
245
246
247
248
249
250
251
252
253
254
            [8, 4, 2, (16, 32), (16, 32), [3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
            [8, 4, 2, (16, 32), (8, 16), [3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
            [8, 4, 2, (24, 48), (12, 24), [3, 3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
            [8, 4, 2, (24, 48), (12, 24), [4, 3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
            [8, 4, 2, (24, 48), (12, 24), [2, 2], "morlet", "mean", "equiangular", "equiangular", False, 1e-4],
            [8, 4, 2, (16, 24), (8, 8), [3], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
            [8, 4, 2, (18, 36), (6, 12), [7], "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4],
            [8, 4, 2, (16, 32), (8, 16), [5], "piecewise linear", "mean", "equiangular", "legendre-gauss", False, 1e-4],
            [8, 4, 2, (16, 32), (8, 16), [5], "piecewise linear", "mean", "legendre-gauss", "equiangular", False, 1e-4],
            [8, 4, 2, (16, 32), (8, 16), [5], "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", False, 1e-4],
Boris Bonev's avatar
Boris Bonev committed
255
            # transpose convolution
256
257
258
259
260
261
262
263
264
265
            [8, 4, 2, (16, 32), (16, 32), [3], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
            [8, 4, 2, (8, 16), (16, 32), [5], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
            [8, 4, 2, (12, 24), (24, 48), [3, 3], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
            [8, 4, 2, (12, 24), (24, 48), [4, 3], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
            [8, 4, 2, (12, 24), (24, 48), [2, 2], "morlet", "mean", "equiangular", "equiangular", True, 1e-4],
            [8, 4, 2, (8, 8), (16, 24), [3], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
            [8, 4, 2, (6, 12), (18, 36), [7], "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4],
            [8, 4, 2, (8, 16), (16, 32), [5], "piecewise linear", "mean", "equiangular", "legendre-gauss", True, 1e-4],
            [8, 4, 2, (8, 16), (16, 32), [5], "piecewise linear", "mean", "legendre-gauss", "equiangular", True, 1e-4],
            [8, 4, 2, (8, 16), (16, 32), [5], "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", True, 1e-4],
Boris Bonev's avatar
Boris Bonev committed
266
267
268
269
270
271
272
273
274
275
        ]
    )
    def test_disco_convolution(
        self,
        batch_size,
        in_channels,
        out_channels,
        in_shape,
        out_shape,
        kernel_shape,
276
277
        basis_type,
        basis_norm_mode,
Boris Bonev's avatar
Boris Bonev committed
278
279
280
281
282
        grid_in,
        grid_out,
        transpose,
        tol,
    ):
Boris Bonev's avatar
Boris Bonev committed
283
284
285
        nlat_in, nlon_in = in_shape
        nlat_out, nlon_out = out_shape

286
        theta_cutoff = (kernel_shape[0] + 1) * torch.pi / float(nlat_in - 1)
Boris Bonev's avatar
Boris Bonev committed
287

Boris Bonev's avatar
Boris Bonev committed
288
        Conv = DiscreteContinuousConvTransposeS2 if transpose else DiscreteContinuousConvS2
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
        conv = Conv(
            in_channels,
            out_channels,
            in_shape,
            out_shape,
            kernel_shape,
            basis_type=basis_type,
            basis_norm_mode=basis_norm_mode,
            groups=1,
            grid_in=grid_in,
            grid_out=grid_out,
            bias=False,
            theta_cutoff=theta_cutoff,
        ).to(self.device)

        filter_basis = conv.filter_basis
Boris Bonev's avatar
Boris Bonev committed
305
306

        if transpose:
307
            psi_dense = _precompute_convolution_tensor_dense(
308
309
310
311
312
313
314
315
316
317
                out_shape,
                in_shape,
                kernel_shape,
                filter_basis,
                grid_in=grid_out,
                grid_out=grid_in,
                theta_cutoff=theta_cutoff,
                transpose_normalization=transpose,
                basis_norm_mode=basis_norm_mode,
                merge_quadrature=True,
318
            ).to(self.device)
Boris Bonev's avatar
Boris Bonev committed
319
320
321
322

            psi = torch.sparse_coo_tensor(conv.psi_idx, conv.psi_vals, size=(conv.kernel_size, conv.nlat_in, conv.nlat_out * conv.nlon_out)).to_dense()

            self.assertTrue(torch.allclose(psi, psi_dense[:, :, 0].reshape(-1, nlat_in, nlat_out * nlon_out)))
Boris Bonev's avatar
Boris Bonev committed
323
        else:
324
            psi_dense = _precompute_convolution_tensor_dense(
325
326
327
328
329
330
331
332
333
334
                in_shape,
                out_shape,
                kernel_shape,
                filter_basis,
                grid_in=grid_in,
                grid_out=grid_out,
                theta_cutoff=theta_cutoff,
                transpose_normalization=transpose,
                basis_norm_mode=basis_norm_mode,
                merge_quadrature=True,
335
            ).to(self.device)
Boris Bonev's avatar
Boris Bonev committed
336

Boris Bonev's avatar
Boris Bonev committed
337
338
            psi = torch.sparse_coo_tensor(conv.psi_idx, conv.psi_vals, size=(conv.kernel_size, conv.nlat_out, conv.nlat_in * conv.nlon_in)).to_dense()

Boris Bonev's avatar
Boris Bonev committed
339
            self.assertTrue(torch.allclose(psi, psi_dense[:, :, 0].reshape(-1, nlat_out, nlat_in * nlon_in)))
Boris Bonev's avatar
Boris Bonev committed
340
341

        # create a copy of the weight
342
343
344
345
        w_ref = torch.empty_like(conv.weight)
        with torch.no_grad():
            w_ref.copy_(conv.weight)
        w_ref.requires_grad = True
Boris Bonev's avatar
Boris Bonev committed
346
347

        # create an input signal
348
349
350
351
352
353
354
355
        x = torch.randn(batch_size, in_channels, *in_shape, device=self.device)

        # FWD and BWD pass
        x.requires_grad = True
        y = conv(x)
        grad_input = torch.randn_like(y)
        y.backward(grad_input)
        x_grad = x.grad.clone()
Boris Bonev's avatar
Boris Bonev committed
356
357
358

        # perform the reference computation
        x_ref = x.clone().detach()
359
        x_ref.requires_grad = True
Boris Bonev's avatar
Boris Bonev committed
360
361
        if transpose:
            y_ref = torch.einsum("oif,biqr->bofqr", w_ref, x_ref)
362
            y_ref = torch.einsum("fqrtp,bofqr->botp", psi_dense, y_ref)
Boris Bonev's avatar
Boris Bonev committed
363
        else:
364
            y_ref = torch.einsum("ftpqr,bcqr->bcftp", psi_dense, x_ref)
Boris Bonev's avatar
Boris Bonev committed
365
            y_ref = torch.einsum("oif,biftp->botp", w_ref, y_ref)
366
        y_ref.backward(grad_input)
Boris Bonev's avatar
Boris Bonev committed
367
368
        x_ref_grad = x_ref.grad.clone()

Boris Bonev's avatar
Boris Bonev committed
369
370
371
        # compare results
        self.assertTrue(torch.allclose(y, y_ref, rtol=tol, atol=tol))

Boris Bonev's avatar
Boris Bonev committed
372
        # compare
373
        self.assertTrue(torch.allclose(x_grad, x_ref_grad, rtol=tol, atol=tol))
Boris Bonev's avatar
Boris Bonev committed
374
375
        self.assertTrue(torch.allclose(conv.weight.grad, w_ref.grad, rtol=tol, atol=tol))

Boris Bonev's avatar
Boris Bonev committed
376

Boris Bonev's avatar
Boris Bonev committed
377
378
if __name__ == "__main__":
    unittest.main()