convolution.py 19.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

Boris Bonev's avatar
Boris Bonev committed
32
import abc
33
from typing import List, Tuple, Union, Optional
Boris Bonev's avatar
Boris Bonev committed
34
from warnings import warn
35
36
37
38
39
40
41
42

import math

import torch
import torch.nn as nn

from functools import partial

Boris Bonev's avatar
Boris Bonev committed
43
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes
Boris Bonev's avatar
Boris Bonev committed
44
45
from torch_harmonics._disco_convolution import _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch
from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda
46
from torch_harmonics.filter_basis import get_filter_basis
47

48
# import custom C++/CUDA extensions if available
Boris Bonev's avatar
Boris Bonev committed
49
try:
50
    from disco_helpers import preprocess_psi
Boris Bonev's avatar
Boris Bonev committed
51
    import disco_cuda_extension
52

Boris Bonev's avatar
Boris Bonev committed
53
54
55
56
57
58
    _cuda_extension_available = True
except ImportError as err:
    disco_cuda_extension = None
    _cuda_extension_available = False


59
def _normalize_convolution_tensor_s2(
60
    psi_idx, psi_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=False, basis_norm_mode="none", merge_quadrature=False, eps=1e-9
61
):
Boris Bonev's avatar
Boris Bonev committed
62
    """
63
64
65
66
    Discretely normalizes the convolution tensor and pre-applies quadrature weights. Supports the following three normalization modes:
    - "none": No normalization is applied.
    - "individual": for each output latitude and filter basis function the filter is numerically integrated over the sphere and normalized so that it yields 1.
    - "mean": the norm is computed for each output latitude and then averaged over the output latitudes. Each basis function is then normalized by this mean.
Boris Bonev's avatar
Boris Bonev committed
67
68
    """

69
70
    # reshape the indices implicitly to be ikernel, out_shape[0], in_shape[0], in_shape[1]
    idx = torch.stack([psi_idx[0], psi_idx[1], psi_idx[2] // in_shape[1], psi_idx[2] % in_shape[1]], dim=0)
Boris Bonev's avatar
Boris Bonev committed
71

72
73
    # getting indices for adressing kernels, input and output latitudes
    ikernel = idx[0]
Boris Bonev's avatar
Boris Bonev committed
74
75

    if transpose_normalization:
76
77
78
79
80
        ilat_out = idx[2]
        ilat_in = idx[1]
        # here we are deliberately swapping input and output shapes to handle transpose normalization with the same code
        nlat_out = in_shape[0]
        correction_factor = out_shape[1] / in_shape[1]
Boris Bonev's avatar
Boris Bonev committed
81
    else:
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
        ilat_out = idx[1]
        ilat_in = idx[2]
        nlat_out = out_shape[0]

    # get the quadrature weights
    q = quad_weights[ilat_in].reshape(-1)

    # buffer to store intermediate values
    vnorm = torch.zeros(kernel_size, nlat_out)

    # loop through dimensions to compute the norms
    for ik in range(kernel_size):
        for ilat in range(nlat_out):

            # find indices corresponding to the given output latitude and kernel basis function
            iidx = torch.argwhere((ikernel == ik) & (ilat_out == ilat))

            # compute the 2-norm, accounting for the fact that it is 4-pi normalized
            vnorm[ik, ilat] = torch.sqrt(torch.sum(psi_vals[iidx].abs().pow(2) * q[iidx]) / 4 / torch.pi)

    # loop over values and renormalize
    for ik in range(kernel_size):
        for ilat in range(nlat_out):

            iidx = torch.argwhere((ikernel == ik) & (ilat_out == ilat))

            if basis_norm_mode == "individual":
                val = vnorm[ik, ilat]
            elif basis_norm_mode == "mean":
                val = vnorm[ik, :].mean()
            elif basis_norm_mode == "none":
                val = 1.0
            else:
                raise ValueError(f"Unknown basis normalization mode {basis_norm_mode}.")

            if merge_quadrature:
                psi_vals[iidx] = psi_vals[iidx] * q[iidx] / (val + eps)
            else:
                psi_vals[iidx] = psi_vals[iidx] / (val + eps)


    if transpose_normalization and merge_quadrature:
        psi_vals = psi_vals / correction_factor
Boris Bonev's avatar
Boris Bonev committed
125
126
127
128
129

    return psi_vals


def _precompute_convolution_tensor_s2(
130
131
    in_shape,
    out_shape,
132
    filter_basis,
133
134
135
136
    grid_in="equiangular",
    grid_out="equiangular",
    theta_cutoff=0.01 * math.pi,
    transpose_normalization=False,
137
    basis_norm_mode="none",
138
    merge_quadrature=False,
Boris Bonev's avatar
Boris Bonev committed
139
):
140
141
142
    """
    Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
    Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al.
143
144
145
146
147
    The output tensor has shape kernel_shape x nlat_out x (nlat_in * nlon_in).

    The rotation of the Euler angles uses the YZY convention, which applied to the northpole $(0,0,1)^T$ yields
    $$
    Y(\alpha) Z(\beta) Y(\gamma) n =
Boris Bonev's avatar
Boris Bonev committed
148
        {\begin{bmatrix}
149
150
151
152
153
            \cos(\gamma)\sin(\alpha) + \cos(\alpha)\cos(\beta)\sin(\gamma) \\
            \sin(\beta)\sin(\gamma) \\
            \cos(\alpha)\cos(\gamma)-\cos(\beta)\sin(\alpha)\sin(\gamma)
        \end{bmatrix}}
    $$
154
155
156
157
158
    """

    assert len(in_shape) == 2
    assert len(out_shape) == 2

159
    kernel_size = filter_basis.kernel_size
160
161
162
163

    nlat_in, nlon_in = in_shape
    nlat_out, nlon_out = out_shape

164
    # precompute input and output grids
Boris Bonev's avatar
Boris Bonev committed
165
    lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in)
166
    lats_in = torch.from_numpy(lats_in).float()
Boris Bonev's avatar
Boris Bonev committed
167
    lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out)
168
169
170
    lats_out = torch.from_numpy(lats_out).float()

    # compute the phi differences
171
    # It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
Boris Bonev's avatar
Boris Bonev committed
172
    lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1]
173

174
175
176
177
178
179
    # compute quadrature weights that will be merged into the Psi tensor
    if transpose_normalization:
        quad_weights = 2.0 * torch.pi * torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in
    else:
        quad_weights = 2.0 * torch.pi * torch.from_numpy(win).float().reshape(-1, 1) / nlon_in

180
181
    out_idx = []
    out_vals = []
182
    for t in range(nlat_out):
183
        # the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis
Boris Bonev's avatar
Boris Bonev committed
184
        alpha = -lats_out[t]
185
186
        beta = lons_in
        gamma = lats_in.reshape(-1, 1)
187
188

        # compute cartesian coordinates of the rotated position
189
190
191
        # This uses the YZY convention of Euler angles, where the last angle (alpha) is a passive rotation,
        # and therefore applied with a negative sign
        x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(gamma) * torch.sin(alpha)
192
        y = torch.sin(beta) * torch.sin(gamma)
193
        z = -torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(alpha) * torch.cos(gamma)
Boris Bonev's avatar
Boris Bonev committed
194

195
        # normalization is important to avoid NaNs when arccos and atan are applied
196
        # this can otherwise lead to spurious artifacts in the solution
Boris Bonev's avatar
Boris Bonev committed
197
        norm = torch.sqrt(x * x + y * y + z * z)
198
199
200
201
        x = x / norm
        y = y / norm
        z = z / norm

Boris Bonev's avatar
Boris Bonev committed
202
        # compute spherical coordinates, where phi needs to fall into the [0, 2pi) range
203
        theta = torch.arccos(z)
204
205
        phi = torch.arctan2(y, x)
        phi = torch.where(phi < 0.0, phi + 2 * torch.pi, phi)
206
207

        # find the indices where the rotated position falls into the support of the kernel
208
        iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff)
209
210
211
212
213

        # add the output latitude and reshape such that psi has dimensions kernel_shape x nlat_out x (nlat_in*nlon_in)
        idx = torch.stack([iidx[:, 0], t * torch.ones_like(iidx[:, 0]), iidx[:, 1] * nlon_in + iidx[:, 2]], dim=0)

        # append indices and values to the COO datastructure
214
215
216
217
218
219
        out_idx.append(idx)
        out_vals.append(vals)

    # concatenate the indices and values
    out_idx = torch.cat(out_idx, dim=-1).to(torch.long).contiguous()
    out_vals = torch.cat(out_vals, dim=-1).to(torch.float32).contiguous()
220

221
    out_vals = _normalize_convolution_tensor_s2(
222
223
224
225
226
227
228
229
230
        out_idx,
        out_vals,
        in_shape,
        out_shape,
        kernel_size,
        quad_weights,
        transpose_normalization=transpose_normalization,
        basis_norm_mode=basis_norm_mode,
        merge_quadrature=merge_quadrature,
231
    )
Boris Bonev's avatar
Boris Bonev committed
232

Boris Bonev's avatar
Boris Bonev committed
233
    return out_idx, out_vals
Boris Bonev's avatar
Boris Bonev committed
234
235
236
237


class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta):
    """
238
    Abstract base class for discrete-continuous convolutions
Boris Bonev's avatar
Boris Bonev committed
239
240
241
242
243
244
245
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_shape: Union[int, List[int]],
246
        basis_type: Optional[str] = "piecewise linear",
Boris Bonev's avatar
Boris Bonev committed
247
248
249
250
251
        groups: Optional[int] = 1,
        bias: Optional[bool] = True,
    ):
        super().__init__()

252
        self.kernel_shape = kernel_shape
Boris Bonev's avatar
Boris Bonev committed
253

254
        # get the filter basis functions
255
        self.filter_basis = get_filter_basis(kernel_shape=kernel_shape, basis_type=basis_type)
Boris Bonev's avatar
Boris Bonev committed
256
257
258
259
260
261
262
263
264
265

        # groups
        self.groups = groups

        # weight tensor
        if in_channels % self.groups != 0:
            raise ValueError("Error, the number of input channels has to be an integer multiple of the group size")
        if out_channels % self.groups != 0:
            raise ValueError("Error, the number of output channels has to be an integer multiple of the group size")
        self.groupsize = in_channels // self.groups
Boris Bonev's avatar
Boris Bonev committed
266
        scale = math.sqrt(1.0 / self.groupsize / self.kernel_size)
Boris Bonev's avatar
Boris Bonev committed
267
268
269
270
271
272
273
        self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, self.kernel_size))

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channels))
        else:
            self.bias = None

274
275
276
277
    @property
    def kernel_size(self):
        return self.filter_basis.kernel_size

Boris Bonev's avatar
Boris Bonev committed
278
279
280
281
282
283
    @abc.abstractmethod
    def forward(self, x: torch.Tensor):
        raise NotImplementedError


class DiscreteContinuousConvS2(DiscreteContinuousConv):
284
    """
285
    Discrete-continuous (DISCO) convolutions on the 2-Sphere as described in [1].
286
287
288
289
290
291
292
293
294
295
296

    [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        in_shape: Tuple[int],
        out_shape: Tuple[int],
        kernel_shape: Union[int, List[int]],
297
        basis_type: Optional[str] = "piecewise linear",
298
        basis_norm_mode: Optional[str] = "none",
299
300
301
302
303
304
        groups: Optional[int] = 1,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        bias: Optional[bool] = True,
        theta_cutoff: Optional[float] = None,
    ):
305
        super().__init__(in_channels, out_channels, kernel_shape, basis_type, groups, bias)
306
307
308
309

        self.nlat_in, self.nlon_in = in_shape
        self.nlat_out, self.nlon_out = out_shape

310
311
312
        # make sure the p-shift works by checking that longitudes are divisible
        assert self.nlon_in % self.nlon_out == 0

Boris Bonev's avatar
Boris Bonev committed
313
        # heuristic to compute theta cutoff based on the bandlimit of the input field and overlaps of the basis functions
314
        if theta_cutoff is None:
Boris Bonev's avatar
Boris Bonev committed
315
            theta_cutoff = torch.pi / float(self.nlat_out - 1)
316
317
318
319

        if theta_cutoff <= 0.0:
            raise ValueError("Error, theta_cutoff has to be positive.")

Boris Bonev's avatar
Boris Bonev committed
320
        idx, vals = _precompute_convolution_tensor_s2(
321
322
323
324
325
326
327
328
329
            in_shape,
            out_shape,
            self.filter_basis,
            grid_in=grid_in,
            grid_out=grid_out,
            theta_cutoff=theta_cutoff,
            transpose_normalization=False,
            basis_norm_mode=basis_norm_mode,
            merge_quadrature=True,
Boris Bonev's avatar
Boris Bonev committed
330
331
332
333
334
335
        )

        # sort the values
        ker_idx = idx[0, ...].contiguous()
        row_idx = idx[1, ...].contiguous()
        col_idx = idx[2, ...].contiguous()
336
        vals = vals.contiguous()
Boris Bonev's avatar
Boris Bonev committed
337

338
339
340
341
342
343
        if _cuda_extension_available:
            # preprocessed data-structure for GPU kernel
            roff_idx = preprocess_psi(self.kernel_size, out_shape[0], ker_idx, row_idx, col_idx, vals).contiguous()
            self.register_buffer("psi_roff_idx", roff_idx, persistent=False)

        # save all datastructures
Boris Bonev's avatar
Boris Bonev committed
344
345
346
        self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
        self.register_buffer("psi_row_idx", row_idx, persistent=False)
        self.register_buffer("psi_col_idx", col_idx, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
347
        self.register_buffer("psi_vals", vals, persistent=False)
348

349
350
351
352
    def extra_repr(self):
        r"""
        Pretty print module
        """
353
        return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
354

Boris Bonev's avatar
Boris Bonev committed
355
356
357
358
    @property
    def psi_idx(self):
        return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()

Boris Bonev's avatar
Boris Bonev committed
359
360
361
    def get_psi(self):
        psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in)).coalesce()
        return psi
362

Boris Bonev's avatar
Boris Bonev committed
363
    def forward(self, x: torch.Tensor) -> torch.Tensor:
364

Boris Bonev's avatar
Boris Bonev committed
365
366
367
368
        if x.is_cuda and _cuda_extension_available:
            x = _disco_s2_contraction_cuda(
                x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out
            )
369
        else:
Boris Bonev's avatar
Boris Bonev committed
370
371
372
            if x.is_cuda:
                warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
            psi = self.get_psi()
Boris Bonev's avatar
Boris Bonev committed
373
            x = _disco_s2_contraction_torch(x, psi, self.nlon_out)
374
375
376
377
378
379

        # extract shape
        B, C, K, H, W = x.shape
        x = x.reshape(B, self.groups, self.groupsize, K, H, W)

        # do weight multiplication
Thorsten Kurth's avatar
Thorsten Kurth committed
380
        out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
Boris Bonev's avatar
Boris Bonev committed
381
        out = out.reshape(B, -1, H, W)
382
383
384
385
386
387
388

        if self.bias is not None:
            out = out + self.bias.reshape(1, -1, 1, 1)

        return out


Boris Bonev's avatar
Boris Bonev committed
389
class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
390
    """
391
    Discrete-continuous (DISCO) transpose convolutions on the 2-Sphere as described in [1].
392
393
394
395
396
397
398
399
400
401
402

    [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        in_shape: Tuple[int],
        out_shape: Tuple[int],
        kernel_shape: Union[int, List[int]],
403
        basis_type: Optional[str] = "piecewise linear",
404
        basis_norm_mode: Optional[str] = "none",
405
406
407
408
409
410
        groups: Optional[int] = 1,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        bias: Optional[bool] = True,
        theta_cutoff: Optional[float] = None,
    ):
411
        super().__init__(in_channels, out_channels, kernel_shape, basis_type, groups, bias)
412
413
414
415

        self.nlat_in, self.nlon_in = in_shape
        self.nlat_out, self.nlon_out = out_shape

416
417
418
        # make sure the p-shift works by checking that longitudes are divisible
        assert self.nlon_out % self.nlon_in == 0

419
420
        # bandlimit
        if theta_cutoff is None:
Boris Bonev's avatar
Boris Bonev committed
421
            theta_cutoff = torch.pi / float(self.nlat_in - 1)
422
423
424
425

        if theta_cutoff <= 0.0:
            raise ValueError("Error, theta_cutoff has to be positive.")

426
        # switch in_shape and out_shape since we want the transpose convolution
Boris Bonev's avatar
Boris Bonev committed
427
        idx, vals = _precompute_convolution_tensor_s2(
428
429
430
431
432
433
434
435
436
            out_shape,
            in_shape,
            self.filter_basis,
            grid_in=grid_out,
            grid_out=grid_in,
            theta_cutoff=theta_cutoff,
            transpose_normalization=True,
            basis_norm_mode=basis_norm_mode,
            merge_quadrature=True,
Boris Bonev's avatar
Boris Bonev committed
437
438
439
440
441
442
        )

        # sort the values
        ker_idx = idx[0, ...].contiguous()
        row_idx = idx[1, ...].contiguous()
        col_idx = idx[2, ...].contiguous()
443
        vals = vals.contiguous()
Boris Bonev's avatar
Boris Bonev committed
444

445
446
447
448
449
450
        if _cuda_extension_available:
            # preprocessed data-structure for GPU kernel
            roff_idx = preprocess_psi(self.kernel_size, in_shape[0], ker_idx, row_idx, col_idx, vals).contiguous()
            self.register_buffer("psi_roff_idx", roff_idx, persistent=False)

        # save all datastructures
Boris Bonev's avatar
Boris Bonev committed
451
452
453
        self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
        self.register_buffer("psi_row_idx", row_idx, persistent=False)
        self.register_buffer("psi_col_idx", col_idx, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
454
        self.register_buffer("psi_vals", vals, persistent=False)
455

456
457
458
459
    def extra_repr(self):
        r"""
        Pretty print module
        """
460
        return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
461

Boris Bonev's avatar
Boris Bonev committed
462
463
464
465
466
467
    @property
    def psi_idx(self):
        return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()

    def get_psi(self, semi_transposed: bool = False):
        if semi_transposed:
468
469
470
471
472
473
            # we do a semi-transposition to faciliate the computation
            tout = self.psi_idx[2] // self.nlon_out
            pout = self.psi_idx[2] % self.nlon_out
            # flip the axis of longitudes
            pout = self.nlon_out - 1 - pout
            tin = self.psi_idx[1]
Boris Bonev's avatar
Boris Bonev committed
474
            idx = torch.stack([self.psi_idx[0], tout, tin * self.nlon_out + pout], dim=0)
475
476
477
            psi = torch.sparse_coo_tensor(idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_out)).coalesce()
        else:
            psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out)).coalesce()
Boris Bonev's avatar
Boris Bonev committed
478

Boris Bonev's avatar
Boris Bonev committed
479
        return psi
480

Boris Bonev's avatar
Boris Bonev committed
481
    def forward(self, x: torch.Tensor) -> torch.Tensor:
482
        # extract shape
Boris Bonev's avatar
Boris Bonev committed
483
        B, C, H, W = x.shape
484
485
486
        x = x.reshape(B, self.groups, self.groupsize, H, W)

        # do weight multiplication
Thorsten Kurth's avatar
Thorsten Kurth committed
487
        x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
Boris Bonev's avatar
Boris Bonev committed
488
        x = x.reshape(B, -1, x.shape[-3], H, W)
489

Boris Bonev's avatar
Boris Bonev committed
490
491
492
493
        if x.is_cuda and _cuda_extension_available:
            out = _disco_s2_transpose_contraction_cuda(
                x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out
            )
494
        else:
Boris Bonev's avatar
Boris Bonev committed
495
496
497
            if x.is_cuda:
                warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
            psi = self.get_psi(semi_transposed=True)
Boris Bonev's avatar
Boris Bonev committed
498
            out = _disco_s2_transpose_contraction_torch(x, psi, self.nlon_out)
499
500
501
502
503

        if self.bias is not None:
            out = out + self.bias.reshape(1, -1, 1, 1)

        return out