"torchvision/vscode:/vscode.git/clone" did not exist on "22385bc66c0a98fa78ce94e858a9d79aeb9a885a"
convolution.py 20.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

Boris Bonev's avatar
Boris Bonev committed
32
import abc
33
from typing import List, Tuple, Union, Optional
Boris Bonev's avatar
Boris Bonev committed
34
from warnings import warn
35
36
37
38
39
40
41
42

import math

import torch
import torch.nn as nn

from functools import partial

Boris Bonev's avatar
Boris Bonev committed
43
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes
Boris Bonev's avatar
Boris Bonev committed
44
45
from torch_harmonics._disco_convolution import _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch
from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda
46
from torch_harmonics.filter_basis import get_filter_basis
47

48
# import custom C++/CUDA extensions if available
Boris Bonev's avatar
Boris Bonev committed
49
try:
50
    from disco_helpers import preprocess_psi
Boris Bonev's avatar
Boris Bonev committed
51
    import disco_cuda_extension
52

Boris Bonev's avatar
Boris Bonev committed
53
54
55
56
57
58
    _cuda_extension_available = True
except ImportError as err:
    disco_cuda_extension = None
    _cuda_extension_available = False


59
def _normalize_convolution_tensor_s2(
60
    psi_idx, psi_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=False, basis_norm_mode="mean", merge_quadrature=False, eps=1e-9
61
):
Boris Bonev's avatar
Boris Bonev committed
62
    """
63
64
65
66
    Discretely normalizes the convolution tensor and pre-applies quadrature weights. Supports the following three normalization modes:
    - "none": No normalization is applied.
    - "individual": for each output latitude and filter basis function the filter is numerically integrated over the sphere and normalized so that it yields 1.
    - "mean": the norm is computed for each output latitude and then averaged over the output latitudes. Each basis function is then normalized by this mean.
Boris Bonev's avatar
Boris Bonev committed
67
68
    """

69
70
    # reshape the indices implicitly to be ikernel, out_shape[0], in_shape[0], in_shape[1]
    idx = torch.stack([psi_idx[0], psi_idx[1], psi_idx[2] // in_shape[1], psi_idx[2] % in_shape[1]], dim=0)
Boris Bonev's avatar
Boris Bonev committed
71

72
73
    # getting indices for adressing kernels, input and output latitudes
    ikernel = idx[0]
Boris Bonev's avatar
Boris Bonev committed
74
75

    if transpose_normalization:
76
77
78
79
80
        ilat_out = idx[2]
        ilat_in = idx[1]
        # here we are deliberately swapping input and output shapes to handle transpose normalization with the same code
        nlat_out = in_shape[0]
        correction_factor = out_shape[1] / in_shape[1]
Boris Bonev's avatar
Boris Bonev committed
81
    else:
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
        ilat_out = idx[1]
        ilat_in = idx[2]
        nlat_out = out_shape[0]

    # get the quadrature weights
    q = quad_weights[ilat_in].reshape(-1)

    # buffer to store intermediate values
    vnorm = torch.zeros(kernel_size, nlat_out)

    # loop through dimensions to compute the norms
    for ik in range(kernel_size):
        for ilat in range(nlat_out):

            # find indices corresponding to the given output latitude and kernel basis function
            iidx = torch.argwhere((ikernel == ik) & (ilat_out == ilat))

99
100
101
            # compute the 1-norm
            # vnorm[ik, ilat] = torch.sqrt(torch.sum(psi_vals[iidx].abs().pow(2) * q[iidx]))
            vnorm[ik, ilat] = torch.sum(psi_vals[iidx].abs() * q[iidx])
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117

    # loop over values and renormalize
    for ik in range(kernel_size):
        for ilat in range(nlat_out):

            iidx = torch.argwhere((ikernel == ik) & (ilat_out == ilat))

            if basis_norm_mode == "individual":
                val = vnorm[ik, ilat]
            elif basis_norm_mode == "mean":
                val = vnorm[ik, :].mean()
            elif basis_norm_mode == "none":
                val = 1.0
            else:
                raise ValueError(f"Unknown basis normalization mode {basis_norm_mode}.")

118
119
            psi_vals[iidx] = psi_vals[iidx] / (val + eps)

120
            if merge_quadrature:
121
                psi_vals[iidx] = psi_vals[iidx] * q[iidx]
122
123
124
125


    if transpose_normalization and merge_quadrature:
        psi_vals = psi_vals / correction_factor
Boris Bonev's avatar
Boris Bonev committed
126
127
128
129
130

    return psi_vals


def _precompute_convolution_tensor_s2(
131
132
    in_shape,
    out_shape,
133
    filter_basis,
134
135
136
    grid_in="equiangular",
    grid_out="equiangular",
    theta_cutoff=0.01 * math.pi,
137
    theta_eps = 1e-3,
138
    transpose_normalization=False,
139
    basis_norm_mode="mean",
140
    merge_quadrature=False,
Boris Bonev's avatar
Boris Bonev committed
141
):
142
143
144
    """
    Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
    Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al.
145
146
147
148
149
    The output tensor has shape kernel_shape x nlat_out x (nlat_in * nlon_in).

    The rotation of the Euler angles uses the YZY convention, which applied to the northpole $(0,0,1)^T$ yields
    $$
    Y(\alpha) Z(\beta) Y(\gamma) n =
Boris Bonev's avatar
Boris Bonev committed
150
        {\begin{bmatrix}
151
152
153
154
155
            \cos(\gamma)\sin(\alpha) + \cos(\alpha)\cos(\beta)\sin(\gamma) \\
            \sin(\beta)\sin(\gamma) \\
            \cos(\alpha)\cos(\gamma)-\cos(\beta)\sin(\alpha)\sin(\gamma)
        \end{bmatrix}}
    $$
156
157
158
159
160
    """

    assert len(in_shape) == 2
    assert len(out_shape) == 2

161
    kernel_size = filter_basis.kernel_size
162
163
164
165

    nlat_in, nlon_in = in_shape
    nlat_out, nlon_out = out_shape

166
    # precompute input and output grids
Boris Bonev's avatar
Boris Bonev committed
167
    lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in)
168
    lats_in = torch.from_numpy(lats_in)
Boris Bonev's avatar
Boris Bonev committed
169
    lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out)
170
    lats_out = torch.from_numpy(lats_out)
171
172

    # compute the phi differences
173
    # It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
174
    lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1, dtype=torch.float64)[:-1]
175

176
177
    # compute quadrature weights and merge them into the convolution tensor.
    # These quadrature integrate to 1 over the sphere.
178
    if transpose_normalization:
179
        quad_weights = torch.from_numpy(wout).reshape(-1, 1) / nlon_in / 2.0
180
    else:
181
182
183
184
        quad_weights = torch.from_numpy(win).reshape(-1, 1) / nlon_in / 2.0

    # effective theta cutoff if multiplied with a fudge factor to avoid aliasing with grid width (especially near poles)
    theta_cutoff_eff = (1.0 + theta_eps) * theta_cutoff
185

186
187
    out_idx = []
    out_vals = []
188
    for t in range(nlat_out):
189
        # the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis
Boris Bonev's avatar
Boris Bonev committed
190
        alpha = -lats_out[t]
191
192
        beta = lons_in
        gamma = lats_in.reshape(-1, 1)
193
194

        # compute cartesian coordinates of the rotated position
195
196
197
        # This uses the YZY convention of Euler angles, where the last angle (alpha) is a passive rotation,
        # and therefore applied with a negative sign
        x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(gamma) * torch.sin(alpha)
198
        y = torch.sin(beta) * torch.sin(gamma)
199
        z = -torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(alpha) * torch.cos(gamma)
Boris Bonev's avatar
Boris Bonev committed
200

201
        # normalization is important to avoid NaNs when arccos and atan are applied
202
        # this can otherwise lead to spurious artifacts in the solution
Boris Bonev's avatar
Boris Bonev committed
203
        norm = torch.sqrt(x * x + y * y + z * z)
204
205
206
207
        x = x / norm
        y = y / norm
        z = z / norm

Boris Bonev's avatar
Boris Bonev committed
208
        # compute spherical coordinates, where phi needs to fall into the [0, 2pi) range
209
        theta = torch.arccos(z)
210
211
        phi = torch.arctan2(y, x)
        phi = torch.where(phi < 0.0, phi + 2 * torch.pi, phi)
212
213

        # find the indices where the rotated position falls into the support of the kernel
214
        iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff_eff)
215
216
217
218
219

        # add the output latitude and reshape such that psi has dimensions kernel_shape x nlat_out x (nlat_in*nlon_in)
        idx = torch.stack([iidx[:, 0], t * torch.ones_like(iidx[:, 0]), iidx[:, 1] * nlon_in + iidx[:, 2]], dim=0)

        # append indices and values to the COO datastructure
220
221
222
223
        out_idx.append(idx)
        out_vals.append(vals)

    # concatenate the indices and values
224
225
    out_idx = torch.cat(out_idx, dim=-1)
    out_vals = torch.cat(out_vals, dim=-1)
226

227
    out_vals = _normalize_convolution_tensor_s2(
228
229
230
231
232
233
234
235
236
        out_idx,
        out_vals,
        in_shape,
        out_shape,
        kernel_size,
        quad_weights,
        transpose_normalization=transpose_normalization,
        basis_norm_mode=basis_norm_mode,
        merge_quadrature=merge_quadrature,
237
    )
Boris Bonev's avatar
Boris Bonev committed
238

239
240
241
    out_idx = out_idx.contiguous()
    out_vals = out_vals.to(dtype=torch.float32).contiguous()

Boris Bonev's avatar
Boris Bonev committed
242
    return out_idx, out_vals
Boris Bonev's avatar
Boris Bonev committed
243
244
245
246


class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta):
    """
247
    Abstract base class for discrete-continuous convolutions
Boris Bonev's avatar
Boris Bonev committed
248
249
250
251
252
253
254
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_shape: Union[int, List[int]],
255
        basis_type: Optional[str] = "piecewise linear",
Boris Bonev's avatar
Boris Bonev committed
256
257
258
259
260
        groups: Optional[int] = 1,
        bias: Optional[bool] = True,
    ):
        super().__init__()

261
        self.kernel_shape = kernel_shape
Boris Bonev's avatar
Boris Bonev committed
262

263
        # get the filter basis functions
264
        self.filter_basis = get_filter_basis(kernel_shape=kernel_shape, basis_type=basis_type)
Boris Bonev's avatar
Boris Bonev committed
265
266
267
268
269
270
271
272
273
274

        # groups
        self.groups = groups

        # weight tensor
        if in_channels % self.groups != 0:
            raise ValueError("Error, the number of input channels has to be an integer multiple of the group size")
        if out_channels % self.groups != 0:
            raise ValueError("Error, the number of output channels has to be an integer multiple of the group size")
        self.groupsize = in_channels // self.groups
Boris Bonev's avatar
Boris Bonev committed
275
        scale = math.sqrt(1.0 / self.groupsize / self.kernel_size)
Boris Bonev's avatar
Boris Bonev committed
276
277
278
279
280
281
282
        self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, self.kernel_size))

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channels))
        else:
            self.bias = None

283
284
285
286
    @property
    def kernel_size(self):
        return self.filter_basis.kernel_size

Boris Bonev's avatar
Boris Bonev committed
287
288
289
290
291
292
    @abc.abstractmethod
    def forward(self, x: torch.Tensor):
        raise NotImplementedError


class DiscreteContinuousConvS2(DiscreteContinuousConv):
293
    """
294
    Discrete-continuous (DISCO) convolutions on the 2-Sphere as described in [1].
295
296
297
298
299
300
301
302
303
304
305

    [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        in_shape: Tuple[int],
        out_shape: Tuple[int],
        kernel_shape: Union[int, List[int]],
306
        basis_type: Optional[str] = "piecewise linear",
307
        basis_norm_mode: Optional[str] = "mean",
308
309
310
311
312
313
        groups: Optional[int] = 1,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        bias: Optional[bool] = True,
        theta_cutoff: Optional[float] = None,
    ):
314
        super().__init__(in_channels, out_channels, kernel_shape, basis_type, groups, bias)
315
316
317
318

        self.nlat_in, self.nlon_in = in_shape
        self.nlat_out, self.nlon_out = out_shape

319
320
321
        # make sure the p-shift works by checking that longitudes are divisible
        assert self.nlon_in % self.nlon_out == 0

Boris Bonev's avatar
Boris Bonev committed
322
        # heuristic to compute theta cutoff based on the bandlimit of the input field and overlaps of the basis functions
323
        if theta_cutoff is None:
Boris Bonev's avatar
Boris Bonev committed
324
            theta_cutoff = torch.pi / float(self.nlat_out - 1)
325
326
327
328

        if theta_cutoff <= 0.0:
            raise ValueError("Error, theta_cutoff has to be positive.")

Boris Bonev's avatar
Boris Bonev committed
329
        idx, vals = _precompute_convolution_tensor_s2(
330
331
332
333
334
335
336
337
338
            in_shape,
            out_shape,
            self.filter_basis,
            grid_in=grid_in,
            grid_out=grid_out,
            theta_cutoff=theta_cutoff,
            transpose_normalization=False,
            basis_norm_mode=basis_norm_mode,
            merge_quadrature=True,
Boris Bonev's avatar
Boris Bonev committed
339
340
341
342
343
344
        )

        # sort the values
        ker_idx = idx[0, ...].contiguous()
        row_idx = idx[1, ...].contiguous()
        col_idx = idx[2, ...].contiguous()
345
        vals = vals.contiguous()
Boris Bonev's avatar
Boris Bonev committed
346

347
348
349
350
351
352
        if _cuda_extension_available:
            # preprocessed data-structure for GPU kernel
            roff_idx = preprocess_psi(self.kernel_size, out_shape[0], ker_idx, row_idx, col_idx, vals).contiguous()
            self.register_buffer("psi_roff_idx", roff_idx, persistent=False)

        # save all datastructures
Boris Bonev's avatar
Boris Bonev committed
353
354
355
        self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
        self.register_buffer("psi_row_idx", row_idx, persistent=False)
        self.register_buffer("psi_col_idx", col_idx, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
356
        self.register_buffer("psi_vals", vals, persistent=False)
357

358
359
360
361
    def extra_repr(self):
        r"""
        Pretty print module
        """
362
        return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
363

Boris Bonev's avatar
Boris Bonev committed
364
365
366
367
    @property
    def psi_idx(self):
        return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()

Boris Bonev's avatar
Boris Bonev committed
368
369
370
    def get_psi(self):
        psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in)).coalesce()
        return psi
371

Boris Bonev's avatar
Boris Bonev committed
372
    def forward(self, x: torch.Tensor) -> torch.Tensor:
373

Boris Bonev's avatar
Boris Bonev committed
374
375
376
377
        if x.is_cuda and _cuda_extension_available:
            x = _disco_s2_contraction_cuda(
                x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out
            )
378
        else:
Boris Bonev's avatar
Boris Bonev committed
379
380
381
            if x.is_cuda:
                warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
            psi = self.get_psi()
Boris Bonev's avatar
Boris Bonev committed
382
            x = _disco_s2_contraction_torch(x, psi, self.nlon_out)
383
384
385
386
387
388

        # extract shape
        B, C, K, H, W = x.shape
        x = x.reshape(B, self.groups, self.groupsize, K, H, W)

        # do weight multiplication
Thorsten Kurth's avatar
Thorsten Kurth committed
389
        out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
Boris Bonev's avatar
Boris Bonev committed
390
        out = out.reshape(B, -1, H, W)
391
392
393
394
395
396
397

        if self.bias is not None:
            out = out + self.bias.reshape(1, -1, 1, 1)

        return out


Boris Bonev's avatar
Boris Bonev committed
398
class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
399
    """
400
    Discrete-continuous (DISCO) transpose convolutions on the 2-Sphere as described in [1].
401
402
403
404
405
406
407
408
409
410
411

    [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        in_shape: Tuple[int],
        out_shape: Tuple[int],
        kernel_shape: Union[int, List[int]],
412
        basis_type: Optional[str] = "piecewise linear",
413
        basis_norm_mode: Optional[str] = "mean",
414
415
416
417
418
419
        groups: Optional[int] = 1,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        bias: Optional[bool] = True,
        theta_cutoff: Optional[float] = None,
    ):
420
        super().__init__(in_channels, out_channels, kernel_shape, basis_type, groups, bias)
421
422
423
424

        self.nlat_in, self.nlon_in = in_shape
        self.nlat_out, self.nlon_out = out_shape

425
426
427
        # make sure the p-shift works by checking that longitudes are divisible
        assert self.nlon_out % self.nlon_in == 0

428
429
        # bandlimit
        if theta_cutoff is None:
Boris Bonev's avatar
Boris Bonev committed
430
            theta_cutoff = torch.pi / float(self.nlat_in - 1)
431
432
433
434

        if theta_cutoff <= 0.0:
            raise ValueError("Error, theta_cutoff has to be positive.")

435
        # switch in_shape and out_shape since we want the transpose convolution
Boris Bonev's avatar
Boris Bonev committed
436
        idx, vals = _precompute_convolution_tensor_s2(
437
438
439
440
441
442
443
444
445
            out_shape,
            in_shape,
            self.filter_basis,
            grid_in=grid_out,
            grid_out=grid_in,
            theta_cutoff=theta_cutoff,
            transpose_normalization=True,
            basis_norm_mode=basis_norm_mode,
            merge_quadrature=True,
Boris Bonev's avatar
Boris Bonev committed
446
447
448
449
450
451
        )

        # sort the values
        ker_idx = idx[0, ...].contiguous()
        row_idx = idx[1, ...].contiguous()
        col_idx = idx[2, ...].contiguous()
452
        vals = vals.contiguous()
Boris Bonev's avatar
Boris Bonev committed
453

454
455
456
457
458
459
        if _cuda_extension_available:
            # preprocessed data-structure for GPU kernel
            roff_idx = preprocess_psi(self.kernel_size, in_shape[0], ker_idx, row_idx, col_idx, vals).contiguous()
            self.register_buffer("psi_roff_idx", roff_idx, persistent=False)

        # save all datastructures
Boris Bonev's avatar
Boris Bonev committed
460
461
462
        self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
        self.register_buffer("psi_row_idx", row_idx, persistent=False)
        self.register_buffer("psi_col_idx", col_idx, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
463
        self.register_buffer("psi_vals", vals, persistent=False)
464

465
466
467
468
    def extra_repr(self):
        r"""
        Pretty print module
        """
469
        return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
470

Boris Bonev's avatar
Boris Bonev committed
471
472
473
474
475
476
    @property
    def psi_idx(self):
        return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()

    def get_psi(self, semi_transposed: bool = False):
        if semi_transposed:
477
478
479
480
481
482
            # we do a semi-transposition to faciliate the computation
            tout = self.psi_idx[2] // self.nlon_out
            pout = self.psi_idx[2] % self.nlon_out
            # flip the axis of longitudes
            pout = self.nlon_out - 1 - pout
            tin = self.psi_idx[1]
Boris Bonev's avatar
Boris Bonev committed
483
            idx = torch.stack([self.psi_idx[0], tout, tin * self.nlon_out + pout], dim=0)
484
485
486
            psi = torch.sparse_coo_tensor(idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_out)).coalesce()
        else:
            psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out)).coalesce()
Boris Bonev's avatar
Boris Bonev committed
487

Boris Bonev's avatar
Boris Bonev committed
488
        return psi
489

Boris Bonev's avatar
Boris Bonev committed
490
    def forward(self, x: torch.Tensor) -> torch.Tensor:
491
        # extract shape
Boris Bonev's avatar
Boris Bonev committed
492
        B, C, H, W = x.shape
493
494
495
        x = x.reshape(B, self.groups, self.groupsize, H, W)

        # do weight multiplication
Thorsten Kurth's avatar
Thorsten Kurth committed
496
        x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
Boris Bonev's avatar
Boris Bonev committed
497
        x = x.reshape(B, -1, x.shape[-3], H, W)
498

Boris Bonev's avatar
Boris Bonev committed
499
500
501
502
        if x.is_cuda and _cuda_extension_available:
            out = _disco_s2_transpose_contraction_cuda(
                x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out
            )
503
        else:
Boris Bonev's avatar
Boris Bonev committed
504
505
506
            if x.is_cuda:
                warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
            psi = self.get_psi(semi_transposed=True)
Boris Bonev's avatar
Boris Bonev committed
507
            out = _disco_s2_transpose_contraction_torch(x, psi, self.nlon_out)
508
509
510
511
512

        if self.bias is not None:
            out = out + self.bias.reshape(1, -1, 1, 1)

        return out