convolution.py 26 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

Boris Bonev's avatar
Boris Bonev committed
32
import abc
33
from typing import List, Tuple, Union, Optional
Boris Bonev's avatar
Boris Bonev committed
34
from warnings import warn
35
36
37
38
39
40
41
42

import math

import torch
import torch.nn as nn

from functools import partial

Thorsten Kurth's avatar
Thorsten Kurth committed
43
44
from torch_harmonics.cache import lru_cache
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes
45
from torch_harmonics._disco_convolution import _get_psi, _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch
Boris Bonev's avatar
Boris Bonev committed
46
from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda
Thorsten Kurth's avatar
Thorsten Kurth committed
47
from torch_harmonics.filter_basis import FilterBasis, get_filter_basis
48

49
# import custom C++/CUDA extensions if available
Boris Bonev's avatar
Boris Bonev committed
50
try:
51
    from disco_helpers import preprocess_psi
Boris Bonev's avatar
Boris Bonev committed
52
    import disco_cuda_extension
53

Boris Bonev's avatar
Boris Bonev committed
54
55
56
57
58
59
    _cuda_extension_available = True
except ImportError as err:
    disco_cuda_extension = None
    _cuda_extension_available = False


60
def _normalize_convolution_tensor_s2(
61
    psi_idx, psi_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=False, basis_norm_mode="mean", merge_quadrature=False, eps=1e-9
62
):
63
64
65
66
67
68
    """Normalizes convolution tensor values based on specified normalization mode.
    
    This function applies different normalization strategies to the convolution tensor
    values based on the basis_norm_mode parameter. It can normalize individual basis
    functions, compute mean normalization across all basis functions, or use support
    weights. The function also optionally merges quadrature weights into the tensor.
Andrea Paris's avatar
Andrea Paris committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101

    Parameters
    -----------
    psi_idx: torch.Tensor
        Index tensor for the sparse convolution tensor.
    psi_vals: torch.Tensor
        Value tensor for the sparse convolution tensor.
    in_shape: Tuple[int]
        Tuple of (nlat_in, nlon_in) representing input grid dimensions.
    out_shape: Tuple[int]
        Tuple of (nlat_out, nlon_out) representing output grid dimensions.
    kernel_size: int
        Number of kernel basis functions.
    quad_weights: torch.Tensor
        Quadrature weights for numerical integration.
    transpose_normalization: bool
        If True, applies normalization in transpose direction.
    basis_norm_mode: str
        Normalization mode, one of ["none", "individual", "mean", "support"].
    merge_quadrature: bool
        If True, multiplies values by quadrature weights.
    eps: float
        Small epsilon value to prevent division by zero.

    Returns
    -------
    torch.Tensor
        Normalized convolution tensor values.

    Raises
    ------
    ValueError
        If basis_norm_mode is not one of the supported modes.
Boris Bonev's avatar
Boris Bonev committed
102
103
    """

Thorsten Kurth's avatar
Thorsten Kurth committed
104
105
106
107
    # exit here if no normalization is needed
    if basis_norm_mode == "none":
        return psi_vals

108
109
    # reshape the indices implicitly to be ikernel, out_shape[0], in_shape[0], in_shape[1]
    idx = torch.stack([psi_idx[0], psi_idx[1], psi_idx[2] // in_shape[1], psi_idx[2] % in_shape[1]], dim=0)
Boris Bonev's avatar
Boris Bonev committed
110

111
112
    # getting indices for adressing kernels, input and output latitudes
    ikernel = idx[0]
Boris Bonev's avatar
Boris Bonev committed
113
114

    if transpose_normalization:
115
116
117
118
119
        ilat_out = idx[2]
        ilat_in = idx[1]
        # here we are deliberately swapping input and output shapes to handle transpose normalization with the same code
        nlat_out = in_shape[0]
        correction_factor = out_shape[1] / in_shape[1]
Boris Bonev's avatar
Boris Bonev committed
120
    else:
121
122
123
124
125
126
127
128
        ilat_out = idx[1]
        ilat_in = idx[2]
        nlat_out = out_shape[0]

    # get the quadrature weights
    q = quad_weights[ilat_in].reshape(-1)

    # buffer to store intermediate values
129
130
    vnorm = torch.zeros(kernel_size, nlat_out, device=psi_vals.device)
    support = torch.zeros(kernel_size, nlat_out, device=psi_vals.device)
131
132
133
134
135
136
137
138

    # loop through dimensions to compute the norms
    for ik in range(kernel_size):
        for ilat in range(nlat_out):

            # find indices corresponding to the given output latitude and kernel basis function
            iidx = torch.argwhere((ikernel == ik) & (ilat_out == ilat))

139
140
141
            # compute the 1-norm
            # vnorm[ik, ilat] = torch.sqrt(torch.sum(psi_vals[iidx].abs().pow(2) * q[iidx]))
            vnorm[ik, ilat] = torch.sum(psi_vals[iidx].abs() * q[iidx])
142

143
144
145
146
            # compute the support
            support[ik, ilat] = torch.sum(q[iidx])


147
148
149
150
151
152
153
154
155
156
    # loop over values and renormalize
    for ik in range(kernel_size):
        for ilat in range(nlat_out):

            iidx = torch.argwhere((ikernel == ik) & (ilat_out == ilat))

            if basis_norm_mode == "individual":
                val = vnorm[ik, ilat]
            elif basis_norm_mode == "mean":
                val = vnorm[ik, :].mean()
157
158
            elif basis_norm_mode == "support":
                val = support[ik, ilat]
159
160
161
162
163
            elif basis_norm_mode == "none":
                val = 1.0
            else:
                raise ValueError(f"Unknown basis normalization mode {basis_norm_mode}.")

164
165
            psi_vals[iidx] = psi_vals[iidx] / (val + eps)

166
            if merge_quadrature:
167
                psi_vals[iidx] = psi_vals[iidx] * q[iidx]
168
169
170
171


    if transpose_normalization and merge_quadrature:
        psi_vals = psi_vals / correction_factor
Boris Bonev's avatar
Boris Bonev committed
172
173
174
175

    return psi_vals


Thorsten Kurth's avatar
Thorsten Kurth committed
176
@lru_cache(typed=True, copy=True)
Boris Bonev's avatar
Boris Bonev committed
177
def _precompute_convolution_tensor_s2(
Thorsten Kurth's avatar
Thorsten Kurth committed
178
179
180
181
182
183
184
185
186
187
    in_shape: Tuple[int],
    out_shape: Tuple[int],
    filter_basis: FilterBasis,
    grid_in: Optional[str]="equiangular",
    grid_out: Optional[str]="equiangular",
    theta_cutoff: Optional[float]=0.01 * math.pi,
    theta_eps: Optional[float]=1e-3,
    transpose_normalization: Optional[bool]=False,
    basis_norm_mode: Optional[str]="mean",
    merge_quadrature: Optional[bool]=False,
Boris Bonev's avatar
Boris Bonev committed
188
):
189
190
191
    """
    Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
    Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al.
192
193
194
195
196
    The output tensor has shape kernel_shape x nlat_out x (nlat_in * nlon_in).

    The rotation of the Euler angles uses the YZY convention, which applied to the northpole $(0,0,1)^T$ yields
    $$
    Y(\alpha) Z(\beta) Y(\gamma) n =
Boris Bonev's avatar
Boris Bonev committed
197
        {\begin{bmatrix}
198
199
200
201
202
            \cos(\gamma)\sin(\alpha) + \cos(\alpha)\cos(\beta)\sin(\gamma) \\
            \sin(\beta)\sin(\gamma) \\
            \cos(\alpha)\cos(\gamma)-\cos(\beta)\sin(\alpha)\sin(\gamma)
        \end{bmatrix}}
    $$
apaaris's avatar
apaaris committed
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233

    Parameters
    -----------
    in_shape: Tuple[int]
        Input shape of the convolution tensor
    out_shape: Tuple[int]
        Output shape of the convolution tensor
    filter_basis: FilterBasis
        Filter basis functions
    grid_in: str
        Input grid type
    grid_out: str
        Output grid type
    theta_cutoff: float
        Theta cutoff for the filter basis functions
    theta_eps: float
        Epsilon for the theta cutoff
    transpose_normalization: bool
        Whether to normalize the convolution tensor in the transpose direction
    basis_norm_mode: str
        Mode for basis normalization
    merge_quadrature: bool
        Whether to merge the quadrature weights into the convolution tensor

    Returns
    -------
    out_idx: torch.Tensor
        Index tensor of the convolution tensor
    out_vals: torch.Tensor
        Values tensor of the convolution tensor

234
235
236
237
238
    """

    assert len(in_shape) == 2
    assert len(out_shape) == 2

239
    kernel_size = filter_basis.kernel_size
240
241
242
243

    nlat_in, nlon_in = in_shape
    nlat_out, nlon_out = out_shape

244
    # precompute input and output grids
Boris Bonev's avatar
Boris Bonev committed
245
246
    lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in)
    lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out)
247
248

    # compute the phi differences
249
    # It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
Thorsten Kurth's avatar
Thorsten Kurth committed
250
    lons_in = _precompute_longitudes(nlon_in)
251

252
253
    # compute quadrature weights and merge them into the convolution tensor.
    # These quadrature integrate to 1 over the sphere.
254
    if transpose_normalization:
Thorsten Kurth's avatar
Thorsten Kurth committed
255
        quad_weights = wout.reshape(-1, 1) / nlon_in / 2.0
256
    else:
Thorsten Kurth's avatar
Thorsten Kurth committed
257
        quad_weights = win.reshape(-1, 1) / nlon_in / 2.0
258
259
260

    # effective theta cutoff if multiplied with a fudge factor to avoid aliasing with grid width (especially near poles)
    theta_cutoff_eff = (1.0 + theta_eps) * theta_cutoff
261

262
263
    out_idx = []
    out_vals = []
Thorsten Kurth's avatar
Thorsten Kurth committed
264
265
266
267
268
269
270
271
272
273
274

    beta = lons_in
    gamma = lats_in.reshape(-1, 1)

    # compute trigs
    cbeta = torch.cos(beta)
    sbeta = torch.sin(beta)
    cgamma = torch.cos(gamma)
    sgamma = torch.sin(gamma)

    # compute row offsets
275
    out_roff = torch.zeros(nlat_out + 1, dtype=torch.int64, device=lons_in.device)
Thorsten Kurth's avatar
Thorsten Kurth committed
276
    out_roff[0] = 0
277
    for t in range(nlat_out):
278
        # the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis
Boris Bonev's avatar
Boris Bonev committed
279
        alpha = -lats_out[t]
280
281

        # compute cartesian coordinates of the rotated position
282
283
        # This uses the YZY convention of Euler angles, where the last angle (alpha) is a passive rotation,
        # and therefore applied with a negative sign
Thorsten Kurth's avatar
Thorsten Kurth committed
284
285
286
        x = torch.cos(alpha) * cbeta * sgamma + cgamma * torch.sin(alpha)
        y = sbeta * sgamma
        z = -cbeta * torch.sin(alpha) * sgamma + torch.cos(alpha) * cgamma
Boris Bonev's avatar
Boris Bonev committed
287

288
        # normalization is important to avoid NaNs when arccos and atan are applied
289
        # this can otherwise lead to spurious artifacts in the solution
Boris Bonev's avatar
Boris Bonev committed
290
        norm = torch.sqrt(x * x + y * y + z * z)
291
292
293
294
        x = x / norm
        y = y / norm
        z = z / norm

Boris Bonev's avatar
Boris Bonev committed
295
        # compute spherical coordinates, where phi needs to fall into the [0, 2pi) range
296
        theta = torch.arccos(z)
297
298
        phi = torch.arctan2(y, x)
        phi = torch.where(phi < 0.0, phi + 2 * torch.pi, phi)
299
300

        # find the indices where the rotated position falls into the support of the kernel
301
        iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff_eff)
302
303
304
305

        # add the output latitude and reshape such that psi has dimensions kernel_shape x nlat_out x (nlat_in*nlon_in)
        idx = torch.stack([iidx[:, 0], t * torch.ones_like(iidx[:, 0]), iidx[:, 1] * nlon_in + iidx[:, 2]], dim=0)

Thorsten Kurth's avatar
Thorsten Kurth committed
306
        # append indices and values to the COO datastructure, compute row offsets
307
308
        out_idx.append(idx)
        out_vals.append(vals)
Thorsten Kurth's avatar
Thorsten Kurth committed
309
        out_roff[t + 1] = out_roff[t] + iidx.shape[0]
310
311

    # concatenate the indices and values
312
313
    out_idx = torch.cat(out_idx, dim=-1)
    out_vals = torch.cat(out_vals, dim=-1)
314

315
    out_vals = _normalize_convolution_tensor_s2(
316
317
318
319
320
321
322
323
324
        out_idx,
        out_vals,
        in_shape,
        out_shape,
        kernel_size,
        quad_weights,
        transpose_normalization=transpose_normalization,
        basis_norm_mode=basis_norm_mode,
        merge_quadrature=merge_quadrature,
325
    )
Boris Bonev's avatar
Boris Bonev committed
326

327
328
329
    out_idx = out_idx.contiguous()
    out_vals = out_vals.to(dtype=torch.float32).contiguous()

Thorsten Kurth's avatar
Thorsten Kurth committed
330
    return out_idx, out_vals, out_roff
Boris Bonev's avatar
Boris Bonev committed
331
332
333
334


class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta):
    """
335
    Abstract base class for discrete-continuous convolutions
apaaris's avatar
apaaris committed
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355

    Parameters
    -----------
    in_channels: int
        Number of input channels
    out_channels: int
        Number of output channels
    kernel_shape: Union[int, Tuple[int], Tuple[int, int]]
        Shape of the kernel
    basis_type: Optional[str]
        Type of the basis functions
    groups: Optional[int]
        Number of groups
    bias: Optional[bool]
        Whether to use bias

    Returns
    -------
    out: torch.Tensor
        Output tensor
Boris Bonev's avatar
Boris Bonev committed
356
357
358
359
360
361
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
Thorsten Kurth's avatar
Thorsten Kurth committed
362
        kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
363
        basis_type: Optional[str] = "piecewise linear",
Boris Bonev's avatar
Boris Bonev committed
364
365
366
367
368
        groups: Optional[int] = 1,
        bias: Optional[bool] = True,
    ):
        super().__init__()

369
        self.kernel_shape = kernel_shape
Boris Bonev's avatar
Boris Bonev committed
370

371
        # get the filter basis functions
372
        self.filter_basis = get_filter_basis(kernel_shape=kernel_shape, basis_type=basis_type)
Boris Bonev's avatar
Boris Bonev committed
373
374
375
376
377
378
379
380
381
382

        # groups
        self.groups = groups

        # weight tensor
        if in_channels % self.groups != 0:
            raise ValueError("Error, the number of input channels has to be an integer multiple of the group size")
        if out_channels % self.groups != 0:
            raise ValueError("Error, the number of output channels has to be an integer multiple of the group size")
        self.groupsize = in_channels // self.groups
Boris Bonev's avatar
Boris Bonev committed
383
        scale = math.sqrt(1.0 / self.groupsize / self.kernel_size)
Boris Bonev's avatar
Boris Bonev committed
384
385
386
387
388
389
390
        self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, self.kernel_size))

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channels))
        else:
            self.bias = None

391
392
393
394
    @property
    def kernel_size(self):
        return self.filter_basis.kernel_size

Boris Bonev's avatar
Boris Bonev committed
395
396
397
398
399
400
    @abc.abstractmethod
    def forward(self, x: torch.Tensor):
        raise NotImplementedError


class DiscreteContinuousConvS2(DiscreteContinuousConv):
401
    """
402
    Discrete-continuous (DISCO) convolutions on the 2-Sphere as described in [1].
403

apaaris's avatar
apaaris committed
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
    Parameters
    -----------
    in_channels: int
        Number of input channels
    out_channels: int
        Number of output channels
    in_shape: Tuple[int]
        Input shape of the convolution tensor
    out_shape: Tuple[int]
        Output shape of the convolution tensor
    kernel_shape: Union[int, Tuple[int], Tuple[int, int]]
        Shape of the kernel
    basis_type: Optional[str]
        Type of the basis functions
    basis_norm_mode: Optional[str]
        Mode for basis normalization
    groups: Optional[int]
        Number of groups
    grid_in: Optional[str]
        Input grid type
    grid_out: Optional[str]
        Output grid type
    bias: Optional[bool]
        Whether to use bias
    theta_cutoff: Optional[float]
        Theta cutoff for the filter basis functions

    Returns
    -------
    out: torch.Tensor
        Output tensor

    References
    ----------
438
439
440
441
442
443
444
445
446
    [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        in_shape: Tuple[int],
        out_shape: Tuple[int],
Thorsten Kurth's avatar
Thorsten Kurth committed
447
        kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
448
        basis_type: Optional[str] = "piecewise linear",
449
        basis_norm_mode: Optional[str] = "mean",
450
451
452
453
454
455
        groups: Optional[int] = 1,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        bias: Optional[bool] = True,
        theta_cutoff: Optional[float] = None,
    ):
456
        super().__init__(in_channels, out_channels, kernel_shape, basis_type, groups, bias)
457
458
459
460

        self.nlat_in, self.nlon_in = in_shape
        self.nlat_out, self.nlon_out = out_shape

461
462
463
        # make sure the p-shift works by checking that longitudes are divisible
        assert self.nlon_in % self.nlon_out == 0

Boris Bonev's avatar
Boris Bonev committed
464
        # heuristic to compute theta cutoff based on the bandlimit of the input field and overlaps of the basis functions
465
        if theta_cutoff is None:
Boris Bonev's avatar
Boris Bonev committed
466
            theta_cutoff = torch.pi / float(self.nlat_out - 1)
467
468
469
470

        if theta_cutoff <= 0.0:
            raise ValueError("Error, theta_cutoff has to be positive.")

Thorsten Kurth's avatar
Thorsten Kurth committed
471
        idx, vals, _ = _precompute_convolution_tensor_s2(
472
473
474
475
476
477
478
479
480
            in_shape,
            out_shape,
            self.filter_basis,
            grid_in=grid_in,
            grid_out=grid_out,
            theta_cutoff=theta_cutoff,
            transpose_normalization=False,
            basis_norm_mode=basis_norm_mode,
            merge_quadrature=True,
Boris Bonev's avatar
Boris Bonev committed
481
482
483
484
485
486
        )

        # sort the values
        ker_idx = idx[0, ...].contiguous()
        row_idx = idx[1, ...].contiguous()
        col_idx = idx[2, ...].contiguous()
487
        vals = vals.contiguous()
Boris Bonev's avatar
Boris Bonev committed
488

489
490
        if _cuda_extension_available:
            # preprocessed data-structure for GPU kernel
Thorsten Kurth's avatar
Thorsten Kurth committed
491
            roff_idx = preprocess_psi(self.kernel_size, self.nlat_out, ker_idx, row_idx, col_idx, vals).contiguous()
492
493
494
            self.register_buffer("psi_roff_idx", roff_idx, persistent=False)

        # save all datastructures
Boris Bonev's avatar
Boris Bonev committed
495
496
497
        self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
        self.register_buffer("psi_row_idx", row_idx, persistent=False)
        self.register_buffer("psi_col_idx", col_idx, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
498
        self.register_buffer("psi_vals", vals, persistent=False)
499

500
501
502
        # also store psi as COO matrix just in case for torch input
        self.psi = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out)

503
504
505
506
    def extra_repr(self):
        r"""
        Pretty print module
        """
507
        return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
508

Boris Bonev's avatar
Boris Bonev committed
509
510
    @property
    def psi_idx(self):
apaaris's avatar
apaaris committed
511
512
513
514
515
516
517
518
        """
        Get the convolution tensor index

        Returns
        -------
        psi_idx: torch.Tensor
            Convolution tensor index
        """
Boris Bonev's avatar
Boris Bonev committed
519
520
        return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()

apaaris's avatar
apaaris committed
521
522
523
524
525
526
527
528
529
530
531
532
    def get_psi(self):
        """
        Get the convolution tensor

        Returns
        -------
        psi: torch.Tensor
            Convolution tensor
        """
        psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in)).coalesce()
        return psi

Boris Bonev's avatar
Boris Bonev committed
533
    def forward(self, x: torch.Tensor) -> torch.Tensor:
534

Boris Bonev's avatar
Boris Bonev committed
535
536
537
538
        if x.is_cuda and _cuda_extension_available:
            x = _disco_s2_contraction_cuda(
                x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out
            )
539
        else:
Boris Bonev's avatar
Boris Bonev committed
540
541
            if x.is_cuda:
                warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
542
            x = _disco_s2_contraction_torch(x, self.psi.to(x.device), self.nlon_out)
543
544
545
546
547
548

        # extract shape
        B, C, K, H, W = x.shape
        x = x.reshape(B, self.groups, self.groupsize, K, H, W)

        # do weight multiplication
Thorsten Kurth's avatar
Thorsten Kurth committed
549
        out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
Boris Bonev's avatar
Boris Bonev committed
550
        out = out.reshape(B, -1, H, W)
551
552
553
554
555
556
557

        if self.bias is not None:
            out = out + self.bias.reshape(1, -1, 1, 1)

        return out


Boris Bonev's avatar
Boris Bonev committed
558
class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
559
    """
560
    Discrete-continuous (DISCO) transpose convolutions on the 2-Sphere as described in [1].
561

apaaris's avatar
apaaris committed
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
    Parameters
    -----------
    in_channels: int
        Number of input channels
    out_channels: int
        Number of output channels
    in_shape: Tuple[int]
        Input shape of the convolution tensor
    out_shape: Tuple[int]
        Output shape of the convolution tensor
    kernel_shape: Union[int, Tuple[int], Tuple[int, int]]
        Shape of the kernel
    basis_type: Optional[str]
        Type of the basis functions
    basis_norm_mode: Optional[str]
        Mode for basis normalization
    groups: Optional[int]
        Number of groups
    grid_in: Optional[str]
        Input grid type
    grid_out: Optional[str]
        Output grid type
    bias: Optional[bool]
        Whether to use bias
    theta_cutoff: Optional[float]
        Theta cutoff for the filter basis functions
    
    Returns
    --------
    out: torch.Tensor
        Output tensor

    References
    ----------
596
597
598
599
600
601
602
603
604
    [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        in_shape: Tuple[int],
        out_shape: Tuple[int],
Thorsten Kurth's avatar
Thorsten Kurth committed
605
        kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
606
        basis_type: Optional[str] = "piecewise linear",
607
        basis_norm_mode: Optional[str] = "mean",
608
609
610
611
612
613
        groups: Optional[int] = 1,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        bias: Optional[bool] = True,
        theta_cutoff: Optional[float] = None,
    ):
614
        super().__init__(in_channels, out_channels, kernel_shape, basis_type, groups, bias)
615
616
617
618

        self.nlat_in, self.nlon_in = in_shape
        self.nlat_out, self.nlon_out = out_shape

619
620
621
        # make sure the p-shift works by checking that longitudes are divisible
        assert self.nlon_out % self.nlon_in == 0

622
623
        # bandlimit
        if theta_cutoff is None:
Boris Bonev's avatar
Boris Bonev committed
624
            theta_cutoff = torch.pi / float(self.nlat_in - 1)
625
626
627
628

        if theta_cutoff <= 0.0:
            raise ValueError("Error, theta_cutoff has to be positive.")

629
        # switch in_shape and out_shape since we want the transpose convolution
Thorsten Kurth's avatar
Thorsten Kurth committed
630
        idx, vals, _ = _precompute_convolution_tensor_s2(
631
632
633
634
635
636
637
638
639
            out_shape,
            in_shape,
            self.filter_basis,
            grid_in=grid_out,
            grid_out=grid_in,
            theta_cutoff=theta_cutoff,
            transpose_normalization=True,
            basis_norm_mode=basis_norm_mode,
            merge_quadrature=True,
Boris Bonev's avatar
Boris Bonev committed
640
641
642
643
644
645
        )

        # sort the values
        ker_idx = idx[0, ...].contiguous()
        row_idx = idx[1, ...].contiguous()
        col_idx = idx[2, ...].contiguous()
646
        vals = vals.contiguous()
Boris Bonev's avatar
Boris Bonev committed
647

648
649
        if _cuda_extension_available:
            # preprocessed data-structure for GPU kernel
Thorsten Kurth's avatar
Thorsten Kurth committed
650
            roff_idx = preprocess_psi(self.kernel_size, self.nlat_in, ker_idx, row_idx, col_idx, vals).contiguous()
651
652
653
            self.register_buffer("psi_roff_idx", roff_idx, persistent=False)

        # save all datastructures
Boris Bonev's avatar
Boris Bonev committed
654
655
656
        self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
        self.register_buffer("psi_row_idx", row_idx, persistent=False)
        self.register_buffer("psi_col_idx", col_idx, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
657
        self.register_buffer("psi_vals", vals, persistent=False)
658

659
660
661
        # also store psi just in case
        self.psi_st = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, semi_transposed=True)

662
663
664
665
    def extra_repr(self):
        r"""
        Pretty print module
        """
666
        return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
667

Boris Bonev's avatar
Boris Bonev committed
668
669
670
671
    @property
    def psi_idx(self):
        return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()

apaaris's avatar
apaaris committed
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
    def get_psi(self, semi_transposed: bool = False):
        """
        Get the convolution tensor

        Parameters
        -----------
        semi_transposed: bool
            Whether to semi-transpose the convolution tensor

        Returns
        -------
        psi: torch.Tensor
            Convolution tensor
        """

        if semi_transposed:
            # we do a semi-transposition to faciliate the computation
            tout = self.psi_idx[2] // self.nlon_out
            pout = self.psi_idx[2] % self.nlon_out
            # flip the axis of longitudes
            pout = self.nlon_out - 1 - pout
            tin = self.psi_idx[1]
            idx = torch.stack([self.psi_idx[0], tout, tin * self.nlon_out + pout], dim=0)
            psi = torch.sparse_coo_tensor(idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_out)).coalesce()
        else:
            psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out)).coalesce()

        return psi

Boris Bonev's avatar
Boris Bonev committed
701
    def forward(self, x: torch.Tensor) -> torch.Tensor:
702
        
703
        # extract shape
Boris Bonev's avatar
Boris Bonev committed
704
        B, C, H, W = x.shape
705
706
707
        x = x.reshape(B, self.groups, self.groupsize, H, W)

        # do weight multiplication
Thorsten Kurth's avatar
Thorsten Kurth committed
708
        x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
Boris Bonev's avatar
Boris Bonev committed
709
        x = x.reshape(B, -1, x.shape[-3], H, W)
710

Boris Bonev's avatar
Boris Bonev committed
711
712
713
714
        if x.is_cuda and _cuda_extension_available:
            out = _disco_s2_transpose_contraction_cuda(
                x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out
            )
715
        else:
Boris Bonev's avatar
Boris Bonev committed
716
717
            if x.is_cuda:
                warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
718
            out = _disco_s2_transpose_contraction_torch(x, self.psi_st.to(x.device), self.nlon_out)
719
720
721
722
723

        if self.bias is not None:
            out = out + self.bias.reshape(1, -1, 1, 1)

        return out