convolution.py 26 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

Boris Bonev's avatar
Boris Bonev committed
32
import abc
33
from typing import List, Tuple, Union, Optional
Boris Bonev's avatar
Boris Bonev committed
34
from warnings import warn
35
36
37
38
39
40
41
42

import math

import torch
import torch.nn as nn

from functools import partial

Thorsten Kurth's avatar
Thorsten Kurth committed
43
44
from torch_harmonics.cache import lru_cache
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes
45
from torch_harmonics._disco_convolution import _get_psi, _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch
Boris Bonev's avatar
Boris Bonev committed
46
from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda
Thorsten Kurth's avatar
Thorsten Kurth committed
47
from torch_harmonics.filter_basis import FilterBasis, get_filter_basis
48

49
# import custom C++/CUDA extensions if available
Boris Bonev's avatar
Boris Bonev committed
50
try:
51
    from disco_helpers import preprocess_psi
Boris Bonev's avatar
Boris Bonev committed
52
    import disco_cuda_extension
53

Boris Bonev's avatar
Boris Bonev committed
54
55
56
57
58
59
    _cuda_extension_available = True
except ImportError as err:
    disco_cuda_extension = None
    _cuda_extension_available = False


60
def _normalize_convolution_tensor_s2(
61
    psi_idx, psi_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=False, basis_norm_mode="mean", merge_quadrature=False, eps=1e-9
62
):
Boris Bonev's avatar
Boris Bonev committed
63
    """
64
65
66
67
    Discretely normalizes the convolution tensor and pre-applies quadrature weights. Supports the following three normalization modes:
    - "none": No normalization is applied.
    - "individual": for each output latitude and filter basis function the filter is numerically integrated over the sphere and normalized so that it yields 1.
    - "mean": the norm is computed for each output latitude and then averaged over the output latitudes. Each basis function is then normalized by this mean.
apaaris's avatar
apaaris committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95

    Parameters
    -----------
    psi_idx: torch.Tensor
        Index tensor of the convolution tensor
    psi_vals: torch.Tensor
        Values tensor of the convolution tensor
    in_shape: Tuple[int]
        Input shape of the convolution tensor
    out_shape: Tuple[int]
        Output shape of the convolution tensor
    kernel_size: int
        Size of the kernel
    quad_weights: torch.Tensor
        Quadrature weights
    transpose_normalization: bool
        Whether to normalize the convolution tensor in the transpose direction
    basis_norm_mode: str
        Mode for basis normalization
    merge_quadrature: bool
        Whether to merge the quadrature weights into the convolution tensor
    eps: float
        Small epsilon to avoid division by zero

    Returns
    -------
    psi_vals: torch.Tensor
        Normalized convolution tensor
Boris Bonev's avatar
Boris Bonev committed
96
97
    """

Thorsten Kurth's avatar
Thorsten Kurth committed
98
99
100
101
    # exit here if no normalization is needed
    if basis_norm_mode == "none":
        return psi_vals

102
103
    # reshape the indices implicitly to be ikernel, out_shape[0], in_shape[0], in_shape[1]
    idx = torch.stack([psi_idx[0], psi_idx[1], psi_idx[2] // in_shape[1], psi_idx[2] % in_shape[1]], dim=0)
Boris Bonev's avatar
Boris Bonev committed
104

105
106
    # getting indices for adressing kernels, input and output latitudes
    ikernel = idx[0]
Boris Bonev's avatar
Boris Bonev committed
107
108

    if transpose_normalization:
109
110
111
112
113
        ilat_out = idx[2]
        ilat_in = idx[1]
        # here we are deliberately swapping input and output shapes to handle transpose normalization with the same code
        nlat_out = in_shape[0]
        correction_factor = out_shape[1] / in_shape[1]
Boris Bonev's avatar
Boris Bonev committed
114
    else:
115
116
117
118
119
120
121
122
        ilat_out = idx[1]
        ilat_in = idx[2]
        nlat_out = out_shape[0]

    # get the quadrature weights
    q = quad_weights[ilat_in].reshape(-1)

    # buffer to store intermediate values
123
124
    vnorm = torch.zeros(kernel_size, nlat_out, device=psi_vals.device)
    support = torch.zeros(kernel_size, nlat_out, device=psi_vals.device)
125
126
127
128
129
130
131
132

    # loop through dimensions to compute the norms
    for ik in range(kernel_size):
        for ilat in range(nlat_out):

            # find indices corresponding to the given output latitude and kernel basis function
            iidx = torch.argwhere((ikernel == ik) & (ilat_out == ilat))

133
134
135
            # compute the 1-norm
            # vnorm[ik, ilat] = torch.sqrt(torch.sum(psi_vals[iidx].abs().pow(2) * q[iidx]))
            vnorm[ik, ilat] = torch.sum(psi_vals[iidx].abs() * q[iidx])
136

137
138
139
140
            # compute the support
            support[ik, ilat] = torch.sum(q[iidx])


141
142
143
144
145
146
147
148
149
150
    # loop over values and renormalize
    for ik in range(kernel_size):
        for ilat in range(nlat_out):

            iidx = torch.argwhere((ikernel == ik) & (ilat_out == ilat))

            if basis_norm_mode == "individual":
                val = vnorm[ik, ilat]
            elif basis_norm_mode == "mean":
                val = vnorm[ik, :].mean()
151
152
            elif basis_norm_mode == "support":
                val = support[ik, ilat]
153
154
155
156
157
            elif basis_norm_mode == "none":
                val = 1.0
            else:
                raise ValueError(f"Unknown basis normalization mode {basis_norm_mode}.")

158
159
            psi_vals[iidx] = psi_vals[iidx] / (val + eps)

160
            if merge_quadrature:
161
                psi_vals[iidx] = psi_vals[iidx] * q[iidx]
162
163
164
165


    if transpose_normalization and merge_quadrature:
        psi_vals = psi_vals / correction_factor
Boris Bonev's avatar
Boris Bonev committed
166
167
168
169

    return psi_vals


Thorsten Kurth's avatar
Thorsten Kurth committed
170
@lru_cache(typed=True, copy=True)
Boris Bonev's avatar
Boris Bonev committed
171
def _precompute_convolution_tensor_s2(
Thorsten Kurth's avatar
Thorsten Kurth committed
172
173
174
175
176
177
178
179
180
181
    in_shape: Tuple[int],
    out_shape: Tuple[int],
    filter_basis: FilterBasis,
    grid_in: Optional[str]="equiangular",
    grid_out: Optional[str]="equiangular",
    theta_cutoff: Optional[float]=0.01 * math.pi,
    theta_eps: Optional[float]=1e-3,
    transpose_normalization: Optional[bool]=False,
    basis_norm_mode: Optional[str]="mean",
    merge_quadrature: Optional[bool]=False,
Boris Bonev's avatar
Boris Bonev committed
182
):
183
184
185
    """
    Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
    Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al.
186
187
188
189
190
    The output tensor has shape kernel_shape x nlat_out x (nlat_in * nlon_in).

    The rotation of the Euler angles uses the YZY convention, which applied to the northpole $(0,0,1)^T$ yields
    $$
    Y(\alpha) Z(\beta) Y(\gamma) n =
Boris Bonev's avatar
Boris Bonev committed
191
        {\begin{bmatrix}
192
193
194
195
196
            \cos(\gamma)\sin(\alpha) + \cos(\alpha)\cos(\beta)\sin(\gamma) \\
            \sin(\beta)\sin(\gamma) \\
            \cos(\alpha)\cos(\gamma)-\cos(\beta)\sin(\alpha)\sin(\gamma)
        \end{bmatrix}}
    $$
apaaris's avatar
apaaris committed
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227

    Parameters
    -----------
    in_shape: Tuple[int]
        Input shape of the convolution tensor
    out_shape: Tuple[int]
        Output shape of the convolution tensor
    filter_basis: FilterBasis
        Filter basis functions
    grid_in: str
        Input grid type
    grid_out: str
        Output grid type
    theta_cutoff: float
        Theta cutoff for the filter basis functions
    theta_eps: float
        Epsilon for the theta cutoff
    transpose_normalization: bool
        Whether to normalize the convolution tensor in the transpose direction
    basis_norm_mode: str
        Mode for basis normalization
    merge_quadrature: bool
        Whether to merge the quadrature weights into the convolution tensor

    Returns
    -------
    out_idx: torch.Tensor
        Index tensor of the convolution tensor
    out_vals: torch.Tensor
        Values tensor of the convolution tensor

228
229
230
231
232
    """

    assert len(in_shape) == 2
    assert len(out_shape) == 2

233
    kernel_size = filter_basis.kernel_size
234
235
236
237

    nlat_in, nlon_in = in_shape
    nlat_out, nlon_out = out_shape

238
    # precompute input and output grids
Boris Bonev's avatar
Boris Bonev committed
239
240
    lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in)
    lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out)
241
242

    # compute the phi differences
243
    # It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
Thorsten Kurth's avatar
Thorsten Kurth committed
244
    lons_in = _precompute_longitudes(nlon_in)
245

246
247
    # compute quadrature weights and merge them into the convolution tensor.
    # These quadrature integrate to 1 over the sphere.
248
    if transpose_normalization:
Thorsten Kurth's avatar
Thorsten Kurth committed
249
        quad_weights = wout.reshape(-1, 1) / nlon_in / 2.0
250
    else:
Thorsten Kurth's avatar
Thorsten Kurth committed
251
        quad_weights = win.reshape(-1, 1) / nlon_in / 2.0
252
253
254

    # effective theta cutoff if multiplied with a fudge factor to avoid aliasing with grid width (especially near poles)
    theta_cutoff_eff = (1.0 + theta_eps) * theta_cutoff
255

256
257
    out_idx = []
    out_vals = []
Thorsten Kurth's avatar
Thorsten Kurth committed
258
259
260
261
262
263
264
265
266
267
268

    beta = lons_in
    gamma = lats_in.reshape(-1, 1)

    # compute trigs
    cbeta = torch.cos(beta)
    sbeta = torch.sin(beta)
    cgamma = torch.cos(gamma)
    sgamma = torch.sin(gamma)

    # compute row offsets
269
    out_roff = torch.zeros(nlat_out + 1, dtype=torch.int64, device=lons_in.device)
Thorsten Kurth's avatar
Thorsten Kurth committed
270
    out_roff[0] = 0
271
    for t in range(nlat_out):
272
        # the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis
Boris Bonev's avatar
Boris Bonev committed
273
        alpha = -lats_out[t]
274
275

        # compute cartesian coordinates of the rotated position
276
277
        # This uses the YZY convention of Euler angles, where the last angle (alpha) is a passive rotation,
        # and therefore applied with a negative sign
Thorsten Kurth's avatar
Thorsten Kurth committed
278
279
280
        x = torch.cos(alpha) * cbeta * sgamma + cgamma * torch.sin(alpha)
        y = sbeta * sgamma
        z = -cbeta * torch.sin(alpha) * sgamma + torch.cos(alpha) * cgamma
Boris Bonev's avatar
Boris Bonev committed
281

282
        # normalization is important to avoid NaNs when arccos and atan are applied
283
        # this can otherwise lead to spurious artifacts in the solution
Boris Bonev's avatar
Boris Bonev committed
284
        norm = torch.sqrt(x * x + y * y + z * z)
285
286
287
288
        x = x / norm
        y = y / norm
        z = z / norm

Boris Bonev's avatar
Boris Bonev committed
289
        # compute spherical coordinates, where phi needs to fall into the [0, 2pi) range
290
        theta = torch.arccos(z)
291
292
        phi = torch.arctan2(y, x)
        phi = torch.where(phi < 0.0, phi + 2 * torch.pi, phi)
293
294

        # find the indices where the rotated position falls into the support of the kernel
295
        iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff_eff)
296
297
298
299

        # add the output latitude and reshape such that psi has dimensions kernel_shape x nlat_out x (nlat_in*nlon_in)
        idx = torch.stack([iidx[:, 0], t * torch.ones_like(iidx[:, 0]), iidx[:, 1] * nlon_in + iidx[:, 2]], dim=0)

Thorsten Kurth's avatar
Thorsten Kurth committed
300
        # append indices and values to the COO datastructure, compute row offsets
301
302
        out_idx.append(idx)
        out_vals.append(vals)
Thorsten Kurth's avatar
Thorsten Kurth committed
303
        out_roff[t + 1] = out_roff[t] + iidx.shape[0]
304
305

    # concatenate the indices and values
306
307
    out_idx = torch.cat(out_idx, dim=-1)
    out_vals = torch.cat(out_vals, dim=-1)
308

309
    out_vals = _normalize_convolution_tensor_s2(
310
311
312
313
314
315
316
317
318
        out_idx,
        out_vals,
        in_shape,
        out_shape,
        kernel_size,
        quad_weights,
        transpose_normalization=transpose_normalization,
        basis_norm_mode=basis_norm_mode,
        merge_quadrature=merge_quadrature,
319
    )
Boris Bonev's avatar
Boris Bonev committed
320

321
322
323
    out_idx = out_idx.contiguous()
    out_vals = out_vals.to(dtype=torch.float32).contiguous()

Thorsten Kurth's avatar
Thorsten Kurth committed
324
    return out_idx, out_vals, out_roff
Boris Bonev's avatar
Boris Bonev committed
325
326
327
328


class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta):
    """
329
    Abstract base class for discrete-continuous convolutions
apaaris's avatar
apaaris committed
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349

    Parameters
    -----------
    in_channels: int
        Number of input channels
    out_channels: int
        Number of output channels
    kernel_shape: Union[int, Tuple[int], Tuple[int, int]]
        Shape of the kernel
    basis_type: Optional[str]
        Type of the basis functions
    groups: Optional[int]
        Number of groups
    bias: Optional[bool]
        Whether to use bias

    Returns
    -------
    out: torch.Tensor
        Output tensor
Boris Bonev's avatar
Boris Bonev committed
350
351
352
353
354
355
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
Thorsten Kurth's avatar
Thorsten Kurth committed
356
        kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
357
        basis_type: Optional[str] = "piecewise linear",
Boris Bonev's avatar
Boris Bonev committed
358
359
360
361
362
        groups: Optional[int] = 1,
        bias: Optional[bool] = True,
    ):
        super().__init__()

363
        self.kernel_shape = kernel_shape
Boris Bonev's avatar
Boris Bonev committed
364

365
        # get the filter basis functions
366
        self.filter_basis = get_filter_basis(kernel_shape=kernel_shape, basis_type=basis_type)
Boris Bonev's avatar
Boris Bonev committed
367
368
369
370
371
372
373
374
375
376

        # groups
        self.groups = groups

        # weight tensor
        if in_channels % self.groups != 0:
            raise ValueError("Error, the number of input channels has to be an integer multiple of the group size")
        if out_channels % self.groups != 0:
            raise ValueError("Error, the number of output channels has to be an integer multiple of the group size")
        self.groupsize = in_channels // self.groups
Boris Bonev's avatar
Boris Bonev committed
377
        scale = math.sqrt(1.0 / self.groupsize / self.kernel_size)
Boris Bonev's avatar
Boris Bonev committed
378
379
380
381
382
383
384
        self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, self.kernel_size))

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channels))
        else:
            self.bias = None

385
386
387
388
    @property
    def kernel_size(self):
        return self.filter_basis.kernel_size

Boris Bonev's avatar
Boris Bonev committed
389
390
391
392
393
394
    @abc.abstractmethod
    def forward(self, x: torch.Tensor):
        raise NotImplementedError


class DiscreteContinuousConvS2(DiscreteContinuousConv):
395
    """
396
    Discrete-continuous (DISCO) convolutions on the 2-Sphere as described in [1].
397

apaaris's avatar
apaaris committed
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
    Parameters
    -----------
    in_channels: int
        Number of input channels
    out_channels: int
        Number of output channels
    in_shape: Tuple[int]
        Input shape of the convolution tensor
    out_shape: Tuple[int]
        Output shape of the convolution tensor
    kernel_shape: Union[int, Tuple[int], Tuple[int, int]]
        Shape of the kernel
    basis_type: Optional[str]
        Type of the basis functions
    basis_norm_mode: Optional[str]
        Mode for basis normalization
    groups: Optional[int]
        Number of groups
    grid_in: Optional[str]
        Input grid type
    grid_out: Optional[str]
        Output grid type
    bias: Optional[bool]
        Whether to use bias
    theta_cutoff: Optional[float]
        Theta cutoff for the filter basis functions

    Returns
    -------
    out: torch.Tensor
        Output tensor

    References
    ----------
432
433
434
435
436
437
438
439
440
    [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        in_shape: Tuple[int],
        out_shape: Tuple[int],
Thorsten Kurth's avatar
Thorsten Kurth committed
441
        kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
442
        basis_type: Optional[str] = "piecewise linear",
443
        basis_norm_mode: Optional[str] = "mean",
444
445
446
447
448
449
        groups: Optional[int] = 1,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        bias: Optional[bool] = True,
        theta_cutoff: Optional[float] = None,
    ):
450
        super().__init__(in_channels, out_channels, kernel_shape, basis_type, groups, bias)
451
452
453
454

        self.nlat_in, self.nlon_in = in_shape
        self.nlat_out, self.nlon_out = out_shape

455
456
457
        # make sure the p-shift works by checking that longitudes are divisible
        assert self.nlon_in % self.nlon_out == 0

Boris Bonev's avatar
Boris Bonev committed
458
        # heuristic to compute theta cutoff based on the bandlimit of the input field and overlaps of the basis functions
459
        if theta_cutoff is None:
Boris Bonev's avatar
Boris Bonev committed
460
            theta_cutoff = torch.pi / float(self.nlat_out - 1)
461
462
463
464

        if theta_cutoff <= 0.0:
            raise ValueError("Error, theta_cutoff has to be positive.")

Thorsten Kurth's avatar
Thorsten Kurth committed
465
        idx, vals, _ = _precompute_convolution_tensor_s2(
466
467
468
469
470
471
472
473
474
            in_shape,
            out_shape,
            self.filter_basis,
            grid_in=grid_in,
            grid_out=grid_out,
            theta_cutoff=theta_cutoff,
            transpose_normalization=False,
            basis_norm_mode=basis_norm_mode,
            merge_quadrature=True,
Boris Bonev's avatar
Boris Bonev committed
475
476
477
478
479
480
        )

        # sort the values
        ker_idx = idx[0, ...].contiguous()
        row_idx = idx[1, ...].contiguous()
        col_idx = idx[2, ...].contiguous()
481
        vals = vals.contiguous()
Boris Bonev's avatar
Boris Bonev committed
482

483
484
        if _cuda_extension_available:
            # preprocessed data-structure for GPU kernel
Thorsten Kurth's avatar
Thorsten Kurth committed
485
            roff_idx = preprocess_psi(self.kernel_size, self.nlat_out, ker_idx, row_idx, col_idx, vals).contiguous()
486
487
488
            self.register_buffer("psi_roff_idx", roff_idx, persistent=False)

        # save all datastructures
Boris Bonev's avatar
Boris Bonev committed
489
490
491
        self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
        self.register_buffer("psi_row_idx", row_idx, persistent=False)
        self.register_buffer("psi_col_idx", col_idx, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
492
        self.register_buffer("psi_vals", vals, persistent=False)
493

494
495
496
        # also store psi as COO matrix just in case for torch input
        self.psi = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out)

497
498
499
500
    def extra_repr(self):
        r"""
        Pretty print module
        """
501
        return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
502

Boris Bonev's avatar
Boris Bonev committed
503
504
    @property
    def psi_idx(self):
apaaris's avatar
apaaris committed
505
506
507
508
509
510
511
512
        """
        Get the convolution tensor index

        Returns
        -------
        psi_idx: torch.Tensor
            Convolution tensor index
        """
Boris Bonev's avatar
Boris Bonev committed
513
514
        return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()

apaaris's avatar
apaaris committed
515
516
517
518
519
520
521
522
523
524
525
526
    def get_psi(self):
        """
        Get the convolution tensor

        Returns
        -------
        psi: torch.Tensor
            Convolution tensor
        """
        psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in)).coalesce()
        return psi

Boris Bonev's avatar
Boris Bonev committed
527
    def forward(self, x: torch.Tensor) -> torch.Tensor:
528

Boris Bonev's avatar
Boris Bonev committed
529
530
531
532
        if x.is_cuda and _cuda_extension_available:
            x = _disco_s2_contraction_cuda(
                x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out
            )
533
        else:
Boris Bonev's avatar
Boris Bonev committed
534
535
            if x.is_cuda:
                warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
536
            x = _disco_s2_contraction_torch(x, self.psi.to(x.device), self.nlon_out)
537
538
539
540
541
542

        # extract shape
        B, C, K, H, W = x.shape
        x = x.reshape(B, self.groups, self.groupsize, K, H, W)

        # do weight multiplication
Thorsten Kurth's avatar
Thorsten Kurth committed
543
        out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
Boris Bonev's avatar
Boris Bonev committed
544
        out = out.reshape(B, -1, H, W)
545
546
547
548
549
550
551

        if self.bias is not None:
            out = out + self.bias.reshape(1, -1, 1, 1)

        return out


Boris Bonev's avatar
Boris Bonev committed
552
class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
553
    """
554
    Discrete-continuous (DISCO) transpose convolutions on the 2-Sphere as described in [1].
555

apaaris's avatar
apaaris committed
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
    Parameters
    -----------
    in_channels: int
        Number of input channels
    out_channels: int
        Number of output channels
    in_shape: Tuple[int]
        Input shape of the convolution tensor
    out_shape: Tuple[int]
        Output shape of the convolution tensor
    kernel_shape: Union[int, Tuple[int], Tuple[int, int]]
        Shape of the kernel
    basis_type: Optional[str]
        Type of the basis functions
    basis_norm_mode: Optional[str]
        Mode for basis normalization
    groups: Optional[int]
        Number of groups
    grid_in: Optional[str]
        Input grid type
    grid_out: Optional[str]
        Output grid type
    bias: Optional[bool]
        Whether to use bias
    theta_cutoff: Optional[float]
        Theta cutoff for the filter basis functions
    
    Returns
    --------
    out: torch.Tensor
        Output tensor

    References
    ----------
590
591
592
593
594
595
596
597
598
    [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        in_shape: Tuple[int],
        out_shape: Tuple[int],
Thorsten Kurth's avatar
Thorsten Kurth committed
599
        kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
600
        basis_type: Optional[str] = "piecewise linear",
601
        basis_norm_mode: Optional[str] = "mean",
602
603
604
605
606
607
        groups: Optional[int] = 1,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        bias: Optional[bool] = True,
        theta_cutoff: Optional[float] = None,
    ):
608
        super().__init__(in_channels, out_channels, kernel_shape, basis_type, groups, bias)
609
610
611
612

        self.nlat_in, self.nlon_in = in_shape
        self.nlat_out, self.nlon_out = out_shape

613
614
615
        # make sure the p-shift works by checking that longitudes are divisible
        assert self.nlon_out % self.nlon_in == 0

616
617
        # bandlimit
        if theta_cutoff is None:
Boris Bonev's avatar
Boris Bonev committed
618
            theta_cutoff = torch.pi / float(self.nlat_in - 1)
619
620
621
622

        if theta_cutoff <= 0.0:
            raise ValueError("Error, theta_cutoff has to be positive.")

623
        # switch in_shape and out_shape since we want the transpose convolution
Thorsten Kurth's avatar
Thorsten Kurth committed
624
        idx, vals, _ = _precompute_convolution_tensor_s2(
625
626
627
628
629
630
631
632
633
            out_shape,
            in_shape,
            self.filter_basis,
            grid_in=grid_out,
            grid_out=grid_in,
            theta_cutoff=theta_cutoff,
            transpose_normalization=True,
            basis_norm_mode=basis_norm_mode,
            merge_quadrature=True,
Boris Bonev's avatar
Boris Bonev committed
634
635
636
637
638
639
        )

        # sort the values
        ker_idx = idx[0, ...].contiguous()
        row_idx = idx[1, ...].contiguous()
        col_idx = idx[2, ...].contiguous()
640
        vals = vals.contiguous()
Boris Bonev's avatar
Boris Bonev committed
641

642
643
        if _cuda_extension_available:
            # preprocessed data-structure for GPU kernel
Thorsten Kurth's avatar
Thorsten Kurth committed
644
            roff_idx = preprocess_psi(self.kernel_size, self.nlat_in, ker_idx, row_idx, col_idx, vals).contiguous()
645
646
647
            self.register_buffer("psi_roff_idx", roff_idx, persistent=False)

        # save all datastructures
Boris Bonev's avatar
Boris Bonev committed
648
649
650
        self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
        self.register_buffer("psi_row_idx", row_idx, persistent=False)
        self.register_buffer("psi_col_idx", col_idx, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
651
        self.register_buffer("psi_vals", vals, persistent=False)
652

653
654
655
        # also store psi just in case
        self.psi_st = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, semi_transposed=True)

656
657
658
659
    def extra_repr(self):
        r"""
        Pretty print module
        """
660
        return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
661

Boris Bonev's avatar
Boris Bonev committed
662
663
664
665
    @property
    def psi_idx(self):
        return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()

apaaris's avatar
apaaris committed
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
    def get_psi(self, semi_transposed: bool = False):
        """
        Get the convolution tensor

        Parameters
        -----------
        semi_transposed: bool
            Whether to semi-transpose the convolution tensor

        Returns
        -------
        psi: torch.Tensor
            Convolution tensor
        """

        if semi_transposed:
            # we do a semi-transposition to faciliate the computation
            tout = self.psi_idx[2] // self.nlon_out
            pout = self.psi_idx[2] % self.nlon_out
            # flip the axis of longitudes
            pout = self.nlon_out - 1 - pout
            tin = self.psi_idx[1]
            idx = torch.stack([self.psi_idx[0], tout, tin * self.nlon_out + pout], dim=0)
            psi = torch.sparse_coo_tensor(idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_out)).coalesce()
        else:
            psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out)).coalesce()

        return psi

Boris Bonev's avatar
Boris Bonev committed
695
    def forward(self, x: torch.Tensor) -> torch.Tensor:
apaaris's avatar
apaaris committed
696
697
698
699
700
701
702
703
704
705
706
707
708
        """
        Forward pass

        Parameters
        -----------
        x: torch.Tensor
            Input tensor

        Returns
        -------
        out: torch.Tensor
            Output tensor
        """
709
        # extract shape
Boris Bonev's avatar
Boris Bonev committed
710
        B, C, H, W = x.shape
711
712
713
        x = x.reshape(B, self.groups, self.groupsize, H, W)

        # do weight multiplication
Thorsten Kurth's avatar
Thorsten Kurth committed
714
        x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
Boris Bonev's avatar
Boris Bonev committed
715
        x = x.reshape(B, -1, x.shape[-3], H, W)
716

Boris Bonev's avatar
Boris Bonev committed
717
718
719
720
        if x.is_cuda and _cuda_extension_available:
            out = _disco_s2_transpose_contraction_cuda(
                x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out
            )
721
        else:
Boris Bonev's avatar
Boris Bonev committed
722
723
            if x.is_cuda:
                warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
724
            out = _disco_s2_transpose_contraction_torch(x, self.psi_st.to(x.device), self.nlon_out)
725
726
727
728
729

        if self.bias is not None:
            out = out + self.bias.reshape(1, -1, 1, 1)

        return out