convolution.py 20 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

Boris Bonev's avatar
Boris Bonev committed
32
import abc
33
from typing import List, Tuple, Union, Optional
Boris Bonev's avatar
Boris Bonev committed
34
from warnings import warn
35
36
37
38
39
40
41
42

import math

import torch
import torch.nn as nn

from functools import partial

Boris Bonev's avatar
Boris Bonev committed
43
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes
Boris Bonev's avatar
Boris Bonev committed
44
45
from torch_harmonics._disco_convolution import _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch
from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda
46
from torch_harmonics.filter_basis import get_filter_basis
47

48
# import custom C++/CUDA extensions if available
Boris Bonev's avatar
Boris Bonev committed
49
try:
50
    from disco_helpers import preprocess_psi
Boris Bonev's avatar
Boris Bonev committed
51
    import disco_cuda_extension
52

Boris Bonev's avatar
Boris Bonev committed
53
54
55
56
57
58
    _cuda_extension_available = True
except ImportError as err:
    disco_cuda_extension = None
    _cuda_extension_available = False


59
60
61
def _normalize_convolution_tensor_s2(
    psi_idx, psi_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=False, basis_norm_mode="sum", merge_quadrature=False, eps=1e-9
):
Boris Bonev's avatar
Boris Bonev committed
62
    """
63
    Discretely normalizes the convolution tensor. Supports different normalization modes
Boris Bonev's avatar
Boris Bonev committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    """

    nlat_in, nlon_in = in_shape
    nlat_out, nlon_out = out_shape

    # reshape the indices implicitly to be ikernel, lat_out, lat_in, lon_in
    idx = torch.stack([psi_idx[0], psi_idx[1], psi_idx[2] // nlon_in, psi_idx[2] % nlon_in], dim=0)

    if transpose_normalization:
        # pre-compute the quadrature weights
        q = quad_weights[idx[1]].reshape(-1)

        # loop through dimensions which require normalization
        for ik in range(kernel_size):
            for ilat in range(nlat_in):
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93

                # get relevant entries depending on the normalization mode
                if basis_norm_mode == "individual":

                    iidx = torch.argwhere((idx[0] == ik) & (idx[2] == ilat))
                    # normalize, while summing also over the input longitude dimension here as this is not available for the output
                    vnorm = torch.sum(psi_vals[iidx].abs() * q[iidx])
                elif basis_norm_mode == "sum":
                    # this will perform repeated normalization in the kernel dimension but this shouldn't lead to issues
                    iidx = torch.argwhere(idx[2] == ilat)
                    # normalize, while summing also over the input longitude dimension here as this is not available for the output
                    vnorm = torch.sum(psi_vals[iidx].abs() * q[iidx])
                else:
                    raise ValueError(f"Unknown basis normalization mode {basis_norm_mode}.")

94
95
96
97
98
                if merge_quadrature:
                    # the correction factor accounts for the difference in longitudinal grid points when the input vector is upscaled
                    psi_vals[iidx] = psi_vals[iidx] * q[iidx] * nlon_in / nlon_out / (vnorm + eps)
                else:
                    psi_vals[iidx] = psi_vals[iidx] / (vnorm + eps)
Boris Bonev's avatar
Boris Bonev committed
99
100
101
102
103
104
105
    else:
        # pre-compute the quadrature weights
        q = quad_weights[idx[2]].reshape(-1)

        # loop through dimensions which require normalization
        for ik in range(kernel_size):
            for ilat in range(nlat_out):
106
107
108
109
110
111
112
113
114
115
116
117
118
119

                # get relevant entries depending on the normalization mode
                if basis_norm_mode == "individual":
                    iidx = torch.argwhere((idx[0] == ik) & (idx[1] == ilat))
                    # normalize
                    vnorm = torch.sum(psi_vals[iidx].abs() * q[iidx])
                elif basis_norm_mode == "sum":
                    # this will perform repeated normalization in the kernel dimension but this shouldn't lead to issues
                    iidx = torch.argwhere(idx[1] == ilat)
                    # normalize
                    vnorm = torch.sum(psi_vals[iidx].abs() * q[iidx])
                else:
                    raise ValueError(f"Unknown basis normalization mode {basis_norm_mode}.")

120
121
122
123
                if merge_quadrature:
                    psi_vals[iidx] = psi_vals[iidx] * q[iidx] / (vnorm + eps)
                else:
                    psi_vals[iidx] = psi_vals[iidx] / (vnorm + eps)
Boris Bonev's avatar
Boris Bonev committed
124
125
126
127
128

    return psi_vals


def _precompute_convolution_tensor_s2(
129
130
    in_shape,
    out_shape,
131
    filter_basis,
132
133
134
135
    grid_in="equiangular",
    grid_out="equiangular",
    theta_cutoff=0.01 * math.pi,
    transpose_normalization=False,
136
    basis_norm_mode="sum",
137
    merge_quadrature=False,
Boris Bonev's avatar
Boris Bonev committed
138
):
139
140
141
    """
    Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
    Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al.
142
143
144
145
146
    The output tensor has shape kernel_shape x nlat_out x (nlat_in * nlon_in).

    The rotation of the Euler angles uses the YZY convention, which applied to the northpole $(0,0,1)^T$ yields
    $$
    Y(\alpha) Z(\beta) Y(\gamma) n =
Boris Bonev's avatar
Boris Bonev committed
147
        {\begin{bmatrix}
148
149
150
151
152
            \cos(\gamma)\sin(\alpha) + \cos(\alpha)\cos(\beta)\sin(\gamma) \\
            \sin(\beta)\sin(\gamma) \\
            \cos(\alpha)\cos(\gamma)-\cos(\beta)\sin(\alpha)\sin(\gamma)
        \end{bmatrix}}
    $$
153
154
155
156
157
    """

    assert len(in_shape) == 2
    assert len(out_shape) == 2

158
    kernel_size = filter_basis.kernel_size
159
160
161
162

    nlat_in, nlon_in = in_shape
    nlat_out, nlon_out = out_shape

163
    # precompute input and output grids
Boris Bonev's avatar
Boris Bonev committed
164
    lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in)
165
    lats_in = torch.from_numpy(lats_in).float()
Boris Bonev's avatar
Boris Bonev committed
166
    lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out)
167
168
169
    lats_out = torch.from_numpy(lats_out).float()

    # compute the phi differences
170
    # It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
Boris Bonev's avatar
Boris Bonev committed
171
    lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1]
172

173
174
175
176
177
178
    # compute quadrature weights that will be merged into the Psi tensor
    if transpose_normalization:
        quad_weights = 2.0 * torch.pi * torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in
    else:
        quad_weights = 2.0 * torch.pi * torch.from_numpy(win).float().reshape(-1, 1) / nlon_in

179
180
    out_idx = []
    out_vals = []
181
    for t in range(nlat_out):
182
        # the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis
Boris Bonev's avatar
Boris Bonev committed
183
        alpha = -lats_out[t]
184
185
        beta = lons_in
        gamma = lats_in.reshape(-1, 1)
186
187

        # compute cartesian coordinates of the rotated position
188
189
        # This uses the YZY convention of Euler angles, where the last angle (alpha) is a passive rotation,
        # and therefore applied with a negative sign
Boris Bonev's avatar
Boris Bonev committed
190
        z = -torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(alpha) * torch.cos(gamma)
191
        x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(gamma) * torch.sin(alpha)
192
        y = torch.sin(beta) * torch.sin(gamma)
Boris Bonev's avatar
Boris Bonev committed
193

194
195
        # normalization is emportant to avoid NaNs when arccos and atan are applied
        # this can otherwise lead to spurious artifacts in the solution
Boris Bonev's avatar
Boris Bonev committed
196
        norm = torch.sqrt(x * x + y * y + z * z)
197
198
199
200
        x = x / norm
        y = y / norm
        z = z / norm

Boris Bonev's avatar
Boris Bonev committed
201
        # compute spherical coordinates, where phi needs to fall into the [0, 2pi) range
202
        theta = torch.arccos(z)
Boris Bonev's avatar
Boris Bonev committed
203
        phi = torch.arctan2(y, x) + torch.pi
204
205

        # find the indices where the rotated position falls into the support of the kernel
206
        iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff)
207
208
209
210
211

        # add the output latitude and reshape such that psi has dimensions kernel_shape x nlat_out x (nlat_in*nlon_in)
        idx = torch.stack([iidx[:, 0], t * torch.ones_like(iidx[:, 0]), iidx[:, 1] * nlon_in + iidx[:, 2]], dim=0)

        # append indices and values to the COO datastructure
212
213
214
215
216
217
        out_idx.append(idx)
        out_vals.append(vals)

    # concatenate the indices and values
    out_idx = torch.cat(out_idx, dim=-1).to(torch.long).contiguous()
    out_vals = torch.cat(out_vals, dim=-1).to(torch.float32).contiguous()
218

219
    out_vals = _normalize_convolution_tensor_s2(
220
221
222
223
224
225
226
227
228
        out_idx,
        out_vals,
        in_shape,
        out_shape,
        kernel_size,
        quad_weights,
        transpose_normalization=transpose_normalization,
        basis_norm_mode=basis_norm_mode,
        merge_quadrature=merge_quadrature,
229
    )
Boris Bonev's avatar
Boris Bonev committed
230

Boris Bonev's avatar
Boris Bonev committed
231
    return out_idx, out_vals
Boris Bonev's avatar
Boris Bonev committed
232
233
234
235


class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta):
    """
236
    Abstract base class for discrete-continuous convolutions
Boris Bonev's avatar
Boris Bonev committed
237
238
239
240
241
242
243
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_shape: Union[int, List[int]],
244
        basis_type: Optional[str] = "piecewise linear",
Boris Bonev's avatar
Boris Bonev committed
245
246
247
248
249
        groups: Optional[int] = 1,
        bias: Optional[bool] = True,
    ):
        super().__init__()

250
        self.kernel_shape = kernel_shape
Boris Bonev's avatar
Boris Bonev committed
251

252
        # get the filter basis functions
253
        self.filter_basis = get_filter_basis(kernel_shape=kernel_shape, basis_type=basis_type)
Boris Bonev's avatar
Boris Bonev committed
254
255
256
257
258
259
260
261
262
263

        # groups
        self.groups = groups

        # weight tensor
        if in_channels % self.groups != 0:
            raise ValueError("Error, the number of input channels has to be an integer multiple of the group size")
        if out_channels % self.groups != 0:
            raise ValueError("Error, the number of output channels has to be an integer multiple of the group size")
        self.groupsize = in_channels // self.groups
Boris Bonev's avatar
Boris Bonev committed
264
        scale = math.sqrt(1.0 / self.groupsize / self.kernel_size)
Boris Bonev's avatar
Boris Bonev committed
265
266
267
268
269
270
271
        self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, self.kernel_size))

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channels))
        else:
            self.bias = None

272
273
274
275
    @property
    def kernel_size(self):
        return self.filter_basis.kernel_size

Boris Bonev's avatar
Boris Bonev committed
276
277
278
279
280
281
    @abc.abstractmethod
    def forward(self, x: torch.Tensor):
        raise NotImplementedError


class DiscreteContinuousConvS2(DiscreteContinuousConv):
282
    """
283
    Discrete-continuous (DISCO) convolutions on the 2-Sphere as described in [1].
284
285
286
287
288
289
290
291
292
293
294

    [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        in_shape: Tuple[int],
        out_shape: Tuple[int],
        kernel_shape: Union[int, List[int]],
295
        basis_type: Optional[str] = "piecewise linear",
296
        basis_norm_mode: Optional[str] = "sum",
297
298
299
300
301
302
        groups: Optional[int] = 1,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        bias: Optional[bool] = True,
        theta_cutoff: Optional[float] = None,
    ):
303
        super().__init__(in_channels, out_channels, kernel_shape, basis_type, groups, bias)
304
305
306
307

        self.nlat_in, self.nlon_in = in_shape
        self.nlat_out, self.nlon_out = out_shape

Boris Bonev's avatar
Boris Bonev committed
308
        # heuristic to compute theta cutoff based on the bandlimit of the input field and overlaps of the basis functions
309
        if theta_cutoff is None:
Boris Bonev's avatar
Boris Bonev committed
310
            theta_cutoff = torch.pi / float(self.nlat_out - 1)
311
312
313
314

        if theta_cutoff <= 0.0:
            raise ValueError("Error, theta_cutoff has to be positive.")

Boris Bonev's avatar
Boris Bonev committed
315
        idx, vals = _precompute_convolution_tensor_s2(
316
317
318
319
320
321
322
323
324
            in_shape,
            out_shape,
            self.filter_basis,
            grid_in=grid_in,
            grid_out=grid_out,
            theta_cutoff=theta_cutoff,
            transpose_normalization=False,
            basis_norm_mode=basis_norm_mode,
            merge_quadrature=True,
Boris Bonev's avatar
Boris Bonev committed
325
326
327
328
329
330
        )

        # sort the values
        ker_idx = idx[0, ...].contiguous()
        row_idx = idx[1, ...].contiguous()
        col_idx = idx[2, ...].contiguous()
331
        vals = vals.contiguous()
Boris Bonev's avatar
Boris Bonev committed
332

333
334
335
336
337
338
        if _cuda_extension_available:
            # preprocessed data-structure for GPU kernel
            roff_idx = preprocess_psi(self.kernel_size, out_shape[0], ker_idx, row_idx, col_idx, vals).contiguous()
            self.register_buffer("psi_roff_idx", roff_idx, persistent=False)

        # save all datastructures
Boris Bonev's avatar
Boris Bonev committed
339
340
341
        self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
        self.register_buffer("psi_row_idx", row_idx, persistent=False)
        self.register_buffer("psi_col_idx", col_idx, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
342
        self.register_buffer("psi_vals", vals, persistent=False)
343

344
345
346
347
    def extra_repr(self):
        r"""
        Pretty print module
        """
348
        return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
349

Boris Bonev's avatar
Boris Bonev committed
350
351
352
353
    @property
    def psi_idx(self):
        return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()

Boris Bonev's avatar
Boris Bonev committed
354
355
356
    def get_psi(self):
        psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in)).coalesce()
        return psi
357

Boris Bonev's avatar
Boris Bonev committed
358
    def forward(self, x: torch.Tensor) -> torch.Tensor:
359

Boris Bonev's avatar
Boris Bonev committed
360
361
362
363
        if x.is_cuda and _cuda_extension_available:
            x = _disco_s2_contraction_cuda(
                x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out
            )
364
        else:
Boris Bonev's avatar
Boris Bonev committed
365
366
367
            if x.is_cuda:
                warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
            psi = self.get_psi()
Boris Bonev's avatar
Boris Bonev committed
368
            x = _disco_s2_contraction_torch(x, psi, self.nlon_out)
369
370
371
372
373
374

        # extract shape
        B, C, K, H, W = x.shape
        x = x.reshape(B, self.groups, self.groupsize, K, H, W)

        # do weight multiplication
Thorsten Kurth's avatar
Thorsten Kurth committed
375
        out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
Boris Bonev's avatar
Boris Bonev committed
376
        out = out.reshape(B, -1, H, W)
377
378
379
380
381
382
383

        if self.bias is not None:
            out = out + self.bias.reshape(1, -1, 1, 1)

        return out


Boris Bonev's avatar
Boris Bonev committed
384
class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
385
    """
386
    Discrete-continuous (DISCO) transpose convolutions on the 2-Sphere as described in [1].
387
388
389
390
391
392
393
394
395
396
397

    [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        in_shape: Tuple[int],
        out_shape: Tuple[int],
        kernel_shape: Union[int, List[int]],
398
        basis_type: Optional[str] = "piecewise linear",
399
        basis_norm_mode: Optional[str] = "sum",
400
401
402
403
404
405
        groups: Optional[int] = 1,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        bias: Optional[bool] = True,
        theta_cutoff: Optional[float] = None,
    ):
406
        super().__init__(in_channels, out_channels, kernel_shape, basis_type, groups, bias)
407
408
409
410
411
412

        self.nlat_in, self.nlon_in = in_shape
        self.nlat_out, self.nlon_out = out_shape

        # bandlimit
        if theta_cutoff is None:
Boris Bonev's avatar
Boris Bonev committed
413
            theta_cutoff = torch.pi / float(self.nlat_in - 1)
414
415
416
417
418

        if theta_cutoff <= 0.0:
            raise ValueError("Error, theta_cutoff has to be positive.")

        # switch in_shape and out_shape since we want transpose conv
Boris Bonev's avatar
Boris Bonev committed
419
        idx, vals = _precompute_convolution_tensor_s2(
420
421
422
423
424
425
426
427
428
            out_shape,
            in_shape,
            self.filter_basis,
            grid_in=grid_out,
            grid_out=grid_in,
            theta_cutoff=theta_cutoff,
            transpose_normalization=True,
            basis_norm_mode=basis_norm_mode,
            merge_quadrature=True,
Boris Bonev's avatar
Boris Bonev committed
429
430
431
432
433
434
        )

        # sort the values
        ker_idx = idx[0, ...].contiguous()
        row_idx = idx[1, ...].contiguous()
        col_idx = idx[2, ...].contiguous()
435
        vals = vals.contiguous()
Boris Bonev's avatar
Boris Bonev committed
436

437
438
439
440
441
442
        if _cuda_extension_available:
            # preprocessed data-structure for GPU kernel
            roff_idx = preprocess_psi(self.kernel_size, in_shape[0], ker_idx, row_idx, col_idx, vals).contiguous()
            self.register_buffer("psi_roff_idx", roff_idx, persistent=False)

        # save all datastructures
Boris Bonev's avatar
Boris Bonev committed
443
444
445
        self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
        self.register_buffer("psi_row_idx", row_idx, persistent=False)
        self.register_buffer("psi_col_idx", col_idx, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
446
        self.register_buffer("psi_vals", vals, persistent=False)
447

448
449
450
451
    def extra_repr(self):
        r"""
        Pretty print module
        """
452
        return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
453

Boris Bonev's avatar
Boris Bonev committed
454
455
456
457
458
459
    @property
    def psi_idx(self):
        return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()

    def get_psi(self, semi_transposed: bool = False):
        if semi_transposed:
460
461
462
463
464
465
            # we do a semi-transposition to faciliate the computation
            tout = self.psi_idx[2] // self.nlon_out
            pout = self.psi_idx[2] % self.nlon_out
            # flip the axis of longitudes
            pout = self.nlon_out - 1 - pout
            tin = self.psi_idx[1]
Boris Bonev's avatar
Boris Bonev committed
466
            idx = torch.stack([self.psi_idx[0], tout, tin * self.nlon_out + pout], dim=0)
467
468
469
            psi = torch.sparse_coo_tensor(idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_out)).coalesce()
        else:
            psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out)).coalesce()
Boris Bonev's avatar
Boris Bonev committed
470

Boris Bonev's avatar
Boris Bonev committed
471
        return psi
472

Boris Bonev's avatar
Boris Bonev committed
473
    def forward(self, x: torch.Tensor) -> torch.Tensor:
474
        # extract shape
Boris Bonev's avatar
Boris Bonev committed
475
        B, C, H, W = x.shape
476
477
478
        x = x.reshape(B, self.groups, self.groupsize, H, W)

        # do weight multiplication
Thorsten Kurth's avatar
Thorsten Kurth committed
479
        x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
Boris Bonev's avatar
Boris Bonev committed
480
        x = x.reshape(B, -1, x.shape[-3], H, W)
481

Boris Bonev's avatar
Boris Bonev committed
482
483
484
485
        if x.is_cuda and _cuda_extension_available:
            out = _disco_s2_transpose_contraction_cuda(
                x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out
            )
486
        else:
Boris Bonev's avatar
Boris Bonev committed
487
488
489
            if x.is_cuda:
                warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
            psi = self.get_psi(semi_transposed=True)
Boris Bonev's avatar
Boris Bonev committed
490
            out = _disco_s2_transpose_contraction_torch(x, psi, self.nlon_out)
491
492
493
494
495

        if self.bias is not None:
            out = out + self.bias.reshape(1, -1, 1, 1)

        return out