test_convolution.py 16.7 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import unittest
33
from parameterized import parameterized, parameterized_class
Boris Bonev's avatar
Boris Bonev committed
34
35
36
37
38
from functools import partial
import math
import numpy as np
import torch
from torch.autograd import gradcheck
Boris Bonev's avatar
Boris Bonev committed
39
from torch_harmonics import quadrature, DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
Boris Bonev's avatar
Boris Bonev committed
40

Thorsten Kurth's avatar
Thorsten Kurth committed
41
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes
Thorsten Kurth's avatar
Thorsten Kurth committed
42
from torch_harmonics.convolution import _precompute_convolution_tensor_s2
43

44
45
46
47
48
_devices = [(torch.device("cpu"),)]
if torch.cuda.is_available():
    _devices.append((torch.device("cuda"),))


49
def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalization=False, basis_norm_mode="none", merge_quadrature=False, eps=1e-9):
Boris Bonev's avatar
Boris Bonev committed
50
51
52
53
54
    """
    Discretely normalizes the convolution tensor.
    """

    kernel_size, nlat_out, nlon_out, nlat_in, nlon_in = psi.shape
55
56
57
58
59
60
    correction_factor = nlon_out / nlon_in

    if basis_norm_mode == "individual":
        if transpose_normalization:
            # the normalization is not quite symmetric due to the compressed way psi is stored in the main code
            # look at the normalization code in the actual implementation
Boris Bonev's avatar
Boris Bonev committed
61
            psi_norm = torch.sum(quad_weights.reshape(1, -1, 1, 1, 1) * psi[:, :, :1].abs(), dim=(1, 4), keepdim=True)
62
        else:
Boris Bonev's avatar
Boris Bonev committed
63
            psi_norm = torch.sum(quad_weights.reshape(1, 1, 1, -1, 1) * psi.abs(), dim=(3, 4), keepdim=True)
64
65
66
67
68

    elif basis_norm_mode == "mean":
        if transpose_normalization:
            # the normalization is not quite symmetric due to the compressed way psi is stored in the main code
            # look at the normalization code in the actual implementation
Boris Bonev's avatar
Boris Bonev committed
69
            psi_norm = torch.sum(quad_weights.reshape(1, -1, 1, 1, 1) * psi[:, :, :1].abs(), dim=(1, 4), keepdim=True)
70
71
            psi_norm = psi_norm.mean(dim=3, keepdim=True)
        else:
Boris Bonev's avatar
Boris Bonev committed
72
            psi_norm = torch.sum(quad_weights.reshape(1, 1, 1, -1, 1) * psi.abs(), dim=(3, 4), keepdim=True)
73
74
75
76
77
            psi_norm = psi_norm.mean(dim=1, keepdim=True)
    elif basis_norm_mode == "none":
        psi_norm = 1.0
    else:
        raise ValueError(f"Unknown basis normalization mode {basis_norm_mode}.")
Boris Bonev's avatar
Boris Bonev committed
78
79

    if transpose_normalization:
80
        if merge_quadrature:
81
            psi = quad_weights.reshape(1, -1, 1, 1, 1) * psi / correction_factor
Boris Bonev's avatar
Boris Bonev committed
82
    else:
83
84
        if merge_quadrature:
            psi = quad_weights.reshape(1, 1, 1, -1, 1) * psi
Boris Bonev's avatar
Boris Bonev committed
85
86
87
88

    return psi / (psi_norm + eps)


89
90
91
def _precompute_convolution_tensor_dense(
    in_shape,
    out_shape,
92
    filter_basis,
93
94
95
    grid_in="equiangular",
    grid_out="equiangular",
    theta_cutoff=0.01 * math.pi,
96
    theta_eps=1e-3,
97
    transpose_normalization=False,
98
    basis_norm_mode="none",
99
100
    merge_quadrature=False,
):
Boris Bonev's avatar
Boris Bonev committed
101
102
103
104
105
106
107
    """
    Helper routine to compute the convolution Tensor in a dense fashion
    """

    assert len(in_shape) == 2
    assert len(out_shape) == 2

108
    kernel_size = filter_basis.kernel_size
Boris Bonev's avatar
Boris Bonev committed
109
110
111
112

    nlat_in, nlon_in = in_shape
    nlat_out, nlon_out = out_shape

113
114
    lats_in, win = quadrature._precompute_latitudes(nlat_in, grid=grid_in)
    lats_out, wout = quadrature._precompute_latitudes(nlat_out, grid=grid_out)
Boris Bonev's avatar
Boris Bonev committed
115

Thorsten Kurth's avatar
Thorsten Kurth committed
116
117
118
    # compute the phi differences.
    lons_in = _precompute_longitudes(nlon_in)
    lons_out = _precompute_longitudes(nlon_out)
119
120
121

    # effective theta cutoff if multiplied with a fudge factor to avoid aliasing with grid width (especially near poles)
    theta_cutoff_eff = (1.0 + theta_eps) * theta_cutoff
Boris Bonev's avatar
Boris Bonev committed
122

123
124
    # compute quadrature weights that will be merged into the Psi tensor
    if transpose_normalization:
Thorsten Kurth's avatar
Thorsten Kurth committed
125
        quad_weights = wout.reshape(-1, 1) / nlon_in / 2.0
126
    else:
Thorsten Kurth's avatar
Thorsten Kurth committed
127
        quad_weights = win.reshape(-1, 1) / nlon_in / 2.0
128

129
    # array for accumulating non-zero indices
130
    out = torch.zeros(kernel_size, nlat_out, nlon_out, nlat_in, nlon_in, dtype=torch.float64, device=lons_in.device)
Boris Bonev's avatar
Boris Bonev committed
131
132
133
134
135
136
137
138
139
140
141
142

    for t in range(nlat_out):
        for p in range(nlon_out):
            alpha = -lats_out[t]
            beta = lons_in - lons_out[p]
            gamma = lats_in.reshape(-1, 1)

            # compute latitude of the rotated position
            z = -torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(alpha) * torch.cos(gamma)

            # compute cartesian coordinates of the rotated position
            x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(gamma) * torch.sin(alpha)
143
            y = torch.sin(beta) * torch.sin(gamma) * torch.ones_like(alpha)
Boris Bonev's avatar
Boris Bonev committed
144
145
146
147
148
149
150
151
152

            # normalize instead of clipping to ensure correct range
            norm = torch.sqrt(x * x + y * y + z * z)
            x = x / norm
            y = y / norm
            z = z / norm

            # compute spherical coordinates
            theta = torch.arccos(z)
153
154
            phi = torch.arctan2(y, x)
            phi = torch.where(phi < 0.0, phi + 2 * torch.pi, phi)
Boris Bonev's avatar
Boris Bonev committed
155
156

            # find the indices where the rotated position falls into the support of the kernel
157
            iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff_eff)
158
            out[iidx[:, 0], t, p, iidx[:, 1], iidx[:, 2]] = vals
Boris Bonev's avatar
Boris Bonev committed
159

160
    # take care of normalization and cast to float
161
162
163
    out = _normalize_convolution_tensor_dense(
        out, quad_weights=quad_weights, transpose_normalization=transpose_normalization, basis_norm_mode=basis_norm_mode, merge_quadrature=merge_quadrature
    )
164
    out = out.to(dtype=torch.float32)
Boris Bonev's avatar
Boris Bonev committed
165

Boris Bonev's avatar
Boris Bonev committed
166
167
168
    return out


169
@parameterized_class(("device"), _devices)
Boris Bonev's avatar
Boris Bonev committed
170
171
class TestDiscreteContinuousConvolution(unittest.TestCase):
    def setUp(self):
172
173
        torch.manual_seed(333)
        if self.device.type == "cuda":
174
            torch.cuda.manual_seed(333)
Boris Bonev's avatar
Boris Bonev committed
175

Boris Bonev's avatar
Boris Bonev committed
176
177
178
    @parameterized.expand(
        [
            # regular convolution
Thorsten Kurth's avatar
Thorsten Kurth committed
179
180
181
182
183
184
185
186
187
188
189
            [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
            [8, 4, 2, (16, 32), (8, 16), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
            [8, 4, 2, (24, 48), (12, 24), (3, 3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
            [8, 4, 2, (24, 48), (12, 24), (4, 3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
            [8, 4, 2, (24, 48), (12, 24), (2, 2), "morlet", "mean", "equiangular", "equiangular", False, 1e-4, False],
            [8, 4, 2, (24, 48), (12, 24), (3), "zernike", "mean", "equiangular", "equiangular", False, 1e-4, False],
            [8, 4, 2, (16, 24), (8, 8), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
            [8, 4, 2, (18, 36), (6, 12), (7), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
            [8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "equiangular", "legendre-gauss", False, 1e-4, False],
            [8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "legendre-gauss", "equiangular", False, 1e-4, False],
            [8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", False, 1e-4, False],
Boris Bonev's avatar
Boris Bonev committed
190
            # transpose convolution
Thorsten Kurth's avatar
Thorsten Kurth committed
191
192
193
194
195
196
197
198
199
200
201
            [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
            [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
            [8, 4, 2, (12, 24), (24, 48), (3, 3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
            [8, 4, 2, (12, 24), (24, 48), (4, 3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
            [8, 4, 2, (12, 24), (24, 48), (2, 2), "morlet", "mean", "equiangular", "equiangular", True, 1e-4, False],
            [8, 4, 2, (12, 24), (24, 48), (3), "zernike", "mean", "equiangular", "equiangular", True, 1e-4, False],
            [8, 4, 2, (8, 8), (16, 24), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
            [8, 4, 2, (6, 12), (18, 36), (7), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
            [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "equiangular", "legendre-gauss", True, 1e-4, False],
            [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "equiangular", True, 1e-4, False],
            [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", True, 1e-4, False],
Thorsten Kurth's avatar
Thorsten Kurth committed
202
203
        ],
        skip_on_empty=True,
Boris Bonev's avatar
Boris Bonev committed
204
    )
Thorsten Kurth's avatar
Thorsten Kurth committed
205
    def test_forward_backward(
Boris Bonev's avatar
Boris Bonev committed
206
207
208
209
210
211
212
        self,
        batch_size,
        in_channels,
        out_channels,
        in_shape,
        out_shape,
        kernel_shape,
213
214
        basis_type,
        basis_norm_mode,
Boris Bonev's avatar
Boris Bonev committed
215
216
217
218
        grid_in,
        grid_out,
        transpose,
        tol,
Thorsten Kurth's avatar
Thorsten Kurth committed
219
        verbose,
Boris Bonev's avatar
Boris Bonev committed
220
    ):
Thorsten Kurth's avatar
Thorsten Kurth committed
221
222
223
224

        if verbose:
            print(f"Testing DISCO convolution on {in_shape[0]}x{in_shape[1]} {grid_in} grid to {out_shape[0]}x{out_shape[1]} {grid_out} grid on {self.device.type} device")
        
Boris Bonev's avatar
Boris Bonev committed
225
226
227
        nlat_in, nlon_in = in_shape
        nlat_out, nlon_out = out_shape

Thorsten Kurth's avatar
Thorsten Kurth committed
228
229
230
231
        if isinstance(kernel_shape, int):
            theta_cutoff = (kernel_shape + 1) * torch.pi / float(nlat_in - 1)
        else:
            theta_cutoff = (kernel_shape[0] + 1) * torch.pi / float(nlat_in - 1)
Boris Bonev's avatar
Boris Bonev committed
232

Boris Bonev's avatar
Boris Bonev committed
233
        Conv = DiscreteContinuousConvTransposeS2 if transpose else DiscreteContinuousConvS2
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
        conv = Conv(
            in_channels,
            out_channels,
            in_shape,
            out_shape,
            kernel_shape,
            basis_type=basis_type,
            basis_norm_mode=basis_norm_mode,
            groups=1,
            grid_in=grid_in,
            grid_out=grid_out,
            bias=False,
            theta_cutoff=theta_cutoff,
        ).to(self.device)

        filter_basis = conv.filter_basis
Boris Bonev's avatar
Boris Bonev committed
250
251

        if transpose:
252
            psi_dense = _precompute_convolution_tensor_dense(
253
254
255
256
257
258
259
260
261
                out_shape,
                in_shape,
                filter_basis,
                grid_in=grid_out,
                grid_out=grid_in,
                theta_cutoff=theta_cutoff,
                transpose_normalization=transpose,
                basis_norm_mode=basis_norm_mode,
                merge_quadrature=True,
262
            ).to(self.device)
Boris Bonev's avatar
Boris Bonev committed
263
264
265
266

            psi = torch.sparse_coo_tensor(conv.psi_idx, conv.psi_vals, size=(conv.kernel_size, conv.nlat_in, conv.nlat_out * conv.nlon_out)).to_dense()

            self.assertTrue(torch.allclose(psi, psi_dense[:, :, 0].reshape(-1, nlat_in, nlat_out * nlon_out)))
Boris Bonev's avatar
Boris Bonev committed
267
        else:
268
            psi_dense = _precompute_convolution_tensor_dense(
269
270
271
272
273
274
275
276
277
                in_shape,
                out_shape,
                filter_basis,
                grid_in=grid_in,
                grid_out=grid_out,
                theta_cutoff=theta_cutoff,
                transpose_normalization=transpose,
                basis_norm_mode=basis_norm_mode,
                merge_quadrature=True,
278
            ).to(self.device)
Boris Bonev's avatar
Boris Bonev committed
279

Boris Bonev's avatar
Boris Bonev committed
280
281
            psi = torch.sparse_coo_tensor(conv.psi_idx, conv.psi_vals, size=(conv.kernel_size, conv.nlat_out, conv.nlat_in * conv.nlon_in)).to_dense()

Boris Bonev's avatar
Boris Bonev committed
282
            self.assertTrue(torch.allclose(psi, psi_dense[:, :, 0].reshape(-1, nlat_out, nlat_in * nlon_in)))
Boris Bonev's avatar
Boris Bonev committed
283
284

        # create a copy of the weight
285
286
287
288
        w_ref = torch.empty_like(conv.weight)
        with torch.no_grad():
            w_ref.copy_(conv.weight)
        w_ref.requires_grad = True
Boris Bonev's avatar
Boris Bonev committed
289
290

        # create an input signal
291
292
293
294
295
296
297
298
        x = torch.randn(batch_size, in_channels, *in_shape, device=self.device)

        # FWD and BWD pass
        x.requires_grad = True
        y = conv(x)
        grad_input = torch.randn_like(y)
        y.backward(grad_input)
        x_grad = x.grad.clone()
Boris Bonev's avatar
Boris Bonev committed
299
300
301

        # perform the reference computation
        x_ref = x.clone().detach()
302
        x_ref.requires_grad = True
Boris Bonev's avatar
Boris Bonev committed
303
304
        if transpose:
            y_ref = torch.einsum("oif,biqr->bofqr", w_ref, x_ref)
305
            y_ref = torch.einsum("fqrtp,bofqr->botp", psi_dense, y_ref)
Boris Bonev's avatar
Boris Bonev committed
306
        else:
307
            y_ref = torch.einsum("ftpqr,bcqr->bcftp", psi_dense, x_ref)
Boris Bonev's avatar
Boris Bonev committed
308
            y_ref = torch.einsum("oif,biftp->botp", w_ref, y_ref)
309
        y_ref.backward(grad_input)
Boris Bonev's avatar
Boris Bonev committed
310
311
        x_ref_grad = x_ref.grad.clone()

Boris Bonev's avatar
Boris Bonev committed
312
313
314
        # compare results
        self.assertTrue(torch.allclose(y, y_ref, rtol=tol, atol=tol))

Boris Bonev's avatar
Boris Bonev committed
315
        # compare
316
        self.assertTrue(torch.allclose(x_grad, x_ref_grad, rtol=tol, atol=tol))
Boris Bonev's avatar
Boris Bonev committed
317
318
        self.assertTrue(torch.allclose(conv.weight.grad, w_ref.grad, rtol=tol, atol=tol))

319
320
321
322
323
324
    @parameterized.expand(
        [
            [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
            [8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", False, 1e-4, False],
            [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
            [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", True, 1e-4, False],
Thorsten Kurth's avatar
Thorsten Kurth committed
325
326
        ],
        skip_on_empty=True,
327
328
329
    )
    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is not available")
    def test_device_instantiation(self, batch_size, in_channels, out_channels, in_shape, out_shape, kernel_shape, basis_type, basis_norm_mode, grid_in, grid_out, transpose, tol, verbose):
Thorsten Kurth's avatar
Thorsten Kurth committed
330

331
332
333
334
335
336
337
338
        nlat_in, nlon_in = in_shape
        nlat_out, nlon_out = out_shape

        if isinstance(kernel_shape, int):
            theta_cutoff = (kernel_shape + 1) * torch.pi / float(nlat_in - 1)
        else:
            theta_cutoff = (kernel_shape[0] + 1) * torch.pi / float(nlat_in - 1)

Thorsten Kurth's avatar
Thorsten Kurth committed
339
        # get handle
340
        Conv = DiscreteContinuousConvTransposeS2 if transpose else DiscreteContinuousConvS2
Thorsten Kurth's avatar
Thorsten Kurth committed
341
342

        # init on cpu
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
        conv_host = Conv(
            in_channels,
            out_channels,
            in_shape,
            out_shape,
            kernel_shape,
            basis_type=basis_type,
            basis_norm_mode=basis_norm_mode,
            groups=1,
            grid_in=grid_in,
            grid_out=grid_out,
            bias=False,
            theta_cutoff=theta_cutoff,
        )

Thorsten Kurth's avatar
Thorsten Kurth committed
358
359
360
        #torch.set_default_device(self.device)
        with torch.device(self.device):
            conv_device = Conv(
361
362
363
364
365
366
367
368
369
370
371
372
373
374
                in_channels,
                out_channels,
                in_shape,
                out_shape,
                kernel_shape,
                basis_type=basis_type,
                basis_norm_mode=basis_norm_mode,
                groups=1,
                grid_in=grid_in,
                grid_out=grid_out,
                bias=False,
                theta_cutoff=theta_cutoff,
            )

Thorsten Kurth's avatar
Thorsten Kurth committed
375
376
        # since we specified the device specifier everywhere, it should always
        # use the cpu and it should be the same everywhere
377
378
379
380
381
382
        self.assertTrue(torch.allclose(conv_host.psi_col_idx.cpu(), conv_device.psi_col_idx.cpu()))
        self.assertTrue(torch.allclose(conv_host.psi_row_idx.cpu(), conv_device.psi_row_idx.cpu()))
        self.assertTrue(torch.allclose(conv_host.psi_roff_idx.cpu(), conv_device.psi_roff_idx.cpu()))
        self.assertTrue(torch.allclose(conv_host.psi_vals.cpu(), conv_device.psi_vals.cpu()))
        self.assertTrue(torch.allclose(conv_host.psi_idx.cpu(), conv_device.psi_idx.cpu()))

Boris Bonev's avatar
Boris Bonev committed
383

Boris Bonev's avatar
Boris Bonev committed
384
385
if __name__ == "__main__":
    unittest.main()