convolution.py 25.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

Boris Bonev's avatar
Boris Bonev committed
32
import abc
33
from typing import List, Tuple, Union, Optional
Boris Bonev's avatar
Boris Bonev committed
34
from warnings import warn
35
36
37
38
39
40
41
42

import math

import torch
import torch.nn as nn

from functools import partial

Thorsten Kurth's avatar
Thorsten Kurth committed
43
44
from torch_harmonics.cache import lru_cache
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes
45
from torch_harmonics._disco_convolution import _get_psi, _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch
Boris Bonev's avatar
Boris Bonev committed
46
from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda
Thorsten Kurth's avatar
Thorsten Kurth committed
47
from torch_harmonics.filter_basis import FilterBasis, get_filter_basis
48

49
# import custom C++/CUDA extensions if available
Boris Bonev's avatar
Boris Bonev committed
50
try:
51
    from disco_helpers import preprocess_psi
Boris Bonev's avatar
Boris Bonev committed
52
    import disco_cuda_extension
53

Boris Bonev's avatar
Boris Bonev committed
54
55
56
57
58
59
    _cuda_extension_available = True
except ImportError as err:
    disco_cuda_extension = None
    _cuda_extension_available = False


60
def _normalize_convolution_tensor_s2(
61
    psi_idx, psi_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=False, basis_norm_mode="mean", merge_quadrature=False, eps=1e-9
62
):
63
64
65
66
67
68
    """Normalizes convolution tensor values based on specified normalization mode.
    
    This function applies different normalization strategies to the convolution tensor
    values based on the basis_norm_mode parameter. It can normalize individual basis
    functions, compute mean normalization across all basis functions, or use support
    weights. The function also optionally merges quadrature weights into the tensor.
Andrea Paris's avatar
Andrea Paris committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101

    Parameters
    -----------
    psi_idx: torch.Tensor
        Index tensor for the sparse convolution tensor.
    psi_vals: torch.Tensor
        Value tensor for the sparse convolution tensor.
    in_shape: Tuple[int]
        Tuple of (nlat_in, nlon_in) representing input grid dimensions.
    out_shape: Tuple[int]
        Tuple of (nlat_out, nlon_out) representing output grid dimensions.
    kernel_size: int
        Number of kernel basis functions.
    quad_weights: torch.Tensor
        Quadrature weights for numerical integration.
    transpose_normalization: bool
        If True, applies normalization in transpose direction.
    basis_norm_mode: str
        Normalization mode, one of ["none", "individual", "mean", "support"].
    merge_quadrature: bool
        If True, multiplies values by quadrature weights.
    eps: float
        Small epsilon value to prevent division by zero.

    Returns
    -------
    torch.Tensor
        Normalized convolution tensor values.

    Raises
    ------
    ValueError
        If basis_norm_mode is not one of the supported modes.
Boris Bonev's avatar
Boris Bonev committed
102
103
    """

Thorsten Kurth's avatar
Thorsten Kurth committed
104
105
106
107
    # exit here if no normalization is needed
    if basis_norm_mode == "none":
        return psi_vals

108
109
    # reshape the indices implicitly to be ikernel, out_shape[0], in_shape[0], in_shape[1]
    idx = torch.stack([psi_idx[0], psi_idx[1], psi_idx[2] // in_shape[1], psi_idx[2] % in_shape[1]], dim=0)
Boris Bonev's avatar
Boris Bonev committed
110

111
112
    # getting indices for adressing kernels, input and output latitudes
    ikernel = idx[0]
Boris Bonev's avatar
Boris Bonev committed
113
114

    if transpose_normalization:
115
116
117
118
119
        ilat_out = idx[2]
        ilat_in = idx[1]
        # here we are deliberately swapping input and output shapes to handle transpose normalization with the same code
        nlat_out = in_shape[0]
        correction_factor = out_shape[1] / in_shape[1]
Boris Bonev's avatar
Boris Bonev committed
120
    else:
121
122
123
124
125
126
127
128
        ilat_out = idx[1]
        ilat_in = idx[2]
        nlat_out = out_shape[0]

    # get the quadrature weights
    q = quad_weights[ilat_in].reshape(-1)

    # buffer to store intermediate values
129
130
    vnorm = torch.zeros(kernel_size, nlat_out, device=psi_vals.device)
    support = torch.zeros(kernel_size, nlat_out, device=psi_vals.device)
131
132
133
134
135
136
137
138

    # loop through dimensions to compute the norms
    for ik in range(kernel_size):
        for ilat in range(nlat_out):

            # find indices corresponding to the given output latitude and kernel basis function
            iidx = torch.argwhere((ikernel == ik) & (ilat_out == ilat))

139
140
141
            # compute the 1-norm
            # vnorm[ik, ilat] = torch.sqrt(torch.sum(psi_vals[iidx].abs().pow(2) * q[iidx]))
            vnorm[ik, ilat] = torch.sum(psi_vals[iidx].abs() * q[iidx])
142

143
144
145
146
            # compute the support
            support[ik, ilat] = torch.sum(q[iidx])


147
148
149
150
151
152
153
154
155
156
    # loop over values and renormalize
    for ik in range(kernel_size):
        for ilat in range(nlat_out):

            iidx = torch.argwhere((ikernel == ik) & (ilat_out == ilat))

            if basis_norm_mode == "individual":
                val = vnorm[ik, ilat]
            elif basis_norm_mode == "mean":
                val = vnorm[ik, :].mean()
157
158
            elif basis_norm_mode == "support":
                val = support[ik, ilat]
159
160
161
162
163
            elif basis_norm_mode == "none":
                val = 1.0
            else:
                raise ValueError(f"Unknown basis normalization mode {basis_norm_mode}.")

164
165
            psi_vals[iidx] = psi_vals[iidx] / (val + eps)

166
            if merge_quadrature:
167
                psi_vals[iidx] = psi_vals[iidx] * q[iidx]
168
169
170
171


    if transpose_normalization and merge_quadrature:
        psi_vals = psi_vals / correction_factor
Boris Bonev's avatar
Boris Bonev committed
172
173
174
175

    return psi_vals


Thorsten Kurth's avatar
Thorsten Kurth committed
176
@lru_cache(typed=True, copy=True)
Boris Bonev's avatar
Boris Bonev committed
177
def _precompute_convolution_tensor_s2(
Thorsten Kurth's avatar
Thorsten Kurth committed
178
179
180
181
182
183
184
185
186
187
    in_shape: Tuple[int],
    out_shape: Tuple[int],
    filter_basis: FilterBasis,
    grid_in: Optional[str]="equiangular",
    grid_out: Optional[str]="equiangular",
    theta_cutoff: Optional[float]=0.01 * math.pi,
    theta_eps: Optional[float]=1e-3,
    transpose_normalization: Optional[bool]=False,
    basis_norm_mode: Optional[str]="mean",
    merge_quadrature: Optional[bool]=False,
Boris Bonev's avatar
Boris Bonev committed
188
):
189
190
191
    """
    Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
    Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al.
192
193
194
195
196
    The output tensor has shape kernel_shape x nlat_out x (nlat_in * nlon_in).

    The rotation of the Euler angles uses the YZY convention, which applied to the northpole $(0,0,1)^T$ yields
    $$
    Y(\alpha) Z(\beta) Y(\gamma) n =
Boris Bonev's avatar
Boris Bonev committed
197
        {\begin{bmatrix}
198
199
200
201
202
            \cos(\gamma)\sin(\alpha) + \cos(\alpha)\cos(\beta)\sin(\gamma) \\
            \sin(\beta)\sin(\gamma) \\
            \cos(\alpha)\cos(\gamma)-\cos(\beta)\sin(\alpha)\sin(\gamma)
        \end{bmatrix}}
    $$
apaaris's avatar
apaaris committed
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233

    Parameters
    -----------
    in_shape: Tuple[int]
        Input shape of the convolution tensor
    out_shape: Tuple[int]
        Output shape of the convolution tensor
    filter_basis: FilterBasis
        Filter basis functions
    grid_in: str
        Input grid type
    grid_out: str
        Output grid type
    theta_cutoff: float
        Theta cutoff for the filter basis functions
    theta_eps: float
        Epsilon for the theta cutoff
    transpose_normalization: bool
        Whether to normalize the convolution tensor in the transpose direction
    basis_norm_mode: str
        Mode for basis normalization
    merge_quadrature: bool
        Whether to merge the quadrature weights into the convolution tensor

    Returns
    -------
    out_idx: torch.Tensor
        Index tensor of the convolution tensor
    out_vals: torch.Tensor
        Values tensor of the convolution tensor

234
235
236
237
238
    """

    assert len(in_shape) == 2
    assert len(out_shape) == 2

239
    kernel_size = filter_basis.kernel_size
240
241
242
243

    nlat_in, nlon_in = in_shape
    nlat_out, nlon_out = out_shape

244
    # precompute input and output grids
Boris Bonev's avatar
Boris Bonev committed
245
246
    lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in)
    lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out)
247
248

    # compute the phi differences
249
    # It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
Thorsten Kurth's avatar
Thorsten Kurth committed
250
    lons_in = _precompute_longitudes(nlon_in)
251

252
253
    # compute quadrature weights and merge them into the convolution tensor.
    # These quadrature integrate to 1 over the sphere.
254
    if transpose_normalization:
Thorsten Kurth's avatar
Thorsten Kurth committed
255
        quad_weights = wout.reshape(-1, 1) / nlon_in / 2.0
256
    else:
Thorsten Kurth's avatar
Thorsten Kurth committed
257
        quad_weights = win.reshape(-1, 1) / nlon_in / 2.0
258
259
260

    # effective theta cutoff if multiplied with a fudge factor to avoid aliasing with grid width (especially near poles)
    theta_cutoff_eff = (1.0 + theta_eps) * theta_cutoff
261

262
263
    out_idx = []
    out_vals = []
Thorsten Kurth's avatar
Thorsten Kurth committed
264
265
266
267
268
269
270
271
272
273
274

    beta = lons_in
    gamma = lats_in.reshape(-1, 1)

    # compute trigs
    cbeta = torch.cos(beta)
    sbeta = torch.sin(beta)
    cgamma = torch.cos(gamma)
    sgamma = torch.sin(gamma)

    # compute row offsets
275
    out_roff = torch.zeros(nlat_out + 1, dtype=torch.int64, device=lons_in.device)
Thorsten Kurth's avatar
Thorsten Kurth committed
276
    out_roff[0] = 0
277
    for t in range(nlat_out):
278
        # the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis
Boris Bonev's avatar
Boris Bonev committed
279
        alpha = -lats_out[t]
280
281

        # compute cartesian coordinates of the rotated position
282
283
        # This uses the YZY convention of Euler angles, where the last angle (alpha) is a passive rotation,
        # and therefore applied with a negative sign
Thorsten Kurth's avatar
Thorsten Kurth committed
284
285
286
        x = torch.cos(alpha) * cbeta * sgamma + cgamma * torch.sin(alpha)
        y = sbeta * sgamma
        z = -cbeta * torch.sin(alpha) * sgamma + torch.cos(alpha) * cgamma
Boris Bonev's avatar
Boris Bonev committed
287

288
        # normalization is important to avoid NaNs when arccos and atan are applied
289
        # this can otherwise lead to spurious artifacts in the solution
Boris Bonev's avatar
Boris Bonev committed
290
        norm = torch.sqrt(x * x + y * y + z * z)
291
292
293
294
        x = x / norm
        y = y / norm
        z = z / norm

Boris Bonev's avatar
Boris Bonev committed
295
        # compute spherical coordinates, where phi needs to fall into the [0, 2pi) range
296
        theta = torch.arccos(z)
297
298
        phi = torch.arctan2(y, x)
        phi = torch.where(phi < 0.0, phi + 2 * torch.pi, phi)
299
300

        # find the indices where the rotated position falls into the support of the kernel
301
        iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff_eff)
302
303
304
305

        # add the output latitude and reshape such that psi has dimensions kernel_shape x nlat_out x (nlat_in*nlon_in)
        idx = torch.stack([iidx[:, 0], t * torch.ones_like(iidx[:, 0]), iidx[:, 1] * nlon_in + iidx[:, 2]], dim=0)

Thorsten Kurth's avatar
Thorsten Kurth committed
306
        # append indices and values to the COO datastructure, compute row offsets
307
308
        out_idx.append(idx)
        out_vals.append(vals)
Thorsten Kurth's avatar
Thorsten Kurth committed
309
        out_roff[t + 1] = out_roff[t] + iidx.shape[0]
310
311

    # concatenate the indices and values
312
313
    out_idx = torch.cat(out_idx, dim=-1)
    out_vals = torch.cat(out_vals, dim=-1)
314

315
    out_vals = _normalize_convolution_tensor_s2(
316
317
318
319
320
321
322
323
324
        out_idx,
        out_vals,
        in_shape,
        out_shape,
        kernel_size,
        quad_weights,
        transpose_normalization=transpose_normalization,
        basis_norm_mode=basis_norm_mode,
        merge_quadrature=merge_quadrature,
325
    )
Boris Bonev's avatar
Boris Bonev committed
326

327
328
329
    out_idx = out_idx.contiguous()
    out_vals = out_vals.to(dtype=torch.float32).contiguous()

Thorsten Kurth's avatar
Thorsten Kurth committed
330
    return out_idx, out_vals, out_roff
Boris Bonev's avatar
Boris Bonev committed
331
332
333
334


class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta):
    """
335
    Abstract base class for discrete-continuous convolutions
apaaris's avatar
apaaris committed
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355

    Parameters
    -----------
    in_channels: int
        Number of input channels
    out_channels: int
        Number of output channels
    kernel_shape: Union[int, Tuple[int], Tuple[int, int]]
        Shape of the kernel
    basis_type: Optional[str]
        Type of the basis functions
    groups: Optional[int]
        Number of groups
    bias: Optional[bool]
        Whether to use bias

    Returns
    -------
    out: torch.Tensor
        Output tensor
Boris Bonev's avatar
Boris Bonev committed
356
357
358
359
360
361
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
Thorsten Kurth's avatar
Thorsten Kurth committed
362
        kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
363
        basis_type: Optional[str] = "piecewise linear",
Boris Bonev's avatar
Boris Bonev committed
364
365
366
367
368
        groups: Optional[int] = 1,
        bias: Optional[bool] = True,
    ):
        super().__init__()

369
        self.kernel_shape = kernel_shape
Boris Bonev's avatar
Boris Bonev committed
370

371
        # get the filter basis functions
372
        self.filter_basis = get_filter_basis(kernel_shape=kernel_shape, basis_type=basis_type)
Boris Bonev's avatar
Boris Bonev committed
373
374
375
376
377
378
379
380
381
382

        # groups
        self.groups = groups

        # weight tensor
        if in_channels % self.groups != 0:
            raise ValueError("Error, the number of input channels has to be an integer multiple of the group size")
        if out_channels % self.groups != 0:
            raise ValueError("Error, the number of output channels has to be an integer multiple of the group size")
        self.groupsize = in_channels // self.groups
Boris Bonev's avatar
Boris Bonev committed
383
        scale = math.sqrt(1.0 / self.groupsize / self.kernel_size)
Boris Bonev's avatar
Boris Bonev committed
384
385
386
387
388
389
390
        self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, self.kernel_size))

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channels))
        else:
            self.bias = None

391
392
393
394
    @property
    def kernel_size(self):
        return self.filter_basis.kernel_size

Boris Bonev's avatar
Boris Bonev committed
395
396
397
398
399
400
    @abc.abstractmethod
    def forward(self, x: torch.Tensor):
        raise NotImplementedError


class DiscreteContinuousConvS2(DiscreteContinuousConv):
401
    """
402
    Discrete-continuous (DISCO) convolutions on the 2-Sphere as described in [1].
403

apaaris's avatar
apaaris committed
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
    Parameters
    -----------
    in_channels: int
        Number of input channels
    out_channels: int
        Number of output channels
    in_shape: Tuple[int]
        Input shape of the convolution tensor
    out_shape: Tuple[int]
        Output shape of the convolution tensor
    kernel_shape: Union[int, Tuple[int], Tuple[int, int]]
        Shape of the kernel
    basis_type: Optional[str]
        Type of the basis functions
    basis_norm_mode: Optional[str]
        Mode for basis normalization
    groups: Optional[int]
        Number of groups
    grid_in: Optional[str]
        Input grid type
    grid_out: Optional[str]
        Output grid type
    bias: Optional[bool]
        Whether to use bias
    theta_cutoff: Optional[float]
        Theta cutoff for the filter basis functions

    Returns
    -------
    out: torch.Tensor
        Output tensor

    References
    ----------
438
439
440
441
442
443
444
445
446
    [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        in_shape: Tuple[int],
        out_shape: Tuple[int],
Thorsten Kurth's avatar
Thorsten Kurth committed
447
        kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
448
        basis_type: Optional[str] = "piecewise linear",
449
        basis_norm_mode: Optional[str] = "mean",
450
451
452
453
454
455
        groups: Optional[int] = 1,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        bias: Optional[bool] = True,
        theta_cutoff: Optional[float] = None,
    ):
456
        super().__init__(in_channels, out_channels, kernel_shape, basis_type, groups, bias)
457
458
459
460

        self.nlat_in, self.nlon_in = in_shape
        self.nlat_out, self.nlon_out = out_shape

461
462
463
        # make sure the p-shift works by checking that longitudes are divisible
        assert self.nlon_in % self.nlon_out == 0

Boris Bonev's avatar
Boris Bonev committed
464
        # heuristic to compute theta cutoff based on the bandlimit of the input field and overlaps of the basis functions
465
        if theta_cutoff is None:
Boris Bonev's avatar
Boris Bonev committed
466
            theta_cutoff = torch.pi / float(self.nlat_out - 1)
467
468
469
470

        if theta_cutoff <= 0.0:
            raise ValueError("Error, theta_cutoff has to be positive.")

Thorsten Kurth's avatar
Thorsten Kurth committed
471
        idx, vals, _ = _precompute_convolution_tensor_s2(
472
473
474
475
476
477
478
479
480
            in_shape,
            out_shape,
            self.filter_basis,
            grid_in=grid_in,
            grid_out=grid_out,
            theta_cutoff=theta_cutoff,
            transpose_normalization=False,
            basis_norm_mode=basis_norm_mode,
            merge_quadrature=True,
Boris Bonev's avatar
Boris Bonev committed
481
482
483
484
485
486
        )

        # sort the values
        ker_idx = idx[0, ...].contiguous()
        row_idx = idx[1, ...].contiguous()
        col_idx = idx[2, ...].contiguous()
487
        vals = vals.contiguous()
Boris Bonev's avatar
Boris Bonev committed
488

489
490
        if _cuda_extension_available:
            # preprocessed data-structure for GPU kernel
Thorsten Kurth's avatar
Thorsten Kurth committed
491
            roff_idx = preprocess_psi(self.kernel_size, self.nlat_out, ker_idx, row_idx, col_idx, vals).contiguous()
492
493
494
            self.register_buffer("psi_roff_idx", roff_idx, persistent=False)

        # save all datastructures
Boris Bonev's avatar
Boris Bonev committed
495
496
497
        self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
        self.register_buffer("psi_row_idx", row_idx, persistent=False)
        self.register_buffer("psi_col_idx", col_idx, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
498
        self.register_buffer("psi_vals", vals, persistent=False)
499

500
501
502
        # also store psi as COO matrix just in case for torch input
        self.psi = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out)

503
    def extra_repr(self):
504
        return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
505

Boris Bonev's avatar
Boris Bonev committed
506
507
    @property
    def psi_idx(self):
apaaris's avatar
apaaris committed
508
509
510
511
512
513
514
515
        """
        Get the convolution tensor index

        Returns
        -------
        psi_idx: torch.Tensor
            Convolution tensor index
        """
Boris Bonev's avatar
Boris Bonev committed
516
517
        return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()

apaaris's avatar
apaaris committed
518
519
520
521
522
523
524
525
526
527
528
529
    def get_psi(self):
        """
        Get the convolution tensor

        Returns
        -------
        psi: torch.Tensor
            Convolution tensor
        """
        psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in)).coalesce()
        return psi

Boris Bonev's avatar
Boris Bonev committed
530
    def forward(self, x: torch.Tensor) -> torch.Tensor:
531

Boris Bonev's avatar
Boris Bonev committed
532
533
534
535
        if x.is_cuda and _cuda_extension_available:
            x = _disco_s2_contraction_cuda(
                x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out
            )
536
        else:
Boris Bonev's avatar
Boris Bonev committed
537
538
            if x.is_cuda:
                warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
539
            x = _disco_s2_contraction_torch(x, self.psi.to(x.device), self.nlon_out)
540
541
542
543
544
545

        # extract shape
        B, C, K, H, W = x.shape
        x = x.reshape(B, self.groups, self.groupsize, K, H, W)

        # do weight multiplication
Thorsten Kurth's avatar
Thorsten Kurth committed
546
        out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
Boris Bonev's avatar
Boris Bonev committed
547
        out = out.reshape(B, -1, H, W)
548
549
550
551
552
553
554

        if self.bias is not None:
            out = out + self.bias.reshape(1, -1, 1, 1)

        return out


Boris Bonev's avatar
Boris Bonev committed
555
class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
556
    """
557
    Discrete-continuous (DISCO) transpose convolutions on the 2-Sphere as described in [1].
558

apaaris's avatar
apaaris committed
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
    Parameters
    -----------
    in_channels: int
        Number of input channels
    out_channels: int
        Number of output channels
    in_shape: Tuple[int]
        Input shape of the convolution tensor
    out_shape: Tuple[int]
        Output shape of the convolution tensor
    kernel_shape: Union[int, Tuple[int], Tuple[int, int]]
        Shape of the kernel
    basis_type: Optional[str]
        Type of the basis functions
    basis_norm_mode: Optional[str]
        Mode for basis normalization
    groups: Optional[int]
        Number of groups
    grid_in: Optional[str]
        Input grid type
    grid_out: Optional[str]
        Output grid type
    bias: Optional[bool]
        Whether to use bias
    theta_cutoff: Optional[float]
        Theta cutoff for the filter basis functions
    
    Returns
    --------
    out: torch.Tensor
        Output tensor

    References
    ----------
593
594
595
596
597
598
599
600
601
    [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        in_shape: Tuple[int],
        out_shape: Tuple[int],
Thorsten Kurth's avatar
Thorsten Kurth committed
602
        kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
603
        basis_type: Optional[str] = "piecewise linear",
604
        basis_norm_mode: Optional[str] = "mean",
605
606
607
608
609
610
        groups: Optional[int] = 1,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        bias: Optional[bool] = True,
        theta_cutoff: Optional[float] = None,
    ):
611
        super().__init__(in_channels, out_channels, kernel_shape, basis_type, groups, bias)
612
613
614
615

        self.nlat_in, self.nlon_in = in_shape
        self.nlat_out, self.nlon_out = out_shape

616
617
618
        # make sure the p-shift works by checking that longitudes are divisible
        assert self.nlon_out % self.nlon_in == 0

619
620
        # bandlimit
        if theta_cutoff is None:
Boris Bonev's avatar
Boris Bonev committed
621
            theta_cutoff = torch.pi / float(self.nlat_in - 1)
622
623
624
625

        if theta_cutoff <= 0.0:
            raise ValueError("Error, theta_cutoff has to be positive.")

626
        # switch in_shape and out_shape since we want the transpose convolution
Thorsten Kurth's avatar
Thorsten Kurth committed
627
        idx, vals, _ = _precompute_convolution_tensor_s2(
628
629
630
631
632
633
634
635
636
            out_shape,
            in_shape,
            self.filter_basis,
            grid_in=grid_out,
            grid_out=grid_in,
            theta_cutoff=theta_cutoff,
            transpose_normalization=True,
            basis_norm_mode=basis_norm_mode,
            merge_quadrature=True,
Boris Bonev's avatar
Boris Bonev committed
637
638
639
640
641
642
        )

        # sort the values
        ker_idx = idx[0, ...].contiguous()
        row_idx = idx[1, ...].contiguous()
        col_idx = idx[2, ...].contiguous()
643
        vals = vals.contiguous()
Boris Bonev's avatar
Boris Bonev committed
644

645
646
        if _cuda_extension_available:
            # preprocessed data-structure for GPU kernel
Thorsten Kurth's avatar
Thorsten Kurth committed
647
            roff_idx = preprocess_psi(self.kernel_size, self.nlat_in, ker_idx, row_idx, col_idx, vals).contiguous()
648
649
650
            self.register_buffer("psi_roff_idx", roff_idx, persistent=False)

        # save all datastructures
Boris Bonev's avatar
Boris Bonev committed
651
652
653
        self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
        self.register_buffer("psi_row_idx", row_idx, persistent=False)
        self.register_buffer("psi_col_idx", col_idx, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
654
        self.register_buffer("psi_vals", vals, persistent=False)
655

656
657
658
        # also store psi just in case
        self.psi_st = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, semi_transposed=True)

659
    def extra_repr(self):
660
        return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
661

Boris Bonev's avatar
Boris Bonev committed
662
663
664
665
    @property
    def psi_idx(self):
        return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()

apaaris's avatar
apaaris committed
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
    def get_psi(self, semi_transposed: bool = False):
        if semi_transposed:
            # we do a semi-transposition to faciliate the computation
            tout = self.psi_idx[2] // self.nlon_out
            pout = self.psi_idx[2] % self.nlon_out
            # flip the axis of longitudes
            pout = self.nlon_out - 1 - pout
            tin = self.psi_idx[1]
            idx = torch.stack([self.psi_idx[0], tout, tin * self.nlon_out + pout], dim=0)
            psi = torch.sparse_coo_tensor(idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_out)).coalesce()
        else:
            psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out)).coalesce()

        return psi

Boris Bonev's avatar
Boris Bonev committed
681
    def forward(self, x: torch.Tensor) -> torch.Tensor:
682
        
683
        # extract shape
Boris Bonev's avatar
Boris Bonev committed
684
        B, C, H, W = x.shape
685
686
687
        x = x.reshape(B, self.groups, self.groupsize, H, W)

        # do weight multiplication
Thorsten Kurth's avatar
Thorsten Kurth committed
688
        x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
Boris Bonev's avatar
Boris Bonev committed
689
        x = x.reshape(B, -1, x.shape[-3], H, W)
690

Boris Bonev's avatar
Boris Bonev committed
691
692
693
694
        if x.is_cuda and _cuda_extension_available:
            out = _disco_s2_transpose_contraction_cuda(
                x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out
            )
695
        else:
Boris Bonev's avatar
Boris Bonev committed
696
697
            if x.is_cuda:
                warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
698
            out = _disco_s2_transpose_contraction_torch(x, self.psi_st.to(x.device), self.nlon_out)
699
700
701
702
703

        if self.bias is not None:
            out = out + self.bias.reshape(1, -1, 1, 1)

        return out