test_convolution.py 16.7 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import unittest
33
from parameterized import parameterized, parameterized_class
Boris Bonev's avatar
Boris Bonev committed
34
35
36
37
38
from functools import partial
import math
import numpy as np
import torch
from torch.autograd import gradcheck
Boris Bonev's avatar
Boris Bonev committed
39
from torch_harmonics import quadrature, DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
Boris Bonev's avatar
Boris Bonev committed
40

Thorsten Kurth's avatar
Thorsten Kurth committed
41
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes
Thorsten Kurth's avatar
Thorsten Kurth committed
42
from torch_harmonics.convolution import _precompute_convolution_tensor_s2
43

44
45
46
47
48
_devices = [(torch.device("cpu"),)]
if torch.cuda.is_available():
    _devices.append((torch.device("cuda"),))


49
def _normalize_convolution_tensor_dense(psi, quad_weights, transpose_normalization=False, basis_norm_mode="none", merge_quadrature=False, eps=1e-9):
Andrea Paris's avatar
Andrea Paris committed
50
    """Discretely normalizes the convolution tensor."""
51
    
Boris Bonev's avatar
Boris Bonev committed
52
    kernel_size, nlat_out, nlon_out, nlat_in, nlon_in = psi.shape
53
54
55
56
57
58
    correction_factor = nlon_out / nlon_in

    if basis_norm_mode == "individual":
        if transpose_normalization:
            # the normalization is not quite symmetric due to the compressed way psi is stored in the main code
            # look at the normalization code in the actual implementation
Boris Bonev's avatar
Boris Bonev committed
59
            psi_norm = torch.sum(quad_weights.reshape(1, -1, 1, 1, 1) * psi[:, :, :1].abs(), dim=(1, 4), keepdim=True)
60
        else:
Boris Bonev's avatar
Boris Bonev committed
61
            psi_norm = torch.sum(quad_weights.reshape(1, 1, 1, -1, 1) * psi.abs(), dim=(3, 4), keepdim=True)
62
63
64
65
66

    elif basis_norm_mode == "mean":
        if transpose_normalization:
            # the normalization is not quite symmetric due to the compressed way psi is stored in the main code
            # look at the normalization code in the actual implementation
Boris Bonev's avatar
Boris Bonev committed
67
            psi_norm = torch.sum(quad_weights.reshape(1, -1, 1, 1, 1) * psi[:, :, :1].abs(), dim=(1, 4), keepdim=True)
68
69
            psi_norm = psi_norm.mean(dim=3, keepdim=True)
        else:
Boris Bonev's avatar
Boris Bonev committed
70
            psi_norm = torch.sum(quad_weights.reshape(1, 1, 1, -1, 1) * psi.abs(), dim=(3, 4), keepdim=True)
71
72
73
74
75
            psi_norm = psi_norm.mean(dim=1, keepdim=True)
    elif basis_norm_mode == "none":
        psi_norm = 1.0
    else:
        raise ValueError(f"Unknown basis normalization mode {basis_norm_mode}.")
Boris Bonev's avatar
Boris Bonev committed
76
77

    if transpose_normalization:
78
        if merge_quadrature:
79
            psi = quad_weights.reshape(1, -1, 1, 1, 1) * psi / correction_factor
Boris Bonev's avatar
Boris Bonev committed
80
    else:
81
82
        if merge_quadrature:
            psi = quad_weights.reshape(1, 1, 1, -1, 1) * psi
Boris Bonev's avatar
Boris Bonev committed
83
84
85
86

    return psi / (psi_norm + eps)


87
88
89
def _precompute_convolution_tensor_dense(
    in_shape,
    out_shape,
90
    filter_basis,
91
92
93
    grid_in="equiangular",
    grid_out="equiangular",
    theta_cutoff=0.01 * math.pi,
94
    theta_eps=1e-3,
95
    transpose_normalization=False,
96
    basis_norm_mode="none",
97
98
    merge_quadrature=False,
):
Boris Bonev's avatar
Boris Bonev committed
99
100
101
    assert len(in_shape) == 2
    assert len(out_shape) == 2

102
    kernel_size = filter_basis.kernel_size
Boris Bonev's avatar
Boris Bonev committed
103
104
105
106

    nlat_in, nlon_in = in_shape
    nlat_out, nlon_out = out_shape

107
108
    lats_in, win = quadrature._precompute_latitudes(nlat_in, grid=grid_in)
    lats_out, wout = quadrature._precompute_latitudes(nlat_out, grid=grid_out)
Boris Bonev's avatar
Boris Bonev committed
109

Thorsten Kurth's avatar
Thorsten Kurth committed
110
111
112
    # compute the phi differences.
    lons_in = _precompute_longitudes(nlon_in)
    lons_out = _precompute_longitudes(nlon_out)
113
114
115

    # effective theta cutoff if multiplied with a fudge factor to avoid aliasing with grid width (especially near poles)
    theta_cutoff_eff = (1.0 + theta_eps) * theta_cutoff
Boris Bonev's avatar
Boris Bonev committed
116

117
118
    # compute quadrature weights that will be merged into the Psi tensor
    if transpose_normalization:
Thorsten Kurth's avatar
Thorsten Kurth committed
119
        quad_weights = wout.reshape(-1, 1) / nlon_in / 2.0
120
    else:
Thorsten Kurth's avatar
Thorsten Kurth committed
121
        quad_weights = win.reshape(-1, 1) / nlon_in / 2.0
122

123
    # array for accumulating non-zero indices
124
    out = torch.zeros(kernel_size, nlat_out, nlon_out, nlat_in, nlon_in, dtype=torch.float64, device=lons_in.device)
Boris Bonev's avatar
Boris Bonev committed
125
126
127
128
129
130
131
132
133
134
135
136

    for t in range(nlat_out):
        for p in range(nlon_out):
            alpha = -lats_out[t]
            beta = lons_in - lons_out[p]
            gamma = lats_in.reshape(-1, 1)

            # compute latitude of the rotated position
            z = -torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(alpha) * torch.cos(gamma)

            # compute cartesian coordinates of the rotated position
            x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(gamma) * torch.sin(alpha)
137
            y = torch.sin(beta) * torch.sin(gamma) * torch.ones_like(alpha)
Boris Bonev's avatar
Boris Bonev committed
138
139
140
141
142
143
144
145
146

            # normalize instead of clipping to ensure correct range
            norm = torch.sqrt(x * x + y * y + z * z)
            x = x / norm
            y = y / norm
            z = z / norm

            # compute spherical coordinates
            theta = torch.arccos(z)
147
148
            phi = torch.arctan2(y, x)
            phi = torch.where(phi < 0.0, phi + 2 * torch.pi, phi)
Boris Bonev's avatar
Boris Bonev committed
149
150

            # find the indices where the rotated position falls into the support of the kernel
151
            iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff_eff)
152
            out[iidx[:, 0], t, p, iidx[:, 1], iidx[:, 2]] = vals
Boris Bonev's avatar
Boris Bonev committed
153

154
    # take care of normalization and cast to float
155
156
157
    out = _normalize_convolution_tensor_dense(
        out, quad_weights=quad_weights, transpose_normalization=transpose_normalization, basis_norm_mode=basis_norm_mode, merge_quadrature=merge_quadrature
    )
158
    out = out.to(dtype=torch.float32)
Boris Bonev's avatar
Boris Bonev committed
159

Boris Bonev's avatar
Boris Bonev committed
160
161
162
    return out


163
@parameterized_class(("device"), _devices)
Boris Bonev's avatar
Boris Bonev committed
164
class TestDiscreteContinuousConvolution(unittest.TestCase):
165
    """Test the discrete-continuous convolution module (CPU/CUDA if available)."""
apaaris's avatar
apaaris committed
166
    
Boris Bonev's avatar
Boris Bonev committed
167
    def setUp(self):
168
169
        torch.manual_seed(333)
        if self.device.type == "cuda":
170
            torch.cuda.manual_seed(333)
Boris Bonev's avatar
Boris Bonev committed
171

Boris Bonev's avatar
Boris Bonev committed
172
173
174
    @parameterized.expand(
        [
            # regular convolution
Thorsten Kurth's avatar
Thorsten Kurth committed
175
176
177
178
179
180
181
182
183
184
185
            [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
            [8, 4, 2, (16, 32), (8, 16), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
            [8, 4, 2, (24, 48), (12, 24), (3, 3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
            [8, 4, 2, (24, 48), (12, 24), (4, 3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
            [8, 4, 2, (24, 48), (12, 24), (2, 2), "morlet", "mean", "equiangular", "equiangular", False, 1e-4, False],
            [8, 4, 2, (24, 48), (12, 24), (3), "zernike", "mean", "equiangular", "equiangular", False, 1e-4, False],
            [8, 4, 2, (16, 24), (8, 8), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
            [8, 4, 2, (18, 36), (6, 12), (7), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
            [8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "equiangular", "legendre-gauss", False, 1e-4, False],
            [8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "legendre-gauss", "equiangular", False, 1e-4, False],
            [8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", False, 1e-4, False],
Boris Bonev's avatar
Boris Bonev committed
186
            # transpose convolution
Thorsten Kurth's avatar
Thorsten Kurth committed
187
188
189
190
191
192
193
194
195
196
197
            [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
            [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
            [8, 4, 2, (12, 24), (24, 48), (3, 3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
            [8, 4, 2, (12, 24), (24, 48), (4, 3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
            [8, 4, 2, (12, 24), (24, 48), (2, 2), "morlet", "mean", "equiangular", "equiangular", True, 1e-4, False],
            [8, 4, 2, (12, 24), (24, 48), (3), "zernike", "mean", "equiangular", "equiangular", True, 1e-4, False],
            [8, 4, 2, (8, 8), (16, 24), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
            [8, 4, 2, (6, 12), (18, 36), (7), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
            [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "equiangular", "legendre-gauss", True, 1e-4, False],
            [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "equiangular", True, 1e-4, False],
            [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", True, 1e-4, False],
Thorsten Kurth's avatar
Thorsten Kurth committed
198
199
        ],
        skip_on_empty=True,
Boris Bonev's avatar
Boris Bonev committed
200
    )
Thorsten Kurth's avatar
Thorsten Kurth committed
201
    def test_forward_backward(
Boris Bonev's avatar
Boris Bonev committed
202
203
204
205
206
207
208
        self,
        batch_size,
        in_channels,
        out_channels,
        in_shape,
        out_shape,
        kernel_shape,
209
210
        basis_type,
        basis_norm_mode,
Boris Bonev's avatar
Boris Bonev committed
211
212
213
214
        grid_in,
        grid_out,
        transpose,
        tol,
Thorsten Kurth's avatar
Thorsten Kurth committed
215
        verbose,
Boris Bonev's avatar
Boris Bonev committed
216
    ):
Thorsten Kurth's avatar
Thorsten Kurth committed
217
218
219
220

        if verbose:
            print(f"Testing DISCO convolution on {in_shape[0]}x{in_shape[1]} {grid_in} grid to {out_shape[0]}x{out_shape[1]} {grid_out} grid on {self.device.type} device")
        
Boris Bonev's avatar
Boris Bonev committed
221
222
223
        nlat_in, nlon_in = in_shape
        nlat_out, nlon_out = out_shape

Thorsten Kurth's avatar
Thorsten Kurth committed
224
225
226
227
        if isinstance(kernel_shape, int):
            theta_cutoff = (kernel_shape + 1) * torch.pi / float(nlat_in - 1)
        else:
            theta_cutoff = (kernel_shape[0] + 1) * torch.pi / float(nlat_in - 1)
Boris Bonev's avatar
Boris Bonev committed
228

Boris Bonev's avatar
Boris Bonev committed
229
        Conv = DiscreteContinuousConvTransposeS2 if transpose else DiscreteContinuousConvS2
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
        conv = Conv(
            in_channels,
            out_channels,
            in_shape,
            out_shape,
            kernel_shape,
            basis_type=basis_type,
            basis_norm_mode=basis_norm_mode,
            groups=1,
            grid_in=grid_in,
            grid_out=grid_out,
            bias=False,
            theta_cutoff=theta_cutoff,
        ).to(self.device)

        filter_basis = conv.filter_basis
Boris Bonev's avatar
Boris Bonev committed
246
247

        if transpose:
248
            psi_dense = _precompute_convolution_tensor_dense(
249
250
251
252
253
254
255
256
257
                out_shape,
                in_shape,
                filter_basis,
                grid_in=grid_out,
                grid_out=grid_in,
                theta_cutoff=theta_cutoff,
                transpose_normalization=transpose,
                basis_norm_mode=basis_norm_mode,
                merge_quadrature=True,
258
            ).to(self.device)
Boris Bonev's avatar
Boris Bonev committed
259
260
261
262

            psi = torch.sparse_coo_tensor(conv.psi_idx, conv.psi_vals, size=(conv.kernel_size, conv.nlat_in, conv.nlat_out * conv.nlon_out)).to_dense()

            self.assertTrue(torch.allclose(psi, psi_dense[:, :, 0].reshape(-1, nlat_in, nlat_out * nlon_out)))
Boris Bonev's avatar
Boris Bonev committed
263
        else:
264
            psi_dense = _precompute_convolution_tensor_dense(
265
266
267
268
269
270
271
272
273
                in_shape,
                out_shape,
                filter_basis,
                grid_in=grid_in,
                grid_out=grid_out,
                theta_cutoff=theta_cutoff,
                transpose_normalization=transpose,
                basis_norm_mode=basis_norm_mode,
                merge_quadrature=True,
274
            ).to(self.device)
Boris Bonev's avatar
Boris Bonev committed
275

Boris Bonev's avatar
Boris Bonev committed
276
277
            psi = torch.sparse_coo_tensor(conv.psi_idx, conv.psi_vals, size=(conv.kernel_size, conv.nlat_out, conv.nlat_in * conv.nlon_in)).to_dense()

Boris Bonev's avatar
Boris Bonev committed
278
            self.assertTrue(torch.allclose(psi, psi_dense[:, :, 0].reshape(-1, nlat_out, nlat_in * nlon_in)))
Boris Bonev's avatar
Boris Bonev committed
279
280

        # create a copy of the weight
281
282
283
284
        w_ref = torch.empty_like(conv.weight)
        with torch.no_grad():
            w_ref.copy_(conv.weight)
        w_ref.requires_grad = True
Boris Bonev's avatar
Boris Bonev committed
285
286

        # create an input signal
287
288
289
290
291
292
293
294
        x = torch.randn(batch_size, in_channels, *in_shape, device=self.device)

        # FWD and BWD pass
        x.requires_grad = True
        y = conv(x)
        grad_input = torch.randn_like(y)
        y.backward(grad_input)
        x_grad = x.grad.clone()
Boris Bonev's avatar
Boris Bonev committed
295
296
297

        # perform the reference computation
        x_ref = x.clone().detach()
298
        x_ref.requires_grad = True
Boris Bonev's avatar
Boris Bonev committed
299
300
        if transpose:
            y_ref = torch.einsum("oif,biqr->bofqr", w_ref, x_ref)
301
            y_ref = torch.einsum("fqrtp,bofqr->botp", psi_dense, y_ref)
Boris Bonev's avatar
Boris Bonev committed
302
        else:
303
            y_ref = torch.einsum("ftpqr,bcqr->bcftp", psi_dense, x_ref)
Boris Bonev's avatar
Boris Bonev committed
304
            y_ref = torch.einsum("oif,biftp->botp", w_ref, y_ref)
305
        y_ref.backward(grad_input)
Boris Bonev's avatar
Boris Bonev committed
306
307
        x_ref_grad = x_ref.grad.clone()

Boris Bonev's avatar
Boris Bonev committed
308
309
310
        # compare results
        self.assertTrue(torch.allclose(y, y_ref, rtol=tol, atol=tol))

Boris Bonev's avatar
Boris Bonev committed
311
        # compare
312
        self.assertTrue(torch.allclose(x_grad, x_ref_grad, rtol=tol, atol=tol))
Boris Bonev's avatar
Boris Bonev committed
313
314
        self.assertTrue(torch.allclose(conv.weight.grad, w_ref.grad, rtol=tol, atol=tol))

315
316
317
318
319
320
    @parameterized.expand(
        [
            [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", False, 1e-4, False],
            [8, 4, 2, (16, 32), (8, 16), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", False, 1e-4, False],
            [8, 4, 2, (16, 32), (16, 32), (3), "piecewise linear", "mean", "equiangular", "equiangular", True, 1e-4, False],
            [8, 4, 2, (8, 16), (16, 32), (5), "piecewise linear", "mean", "legendre-gauss", "legendre-gauss", True, 1e-4, False],
Thorsten Kurth's avatar
Thorsten Kurth committed
321
322
        ],
        skip_on_empty=True,
323
324
325
    )
    @unittest.skipIf(not torch.cuda.is_available(), "CUDA is not available")
    def test_device_instantiation(self, batch_size, in_channels, out_channels, in_shape, out_shape, kernel_shape, basis_type, basis_norm_mode, grid_in, grid_out, transpose, tol, verbose):
Thorsten Kurth's avatar
Thorsten Kurth committed
326

327
328
329
330
331
332
333
334
        nlat_in, nlon_in = in_shape
        nlat_out, nlon_out = out_shape

        if isinstance(kernel_shape, int):
            theta_cutoff = (kernel_shape + 1) * torch.pi / float(nlat_in - 1)
        else:
            theta_cutoff = (kernel_shape[0] + 1) * torch.pi / float(nlat_in - 1)

Thorsten Kurth's avatar
Thorsten Kurth committed
335
        # get handle
336
        Conv = DiscreteContinuousConvTransposeS2 if transpose else DiscreteContinuousConvS2
Thorsten Kurth's avatar
Thorsten Kurth committed
337
338

        # init on cpu
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
        conv_host = Conv(
            in_channels,
            out_channels,
            in_shape,
            out_shape,
            kernel_shape,
            basis_type=basis_type,
            basis_norm_mode=basis_norm_mode,
            groups=1,
            grid_in=grid_in,
            grid_out=grid_out,
            bias=False,
            theta_cutoff=theta_cutoff,
        )

Thorsten Kurth's avatar
Thorsten Kurth committed
354
355
356
        #torch.set_default_device(self.device)
        with torch.device(self.device):
            conv_device = Conv(
357
358
359
360
361
362
363
364
365
366
367
368
369
370
                in_channels,
                out_channels,
                in_shape,
                out_shape,
                kernel_shape,
                basis_type=basis_type,
                basis_norm_mode=basis_norm_mode,
                groups=1,
                grid_in=grid_in,
                grid_out=grid_out,
                bias=False,
                theta_cutoff=theta_cutoff,
            )

Thorsten Kurth's avatar
Thorsten Kurth committed
371
372
        # since we specified the device specifier everywhere, it should always
        # use the cpu and it should be the same everywhere
373
374
375
376
377
378
        self.assertTrue(torch.allclose(conv_host.psi_col_idx.cpu(), conv_device.psi_col_idx.cpu()))
        self.assertTrue(torch.allclose(conv_host.psi_row_idx.cpu(), conv_device.psi_row_idx.cpu()))
        self.assertTrue(torch.allclose(conv_host.psi_roff_idx.cpu(), conv_device.psi_roff_idx.cpu()))
        self.assertTrue(torch.allclose(conv_host.psi_vals.cpu(), conv_device.psi_vals.cpu()))
        self.assertTrue(torch.allclose(conv_host.psi_idx.cpu(), conv_device.psi_idx.cpu()))

Boris Bonev's avatar
Boris Bonev committed
379

Boris Bonev's avatar
Boris Bonev committed
380
381
if __name__ == "__main__":
    unittest.main()