convolution.py 19.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

Boris Bonev's avatar
Boris Bonev committed
32
import abc
33
from typing import List, Tuple, Union, Optional
Boris Bonev's avatar
Boris Bonev committed
34
from warnings import warn
35
36
37
38
39
40
41
42

import math

import torch
import torch.nn as nn

from functools import partial

Boris Bonev's avatar
Boris Bonev committed
43
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes
Boris Bonev's avatar
Boris Bonev committed
44
45
from torch_harmonics._disco_convolution import _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch
from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda
46
from torch_harmonics.filter_basis import get_filter_basis
47

48
# import custom C++/CUDA extensions if available
Boris Bonev's avatar
Boris Bonev committed
49
try:
50
    from disco_helpers import preprocess_psi
Boris Bonev's avatar
Boris Bonev committed
51
    import disco_cuda_extension
52

Boris Bonev's avatar
Boris Bonev committed
53
54
55
56
57
58
    _cuda_extension_available = True
except ImportError as err:
    disco_cuda_extension = None
    _cuda_extension_available = False


59
def _normalize_convolution_tensor_s2(
60
    psi_idx, psi_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=False, basis_norm_mode="mean", merge_quadrature=False, eps=1e-9
61
):
Boris Bonev's avatar
Boris Bonev committed
62
    """
63
64
65
66
    Discretely normalizes the convolution tensor and pre-applies quadrature weights. Supports the following three normalization modes:
    - "none": No normalization is applied.
    - "individual": for each output latitude and filter basis function the filter is numerically integrated over the sphere and normalized so that it yields 1.
    - "mean": the norm is computed for each output latitude and then averaged over the output latitudes. Each basis function is then normalized by this mean.
Boris Bonev's avatar
Boris Bonev committed
67
68
    """

69
70
    # reshape the indices implicitly to be ikernel, out_shape[0], in_shape[0], in_shape[1]
    idx = torch.stack([psi_idx[0], psi_idx[1], psi_idx[2] // in_shape[1], psi_idx[2] % in_shape[1]], dim=0)
Boris Bonev's avatar
Boris Bonev committed
71

72
73
    # getting indices for adressing kernels, input and output latitudes
    ikernel = idx[0]
Boris Bonev's avatar
Boris Bonev committed
74
75

    if transpose_normalization:
76
77
78
79
80
        ilat_out = idx[2]
        ilat_in = idx[1]
        # here we are deliberately swapping input and output shapes to handle transpose normalization with the same code
        nlat_out = in_shape[0]
        correction_factor = out_shape[1] / in_shape[1]
Boris Bonev's avatar
Boris Bonev committed
81
    else:
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
        ilat_out = idx[1]
        ilat_in = idx[2]
        nlat_out = out_shape[0]

    # get the quadrature weights
    q = quad_weights[ilat_in].reshape(-1)

    # buffer to store intermediate values
    vnorm = torch.zeros(kernel_size, nlat_out)

    # loop through dimensions to compute the norms
    for ik in range(kernel_size):
        for ilat in range(nlat_out):

            # find indices corresponding to the given output latitude and kernel basis function
            iidx = torch.argwhere((ikernel == ik) & (ilat_out == ilat))

99
100
101
            # compute the 1-norm
            # vnorm[ik, ilat] = torch.sqrt(torch.sum(psi_vals[iidx].abs().pow(2) * q[iidx]))
            vnorm[ik, ilat] = torch.sum(psi_vals[iidx].abs() * q[iidx])
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117

    # loop over values and renormalize
    for ik in range(kernel_size):
        for ilat in range(nlat_out):

            iidx = torch.argwhere((ikernel == ik) & (ilat_out == ilat))

            if basis_norm_mode == "individual":
                val = vnorm[ik, ilat]
            elif basis_norm_mode == "mean":
                val = vnorm[ik, :].mean()
            elif basis_norm_mode == "none":
                val = 1.0
            else:
                raise ValueError(f"Unknown basis normalization mode {basis_norm_mode}.")

118
119
            psi_vals[iidx] = psi_vals[iidx] / (val + eps)

120
            if merge_quadrature:
121
                psi_vals[iidx] = psi_vals[iidx] * q[iidx]
122
123
124
125


    if transpose_normalization and merge_quadrature:
        psi_vals = psi_vals / correction_factor
Boris Bonev's avatar
Boris Bonev committed
126
127
128
129
130

    return psi_vals


def _precompute_convolution_tensor_s2(
131
132
    in_shape,
    out_shape,
133
    filter_basis,
134
135
136
137
    grid_in="equiangular",
    grid_out="equiangular",
    theta_cutoff=0.01 * math.pi,
    transpose_normalization=False,
138
    basis_norm_mode="mean",
139
    merge_quadrature=False,
Boris Bonev's avatar
Boris Bonev committed
140
):
141
142
143
    """
    Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
    Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al.
144
145
146
147
148
    The output tensor has shape kernel_shape x nlat_out x (nlat_in * nlon_in).

    The rotation of the Euler angles uses the YZY convention, which applied to the northpole $(0,0,1)^T$ yields
    $$
    Y(\alpha) Z(\beta) Y(\gamma) n =
Boris Bonev's avatar
Boris Bonev committed
149
        {\begin{bmatrix}
150
151
152
153
154
            \cos(\gamma)\sin(\alpha) + \cos(\alpha)\cos(\beta)\sin(\gamma) \\
            \sin(\beta)\sin(\gamma) \\
            \cos(\alpha)\cos(\gamma)-\cos(\beta)\sin(\alpha)\sin(\gamma)
        \end{bmatrix}}
    $$
155
156
157
158
159
    """

    assert len(in_shape) == 2
    assert len(out_shape) == 2

160
    kernel_size = filter_basis.kernel_size
161
162
163
164

    nlat_in, nlon_in = in_shape
    nlat_out, nlon_out = out_shape

165
    # precompute input and output grids
Boris Bonev's avatar
Boris Bonev committed
166
    lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in)
167
    lats_in = torch.from_numpy(lats_in).float()
Boris Bonev's avatar
Boris Bonev committed
168
    lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out)
169
170
171
    lats_out = torch.from_numpy(lats_out).float()

    # compute the phi differences
172
    # It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
Boris Bonev's avatar
Boris Bonev committed
173
    lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1]
174

175
176
    # compute quadrature weights and merge them into the convolution tensor.
    # These quadrature integrate to 1 over the sphere.
177
    if transpose_normalization:
178
        quad_weights = torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in / 2.0
179
    else:
180
        quad_weights = torch.from_numpy(win).float().reshape(-1, 1) / nlon_in / 2.0
181

182
183
    out_idx = []
    out_vals = []
184
    for t in range(nlat_out):
185
        # the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis
Boris Bonev's avatar
Boris Bonev committed
186
        alpha = -lats_out[t]
187
188
        beta = lons_in
        gamma = lats_in.reshape(-1, 1)
189
190

        # compute cartesian coordinates of the rotated position
191
192
193
        # This uses the YZY convention of Euler angles, where the last angle (alpha) is a passive rotation,
        # and therefore applied with a negative sign
        x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(gamma) * torch.sin(alpha)
194
        y = torch.sin(beta) * torch.sin(gamma)
195
        z = -torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(alpha) * torch.cos(gamma)
Boris Bonev's avatar
Boris Bonev committed
196

197
        # normalization is important to avoid NaNs when arccos and atan are applied
198
        # this can otherwise lead to spurious artifacts in the solution
Boris Bonev's avatar
Boris Bonev committed
199
        norm = torch.sqrt(x * x + y * y + z * z)
200
201
202
203
        x = x / norm
        y = y / norm
        z = z / norm

Boris Bonev's avatar
Boris Bonev committed
204
        # compute spherical coordinates, where phi needs to fall into the [0, 2pi) range
205
        theta = torch.arccos(z)
206
207
        phi = torch.arctan2(y, x)
        phi = torch.where(phi < 0.0, phi + 2 * torch.pi, phi)
208
209

        # find the indices where the rotated position falls into the support of the kernel
210
        iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff)
211
212
213
214
215

        # add the output latitude and reshape such that psi has dimensions kernel_shape x nlat_out x (nlat_in*nlon_in)
        idx = torch.stack([iidx[:, 0], t * torch.ones_like(iidx[:, 0]), iidx[:, 1] * nlon_in + iidx[:, 2]], dim=0)

        # append indices and values to the COO datastructure
216
217
218
219
220
221
        out_idx.append(idx)
        out_vals.append(vals)

    # concatenate the indices and values
    out_idx = torch.cat(out_idx, dim=-1).to(torch.long).contiguous()
    out_vals = torch.cat(out_vals, dim=-1).to(torch.float32).contiguous()
222

223
    out_vals = _normalize_convolution_tensor_s2(
224
225
226
227
228
229
230
231
232
        out_idx,
        out_vals,
        in_shape,
        out_shape,
        kernel_size,
        quad_weights,
        transpose_normalization=transpose_normalization,
        basis_norm_mode=basis_norm_mode,
        merge_quadrature=merge_quadrature,
233
    )
Boris Bonev's avatar
Boris Bonev committed
234

Boris Bonev's avatar
Boris Bonev committed
235
    return out_idx, out_vals
Boris Bonev's avatar
Boris Bonev committed
236
237
238
239


class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta):
    """
240
    Abstract base class for discrete-continuous convolutions
Boris Bonev's avatar
Boris Bonev committed
241
242
243
244
245
246
247
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_shape: Union[int, List[int]],
248
        basis_type: Optional[str] = "piecewise linear",
Boris Bonev's avatar
Boris Bonev committed
249
250
251
252
253
        groups: Optional[int] = 1,
        bias: Optional[bool] = True,
    ):
        super().__init__()

254
        self.kernel_shape = kernel_shape
Boris Bonev's avatar
Boris Bonev committed
255

256
        # get the filter basis functions
257
        self.filter_basis = get_filter_basis(kernel_shape=kernel_shape, basis_type=basis_type)
Boris Bonev's avatar
Boris Bonev committed
258
259
260
261
262
263
264
265
266
267

        # groups
        self.groups = groups

        # weight tensor
        if in_channels % self.groups != 0:
            raise ValueError("Error, the number of input channels has to be an integer multiple of the group size")
        if out_channels % self.groups != 0:
            raise ValueError("Error, the number of output channels has to be an integer multiple of the group size")
        self.groupsize = in_channels // self.groups
Boris Bonev's avatar
Boris Bonev committed
268
        scale = math.sqrt(1.0 / self.groupsize / self.kernel_size)
Boris Bonev's avatar
Boris Bonev committed
269
270
271
272
273
274
275
        self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, self.kernel_size))

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channels))
        else:
            self.bias = None

276
277
278
279
    @property
    def kernel_size(self):
        return self.filter_basis.kernel_size

Boris Bonev's avatar
Boris Bonev committed
280
281
282
283
284
285
    @abc.abstractmethod
    def forward(self, x: torch.Tensor):
        raise NotImplementedError


class DiscreteContinuousConvS2(DiscreteContinuousConv):
286
    """
287
    Discrete-continuous (DISCO) convolutions on the 2-Sphere as described in [1].
288
289
290
291
292
293
294
295
296
297
298

    [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        in_shape: Tuple[int],
        out_shape: Tuple[int],
        kernel_shape: Union[int, List[int]],
299
        basis_type: Optional[str] = "piecewise linear",
300
        basis_norm_mode: Optional[str] = "mean",
301
302
303
304
305
306
        groups: Optional[int] = 1,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        bias: Optional[bool] = True,
        theta_cutoff: Optional[float] = None,
    ):
307
        super().__init__(in_channels, out_channels, kernel_shape, basis_type, groups, bias)
308
309
310
311

        self.nlat_in, self.nlon_in = in_shape
        self.nlat_out, self.nlon_out = out_shape

312
313
314
        # make sure the p-shift works by checking that longitudes are divisible
        assert self.nlon_in % self.nlon_out == 0

Boris Bonev's avatar
Boris Bonev committed
315
        # heuristic to compute theta cutoff based on the bandlimit of the input field and overlaps of the basis functions
316
        if theta_cutoff is None:
Boris Bonev's avatar
Boris Bonev committed
317
            theta_cutoff = torch.pi / float(self.nlat_out - 1)
318
319
320
321

        if theta_cutoff <= 0.0:
            raise ValueError("Error, theta_cutoff has to be positive.")

Boris Bonev's avatar
Boris Bonev committed
322
        idx, vals = _precompute_convolution_tensor_s2(
323
324
325
326
327
328
329
330
331
            in_shape,
            out_shape,
            self.filter_basis,
            grid_in=grid_in,
            grid_out=grid_out,
            theta_cutoff=theta_cutoff,
            transpose_normalization=False,
            basis_norm_mode=basis_norm_mode,
            merge_quadrature=True,
Boris Bonev's avatar
Boris Bonev committed
332
333
334
335
336
337
        )

        # sort the values
        ker_idx = idx[0, ...].contiguous()
        row_idx = idx[1, ...].contiguous()
        col_idx = idx[2, ...].contiguous()
338
        vals = vals.contiguous()
Boris Bonev's avatar
Boris Bonev committed
339

340
341
342
343
344
345
        if _cuda_extension_available:
            # preprocessed data-structure for GPU kernel
            roff_idx = preprocess_psi(self.kernel_size, out_shape[0], ker_idx, row_idx, col_idx, vals).contiguous()
            self.register_buffer("psi_roff_idx", roff_idx, persistent=False)

        # save all datastructures
Boris Bonev's avatar
Boris Bonev committed
346
347
348
        self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
        self.register_buffer("psi_row_idx", row_idx, persistent=False)
        self.register_buffer("psi_col_idx", col_idx, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
349
        self.register_buffer("psi_vals", vals, persistent=False)
350

351
352
353
354
    def extra_repr(self):
        r"""
        Pretty print module
        """
355
        return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
356

Boris Bonev's avatar
Boris Bonev committed
357
358
359
360
    @property
    def psi_idx(self):
        return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()

Boris Bonev's avatar
Boris Bonev committed
361
362
363
    def get_psi(self):
        psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in)).coalesce()
        return psi
364

Boris Bonev's avatar
Boris Bonev committed
365
    def forward(self, x: torch.Tensor) -> torch.Tensor:
366

Boris Bonev's avatar
Boris Bonev committed
367
368
369
370
        if x.is_cuda and _cuda_extension_available:
            x = _disco_s2_contraction_cuda(
                x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out
            )
371
        else:
Boris Bonev's avatar
Boris Bonev committed
372
373
374
            if x.is_cuda:
                warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
            psi = self.get_psi()
Boris Bonev's avatar
Boris Bonev committed
375
            x = _disco_s2_contraction_torch(x, psi, self.nlon_out)
376
377
378
379
380
381

        # extract shape
        B, C, K, H, W = x.shape
        x = x.reshape(B, self.groups, self.groupsize, K, H, W)

        # do weight multiplication
Thorsten Kurth's avatar
Thorsten Kurth committed
382
        out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
Boris Bonev's avatar
Boris Bonev committed
383
        out = out.reshape(B, -1, H, W)
384
385
386
387
388
389
390

        if self.bias is not None:
            out = out + self.bias.reshape(1, -1, 1, 1)

        return out


Boris Bonev's avatar
Boris Bonev committed
391
class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
392
    """
393
    Discrete-continuous (DISCO) transpose convolutions on the 2-Sphere as described in [1].
394
395
396
397
398
399
400
401
402
403
404

    [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        in_shape: Tuple[int],
        out_shape: Tuple[int],
        kernel_shape: Union[int, List[int]],
405
        basis_type: Optional[str] = "piecewise linear",
406
        basis_norm_mode: Optional[str] = "mean",
407
408
409
410
411
412
        groups: Optional[int] = 1,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        bias: Optional[bool] = True,
        theta_cutoff: Optional[float] = None,
    ):
413
        super().__init__(in_channels, out_channels, kernel_shape, basis_type, groups, bias)
414
415
416
417

        self.nlat_in, self.nlon_in = in_shape
        self.nlat_out, self.nlon_out = out_shape

418
419
420
        # make sure the p-shift works by checking that longitudes are divisible
        assert self.nlon_out % self.nlon_in == 0

421
422
        # bandlimit
        if theta_cutoff is None:
Boris Bonev's avatar
Boris Bonev committed
423
            theta_cutoff = torch.pi / float(self.nlat_in - 1)
424
425
426
427

        if theta_cutoff <= 0.0:
            raise ValueError("Error, theta_cutoff has to be positive.")

428
        # switch in_shape and out_shape since we want the transpose convolution
Boris Bonev's avatar
Boris Bonev committed
429
        idx, vals = _precompute_convolution_tensor_s2(
430
431
432
433
434
435
436
437
438
            out_shape,
            in_shape,
            self.filter_basis,
            grid_in=grid_out,
            grid_out=grid_in,
            theta_cutoff=theta_cutoff,
            transpose_normalization=True,
            basis_norm_mode=basis_norm_mode,
            merge_quadrature=True,
Boris Bonev's avatar
Boris Bonev committed
439
440
441
442
443
444
        )

        # sort the values
        ker_idx = idx[0, ...].contiguous()
        row_idx = idx[1, ...].contiguous()
        col_idx = idx[2, ...].contiguous()
445
        vals = vals.contiguous()
Boris Bonev's avatar
Boris Bonev committed
446

447
448
449
450
451
452
        if _cuda_extension_available:
            # preprocessed data-structure for GPU kernel
            roff_idx = preprocess_psi(self.kernel_size, in_shape[0], ker_idx, row_idx, col_idx, vals).contiguous()
            self.register_buffer("psi_roff_idx", roff_idx, persistent=False)

        # save all datastructures
Boris Bonev's avatar
Boris Bonev committed
453
454
455
        self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
        self.register_buffer("psi_row_idx", row_idx, persistent=False)
        self.register_buffer("psi_col_idx", col_idx, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
456
        self.register_buffer("psi_vals", vals, persistent=False)
457

458
459
460
461
    def extra_repr(self):
        r"""
        Pretty print module
        """
462
        return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
463

Boris Bonev's avatar
Boris Bonev committed
464
465
466
467
468
469
    @property
    def psi_idx(self):
        return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()

    def get_psi(self, semi_transposed: bool = False):
        if semi_transposed:
470
471
472
473
474
475
            # we do a semi-transposition to faciliate the computation
            tout = self.psi_idx[2] // self.nlon_out
            pout = self.psi_idx[2] % self.nlon_out
            # flip the axis of longitudes
            pout = self.nlon_out - 1 - pout
            tin = self.psi_idx[1]
Boris Bonev's avatar
Boris Bonev committed
476
            idx = torch.stack([self.psi_idx[0], tout, tin * self.nlon_out + pout], dim=0)
477
478
479
            psi = torch.sparse_coo_tensor(idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_out)).coalesce()
        else:
            psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out)).coalesce()
Boris Bonev's avatar
Boris Bonev committed
480

Boris Bonev's avatar
Boris Bonev committed
481
        return psi
482

Boris Bonev's avatar
Boris Bonev committed
483
    def forward(self, x: torch.Tensor) -> torch.Tensor:
484
        # extract shape
Boris Bonev's avatar
Boris Bonev committed
485
        B, C, H, W = x.shape
486
487
488
        x = x.reshape(B, self.groups, self.groupsize, H, W)

        # do weight multiplication
Thorsten Kurth's avatar
Thorsten Kurth committed
489
        x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
Boris Bonev's avatar
Boris Bonev committed
490
        x = x.reshape(B, -1, x.shape[-3], H, W)
491

Boris Bonev's avatar
Boris Bonev committed
492
493
494
495
        if x.is_cuda and _cuda_extension_available:
            out = _disco_s2_transpose_contraction_cuda(
                x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out
            )
496
        else:
Boris Bonev's avatar
Boris Bonev committed
497
498
499
            if x.is_cuda:
                warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
            psi = self.get_psi(semi_transposed=True)
Boris Bonev's avatar
Boris Bonev committed
500
            out = _disco_s2_transpose_contraction_torch(x, psi, self.nlon_out)
501
502
503
504
505

        if self.bias is not None:
            out = out + self.bias.reshape(1, -1, 1, 1)

        return out