convolution.py 22.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

Boris Bonev's avatar
Boris Bonev committed
32
import abc
33
from typing import List, Tuple, Union, Optional
Boris Bonev's avatar
Boris Bonev committed
34
from warnings import warn
35
36
37
38
39
40
41
42

import math

import torch
import torch.nn as nn

from functools import partial

Boris Bonev's avatar
Boris Bonev committed
43
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes
Boris Bonev's avatar
Boris Bonev committed
44
45
from torch_harmonics._disco_convolution import _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch
from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda
46

Boris Bonev's avatar
Boris Bonev committed
47
48
# import custom C++/CUDA extensions
from disco_helpers import preprocess_psi
49

Boris Bonev's avatar
Boris Bonev committed
50
51
52
53
54
55
56
57
58
59
try:
    import disco_cuda_extension

    _cuda_extension_available = True
except ImportError as err:
    disco_cuda_extension = None
    _cuda_extension_available = False


def _compute_support_vals_isotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, r_cutoff: float):
60
61
62
63
    """
    Computes the index set that falls into the isotropic kernel's support and returns both indices and values.
    """

Boris Bonev's avatar
Boris Bonev committed
64
65
66
67
    kernel_size = (nr // 2) + nr % 2
    ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
    dr = 2 * r_cutoff / (nr + 1)

68
    # compute the support
Boris Bonev's avatar
Boris Bonev committed
69
70
    if nr % 2 == 1:
        ir = ikernel * dr
Boris Bonev's avatar
Boris Bonev committed
71
    else:
Boris Bonev's avatar
Boris Bonev committed
72
        ir = (ikernel + 0.5) * dr
73
74

    # find the indices where the rotated position falls into the support of the kernel
Boris Bonev's avatar
Boris Bonev committed
75
    iidx = torch.argwhere(((r - ir).abs() <= dr) & (r <= r_cutoff))
Boris Bonev's avatar
Boris Bonev committed
76
77
    vals = 1 - (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs() / dr

78
79
    return iidx, vals

Boris Bonev's avatar
Boris Bonev committed
80

Boris Bonev's avatar
Boris Bonev committed
81
def _compute_support_vals_anisotropic(r: torch.Tensor, phi: torch.Tensor, nr: int, nphi: int, r_cutoff: float):
Boris Bonev's avatar
Boris Bonev committed
82
    """
Boris Bonev's avatar
Boris Bonev committed
83
84
    Computes the index set that falls into the anisotropic kernel's support and returns both indices and values. Handles the special case
    when there is an uneven number of collocation points across the diameter of the kernel.
Boris Bonev's avatar
Boris Bonev committed
85
86
    """

Boris Bonev's avatar
Boris Bonev committed
87
    kernel_size = (nr // 2) * nphi + nr % 2
Boris Bonev's avatar
Boris Bonev committed
88
    ikernel = torch.arange(kernel_size).reshape(-1, 1, 1)
Boris Bonev's avatar
Boris Bonev committed
89
90
91
92
93
94
95
    dr = 2 * r_cutoff / (nr + 1)
    dphi = 2.0 * math.pi / nphi

    # disambiguate even and uneven cases and compute the support
    if nr % 2 == 1:
        ir = ((ikernel - 1) // nphi + 1) * dr
        iphi = ((ikernel - 1) % nphi) * dphi
Boris Bonev's avatar
Boris Bonev committed
96
    else:
Boris Bonev's avatar
Boris Bonev committed
97
98
        ir = (ikernel // nphi + 0.5) * dr
        iphi = (ikernel % nphi) * dphi
Boris Bonev's avatar
Boris Bonev committed
99
100

    # find the indices where the rotated position falls into the support of the kernel
Boris Bonev's avatar
Boris Bonev committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
    if nr % 2 == 1:
        # find the support
        cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
        cond_phi = (ikernel == 0) | ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi)
        # find indices where conditions are met
        iidx = torch.argwhere(cond_r & cond_phi)
        # compute the distance to the collocation points
        dist_r = (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs()
        dist_phi = (phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs()
        # compute the value of the basis functions
        vals = 1 - dist_r / dr
        vals *= torch.where(
            (iidx[:, 0] > 0),
            (1 - torch.minimum(dist_phi, (2 * math.pi - dist_phi)) / dphi),
            1.0,
        )

    else:
        # in the even case, the inner casis functions overlap into areas with a negative areas
        rn = -r
        phin = torch.where(phi + math.pi >= 2 * math.pi, phi - math.pi, phi + math.pi)
        # find the support
        cond_r = ((r - ir).abs() <= dr) & (r <= r_cutoff)
        cond_phi = ((phi - iphi).abs() <= dphi) | ((2 * math.pi - (phi - iphi).abs()) <= dphi)
        cond_rn = ((rn - ir).abs() <= dr) & (rn <= r_cutoff)
        cond_phin = ((phin - iphi).abs() <= dphi) | ((2 * math.pi - (phin - iphi).abs()) <= dphi)
        # find indices where conditions are met
        iidx = torch.argwhere((cond_r & cond_phi) | (cond_rn & cond_phin))
        dist_r = (r[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs()
        dist_phi = (phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs()
        dist_rn = (rn[iidx[:, 1], iidx[:, 2]] - ir[iidx[:, 0], 0, 0]).abs()
        dist_phin = (phin[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs()
        # compute the value of the basis functions
        vals = cond_r[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_r / dr)
        vals *= cond_phi[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - torch.minimum(dist_phi, (2 * math.pi - dist_phi)) / dphi)
        valsn = cond_rn[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - dist_rn / dr)
        valsn *= cond_phin[iidx[:, 0], iidx[:, 1], iidx[:, 2]] * (1 - torch.minimum(dist_phin, (2 * math.pi - dist_phin)) / dphi)
        vals += valsn

Boris Bonev's avatar
Boris Bonev committed
140
141
    return iidx, vals

142

143
def _normalize_convolution_tensor_s2(psi_idx, psi_vals, in_shape, out_shape, kernel_shape, quad_weights, transpose_normalization=False, merge_quadrature=False, eps=1e-9):
Boris Bonev's avatar
Boris Bonev committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
    """
    Discretely normalizes the convolution tensor.
    """

    nlat_in, nlon_in = in_shape
    nlat_out, nlon_out = out_shape

    if len(kernel_shape) == 1:
        kernel_size = math.ceil(kernel_shape[0] / 2)
    elif len(kernel_shape) == 2:
        kernel_size = (kernel_shape[0] // 2) * kernel_shape[1] + kernel_shape[0] % 2

    # reshape the indices implicitly to be ikernel, lat_out, lat_in, lon_in
    idx = torch.stack([psi_idx[0], psi_idx[1], psi_idx[2] // nlon_in, psi_idx[2] % nlon_in], dim=0)

    if transpose_normalization:
        # pre-compute the quadrature weights
        q = quad_weights[idx[1]].reshape(-1)

        # loop through dimensions which require normalization
        for ik in range(kernel_size):
            for ilat in range(nlat_in):
                # get relevant entries
                iidx = torch.argwhere((idx[0] == ik) & (idx[2] == ilat))
                # normalize, while summing also over the input longitude dimension here as this is not available for the output
                vnorm = torch.sum(psi_vals[iidx] * q[iidx])
170
171
172
173
174
                if merge_quadrature:
                    # the correction factor accounts for the difference in longitudinal grid points when the input vector is upscaled
                    psi_vals[iidx] = psi_vals[iidx] * q[iidx] * nlon_in / nlon_out / (vnorm + eps)
                else:
                    psi_vals[iidx] = psi_vals[iidx] / (vnorm + eps)
Boris Bonev's avatar
Boris Bonev committed
175
176
177
178
179
180
181
182
183
184
185
    else:
        # pre-compute the quadrature weights
        q = quad_weights[idx[2]].reshape(-1)

        # loop through dimensions which require normalization
        for ik in range(kernel_size):
            for ilat in range(nlat_out):
                # get relevant entries
                iidx = torch.argwhere((idx[0] == ik) & (idx[1] == ilat))
                # normalize
                vnorm = torch.sum(psi_vals[iidx] * q[iidx])
186
187
188
189
                if merge_quadrature:
                    psi_vals[iidx] = psi_vals[iidx] * q[iidx] / (vnorm + eps)
                else:
                    psi_vals[iidx] = psi_vals[iidx] / (vnorm + eps)
Boris Bonev's avatar
Boris Bonev committed
190
191
192
193
194

    return psi_vals


def _precompute_convolution_tensor_s2(
195
    in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi, transpose_normalization=False, merge_quadrature=False
Boris Bonev's avatar
Boris Bonev committed
196
):
197
198
199
    """
    Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
    Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al.
200
201
202
203
204
    The output tensor has shape kernel_shape x nlat_out x (nlat_in * nlon_in).

    The rotation of the Euler angles uses the YZY convention, which applied to the northpole $(0,0,1)^T$ yields
    $$
    Y(\alpha) Z(\beta) Y(\gamma) n =
Boris Bonev's avatar
Boris Bonev committed
205
        {\begin{bmatrix}
206
207
208
209
210
            \cos(\gamma)\sin(\alpha) + \cos(\alpha)\cos(\beta)\sin(\gamma) \\
            \sin(\beta)\sin(\gamma) \\
            \cos(\alpha)\cos(\gamma)-\cos(\beta)\sin(\alpha)\sin(\gamma)
        \end{bmatrix}}
    $$
211
212
213
214
215
216
    """

    assert len(in_shape) == 2
    assert len(out_shape) == 2

    if len(kernel_shape) == 1:
Boris Bonev's avatar
Boris Bonev committed
217
        kernel_handle = partial(_compute_support_vals_isotropic, nr=kernel_shape[0], r_cutoff=theta_cutoff)
Boris Bonev's avatar
Boris Bonev committed
218
    elif len(kernel_shape) == 2:
Boris Bonev's avatar
Boris Bonev committed
219
        kernel_handle = partial(_compute_support_vals_anisotropic, nr=kernel_shape[0], nphi=kernel_shape[1], r_cutoff=theta_cutoff)
220
221
222
223
224
225
    else:
        raise ValueError("kernel_shape should be either one- or two-dimensional.")

    nlat_in, nlon_in = in_shape
    nlat_out, nlon_out = out_shape

Boris Bonev's avatar
Boris Bonev committed
226
    lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in)
227
    lats_in = torch.from_numpy(lats_in).float()
Boris Bonev's avatar
Boris Bonev committed
228
    lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out)
229
230
231
    lats_out = torch.from_numpy(lats_out).float()

    # compute the phi differences
232
    # It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
Boris Bonev's avatar
Boris Bonev committed
233
    lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1]
234

235
236
    out_idx = []
    out_vals = []
237
    for t in range(nlat_out):
238
        # the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis
Boris Bonev's avatar
Boris Bonev committed
239
        alpha = -lats_out[t]
240
241
        beta = lons_in
        gamma = lats_in.reshape(-1, 1)
242
243

        # compute cartesian coordinates of the rotated position
244
245
        # This uses the YZY convention of Euler angles, where the last angle (alpha) is a passive rotation,
        # and therefore applied with a negative sign
Boris Bonev's avatar
Boris Bonev committed
246
        z = -torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(alpha) * torch.cos(gamma)
247
        x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(gamma) * torch.sin(alpha)
248
        y = torch.sin(beta) * torch.sin(gamma)
Boris Bonev's avatar
Boris Bonev committed
249

250
251
        # normalization is emportant to avoid NaNs when arccos and atan are applied
        # this can otherwise lead to spurious artifacts in the solution
Boris Bonev's avatar
Boris Bonev committed
252
        norm = torch.sqrt(x * x + y * y + z * z)
253
254
255
256
        x = x / norm
        y = y / norm
        z = z / norm

Boris Bonev's avatar
Boris Bonev committed
257
        # compute spherical coordinates, where phi needs to fall into the [0, 2pi) range
258
        theta = torch.arccos(z)
Boris Bonev's avatar
Boris Bonev committed
259
        phi = torch.arctan2(y, x) + torch.pi
260
261
262
263
264
265
266
267

        # find the indices where the rotated position falls into the support of the kernel
        iidx, vals = kernel_handle(theta, phi)

        # add the output latitude and reshape such that psi has dimensions kernel_shape x nlat_out x (nlat_in*nlon_in)
        idx = torch.stack([iidx[:, 0], t * torch.ones_like(iidx[:, 0]), iidx[:, 1] * nlon_in + iidx[:, 2]], dim=0)

        # append indices and values to the COO datastructure
268
269
270
271
272
273
        out_idx.append(idx)
        out_vals.append(vals)

    # concatenate the indices and values
    out_idx = torch.cat(out_idx, dim=-1).to(torch.long).contiguous()
    out_vals = torch.cat(out_vals, dim=-1).to(torch.float32).contiguous()
274

Boris Bonev's avatar
Boris Bonev committed
275
276
    if transpose_normalization:
        quad_weights = 2.0 * torch.pi * torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in
Boris Bonev's avatar
Boris Bonev committed
277
    else:
Boris Bonev's avatar
Boris Bonev committed
278
        quad_weights = 2.0 * torch.pi * torch.from_numpy(win).float().reshape(-1, 1) / nlon_in
279
280
281
    out_vals = _normalize_convolution_tensor_s2(
        out_idx, out_vals, in_shape, out_shape, kernel_shape, quad_weights, transpose_normalization=transpose_normalization, merge_quadrature=merge_quadrature
    )
Boris Bonev's avatar
Boris Bonev committed
282

Boris Bonev's avatar
Boris Bonev committed
283
    return out_idx, out_vals
Boris Bonev's avatar
Boris Bonev committed
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306


class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta):
    """
    Abstract base class for DISCO convolutions
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_shape: Union[int, List[int]],
        groups: Optional[int] = 1,
        bias: Optional[bool] = True,
    ):
        super().__init__()

        if isinstance(kernel_shape, int):
            self.kernel_shape = [kernel_shape]
        else:
            self.kernel_shape = kernel_shape

        if len(self.kernel_shape) == 1:
Boris Bonev's avatar
Boris Bonev committed
307
308
309
310
311
            self.kernel_size = math.ceil(self.kernel_shape[0] / 2)
            if self.kernel_shape[0] % 2 == 0:
                warn(
                    "Detected isotropic kernel with even number of collocation points in the radial direction. This feature is only supported out of consistency and may lead to unexpected behavior."
                )
Boris Bonev's avatar
Boris Bonev committed
312
        elif len(self.kernel_shape) == 2:
Boris Bonev's avatar
Boris Bonev committed
313
314
            self.kernel_size = (self.kernel_shape[0] // 2) * self.kernel_shape[1] + self.kernel_shape[0] % 2
        if len(self.kernel_shape) > 2:
Boris Bonev's avatar
Boris Bonev committed
315
316
317
318
319
320
321
322
323
324
325
            raise ValueError("kernel_shape should be either one- or two-dimensional.")

        # groups
        self.groups = groups

        # weight tensor
        if in_channels % self.groups != 0:
            raise ValueError("Error, the number of input channels has to be an integer multiple of the group size")
        if out_channels % self.groups != 0:
            raise ValueError("Error, the number of output channels has to be an integer multiple of the group size")
        self.groupsize = in_channels // self.groups
Boris Bonev's avatar
Boris Bonev committed
326
        scale = math.sqrt(1.0 / self.groupsize / self.kernel_size)
Boris Bonev's avatar
Boris Bonev committed
327
328
329
330
331
332
333
334
335
336
337
338
339
        self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, self.kernel_size))

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channels))
        else:
            self.bias = None

    @abc.abstractmethod
    def forward(self, x: torch.Tensor):
        raise NotImplementedError


class DiscreteContinuousConvS2(DiscreteContinuousConv):
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
    """
    Discrete-continuous convolutions (DISCO) on the 2-Sphere as described in [1].

    [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        in_shape: Tuple[int],
        out_shape: Tuple[int],
        kernel_shape: Union[int, List[int]],
        groups: Optional[int] = 1,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        bias: Optional[bool] = True,
        theta_cutoff: Optional[float] = None,
    ):
Boris Bonev's avatar
Boris Bonev committed
359
        super().__init__(in_channels, out_channels, kernel_shape, groups, bias)
360
361
362
363

        self.nlat_in, self.nlon_in = in_shape
        self.nlat_out, self.nlon_out = out_shape

Boris Bonev's avatar
Boris Bonev committed
364
        # heuristic to compute theta cutoff based on the bandlimit of the input field and overlaps of the basis functions
365
        if theta_cutoff is None:
Boris Bonev's avatar
Boris Bonev committed
366
            theta_cutoff = torch.pi / float(self.nlat_out - 1)
367
368
369
370

        if theta_cutoff <= 0.0:
            raise ValueError("Error, theta_cutoff has to be positive.")

Boris Bonev's avatar
Boris Bonev committed
371
        idx, vals = _precompute_convolution_tensor_s2(
372
            in_shape, out_shape, self.kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff, transpose_normalization=False, merge_quadrature=True
Boris Bonev's avatar
Boris Bonev committed
373
374
375
376
377
378
379
380
381
382
383
384
385
        )

        # sort the values
        ker_idx = idx[0, ...].contiguous()
        row_idx = idx[1, ...].contiguous()
        col_idx = idx[2, ...].contiguous()
        roff_idx = preprocess_psi(self.kernel_size, out_shape[0], ker_idx, row_idx, col_idx, vals)

        # preprocessed data-structure for GPU kernel
        self.register_buffer("psi_roff_idx", roff_idx, persistent=False)
        self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
        self.register_buffer("psi_row_idx", row_idx, persistent=False)
        self.register_buffer("psi_col_idx", col_idx, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
386
        self.register_buffer("psi_vals", vals, persistent=False)
387

388
389
390
391
392
393
    def extra_repr(self):
        r"""
        Pretty print module
        """
        return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, kernel_shape={self.kernel_shape}, groups={self.groups}"

Boris Bonev's avatar
Boris Bonev committed
394
395
396
397
    @property
    def psi_idx(self):
        return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()

Boris Bonev's avatar
Boris Bonev committed
398
399
400
    def get_psi(self):
        psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in)).coalesce()
        return psi
401

Boris Bonev's avatar
Boris Bonev committed
402
    def forward(self, x: torch.Tensor) -> torch.Tensor:
403

Boris Bonev's avatar
Boris Bonev committed
404
405
406
407
        if x.is_cuda and _cuda_extension_available:
            x = _disco_s2_contraction_cuda(
                x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out
            )
408
        else:
Boris Bonev's avatar
Boris Bonev committed
409
410
411
            if x.is_cuda:
                warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
            psi = self.get_psi()
Boris Bonev's avatar
Boris Bonev committed
412
            x = _disco_s2_contraction_torch(x, psi, self.nlon_out)
413
414
415
416
417
418

        # extract shape
        B, C, K, H, W = x.shape
        x = x.reshape(B, self.groups, self.groupsize, K, H, W)

        # do weight multiplication
419
        out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2]))
Boris Bonev's avatar
Boris Bonev committed
420
        out = out.reshape(B, -1, H, W)
421
422
423
424
425
426
427

        if self.bias is not None:
            out = out + self.bias.reshape(1, -1, 1, 1)

        return out


Boris Bonev's avatar
Boris Bonev committed
428
class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
    """
    Discrete-continuous transpose convolutions (DISCO) on the 2-Sphere as described in [1].

    [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        in_shape: Tuple[int],
        out_shape: Tuple[int],
        kernel_shape: Union[int, List[int]],
        groups: Optional[int] = 1,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        bias: Optional[bool] = True,
        theta_cutoff: Optional[float] = None,
    ):
Boris Bonev's avatar
Boris Bonev committed
448
        super().__init__(in_channels, out_channels, kernel_shape, groups, bias)
449
450
451
452
453
454

        self.nlat_in, self.nlon_in = in_shape
        self.nlat_out, self.nlon_out = out_shape

        # bandlimit
        if theta_cutoff is None:
Boris Bonev's avatar
Boris Bonev committed
455
            theta_cutoff = torch.pi / float(self.nlat_in - 1)
456
457
458
459
460

        if theta_cutoff <= 0.0:
            raise ValueError("Error, theta_cutoff has to be positive.")

        # switch in_shape and out_shape since we want transpose conv
Boris Bonev's avatar
Boris Bonev committed
461
        idx, vals = _precompute_convolution_tensor_s2(
462
            out_shape, in_shape, self.kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff, transpose_normalization=True, merge_quadrature=True
Boris Bonev's avatar
Boris Bonev committed
463
464
465
466
467
468
469
470
471
472
473
474
475
        )

        # sort the values
        ker_idx = idx[0, ...].contiguous()
        row_idx = idx[1, ...].contiguous()
        col_idx = idx[2, ...].contiguous()
        roff_idx = preprocess_psi(self.kernel_size, in_shape[0], ker_idx, row_idx, col_idx, vals)

        # preprocessed data-structure for GPU kernel
        self.register_buffer("psi_roff_idx", roff_idx, persistent=False)
        self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
        self.register_buffer("psi_row_idx", row_idx, persistent=False)
        self.register_buffer("psi_col_idx", col_idx, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
476
        self.register_buffer("psi_vals", vals, persistent=False)
477

478
479
480
481
482
483
    def extra_repr(self):
        r"""
        Pretty print module
        """
        return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, kernel_shape={self.kernel_shape}, groups={self.groups}"

Boris Bonev's avatar
Boris Bonev committed
484
485
486
487
488
489
    @property
    def psi_idx(self):
        return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()

    def get_psi(self, semi_transposed: bool = False):
        if semi_transposed:
490
491
492
493
494
495
            # we do a semi-transposition to faciliate the computation
            tout = self.psi_idx[2] // self.nlon_out
            pout = self.psi_idx[2] % self.nlon_out
            # flip the axis of longitudes
            pout = self.nlon_out - 1 - pout
            tin = self.psi_idx[1]
Boris Bonev's avatar
Boris Bonev committed
496
            idx = torch.stack([self.psi_idx[0], tout, tin * self.nlon_out + pout], dim=0)
497
498
499
            psi = torch.sparse_coo_tensor(idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_out)).coalesce()
        else:
            psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out)).coalesce()
Boris Bonev's avatar
Boris Bonev committed
500

Boris Bonev's avatar
Boris Bonev committed
501
        return psi
502

Boris Bonev's avatar
Boris Bonev committed
503
    def forward(self, x: torch.Tensor) -> torch.Tensor:
504
        # extract shape
Boris Bonev's avatar
Boris Bonev committed
505
        B, C, H, W = x.shape
506
507
508
        x = x.reshape(B, self.groups, self.groupsize, H, W)

        # do weight multiplication
509
        x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2]))
Boris Bonev's avatar
Boris Bonev committed
510
        x = x.reshape(B, -1, x.shape[-3], H, W)
511

Boris Bonev's avatar
Boris Bonev committed
512
513
514
515
        if x.is_cuda and _cuda_extension_available:
            out = _disco_s2_transpose_contraction_cuda(
                x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out
            )
516
        else:
Boris Bonev's avatar
Boris Bonev committed
517
518
519
            if x.is_cuda:
                warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
            psi = self.get_psi(semi_transposed=True)
Boris Bonev's avatar
Boris Bonev committed
520
            out = _disco_s2_transpose_contraction_torch(x, psi, self.nlon_out)
521
522
523
524
525

        if self.bias is not None:
            out = out + self.bias.reshape(1, -1, 1, 1)

        return out