convolution.py 25.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

Boris Bonev's avatar
Boris Bonev committed
32
import abc
33
from typing import List, Tuple, Union, Optional
Boris Bonev's avatar
Boris Bonev committed
34
from warnings import warn
35
36
37
38
39
40
41
42

import math

import torch
import torch.nn as nn

from functools import partial

Thorsten Kurth's avatar
Thorsten Kurth committed
43
44
from torch_harmonics.cache import lru_cache
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes
45
from torch_harmonics._disco_convolution import _get_psi, _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch
Boris Bonev's avatar
Boris Bonev committed
46
from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda
Thorsten Kurth's avatar
Thorsten Kurth committed
47
from torch_harmonics.filter_basis import FilterBasis, get_filter_basis
48

49
# import custom C++/CUDA extensions if available
Boris Bonev's avatar
Boris Bonev committed
50
try:
51
    from disco_helpers import preprocess_psi
Boris Bonev's avatar
Boris Bonev committed
52
    import disco_cuda_extension
53

Boris Bonev's avatar
Boris Bonev committed
54
55
56
57
58
59
    _cuda_extension_available = True
except ImportError as err:
    disco_cuda_extension = None
    _cuda_extension_available = False


60
def _normalize_convolution_tensor_s2(
61
    psi_idx, psi_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=False, basis_norm_mode="mean", merge_quadrature=False, eps=1e-9
62
):
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
    """Normalizes convolution tensor values based on specified normalization mode.
    
    This function applies different normalization strategies to the convolution tensor
    values based on the basis_norm_mode parameter. It can normalize individual basis
    functions, compute mean normalization across all basis functions, or use support
    weights. The function also optionally merges quadrature weights into the tensor.
    
    Args:
        psi_idx: Index tensor for the sparse convolution tensor.
        psi_vals: Value tensor for the sparse convolution tensor.
        in_shape: Tuple of (nlat_in, nlon_in) representing input grid dimensions.
        out_shape: Tuple of (nlat_out, nlon_out) representing output grid dimensions.
        kernel_size: Number of kernel basis functions.
        quad_weights: Quadrature weights for numerical integration.
        transpose_normalization: If True, applies normalization in transpose direction.
        basis_norm_mode: Normalization mode, one of ["none", "individual", "mean", "support"].
        merge_quadrature: If True, multiplies values by quadrature weights.
        eps: Small epsilon value to prevent division by zero.
    
    Returns:
        torch.Tensor: Normalized convolution tensor values.
    
    Raises:
        ValueError: If basis_norm_mode is not one of the supported modes.
Boris Bonev's avatar
Boris Bonev committed
87
88
    """

Thorsten Kurth's avatar
Thorsten Kurth committed
89
90
91
92
    # exit here if no normalization is needed
    if basis_norm_mode == "none":
        return psi_vals

93
94
    # reshape the indices implicitly to be ikernel, out_shape[0], in_shape[0], in_shape[1]
    idx = torch.stack([psi_idx[0], psi_idx[1], psi_idx[2] // in_shape[1], psi_idx[2] % in_shape[1]], dim=0)
Boris Bonev's avatar
Boris Bonev committed
95

96
97
    # getting indices for adressing kernels, input and output latitudes
    ikernel = idx[0]
Boris Bonev's avatar
Boris Bonev committed
98
99

    if transpose_normalization:
100
101
102
103
104
        ilat_out = idx[2]
        ilat_in = idx[1]
        # here we are deliberately swapping input and output shapes to handle transpose normalization with the same code
        nlat_out = in_shape[0]
        correction_factor = out_shape[1] / in_shape[1]
Boris Bonev's avatar
Boris Bonev committed
105
    else:
106
107
108
109
110
111
112
113
        ilat_out = idx[1]
        ilat_in = idx[2]
        nlat_out = out_shape[0]

    # get the quadrature weights
    q = quad_weights[ilat_in].reshape(-1)

    # buffer to store intermediate values
114
115
    vnorm = torch.zeros(kernel_size, nlat_out, device=psi_vals.device)
    support = torch.zeros(kernel_size, nlat_out, device=psi_vals.device)
116
117
118
119
120
121
122
123

    # loop through dimensions to compute the norms
    for ik in range(kernel_size):
        for ilat in range(nlat_out):

            # find indices corresponding to the given output latitude and kernel basis function
            iidx = torch.argwhere((ikernel == ik) & (ilat_out == ilat))

124
125
126
            # compute the 1-norm
            # vnorm[ik, ilat] = torch.sqrt(torch.sum(psi_vals[iidx].abs().pow(2) * q[iidx]))
            vnorm[ik, ilat] = torch.sum(psi_vals[iidx].abs() * q[iidx])
127

128
129
130
131
            # compute the support
            support[ik, ilat] = torch.sum(q[iidx])


132
133
134
135
136
137
138
139
140
141
    # loop over values and renormalize
    for ik in range(kernel_size):
        for ilat in range(nlat_out):

            iidx = torch.argwhere((ikernel == ik) & (ilat_out == ilat))

            if basis_norm_mode == "individual":
                val = vnorm[ik, ilat]
            elif basis_norm_mode == "mean":
                val = vnorm[ik, :].mean()
142
143
            elif basis_norm_mode == "support":
                val = support[ik, ilat]
144
145
146
147
148
            elif basis_norm_mode == "none":
                val = 1.0
            else:
                raise ValueError(f"Unknown basis normalization mode {basis_norm_mode}.")

149
150
            psi_vals[iidx] = psi_vals[iidx] / (val + eps)

151
            if merge_quadrature:
152
                psi_vals[iidx] = psi_vals[iidx] * q[iidx]
153
154
155
156


    if transpose_normalization and merge_quadrature:
        psi_vals = psi_vals / correction_factor
Boris Bonev's avatar
Boris Bonev committed
157
158
159
160

    return psi_vals


Thorsten Kurth's avatar
Thorsten Kurth committed
161
@lru_cache(typed=True, copy=True)
Boris Bonev's avatar
Boris Bonev committed
162
def _precompute_convolution_tensor_s2(
Thorsten Kurth's avatar
Thorsten Kurth committed
163
164
165
166
167
168
169
170
171
172
    in_shape: Tuple[int],
    out_shape: Tuple[int],
    filter_basis: FilterBasis,
    grid_in: Optional[str]="equiangular",
    grid_out: Optional[str]="equiangular",
    theta_cutoff: Optional[float]=0.01 * math.pi,
    theta_eps: Optional[float]=1e-3,
    transpose_normalization: Optional[bool]=False,
    basis_norm_mode: Optional[str]="mean",
    merge_quadrature: Optional[bool]=False,
Boris Bonev's avatar
Boris Bonev committed
173
):
174
175
176
    """
    Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
    Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al.
177
178
179
180
181
    The output tensor has shape kernel_shape x nlat_out x (nlat_in * nlon_in).

    The rotation of the Euler angles uses the YZY convention, which applied to the northpole $(0,0,1)^T$ yields
    $$
    Y(\alpha) Z(\beta) Y(\gamma) n =
Boris Bonev's avatar
Boris Bonev committed
182
        {\begin{bmatrix}
183
184
185
186
187
            \cos(\gamma)\sin(\alpha) + \cos(\alpha)\cos(\beta)\sin(\gamma) \\
            \sin(\beta)\sin(\gamma) \\
            \cos(\alpha)\cos(\gamma)-\cos(\beta)\sin(\alpha)\sin(\gamma)
        \end{bmatrix}}
    $$
apaaris's avatar
apaaris committed
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218

    Parameters
    -----------
    in_shape: Tuple[int]
        Input shape of the convolution tensor
    out_shape: Tuple[int]
        Output shape of the convolution tensor
    filter_basis: FilterBasis
        Filter basis functions
    grid_in: str
        Input grid type
    grid_out: str
        Output grid type
    theta_cutoff: float
        Theta cutoff for the filter basis functions
    theta_eps: float
        Epsilon for the theta cutoff
    transpose_normalization: bool
        Whether to normalize the convolution tensor in the transpose direction
    basis_norm_mode: str
        Mode for basis normalization
    merge_quadrature: bool
        Whether to merge the quadrature weights into the convolution tensor

    Returns
    -------
    out_idx: torch.Tensor
        Index tensor of the convolution tensor
    out_vals: torch.Tensor
        Values tensor of the convolution tensor

219
220
221
222
223
    """

    assert len(in_shape) == 2
    assert len(out_shape) == 2

224
    kernel_size = filter_basis.kernel_size
225
226
227
228

    nlat_in, nlon_in = in_shape
    nlat_out, nlon_out = out_shape

229
    # precompute input and output grids
Boris Bonev's avatar
Boris Bonev committed
230
231
    lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in)
    lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out)
232
233

    # compute the phi differences
234
    # It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
Thorsten Kurth's avatar
Thorsten Kurth committed
235
    lons_in = _precompute_longitudes(nlon_in)
236

237
238
    # compute quadrature weights and merge them into the convolution tensor.
    # These quadrature integrate to 1 over the sphere.
239
    if transpose_normalization:
Thorsten Kurth's avatar
Thorsten Kurth committed
240
        quad_weights = wout.reshape(-1, 1) / nlon_in / 2.0
241
    else:
Thorsten Kurth's avatar
Thorsten Kurth committed
242
        quad_weights = win.reshape(-1, 1) / nlon_in / 2.0
243
244
245

    # effective theta cutoff if multiplied with a fudge factor to avoid aliasing with grid width (especially near poles)
    theta_cutoff_eff = (1.0 + theta_eps) * theta_cutoff
246

247
248
    out_idx = []
    out_vals = []
Thorsten Kurth's avatar
Thorsten Kurth committed
249
250
251
252
253
254
255
256
257
258
259

    beta = lons_in
    gamma = lats_in.reshape(-1, 1)

    # compute trigs
    cbeta = torch.cos(beta)
    sbeta = torch.sin(beta)
    cgamma = torch.cos(gamma)
    sgamma = torch.sin(gamma)

    # compute row offsets
260
    out_roff = torch.zeros(nlat_out + 1, dtype=torch.int64, device=lons_in.device)
Thorsten Kurth's avatar
Thorsten Kurth committed
261
    out_roff[0] = 0
262
    for t in range(nlat_out):
263
        # the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis
Boris Bonev's avatar
Boris Bonev committed
264
        alpha = -lats_out[t]
265
266

        # compute cartesian coordinates of the rotated position
267
268
        # This uses the YZY convention of Euler angles, where the last angle (alpha) is a passive rotation,
        # and therefore applied with a negative sign
Thorsten Kurth's avatar
Thorsten Kurth committed
269
270
271
        x = torch.cos(alpha) * cbeta * sgamma + cgamma * torch.sin(alpha)
        y = sbeta * sgamma
        z = -cbeta * torch.sin(alpha) * sgamma + torch.cos(alpha) * cgamma
Boris Bonev's avatar
Boris Bonev committed
272

273
        # normalization is important to avoid NaNs when arccos and atan are applied
274
        # this can otherwise lead to spurious artifacts in the solution
Boris Bonev's avatar
Boris Bonev committed
275
        norm = torch.sqrt(x * x + y * y + z * z)
276
277
278
279
        x = x / norm
        y = y / norm
        z = z / norm

Boris Bonev's avatar
Boris Bonev committed
280
        # compute spherical coordinates, where phi needs to fall into the [0, 2pi) range
281
        theta = torch.arccos(z)
282
283
        phi = torch.arctan2(y, x)
        phi = torch.where(phi < 0.0, phi + 2 * torch.pi, phi)
284
285

        # find the indices where the rotated position falls into the support of the kernel
286
        iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff_eff)
287
288
289
290

        # add the output latitude and reshape such that psi has dimensions kernel_shape x nlat_out x (nlat_in*nlon_in)
        idx = torch.stack([iidx[:, 0], t * torch.ones_like(iidx[:, 0]), iidx[:, 1] * nlon_in + iidx[:, 2]], dim=0)

Thorsten Kurth's avatar
Thorsten Kurth committed
291
        # append indices and values to the COO datastructure, compute row offsets
292
293
        out_idx.append(idx)
        out_vals.append(vals)
Thorsten Kurth's avatar
Thorsten Kurth committed
294
        out_roff[t + 1] = out_roff[t] + iidx.shape[0]
295
296

    # concatenate the indices and values
297
298
    out_idx = torch.cat(out_idx, dim=-1)
    out_vals = torch.cat(out_vals, dim=-1)
299

300
    out_vals = _normalize_convolution_tensor_s2(
301
302
303
304
305
306
307
308
309
        out_idx,
        out_vals,
        in_shape,
        out_shape,
        kernel_size,
        quad_weights,
        transpose_normalization=transpose_normalization,
        basis_norm_mode=basis_norm_mode,
        merge_quadrature=merge_quadrature,
310
    )
Boris Bonev's avatar
Boris Bonev committed
311

312
313
314
    out_idx = out_idx.contiguous()
    out_vals = out_vals.to(dtype=torch.float32).contiguous()

Thorsten Kurth's avatar
Thorsten Kurth committed
315
    return out_idx, out_vals, out_roff
Boris Bonev's avatar
Boris Bonev committed
316
317
318
319


class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta):
    """
320
    Abstract base class for discrete-continuous convolutions
apaaris's avatar
apaaris committed
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340

    Parameters
    -----------
    in_channels: int
        Number of input channels
    out_channels: int
        Number of output channels
    kernel_shape: Union[int, Tuple[int], Tuple[int, int]]
        Shape of the kernel
    basis_type: Optional[str]
        Type of the basis functions
    groups: Optional[int]
        Number of groups
    bias: Optional[bool]
        Whether to use bias

    Returns
    -------
    out: torch.Tensor
        Output tensor
Boris Bonev's avatar
Boris Bonev committed
341
342
343
344
345
346
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
Thorsten Kurth's avatar
Thorsten Kurth committed
347
        kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
348
        basis_type: Optional[str] = "piecewise linear",
Boris Bonev's avatar
Boris Bonev committed
349
350
351
352
353
        groups: Optional[int] = 1,
        bias: Optional[bool] = True,
    ):
        super().__init__()

354
        self.kernel_shape = kernel_shape
Boris Bonev's avatar
Boris Bonev committed
355

356
        # get the filter basis functions
357
        self.filter_basis = get_filter_basis(kernel_shape=kernel_shape, basis_type=basis_type)
Boris Bonev's avatar
Boris Bonev committed
358
359
360
361
362
363
364
365
366
367

        # groups
        self.groups = groups

        # weight tensor
        if in_channels % self.groups != 0:
            raise ValueError("Error, the number of input channels has to be an integer multiple of the group size")
        if out_channels % self.groups != 0:
            raise ValueError("Error, the number of output channels has to be an integer multiple of the group size")
        self.groupsize = in_channels // self.groups
Boris Bonev's avatar
Boris Bonev committed
368
        scale = math.sqrt(1.0 / self.groupsize / self.kernel_size)
Boris Bonev's avatar
Boris Bonev committed
369
370
371
372
373
374
375
        self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, self.kernel_size))

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channels))
        else:
            self.bias = None

376
377
378
379
    @property
    def kernel_size(self):
        return self.filter_basis.kernel_size

Boris Bonev's avatar
Boris Bonev committed
380
381
382
383
384
385
    @abc.abstractmethod
    def forward(self, x: torch.Tensor):
        raise NotImplementedError


class DiscreteContinuousConvS2(DiscreteContinuousConv):
386
    """
387
    Discrete-continuous (DISCO) convolutions on the 2-Sphere as described in [1].
388

apaaris's avatar
apaaris committed
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
    Parameters
    -----------
    in_channels: int
        Number of input channels
    out_channels: int
        Number of output channels
    in_shape: Tuple[int]
        Input shape of the convolution tensor
    out_shape: Tuple[int]
        Output shape of the convolution tensor
    kernel_shape: Union[int, Tuple[int], Tuple[int, int]]
        Shape of the kernel
    basis_type: Optional[str]
        Type of the basis functions
    basis_norm_mode: Optional[str]
        Mode for basis normalization
    groups: Optional[int]
        Number of groups
    grid_in: Optional[str]
        Input grid type
    grid_out: Optional[str]
        Output grid type
    bias: Optional[bool]
        Whether to use bias
    theta_cutoff: Optional[float]
        Theta cutoff for the filter basis functions

    Returns
    -------
    out: torch.Tensor
        Output tensor

    References
    ----------
423
424
425
426
427
428
429
430
431
    [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        in_shape: Tuple[int],
        out_shape: Tuple[int],
Thorsten Kurth's avatar
Thorsten Kurth committed
432
        kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
433
        basis_type: Optional[str] = "piecewise linear",
434
        basis_norm_mode: Optional[str] = "mean",
435
436
437
438
439
440
        groups: Optional[int] = 1,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        bias: Optional[bool] = True,
        theta_cutoff: Optional[float] = None,
    ):
441
        super().__init__(in_channels, out_channels, kernel_shape, basis_type, groups, bias)
442
443
444
445

        self.nlat_in, self.nlon_in = in_shape
        self.nlat_out, self.nlon_out = out_shape

446
447
448
        # make sure the p-shift works by checking that longitudes are divisible
        assert self.nlon_in % self.nlon_out == 0

Boris Bonev's avatar
Boris Bonev committed
449
        # heuristic to compute theta cutoff based on the bandlimit of the input field and overlaps of the basis functions
450
        if theta_cutoff is None:
Boris Bonev's avatar
Boris Bonev committed
451
            theta_cutoff = torch.pi / float(self.nlat_out - 1)
452
453
454
455

        if theta_cutoff <= 0.0:
            raise ValueError("Error, theta_cutoff has to be positive.")

Thorsten Kurth's avatar
Thorsten Kurth committed
456
        idx, vals, _ = _precompute_convolution_tensor_s2(
457
458
459
460
461
462
463
464
465
            in_shape,
            out_shape,
            self.filter_basis,
            grid_in=grid_in,
            grid_out=grid_out,
            theta_cutoff=theta_cutoff,
            transpose_normalization=False,
            basis_norm_mode=basis_norm_mode,
            merge_quadrature=True,
Boris Bonev's avatar
Boris Bonev committed
466
467
468
469
470
471
        )

        # sort the values
        ker_idx = idx[0, ...].contiguous()
        row_idx = idx[1, ...].contiguous()
        col_idx = idx[2, ...].contiguous()
472
        vals = vals.contiguous()
Boris Bonev's avatar
Boris Bonev committed
473

474
475
        if _cuda_extension_available:
            # preprocessed data-structure for GPU kernel
Thorsten Kurth's avatar
Thorsten Kurth committed
476
            roff_idx = preprocess_psi(self.kernel_size, self.nlat_out, ker_idx, row_idx, col_idx, vals).contiguous()
477
478
479
            self.register_buffer("psi_roff_idx", roff_idx, persistent=False)

        # save all datastructures
Boris Bonev's avatar
Boris Bonev committed
480
481
482
        self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
        self.register_buffer("psi_row_idx", row_idx, persistent=False)
        self.register_buffer("psi_col_idx", col_idx, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
483
        self.register_buffer("psi_vals", vals, persistent=False)
484

485
486
487
        # also store psi as COO matrix just in case for torch input
        self.psi = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out)

488
489
490
491
    def extra_repr(self):
        r"""
        Pretty print module
        """
492
        return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
493

Boris Bonev's avatar
Boris Bonev committed
494
495
    @property
    def psi_idx(self):
apaaris's avatar
apaaris committed
496
497
498
499
500
501
502
503
        """
        Get the convolution tensor index

        Returns
        -------
        psi_idx: torch.Tensor
            Convolution tensor index
        """
Boris Bonev's avatar
Boris Bonev committed
504
505
        return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()

apaaris's avatar
apaaris committed
506
507
508
509
510
511
512
513
514
515
516
517
    def get_psi(self):
        """
        Get the convolution tensor

        Returns
        -------
        psi: torch.Tensor
            Convolution tensor
        """
        psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in)).coalesce()
        return psi

Boris Bonev's avatar
Boris Bonev committed
518
    def forward(self, x: torch.Tensor) -> torch.Tensor:
519

Boris Bonev's avatar
Boris Bonev committed
520
521
522
523
        if x.is_cuda and _cuda_extension_available:
            x = _disco_s2_contraction_cuda(
                x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out
            )
524
        else:
Boris Bonev's avatar
Boris Bonev committed
525
526
            if x.is_cuda:
                warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
527
            x = _disco_s2_contraction_torch(x, self.psi.to(x.device), self.nlon_out)
528
529
530
531
532
533

        # extract shape
        B, C, K, H, W = x.shape
        x = x.reshape(B, self.groups, self.groupsize, K, H, W)

        # do weight multiplication
Thorsten Kurth's avatar
Thorsten Kurth committed
534
        out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
Boris Bonev's avatar
Boris Bonev committed
535
        out = out.reshape(B, -1, H, W)
536
537
538
539
540
541
542

        if self.bias is not None:
            out = out + self.bias.reshape(1, -1, 1, 1)

        return out


Boris Bonev's avatar
Boris Bonev committed
543
class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
544
    """
545
    Discrete-continuous (DISCO) transpose convolutions on the 2-Sphere as described in [1].
546

apaaris's avatar
apaaris committed
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
    Parameters
    -----------
    in_channels: int
        Number of input channels
    out_channels: int
        Number of output channels
    in_shape: Tuple[int]
        Input shape of the convolution tensor
    out_shape: Tuple[int]
        Output shape of the convolution tensor
    kernel_shape: Union[int, Tuple[int], Tuple[int, int]]
        Shape of the kernel
    basis_type: Optional[str]
        Type of the basis functions
    basis_norm_mode: Optional[str]
        Mode for basis normalization
    groups: Optional[int]
        Number of groups
    grid_in: Optional[str]
        Input grid type
    grid_out: Optional[str]
        Output grid type
    bias: Optional[bool]
        Whether to use bias
    theta_cutoff: Optional[float]
        Theta cutoff for the filter basis functions
    
    Returns
    --------
    out: torch.Tensor
        Output tensor

    References
    ----------
581
582
583
584
585
586
587
588
589
    [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        in_shape: Tuple[int],
        out_shape: Tuple[int],
Thorsten Kurth's avatar
Thorsten Kurth committed
590
        kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
591
        basis_type: Optional[str] = "piecewise linear",
592
        basis_norm_mode: Optional[str] = "mean",
593
594
595
596
597
598
        groups: Optional[int] = 1,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        bias: Optional[bool] = True,
        theta_cutoff: Optional[float] = None,
    ):
599
        super().__init__(in_channels, out_channels, kernel_shape, basis_type, groups, bias)
600
601
602
603

        self.nlat_in, self.nlon_in = in_shape
        self.nlat_out, self.nlon_out = out_shape

604
605
606
        # make sure the p-shift works by checking that longitudes are divisible
        assert self.nlon_out % self.nlon_in == 0

607
608
        # bandlimit
        if theta_cutoff is None:
Boris Bonev's avatar
Boris Bonev committed
609
            theta_cutoff = torch.pi / float(self.nlat_in - 1)
610
611
612
613

        if theta_cutoff <= 0.0:
            raise ValueError("Error, theta_cutoff has to be positive.")

614
        # switch in_shape and out_shape since we want the transpose convolution
Thorsten Kurth's avatar
Thorsten Kurth committed
615
        idx, vals, _ = _precompute_convolution_tensor_s2(
616
617
618
619
620
621
622
623
624
            out_shape,
            in_shape,
            self.filter_basis,
            grid_in=grid_out,
            grid_out=grid_in,
            theta_cutoff=theta_cutoff,
            transpose_normalization=True,
            basis_norm_mode=basis_norm_mode,
            merge_quadrature=True,
Boris Bonev's avatar
Boris Bonev committed
625
626
627
628
629
630
        )

        # sort the values
        ker_idx = idx[0, ...].contiguous()
        row_idx = idx[1, ...].contiguous()
        col_idx = idx[2, ...].contiguous()
631
        vals = vals.contiguous()
Boris Bonev's avatar
Boris Bonev committed
632

633
634
        if _cuda_extension_available:
            # preprocessed data-structure for GPU kernel
Thorsten Kurth's avatar
Thorsten Kurth committed
635
            roff_idx = preprocess_psi(self.kernel_size, self.nlat_in, ker_idx, row_idx, col_idx, vals).contiguous()
636
637
638
            self.register_buffer("psi_roff_idx", roff_idx, persistent=False)

        # save all datastructures
Boris Bonev's avatar
Boris Bonev committed
639
640
641
        self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
        self.register_buffer("psi_row_idx", row_idx, persistent=False)
        self.register_buffer("psi_col_idx", col_idx, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
642
        self.register_buffer("psi_vals", vals, persistent=False)
643

644
645
646
        # also store psi just in case
        self.psi_st = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, semi_transposed=True)

647
648
649
650
    def extra_repr(self):
        r"""
        Pretty print module
        """
651
        return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
652

Boris Bonev's avatar
Boris Bonev committed
653
654
655
656
    @property
    def psi_idx(self):
        return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()

apaaris's avatar
apaaris committed
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
    def get_psi(self, semi_transposed: bool = False):
        """
        Get the convolution tensor

        Parameters
        -----------
        semi_transposed: bool
            Whether to semi-transpose the convolution tensor

        Returns
        -------
        psi: torch.Tensor
            Convolution tensor
        """

        if semi_transposed:
            # we do a semi-transposition to faciliate the computation
            tout = self.psi_idx[2] // self.nlon_out
            pout = self.psi_idx[2] % self.nlon_out
            # flip the axis of longitudes
            pout = self.nlon_out - 1 - pout
            tin = self.psi_idx[1]
            idx = torch.stack([self.psi_idx[0], tout, tin * self.nlon_out + pout], dim=0)
            psi = torch.sparse_coo_tensor(idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_out)).coalesce()
        else:
            psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out)).coalesce()

        return psi

Boris Bonev's avatar
Boris Bonev committed
686
    def forward(self, x: torch.Tensor) -> torch.Tensor:
687
        
688
        # extract shape
Boris Bonev's avatar
Boris Bonev committed
689
        B, C, H, W = x.shape
690
691
692
        x = x.reshape(B, self.groups, self.groupsize, H, W)

        # do weight multiplication
Thorsten Kurth's avatar
Thorsten Kurth committed
693
        x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
Boris Bonev's avatar
Boris Bonev committed
694
        x = x.reshape(B, -1, x.shape[-3], H, W)
695

Boris Bonev's avatar
Boris Bonev committed
696
697
698
699
        if x.is_cuda and _cuda_extension_available:
            out = _disco_s2_transpose_contraction_cuda(
                x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out
            )
700
        else:
Boris Bonev's avatar
Boris Bonev committed
701
702
            if x.is_cuda:
                warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
703
            out = _disco_s2_transpose_contraction_torch(x, self.psi_st.to(x.device), self.nlon_out)
704
705
706
707
708

        if self.bias is not None:
            out = out + self.bias.reshape(1, -1, 1, 1)

        return out