convolution.py 22.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

Boris Bonev's avatar
Boris Bonev committed
32
import abc
33
from typing import List, Tuple, Union, Optional
Boris Bonev's avatar
Boris Bonev committed
34
from warnings import warn
35
36
37
38
39
40
41
42

import math

import torch
import torch.nn as nn

from functools import partial

Boris Bonev's avatar
Boris Bonev committed
43
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes
Boris Bonev's avatar
Boris Bonev committed
44
45
from torch_harmonics._disco_convolution import _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch
from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda
46

47
# import custom C++/CUDA extensions if available
Boris Bonev's avatar
Boris Bonev committed
48
try:
49
    from disco_helpers import preprocess_psi
Boris Bonev's avatar
Boris Bonev committed
50
51
52
53
54
55
56
57
    import disco_cuda_extension
    _cuda_extension_available = True
except ImportError as err:
    disco_cuda_extension = None
    _cuda_extension_available = False


def _compute_support_vals_isotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, r_cutoff: float):
58
59
60
61
    """
    Computes the index set that falls into the isotropic kernel's support and returns both indices and values.
    """

Boris Bonev's avatar
Boris Bonev committed
62
63
64
65
    kernel_size = (nr // 2) + nr % 2
    ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
    dr = 2 * r_cutoff / (nr + 1)

66
    # compute the support
Boris Bonev's avatar
Boris Bonev committed
67
68
    if nr % 2 == 1:
        ir = ikernel * dr
Boris Bonev's avatar
Boris Bonev committed
69
    else:
Boris Bonev's avatar
Boris Bonev committed
70
        ir = (ikernel + 0.5) * dr
71
72

    # find the indices where the rotated position falls into the support of the kernel
Boris Bonev's avatar
Boris Bonev committed
73
    iidx = torch.argwhere(((r - ir).abs() <= dr) & (r <= r_cutoff))
Boris Bonev's avatar
Boris Bonev committed
74
75
    vals = 1 - (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() / dr

76
77
    return iidx, vals

Boris Bonev's avatar
Boris Bonev committed
78

Boris Bonev's avatar
Boris Bonev committed
79
def _compute_support_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, nphi: int, r_cutoff: float):
Boris Bonev's avatar
Boris Bonev committed
80
    """
Boris Bonev's avatar
Boris Bonev committed
81
82
    Computes the index set that falls into the anisotropic kernel's support and returns both indices and values. Handles the special case
    when there is an uneven number of collocation points across the diameter of the kernel.
Boris Bonev's avatar
Boris Bonev committed
83
84
    """

Boris Bonev's avatar
Boris Bonev committed
85
    kernel_size = (nr // 2) * nphi + nr % 2
Boris Bonev's avatar
Boris Bonev committed
86
    ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
Boris Bonev's avatar
Boris Bonev committed
87
88
89
90
91
92
93
    dr = 2 * r_cutoff / (nr + 1)
    dphi = 2.0 * math.pi / nphi

    # disambiguate even and uneven cases and compute the support
    if nr % 2 == 1:
        ir = ((ikernel - 1) // nphi + 1) * dr
        iphi = ((ikernel - 1) % nphi) * dphi
Boris Bonev's avatar
Boris Bonev committed
94
    else:
Boris Bonev's avatar
Boris Bonev committed
95
96
        ir = (ikernel // nphi + 0.5) * dr
        iphi = (ikernel % nphi) * dphi
Boris Bonev's avatar
Boris Bonev committed
97
98

    # find the indices where the rotated position falls into the support of the kernel
Boris Bonev's avatar
Boris Bonev committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
    if nr % 2 == 1:
        # find the support
        cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
        cond_phi = (ikernel == 0) | ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi)
        # find indices where conditions are met
        iidx = torch.argwhere(cond_r & cond_phi)
        # compute the distance to the collocation points
        dist_r = (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs()
        dist_phi = (phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs()
        # compute the value of the basis functions
        vals = 1 - dist_r / dr
        vals *= torch.where(
            (iidx[:, 0] > 0),
            (1 - torch.minimum(dist_phi, (2 * math.pi - dist_phi)) / dphi),
            1.0,
        )

    else:
        # in the even case, the inner casis functions overlap into areas with a negative areas
        rn = -r
        phin = torch.where(phi + math.pi >= 2 * math.pi, phi - math.pi, phi + math.pi)
        # find the support
        cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
        cond_phi = ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi)
        cond_rn = ((rn - ir).abs() <= dr) & (rn <= r_cutoff)
        cond_phin = ((phin - iphi).abs() <= dphi) | ((2 * math.pi - (phin - iphi).abs()) <= dphi)
        # find indices where conditions are met
        iidx = torch.argwhere((cond_r & cond_phi) | (cond_rn & cond_phin))
        dist_r = (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs()
        dist_phi = (phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs()
        dist_rn = (rn[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs()
        dist_phin = (phin[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs()
        # compute the value of the basis functions
        vals = cond_r[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_r / dr)
        vals *= cond_phi[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - torch.minimum(dist_phi, (2 * math.pi - dist_phi)) / dphi)
        valsn = cond_rn[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_rn / dr)
        valsn *= cond_phin[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - torch.minimum(dist_phin, (2 * math.pi - dist_phin)) / dphi)
        vals += valsn

Boris Bonev's avatar
Boris Bonev committed
138
139
    return iidx, vals

140

141
def _normalize_convolution_tensor_s2(psi_idx, psi_vals, in_shape, out_shape, kernel_shape, quad_weights, transpose_normalization=False, merge_quadrature=False, eps=1e-9):
Boris Bonev's avatar
Boris Bonev committed
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
    """
    Discretely normalizes the convolution tensor.
    """

    nlat_in, nlon_in = in_shape
    nlat_out, nlon_out = out_shape

    if len(kernel_shape) == 1:
        kernel_size = math.ceil(kernel_shape[0] / 2)
    elif len(kernel_shape) == 2:
        kernel_size = (kernel_shape[0] // 2) * kernel_shape[1] + kernel_shape[0] % 2

    # reshape the indices implicitly to be ikernel, lat_out, lat_in, lon_in
    idx = torch.stack([psi_idx[0], psi_idx[1], psi_idx[2] // nlon_in, psi_idx[2] % nlon_in], dim=0)

    if transpose_normalization:
        # pre-compute the quadrature weights
        q = quad_weights[idx[1]].reshape(-1)

        # loop through dimensions which require normalization
        for ik in range(kernel_size):
            for ilat in range(nlat_in):
                # get relevant entries
                iidx = torch.argwhere((idx[0] == ik) & (idx[2] == ilat))
                # normalize, while summing also over the input longitude dimension here as this is not available for the output
                vnorm = torch.sum(psi_vals[iidx] * q[iidx])
168
169
170
171
172
                if merge_quadrature:
                    # the correction factor accounts for the difference in longitudinal grid points when the input vector is upscaled
                    psi_vals[iidx] = psi_vals[iidx] * q[iidx] * nlon_in / nlon_out / (vnorm + eps)
                else:
                    psi_vals[iidx] = psi_vals[iidx] / (vnorm + eps)
Boris Bonev's avatar
Boris Bonev committed
173
174
175
176
177
178
179
180
181
182
183
    else:
        # pre-compute the quadrature weights
        q = quad_weights[idx[2]].reshape(-1)

        # loop through dimensions which require normalization
        for ik in range(kernel_size):
            for ilat in range(nlat_out):
                # get relevant entries
                iidx = torch.argwhere((idx[0] == ik) & (idx[1] == ilat))
                # normalize
                vnorm = torch.sum(psi_vals[iidx] * q[iidx])
184
185
186
187
                if merge_quadrature:
                    psi_vals[iidx] = psi_vals[iidx] * q[iidx] / (vnorm + eps)
                else:
                    psi_vals[iidx] = psi_vals[iidx] / (vnorm + eps)
Boris Bonev's avatar
Boris Bonev committed
188
189
190
191
192

    return psi_vals


def _precompute_convolution_tensor_s2(
193
    in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi, transpose_normalization=False, merge_quadrature=False
Boris Bonev's avatar
Boris Bonev committed
194
):
195
196
197
    """
    Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
    Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al.
198
199
200
201
202
    The output tensor has shape kernel_shape x nlat_out x (nlat_in * nlon_in).

    The rotation of the Euler angles uses the YZY convention, which applied to the northpole $(0,0,1)^T$ yields
    $$
    Y(\alpha) Z(\beta) Y(\gamma) n =
Boris Bonev's avatar
Boris Bonev committed
203
        {\begin{bmatrix}
204
205
206
207
208
            \cos(\gamma)\sin(\alpha) + \cos(\alpha)\cos(\beta)\sin(\gamma) \\
            \sin(\beta)\sin(\gamma) \\
            \cos(\alpha)\cos(\gamma)-\cos(\beta)\sin(\alpha)\sin(\gamma)
        \end{bmatrix}}
    $$
209
210
211
212
213
214
    """

    assert len(in_shape) == 2
    assert len(out_shape) == 2

    if len(kernel_shape) == 1:
Boris Bonev's avatar
Boris Bonev committed
215
        kernel_handle = partial(_compute_support_vals_isotropic, nr=kernel_shape[0], r_cutoff=theta_cutoff)
Boris Bonev's avatar
Boris Bonev committed
216
    elif len(kernel_shape) == 2:
Boris Bonev's avatar
Boris Bonev committed
217
        kernel_handle = partial(_compute_support_vals_anisotropic, nr=kernel_shape[0], nphi=kernel_shape[1], r_cutoff=theta_cutoff)
218
219
220
221
222
223
    else:
        raise ValueError("kernel_shape should be either one- or two-dimensional.")

    nlat_in, nlon_in = in_shape
    nlat_out, nlon_out = out_shape

Boris Bonev's avatar
Boris Bonev committed
224
    lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in)
225
    lats_in = torch.from_numpy(lats_in).float()
Boris Bonev's avatar
Boris Bonev committed
226
    lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out)
227
228
229
    lats_out = torch.from_numpy(lats_out).float()

    # compute the phi differences
230
    # It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
Boris Bonev's avatar
Boris Bonev committed
231
    lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1]
232

233
234
    out_idx = []
    out_vals = []
235
    for t in range(nlat_out):
236
        # the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis
Boris Bonev's avatar
Boris Bonev committed
237
        alpha = -lats_out[t]
238
239
        beta = lons_in
        gamma = lats_in.reshape(-1, 1)
240
241

        # compute cartesian coordinates of the rotated position
242
243
        # This uses the YZY convention of Euler angles, where the last angle (alpha) is a passive rotation,
        # and therefore applied with a negative sign
Boris Bonev's avatar
Boris Bonev committed
244
        z = -torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(alpha) * torch.cos(gamma)
245
        x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(gamma) * torch.sin(alpha)
246
        y = torch.sin(beta) * torch.sin(gamma)
Boris Bonev's avatar
Boris Bonev committed
247

248
249
        # normalization is emportant to avoid NaNs when arccos and atan are applied
        # this can otherwise lead to spurious artifacts in the solution
Boris Bonev's avatar
Boris Bonev committed
250
        norm = torch.sqrt(x * x + y * y + z * z)
251
252
253
254
        x = x / norm
        y = y / norm
        z = z / norm

Boris Bonev's avatar
Boris Bonev committed
255
        # compute spherical coordinates, where phi needs to fall into the [0, 2pi) range
256
        theta = torch.arccos(z)
Boris Bonev's avatar
Boris Bonev committed
257
        phi = torch.arctan2(y, x) + torch.pi
258
259
260
261
262
263
264
265

        # find the indices where the rotated position falls into the support of the kernel
        iidx, vals = kernel_handle(theta, phi)

        # add the output latitude and reshape such that psi has dimensions kernel_shape x nlat_out x (nlat_in*nlon_in)
        idx = torch.stack([iidx[:, 0], t * torch.ones_like(iidx[:, 0]), iidx[:, 1] * nlon_in + iidx[:, 2]], dim=0)

        # append indices and values to the COO datastructure
266
267
268
269
270
271
        out_idx.append(idx)
        out_vals.append(vals)

    # concatenate the indices and values
    out_idx = torch.cat(out_idx, dim=-1).to(torch.long).contiguous()
    out_vals = torch.cat(out_vals, dim=-1).to(torch.float32).contiguous()
272

Boris Bonev's avatar
Boris Bonev committed
273
274
    if transpose_normalization:
        quad_weights = 2.0 * torch.pi * torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in
Boris Bonev's avatar
Boris Bonev committed
275
    else:
Boris Bonev's avatar
Boris Bonev committed
276
        quad_weights = 2.0 * torch.pi * torch.from_numpy(win).float().reshape(-1, 1) / nlon_in
277
278
279
    out_vals = _normalize_convolution_tensor_s2(
        out_idx, out_vals, in_shape, out_shape, kernel_shape, quad_weights, transpose_normalization=transpose_normalization, merge_quadrature=merge_quadrature
    )
Boris Bonev's avatar
Boris Bonev committed
280

Boris Bonev's avatar
Boris Bonev committed
281
    return out_idx, out_vals
Boris Bonev's avatar
Boris Bonev committed
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304


class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta):
    """
    Abstract base class for DISCO convolutions
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_shape: Union[int, List[int]],
        groups: Optional[int] = 1,
        bias: Optional[bool] = True,
    ):
        super().__init__()

        if isinstance(kernel_shape, int):
            self.kernel_shape = [kernel_shape]
        else:
            self.kernel_shape = kernel_shape

        if len(self.kernel_shape) == 1:
Boris Bonev's avatar
Boris Bonev committed
305
306
307
308
309
            self.kernel_size = math.ceil(self.kernel_shape[0] / 2)
            if self.kernel_shape[0] % 2 == 0:
                warn(
                    "Detected isotropic kernel with even number of collocation points in the radial direction. This feature is only supported out of consistency and may lead to unexpected behavior."
                )
Boris Bonev's avatar
Boris Bonev committed
310
        elif len(self.kernel_shape) == 2:
Boris Bonev's avatar
Boris Bonev committed
311
312
            self.kernel_size = (self.kernel_shape[0] // 2) * self.kernel_shape[1] + self.kernel_shape[0] % 2
        if len(self.kernel_shape) > 2:
Boris Bonev's avatar
Boris Bonev committed
313
314
315
316
317
318
319
320
321
322
323
            raise ValueError("kernel_shape should be either one- or two-dimensional.")

        # groups
        self.groups = groups

        # weight tensor
        if in_channels % self.groups != 0:
            raise ValueError("Error, the number of input channels has to be an integer multiple of the group size")
        if out_channels % self.groups != 0:
            raise ValueError("Error, the number of output channels has to be an integer multiple of the group size")
        self.groupsize = in_channels // self.groups
Boris Bonev's avatar
Boris Bonev committed
324
        scale = math.sqrt(1.0 / self.groupsize / self.kernel_size)
Boris Bonev's avatar
Boris Bonev committed
325
326
327
328
329
330
331
332
333
334
335
336
337
        self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, self.kernel_size))

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channels))
        else:
            self.bias = None

    @abc.abstractmethod
    def forward(self, x: torch.Tensor):
        raise NotImplementedError


class DiscreteContinuousConvS2(DiscreteContinuousConv):
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
    """
    Discrete-continuous convolutions (DISCO) on the 2-Sphere as described in [1].

    [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        in_shape: Tuple[int],
        out_shape: Tuple[int],
        kernel_shape: Union[int, List[int]],
        groups: Optional[int] = 1,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        bias: Optional[bool] = True,
        theta_cutoff: Optional[float] = None,
    ):
Boris Bonev's avatar
Boris Bonev committed
357
        super().__init__(in_channels, out_channels, kernel_shape, groups, bias)
358
359
360
361

        self.nlat_in, self.nlon_in = in_shape
        self.nlat_out, self.nlon_out = out_shape

Boris Bonev's avatar
Boris Bonev committed
362
        # heuristic to compute theta cutoff based on the bandlimit of the input field and overlaps of the basis functions
363
        if theta_cutoff is None:
Boris Bonev's avatar
Boris Bonev committed
364
            theta_cutoff = torch.pi / float(self.nlat_out - 1)
365
366
367
368

        if theta_cutoff <= 0.0:
            raise ValueError("Error, theta_cutoff has to be positive.")

Boris Bonev's avatar
Boris Bonev committed
369
        idx, vals = _precompute_convolution_tensor_s2(
370
            in_shape, out_shape, self.kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff, transpose_normalization=False, merge_quadrature=True
Boris Bonev's avatar
Boris Bonev committed
371
372
373
374
375
376
        )

        # sort the values
        ker_idx = idx[0, ...].contiguous()
        row_idx = idx[1, ...].contiguous()
        col_idx = idx[2, ...].contiguous()
377
        vals = vals.contiguous()
Boris Bonev's avatar
Boris Bonev committed
378

379
380
381
382
383
384
        if _cuda_extension_available:
            # preprocessed data-structure for GPU kernel
            roff_idx = preprocess_psi(self.kernel_size, out_shape[0], ker_idx, row_idx, col_idx, vals).contiguous()
            self.register_buffer("psi_roff_idx", roff_idx, persistent=False)

        # save all datastructures
Boris Bonev's avatar
Boris Bonev committed
385
386
387
        self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
        self.register_buffer("psi_row_idx", row_idx, persistent=False)
        self.register_buffer("psi_col_idx", col_idx, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
388
        self.register_buffer("psi_vals", vals, persistent=False)
389

390
391
392
393
394
395
    def extra_repr(self):
        r"""
        Pretty print module
        """
        return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, kernel_shape={self.kernel_shape}, groups={self.groups}"

Boris Bonev's avatar
Boris Bonev committed
396
397
398
399
    @property
    def psi_idx(self):
        return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()

Boris Bonev's avatar
Boris Bonev committed
400
401
402
    def get_psi(self):
        psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in)).coalesce()
        return psi
403

Boris Bonev's avatar
Boris Bonev committed
404
    def forward(self, x: torch.Tensor) -> torch.Tensor:
405

Boris Bonev's avatar
Boris Bonev committed
406
407
408
409
        if x.is_cuda and _cuda_extension_available:
            x = _disco_s2_contraction_cuda(
                x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out
            )
410
        else:
Boris Bonev's avatar
Boris Bonev committed
411
412
413
            if x.is_cuda:
                warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
            psi = self.get_psi()
Boris Bonev's avatar
Boris Bonev committed
414
            x = _disco_s2_contraction_torch(x, psi, self.nlon_out)
415
416
417
418
419
420

        # extract shape
        B, C, K, H, W = x.shape
        x = x.reshape(B, self.groups, self.groupsize, K, H, W)

        # do weight multiplication
Thorsten Kurth's avatar
Thorsten Kurth committed
421
        out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
Boris Bonev's avatar
Boris Bonev committed
422
        out = out.reshape(B, -1, H, W)
423
424
425
426
427
428
429

        if self.bias is not None:
            out = out + self.bias.reshape(1, -1, 1, 1)

        return out


Boris Bonev's avatar
Boris Bonev committed
430
class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
    """
    Discrete-continuous transpose convolutions (DISCO) on the 2-Sphere as described in [1].

    [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        in_shape: Tuple[int],
        out_shape: Tuple[int],
        kernel_shape: Union[int, List[int]],
        groups: Optional[int] = 1,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        bias: Optional[bool] = True,
        theta_cutoff: Optional[float] = None,
    ):
Boris Bonev's avatar
Boris Bonev committed
450
        super().__init__(in_channels, out_channels, kernel_shape, groups, bias)
451
452
453
454
455
456

        self.nlat_in, self.nlon_in = in_shape
        self.nlat_out, self.nlon_out = out_shape

        # bandlimit
        if theta_cutoff is None:
Boris Bonev's avatar
Boris Bonev committed
457
            theta_cutoff = torch.pi / float(self.nlat_in - 1)
458
459
460
461
462

        if theta_cutoff <= 0.0:
            raise ValueError("Error, theta_cutoff has to be positive.")

        # switch in_shape and out_shape since we want transpose conv
Boris Bonev's avatar
Boris Bonev committed
463
        idx, vals = _precompute_convolution_tensor_s2(
464
            out_shape, in_shape, self.kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff, transpose_normalization=True, merge_quadrature=True
Boris Bonev's avatar
Boris Bonev committed
465
466
467
468
469
470
        )

        # sort the values
        ker_idx = idx[0, ...].contiguous()
        row_idx = idx[1, ...].contiguous()
        col_idx = idx[2, ...].contiguous()
471
        vals = vals.contiguous()
Boris Bonev's avatar
Boris Bonev committed
472

473
474
475
476
477
478
        if _cuda_extension_available:
            # preprocessed data-structure for GPU kernel
            roff_idx = preprocess_psi(self.kernel_size, in_shape[0], ker_idx, row_idx, col_idx, vals).contiguous()
            self.register_buffer("psi_roff_idx", roff_idx, persistent=False)

        # save all datastructures
Boris Bonev's avatar
Boris Bonev committed
479
480
481
        self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
        self.register_buffer("psi_row_idx", row_idx, persistent=False)
        self.register_buffer("psi_col_idx", col_idx, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
482
        self.register_buffer("psi_vals", vals, persistent=False)
483

484
485
486
487
488
489
    def extra_repr(self):
        r"""
        Pretty print module
        """
        return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, kernel_shape={self.kernel_shape}, groups={self.groups}"

Boris Bonev's avatar
Boris Bonev committed
490
491
492
493
494
495
    @property
    def psi_idx(self):
        return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()

    def get_psi(self, semi_transposed: bool = False):
        if semi_transposed:
496
497
498
499
500
501
            # we do a semi-transposition to faciliate the computation
            tout = self.psi_idx[2] // self.nlon_out
            pout = self.psi_idx[2] % self.nlon_out
            # flip the axis of longitudes
            pout = self.nlon_out - 1 - pout
            tin = self.psi_idx[1]
Boris Bonev's avatar
Boris Bonev committed
502
            idx = torch.stack([self.psi_idx[0], tout, tin * self.nlon_out + pout], dim=0)
503
504
505
            psi = torch.sparse_coo_tensor(idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_out)).coalesce()
        else:
            psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out)).coalesce()
Boris Bonev's avatar
Boris Bonev committed
506

Boris Bonev's avatar
Boris Bonev committed
507
        return psi
508

Boris Bonev's avatar
Boris Bonev committed
509
    def forward(self, x: torch.Tensor) -> torch.Tensor:
510
        # extract shape
Boris Bonev's avatar
Boris Bonev committed
511
        B, C, H, W = x.shape
512
513
514
        x = x.reshape(B, self.groups, self.groupsize, H, W)

        # do weight multiplication
Thorsten Kurth's avatar
Thorsten Kurth committed
515
        x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
Boris Bonev's avatar
Boris Bonev committed
516
        x = x.reshape(B, -1, x.shape[-3], H, W)
517

Boris Bonev's avatar
Boris Bonev committed
518
519
520
521
        if x.is_cuda and _cuda_extension_available:
            out = _disco_s2_transpose_contraction_cuda(
                x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out
            )
522
        else:
Boris Bonev's avatar
Boris Bonev committed
523
524
525
            if x.is_cuda:
                warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
            psi = self.get_psi(semi_transposed=True)
Boris Bonev's avatar
Boris Bonev committed
526
            out = _disco_s2_transpose_contraction_torch(x, psi, self.nlon_out)
527
528
529
530
531

        if self.bias is not None:
            out = out + self.bias.reshape(1, -1, 1, 1)

        return out