convolution.py 24.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

Boris Bonev's avatar
Boris Bonev committed
32
import abc
33
from typing import List, Tuple, Union, Optional
Boris Bonev's avatar
Boris Bonev committed
34
from warnings import warn
35
36
37
38
39
40
41
42

import math

import torch
import torch.nn as nn

from functools import partial

Thorsten Kurth's avatar
Thorsten Kurth committed
43
44
from torch_harmonics.cache import lru_cache
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes
45
from torch_harmonics._disco_convolution import _get_psi, _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch
Boris Bonev's avatar
Boris Bonev committed
46
from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda
Thorsten Kurth's avatar
Thorsten Kurth committed
47
from torch_harmonics.filter_basis import FilterBasis, get_filter_basis
48

49
# import custom C++/CUDA extensions if available
Boris Bonev's avatar
Boris Bonev committed
50
try:
51
    from disco_helpers import preprocess_psi
Boris Bonev's avatar
Boris Bonev committed
52
    import disco_cuda_extension
53

Boris Bonev's avatar
Boris Bonev committed
54
55
56
57
58
59
    _cuda_extension_available = True
except ImportError as err:
    disco_cuda_extension = None
    _cuda_extension_available = False


60
def _normalize_convolution_tensor_s2(
61
    psi_idx, psi_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=False, basis_norm_mode="mean", merge_quadrature=False, eps=1e-9
62
):
63
    """Normalizes convolution tensor values based on specified normalization mode.
64

65
66
67
68
    This function applies different normalization strategies to the convolution tensor
    values based on the basis_norm_mode parameter. It can normalize individual basis
    functions, compute mean normalization across all basis functions, or use support
    weights. The function also optionally merges quadrature weights into the tensor.
Andrea Paris's avatar
Andrea Paris committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101

    Parameters
    -----------
    psi_idx: torch.Tensor
        Index tensor for the sparse convolution tensor.
    psi_vals: torch.Tensor
        Value tensor for the sparse convolution tensor.
    in_shape: Tuple[int]
        Tuple of (nlat_in, nlon_in) representing input grid dimensions.
    out_shape: Tuple[int]
        Tuple of (nlat_out, nlon_out) representing output grid dimensions.
    kernel_size: int
        Number of kernel basis functions.
    quad_weights: torch.Tensor
        Quadrature weights for numerical integration.
    transpose_normalization: bool
        If True, applies normalization in transpose direction.
    basis_norm_mode: str
        Normalization mode, one of ["none", "individual", "mean", "support"].
    merge_quadrature: bool
        If True, multiplies values by quadrature weights.
    eps: float
        Small epsilon value to prevent division by zero.

    Returns
    -------
    torch.Tensor
        Normalized convolution tensor values.

    Raises
    ------
    ValueError
        If basis_norm_mode is not one of the supported modes.
Boris Bonev's avatar
Boris Bonev committed
102
103
    """

Thorsten Kurth's avatar
Thorsten Kurth committed
104
105
106
107
    # exit here if no normalization is needed
    if basis_norm_mode == "none":
        return psi_vals

108
109
    # reshape the indices implicitly to be ikernel, out_shape[0], in_shape[0], in_shape[1]
    idx = torch.stack([psi_idx[0], psi_idx[1], psi_idx[2] // in_shape[1], psi_idx[2] % in_shape[1]], dim=0)
Boris Bonev's avatar
Boris Bonev committed
110

111
112
    # getting indices for adressing kernels, input and output latitudes
    ikernel = idx[0]
Boris Bonev's avatar
Boris Bonev committed
113
114

    if transpose_normalization:
115
116
117
118
119
        ilat_out = idx[2]
        ilat_in = idx[1]
        # here we are deliberately swapping input and output shapes to handle transpose normalization with the same code
        nlat_out = in_shape[0]
        correction_factor = out_shape[1] / in_shape[1]
Boris Bonev's avatar
Boris Bonev committed
120
    else:
121
122
123
124
125
126
127
128
        ilat_out = idx[1]
        ilat_in = idx[2]
        nlat_out = out_shape[0]

    # get the quadrature weights
    q = quad_weights[ilat_in].reshape(-1)

    # buffer to store intermediate values
129
130
    vnorm = torch.zeros(kernel_size, nlat_out, device=psi_vals.device)
    support = torch.zeros(kernel_size, nlat_out, device=psi_vals.device)
131
132
133
134
135
136
137
138

    # loop through dimensions to compute the norms
    for ik in range(kernel_size):
        for ilat in range(nlat_out):

            # find indices corresponding to the given output latitude and kernel basis function
            iidx = torch.argwhere((ikernel == ik) & (ilat_out == ilat))

139
140
141
            # compute the 1-norm
            # vnorm[ik, ilat] = torch.sqrt(torch.sum(psi_vals[iidx].abs().pow(2) * q[iidx]))
            vnorm[ik, ilat] = torch.sum(psi_vals[iidx].abs() * q[iidx])
142

143
144
145
            # compute the support
            support[ik, ilat] = torch.sum(q[iidx])

146
147
148
149
150
151
152
153
154
155
    # loop over values and renormalize
    for ik in range(kernel_size):
        for ilat in range(nlat_out):

            iidx = torch.argwhere((ikernel == ik) & (ilat_out == ilat))

            if basis_norm_mode == "individual":
                val = vnorm[ik, ilat]
            elif basis_norm_mode == "mean":
                val = vnorm[ik, :].mean()
156
157
            elif basis_norm_mode == "support":
                val = support[ik, ilat]
158
159
160
161
162
            elif basis_norm_mode == "none":
                val = 1.0
            else:
                raise ValueError(f"Unknown basis normalization mode {basis_norm_mode}.")

163
164
            psi_vals[iidx] = psi_vals[iidx] / (val + eps)

165
            if merge_quadrature:
166
                psi_vals[iidx] = psi_vals[iidx] * q[iidx]
167
168
169

    if transpose_normalization and merge_quadrature:
        psi_vals = psi_vals / correction_factor
Boris Bonev's avatar
Boris Bonev committed
170
171
172
173

    return psi_vals


Thorsten Kurth's avatar
Thorsten Kurth committed
174
@lru_cache(typed=True, copy=True)
Boris Bonev's avatar
Boris Bonev committed
175
def _precompute_convolution_tensor_s2(
Thorsten Kurth's avatar
Thorsten Kurth committed
176
177
178
    in_shape: Tuple[int],
    out_shape: Tuple[int],
    filter_basis: FilterBasis,
179
180
181
182
183
184
185
    grid_in: Optional[str] = "equiangular",
    grid_out: Optional[str] = "equiangular",
    theta_cutoff: Optional[float] = 0.01 * math.pi,
    theta_eps: Optional[float] = 1e-3,
    transpose_normalization: Optional[bool] = False,
    basis_norm_mode: Optional[str] = "mean",
    merge_quadrature: Optional[bool] = False,
Boris Bonev's avatar
Boris Bonev committed
186
):
187
188
189
    """
    Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
    Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al.
190
191
192
193
194
    The output tensor has shape kernel_shape x nlat_out x (nlat_in * nlon_in).

    The rotation of the Euler angles uses the YZY convention, which applied to the northpole $(0,0,1)^T$ yields
    $$
    Y(\alpha) Z(\beta) Y(\gamma) n =
Boris Bonev's avatar
Boris Bonev committed
195
        {\begin{bmatrix}
196
197
198
199
200
            \cos(\gamma)\sin(\alpha) + \cos(\alpha)\cos(\beta)\sin(\gamma) \\
            \sin(\beta)\sin(\gamma) \\
            \cos(\alpha)\cos(\gamma)-\cos(\beta)\sin(\alpha)\sin(\gamma)
        \end{bmatrix}}
    $$
apaaris's avatar
apaaris committed
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231

    Parameters
    -----------
    in_shape: Tuple[int]
        Input shape of the convolution tensor
    out_shape: Tuple[int]
        Output shape of the convolution tensor
    filter_basis: FilterBasis
        Filter basis functions
    grid_in: str
        Input grid type
    grid_out: str
        Output grid type
    theta_cutoff: float
        Theta cutoff for the filter basis functions
    theta_eps: float
        Epsilon for the theta cutoff
    transpose_normalization: bool
        Whether to normalize the convolution tensor in the transpose direction
    basis_norm_mode: str
        Mode for basis normalization
    merge_quadrature: bool
        Whether to merge the quadrature weights into the convolution tensor

    Returns
    -------
    out_idx: torch.Tensor
        Index tensor of the convolution tensor
    out_vals: torch.Tensor
        Values tensor of the convolution tensor

232
233
234
235
236
    """

    assert len(in_shape) == 2
    assert len(out_shape) == 2

237
    kernel_size = filter_basis.kernel_size
238
239
240
241

    nlat_in, nlon_in = in_shape
    nlat_out, nlon_out = out_shape

242
    # precompute input and output grids
Boris Bonev's avatar
Boris Bonev committed
243
244
    lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in)
    lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out)
245
246

    # compute the phi differences
247
    # It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
Thorsten Kurth's avatar
Thorsten Kurth committed
248
    lons_in = _precompute_longitudes(nlon_in)
249

250
251
    # compute quadrature weights and merge them into the convolution tensor.
    # These quadrature integrate to 1 over the sphere.
252
    if transpose_normalization:
Thorsten Kurth's avatar
Thorsten Kurth committed
253
        quad_weights = wout.reshape(-1, 1) / nlon_in / 2.0
254
    else:
Thorsten Kurth's avatar
Thorsten Kurth committed
255
        quad_weights = win.reshape(-1, 1) / nlon_in / 2.0
256
257
258

    # effective theta cutoff if multiplied with a fudge factor to avoid aliasing with grid width (especially near poles)
    theta_cutoff_eff = (1.0 + theta_eps) * theta_cutoff
259

260
261
    out_idx = []
    out_vals = []
Thorsten Kurth's avatar
Thorsten Kurth committed
262
263
264
265
266
267
268
269
270
271
272

    beta = lons_in
    gamma = lats_in.reshape(-1, 1)

    # compute trigs
    cbeta = torch.cos(beta)
    sbeta = torch.sin(beta)
    cgamma = torch.cos(gamma)
    sgamma = torch.sin(gamma)

    # compute row offsets
273
    out_roff = torch.zeros(nlat_out + 1, dtype=torch.int64, device=lons_in.device)
Thorsten Kurth's avatar
Thorsten Kurth committed
274
    out_roff[0] = 0
275
    for t in range(nlat_out):
276
        # the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis
Boris Bonev's avatar
Boris Bonev committed
277
        alpha = -lats_out[t]
278
279

        # compute cartesian coordinates of the rotated position
280
281
        # This uses the YZY convention of Euler angles, where the last angle (alpha) is a passive rotation,
        # and therefore applied with a negative sign
Thorsten Kurth's avatar
Thorsten Kurth committed
282
283
284
        x = torch.cos(alpha) * cbeta * sgamma + cgamma * torch.sin(alpha)
        y = sbeta * sgamma
        z = -cbeta * torch.sin(alpha) * sgamma + torch.cos(alpha) * cgamma
Boris Bonev's avatar
Boris Bonev committed
285

286
        # normalization is important to avoid NaNs when arccos and atan are applied
287
        # this can otherwise lead to spurious artifacts in the solution
Boris Bonev's avatar
Boris Bonev committed
288
        norm = torch.sqrt(x * x + y * y + z * z)
289
290
291
292
        x = x / norm
        y = y / norm
        z = z / norm

Boris Bonev's avatar
Boris Bonev committed
293
        # compute spherical coordinates, where phi needs to fall into the [0, 2pi) range
294
        theta = torch.arccos(z)
295
296
        phi = torch.arctan2(y, x)
        phi = torch.where(phi < 0.0, phi + 2 * torch.pi, phi)
297
298

        # find the indices where the rotated position falls into the support of the kernel
299
        iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff_eff)
300
301
302
303

        # add the output latitude and reshape such that psi has dimensions kernel_shape x nlat_out x (nlat_in*nlon_in)
        idx = torch.stack([iidx[:, 0], t * torch.ones_like(iidx[:, 0]), iidx[:, 1] * nlon_in + iidx[:, 2]], dim=0)

Thorsten Kurth's avatar
Thorsten Kurth committed
304
        # append indices and values to the COO datastructure, compute row offsets
305
306
        out_idx.append(idx)
        out_vals.append(vals)
Thorsten Kurth's avatar
Thorsten Kurth committed
307
        out_roff[t + 1] = out_roff[t] + iidx.shape[0]
308
309

    # concatenate the indices and values
310
311
    out_idx = torch.cat(out_idx, dim=-1)
    out_vals = torch.cat(out_vals, dim=-1)
312

313
    out_vals = _normalize_convolution_tensor_s2(
314
315
316
317
318
319
320
321
322
        out_idx,
        out_vals,
        in_shape,
        out_shape,
        kernel_size,
        quad_weights,
        transpose_normalization=transpose_normalization,
        basis_norm_mode=basis_norm_mode,
        merge_quadrature=merge_quadrature,
323
    )
Boris Bonev's avatar
Boris Bonev committed
324

325
326
327
    out_idx = out_idx.contiguous()
    out_vals = out_vals.to(dtype=torch.float32).contiguous()

Thorsten Kurth's avatar
Thorsten Kurth committed
328
    return out_idx, out_vals, out_roff
Boris Bonev's avatar
Boris Bonev committed
329
330
331
332


class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta):
    """
333
    Abstract base class for discrete-continuous convolutions
apaaris's avatar
apaaris committed
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353

    Parameters
    -----------
    in_channels: int
        Number of input channels
    out_channels: int
        Number of output channels
    kernel_shape: Union[int, Tuple[int], Tuple[int, int]]
        Shape of the kernel
    basis_type: Optional[str]
        Type of the basis functions
    groups: Optional[int]
        Number of groups
    bias: Optional[bool]
        Whether to use bias

    Returns
    -------
    out: torch.Tensor
        Output tensor
Boris Bonev's avatar
Boris Bonev committed
354
355
356
357
358
359
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
Thorsten Kurth's avatar
Thorsten Kurth committed
360
        kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
361
        basis_type: Optional[str] = "piecewise linear",
Boris Bonev's avatar
Boris Bonev committed
362
363
364
365
366
        groups: Optional[int] = 1,
        bias: Optional[bool] = True,
    ):
        super().__init__()

367
        self.kernel_shape = kernel_shape
Boris Bonev's avatar
Boris Bonev committed
368

369
        # get the filter basis functions
370
        self.filter_basis = get_filter_basis(kernel_shape=kernel_shape, basis_type=basis_type)
Boris Bonev's avatar
Boris Bonev committed
371
372
373
374
375
376
377
378
379
380

        # groups
        self.groups = groups

        # weight tensor
        if in_channels % self.groups != 0:
            raise ValueError("Error, the number of input channels has to be an integer multiple of the group size")
        if out_channels % self.groups != 0:
            raise ValueError("Error, the number of output channels has to be an integer multiple of the group size")
        self.groupsize = in_channels // self.groups
Boris Bonev's avatar
Boris Bonev committed
381
        scale = math.sqrt(1.0 / self.groupsize / self.kernel_size)
Boris Bonev's avatar
Boris Bonev committed
382
383
384
385
386
387
388
        self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, self.kernel_size))

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channels))
        else:
            self.bias = None

389
390
391
392
    @property
    def kernel_size(self):
        return self.filter_basis.kernel_size

Boris Bonev's avatar
Boris Bonev committed
393
394
395
396
397
398
    @abc.abstractmethod
    def forward(self, x: torch.Tensor):
        raise NotImplementedError


class DiscreteContinuousConvS2(DiscreteContinuousConv):
399
    """
400
    Discrete-continuous (DISCO) convolutions on the 2-Sphere as described in [1].
401

apaaris's avatar
apaaris committed
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
    Parameters
    -----------
    in_channels: int
        Number of input channels
    out_channels: int
        Number of output channels
    in_shape: Tuple[int]
        Input shape of the convolution tensor
    out_shape: Tuple[int]
        Output shape of the convolution tensor
    kernel_shape: Union[int, Tuple[int], Tuple[int, int]]
        Shape of the kernel
    basis_type: Optional[str]
        Type of the basis functions
    basis_norm_mode: Optional[str]
        Mode for basis normalization
    groups: Optional[int]
        Number of groups
    grid_in: Optional[str]
        Input grid type
    grid_out: Optional[str]
        Output grid type
    bias: Optional[bool]
        Whether to use bias
    theta_cutoff: Optional[float]
        Theta cutoff for the filter basis functions

    Returns
    -------
    out: torch.Tensor
        Output tensor

    References
    ----------
436
437
438
439
440
441
442
443
444
    [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        in_shape: Tuple[int],
        out_shape: Tuple[int],
Thorsten Kurth's avatar
Thorsten Kurth committed
445
        kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
446
        basis_type: Optional[str] = "piecewise linear",
447
        basis_norm_mode: Optional[str] = "mean",
448
449
450
451
452
453
        groups: Optional[int] = 1,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        bias: Optional[bool] = True,
        theta_cutoff: Optional[float] = None,
    ):
454
        super().__init__(in_channels, out_channels, kernel_shape, basis_type, groups, bias)
455
456
457
458

        self.nlat_in, self.nlon_in = in_shape
        self.nlat_out, self.nlon_out = out_shape

459
460
461
        # make sure the p-shift works by checking that longitudes are divisible
        assert self.nlon_in % self.nlon_out == 0

Boris Bonev's avatar
Boris Bonev committed
462
        # heuristic to compute theta cutoff based on the bandlimit of the input field and overlaps of the basis functions
463
        if theta_cutoff is None:
Boris Bonev's avatar
Boris Bonev committed
464
            theta_cutoff = torch.pi / float(self.nlat_out - 1)
465
466
467
468

        if theta_cutoff <= 0.0:
            raise ValueError("Error, theta_cutoff has to be positive.")

Thorsten Kurth's avatar
Thorsten Kurth committed
469
        idx, vals, _ = _precompute_convolution_tensor_s2(
470
471
472
473
474
475
476
477
478
            in_shape,
            out_shape,
            self.filter_basis,
            grid_in=grid_in,
            grid_out=grid_out,
            theta_cutoff=theta_cutoff,
            transpose_normalization=False,
            basis_norm_mode=basis_norm_mode,
            merge_quadrature=True,
Boris Bonev's avatar
Boris Bonev committed
479
480
481
482
483
484
        )

        # sort the values
        ker_idx = idx[0, ...].contiguous()
        row_idx = idx[1, ...].contiguous()
        col_idx = idx[2, ...].contiguous()
485
        vals = vals.contiguous()
Boris Bonev's avatar
Boris Bonev committed
486

487
488
        if _cuda_extension_available:
            # preprocessed data-structure for GPU kernel
Thorsten Kurth's avatar
Thorsten Kurth committed
489
            roff_idx = preprocess_psi(self.kernel_size, self.nlat_out, ker_idx, row_idx, col_idx, vals).contiguous()
490
491
492
            self.register_buffer("psi_roff_idx", roff_idx, persistent=False)

        # save all datastructures
Boris Bonev's avatar
Boris Bonev committed
493
494
495
        self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
        self.register_buffer("psi_row_idx", row_idx, persistent=False)
        self.register_buffer("psi_col_idx", col_idx, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
496
        self.register_buffer("psi_vals", vals, persistent=False)
497

498
499
500
        # also store psi as COO matrix just in case for torch input
        self.psi = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out)

501
    def extra_repr(self):
502
        return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
503

Boris Bonev's avatar
Boris Bonev committed
504
505
506
507
508
    @property
    def psi_idx(self):
        return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
509

Boris Bonev's avatar
Boris Bonev committed
510
511
512
513
        if x.is_cuda and _cuda_extension_available:
            x = _disco_s2_contraction_cuda(
                x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out
            )
514
        else:
Boris Bonev's avatar
Boris Bonev committed
515
516
            if x.is_cuda:
                warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
517
            x = _disco_s2_contraction_torch(x, self.psi.to(x.device), self.nlon_out)
518
519
520
521
522
523

        # extract shape
        B, C, K, H, W = x.shape
        x = x.reshape(B, self.groups, self.groupsize, K, H, W)

        # do weight multiplication
Thorsten Kurth's avatar
Thorsten Kurth committed
524
        out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
Boris Bonev's avatar
Boris Bonev committed
525
        out = out.reshape(B, -1, H, W)
526
527
528
529
530
531
532

        if self.bias is not None:
            out = out + self.bias.reshape(1, -1, 1, 1)

        return out


Boris Bonev's avatar
Boris Bonev committed
533
class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
534
    """
535
    Discrete-continuous (DISCO) transpose convolutions on the 2-Sphere as described in [1].
536

apaaris's avatar
apaaris committed
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
    Parameters
    -----------
    in_channels: int
        Number of input channels
    out_channels: int
        Number of output channels
    in_shape: Tuple[int]
        Input shape of the convolution tensor
    out_shape: Tuple[int]
        Output shape of the convolution tensor
    kernel_shape: Union[int, Tuple[int], Tuple[int, int]]
        Shape of the kernel
    basis_type: Optional[str]
        Type of the basis functions
    basis_norm_mode: Optional[str]
        Mode for basis normalization
    groups: Optional[int]
        Number of groups
    grid_in: Optional[str]
        Input grid type
    grid_out: Optional[str]
        Output grid type
    bias: Optional[bool]
        Whether to use bias
    theta_cutoff: Optional[float]
        Theta cutoff for the filter basis functions
563

apaaris's avatar
apaaris committed
564
565
566
567
568
569
570
    Returns
    --------
    out: torch.Tensor
        Output tensor

    References
    ----------
571
572
573
574
575
576
577
578
579
    [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        in_shape: Tuple[int],
        out_shape: Tuple[int],
Thorsten Kurth's avatar
Thorsten Kurth committed
580
        kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
581
        basis_type: Optional[str] = "piecewise linear",
582
        basis_norm_mode: Optional[str] = "mean",
583
584
585
586
587
588
        groups: Optional[int] = 1,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        bias: Optional[bool] = True,
        theta_cutoff: Optional[float] = None,
    ):
589
        super().__init__(in_channels, out_channels, kernel_shape, basis_type, groups, bias)
590
591
592
593

        self.nlat_in, self.nlon_in = in_shape
        self.nlat_out, self.nlon_out = out_shape

594
595
596
        # make sure the p-shift works by checking that longitudes are divisible
        assert self.nlon_out % self.nlon_in == 0

597
598
        # bandlimit
        if theta_cutoff is None:
Boris Bonev's avatar
Boris Bonev committed
599
            theta_cutoff = torch.pi / float(self.nlat_in - 1)
600
601
602
603

        if theta_cutoff <= 0.0:
            raise ValueError("Error, theta_cutoff has to be positive.")

604
        # switch in_shape and out_shape since we want the transpose convolution
Thorsten Kurth's avatar
Thorsten Kurth committed
605
        idx, vals, _ = _precompute_convolution_tensor_s2(
606
607
608
609
610
611
612
613
614
            out_shape,
            in_shape,
            self.filter_basis,
            grid_in=grid_out,
            grid_out=grid_in,
            theta_cutoff=theta_cutoff,
            transpose_normalization=True,
            basis_norm_mode=basis_norm_mode,
            merge_quadrature=True,
Boris Bonev's avatar
Boris Bonev committed
615
616
617
618
619
620
        )

        # sort the values
        ker_idx = idx[0, ...].contiguous()
        row_idx = idx[1, ...].contiguous()
        col_idx = idx[2, ...].contiguous()
621
        vals = vals.contiguous()
Boris Bonev's avatar
Boris Bonev committed
622

623
624
        if _cuda_extension_available:
            # preprocessed data-structure for GPU kernel
Thorsten Kurth's avatar
Thorsten Kurth committed
625
            roff_idx = preprocess_psi(self.kernel_size, self.nlat_in, ker_idx, row_idx, col_idx, vals).contiguous()
626
627
628
            self.register_buffer("psi_roff_idx", roff_idx, persistent=False)

        # save all datastructures
Boris Bonev's avatar
Boris Bonev committed
629
630
631
        self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
        self.register_buffer("psi_row_idx", row_idx, persistent=False)
        self.register_buffer("psi_col_idx", col_idx, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
632
        self.register_buffer("psi_vals", vals, persistent=False)
633

634
635
636
        # also store psi just in case
        self.psi_st = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, semi_transposed=True)

637
    def extra_repr(self):
638
        return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
639

Boris Bonev's avatar
Boris Bonev committed
640
641
642
643
644
    @property
    def psi_idx(self):
        return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
645

646
        # extract shape
Boris Bonev's avatar
Boris Bonev committed
647
        B, C, H, W = x.shape
648
649
650
        x = x.reshape(B, self.groups, self.groupsize, H, W)

        # do weight multiplication
Thorsten Kurth's avatar
Thorsten Kurth committed
651
        x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
Boris Bonev's avatar
Boris Bonev committed
652
        x = x.reshape(B, -1, x.shape[-3], H, W)
653

Boris Bonev's avatar
Boris Bonev committed
654
655
656
657
        if x.is_cuda and _cuda_extension_available:
            out = _disco_s2_transpose_contraction_cuda(
                x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out
            )
658
        else:
Boris Bonev's avatar
Boris Bonev committed
659
660
            if x.is_cuda:
                warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
661
            out = _disco_s2_transpose_contraction_torch(x, self.psi_st.to(x.device), self.nlon_out)
662
663
664
665
666

        if self.bias is not None:
            out = out + self.bias.reshape(1, -1, 1, 1)

        return out