convolution.py 18.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

Boris Bonev's avatar
Boris Bonev committed
32
import abc
33
from typing import List, Tuple, Union, Optional
Boris Bonev's avatar
Boris Bonev committed
34
from warnings import warn
35
36
37
38
39
40
41
42

import math

import torch
import torch.nn as nn

from functools import partial

Boris Bonev's avatar
Boris Bonev committed
43
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes
Boris Bonev's avatar
Boris Bonev committed
44
45
from torch_harmonics._disco_convolution import _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch
from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda
46
from torch_harmonics.filter_basis import get_filter_basis
47

48
# import custom C++/CUDA extensions if available
Boris Bonev's avatar
Boris Bonev committed
49
try:
50
    from disco_helpers import preprocess_psi
Boris Bonev's avatar
Boris Bonev committed
51
    import disco_cuda_extension
52

Boris Bonev's avatar
Boris Bonev committed
53
54
55
56
57
58
    _cuda_extension_available = True
except ImportError as err:
    disco_cuda_extension = None
    _cuda_extension_available = False


59
def _normalize_convolution_tensor_s2(psi_idx, psi_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=False, merge_quadrature=False, eps=1e-9):
Boris Bonev's avatar
Boris Bonev committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    """
    Discretely normalizes the convolution tensor.
    """

    nlat_in, nlon_in = in_shape
    nlat_out, nlon_out = out_shape

    # reshape the indices implicitly to be ikernel, lat_out, lat_in, lon_in
    idx = torch.stack([psi_idx[0], psi_idx[1], psi_idx[2] // nlon_in, psi_idx[2] % nlon_in], dim=0)

    if transpose_normalization:
        # pre-compute the quadrature weights
        q = quad_weights[idx[1]].reshape(-1)

        # loop through dimensions which require normalization
        for ik in range(kernel_size):
            for ilat in range(nlat_in):
                # get relevant entries
                iidx = torch.argwhere((idx[0] == ik) & (idx[2] == ilat))
                # normalize, while summing also over the input longitude dimension here as this is not available for the output
                vnorm = torch.sum(psi_vals[iidx] * q[iidx])
81
82
83
84
85
                if merge_quadrature:
                    # the correction factor accounts for the difference in longitudinal grid points when the input vector is upscaled
                    psi_vals[iidx] = psi_vals[iidx] * q[iidx] * nlon_in / nlon_out / (vnorm + eps)
                else:
                    psi_vals[iidx] = psi_vals[iidx] / (vnorm + eps)
Boris Bonev's avatar
Boris Bonev committed
86
87
88
89
90
91
92
93
94
95
96
    else:
        # pre-compute the quadrature weights
        q = quad_weights[idx[2]].reshape(-1)

        # loop through dimensions which require normalization
        for ik in range(kernel_size):
            for ilat in range(nlat_out):
                # get relevant entries
                iidx = torch.argwhere((idx[0] == ik) & (idx[1] == ilat))
                # normalize
                vnorm = torch.sum(psi_vals[iidx] * q[iidx])
97
98
99
100
                if merge_quadrature:
                    psi_vals[iidx] = psi_vals[iidx] * q[iidx] / (vnorm + eps)
                else:
                    psi_vals[iidx] = psi_vals[iidx] / (vnorm + eps)
Boris Bonev's avatar
Boris Bonev committed
101
102
103
104
105

    return psi_vals


def _precompute_convolution_tensor_s2(
106
107
    in_shape,
    out_shape,
108
    filter_basis,
109
110
111
112
113
    grid_in="equiangular",
    grid_out="equiangular",
    theta_cutoff=0.01 * math.pi,
    transpose_normalization=False,
    merge_quadrature=False,
Boris Bonev's avatar
Boris Bonev committed
114
):
115
116
117
    """
    Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
    Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al.
118
119
120
121
122
    The output tensor has shape kernel_shape x nlat_out x (nlat_in * nlon_in).

    The rotation of the Euler angles uses the YZY convention, which applied to the northpole $(0,0,1)^T$ yields
    $$
    Y(\alpha) Z(\beta) Y(\gamma) n =
Boris Bonev's avatar
Boris Bonev committed
123
        {\begin{bmatrix}
124
125
126
127
128
            \cos(\gamma)\sin(\alpha) + \cos(\alpha)\cos(\beta)\sin(\gamma) \\
            \sin(\beta)\sin(\gamma) \\
            \cos(\alpha)\cos(\gamma)-\cos(\beta)\sin(\alpha)\sin(\gamma)
        \end{bmatrix}}
    $$
129
130
131
132
133
    """

    assert len(in_shape) == 2
    assert len(out_shape) == 2

134
    kernel_size = filter_basis.kernel_size
135
136
137
138

    nlat_in, nlon_in = in_shape
    nlat_out, nlon_out = out_shape

Boris Bonev's avatar
Boris Bonev committed
139
    lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in)
140
    lats_in = torch.from_numpy(lats_in).float()
Boris Bonev's avatar
Boris Bonev committed
141
    lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out)
142
143
144
    lats_out = torch.from_numpy(lats_out).float()

    # compute the phi differences
145
    # It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
Boris Bonev's avatar
Boris Bonev committed
146
    lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1]
147

148
149
    out_idx = []
    out_vals = []
150
    for t in range(nlat_out):
151
        # the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis
Boris Bonev's avatar
Boris Bonev committed
152
        alpha = -lats_out[t]
153
154
        beta = lons_in
        gamma = lats_in.reshape(-1, 1)
155
156

        # compute cartesian coordinates of the rotated position
157
158
        # This uses the YZY convention of Euler angles, where the last angle (alpha) is a passive rotation,
        # and therefore applied with a negative sign
Boris Bonev's avatar
Boris Bonev committed
159
        z = -torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(alpha) * torch.cos(gamma)
160
        x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(gamma) * torch.sin(alpha)
161
        y = torch.sin(beta) * torch.sin(gamma)
Boris Bonev's avatar
Boris Bonev committed
162

163
164
        # normalization is emportant to avoid NaNs when arccos and atan are applied
        # this can otherwise lead to spurious artifacts in the solution
Boris Bonev's avatar
Boris Bonev committed
165
        norm = torch.sqrt(x * x + y * y + z * z)
166
167
168
169
        x = x / norm
        y = y / norm
        z = z / norm

Boris Bonev's avatar
Boris Bonev committed
170
        # compute spherical coordinates, where phi needs to fall into the [0, 2pi) range
171
        theta = torch.arccos(z)
Boris Bonev's avatar
Boris Bonev committed
172
        phi = torch.arctan2(y, x) + torch.pi
173
174

        # find the indices where the rotated position falls into the support of the kernel
175
        iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff)
176
177
178
179
180

        # add the output latitude and reshape such that psi has dimensions kernel_shape x nlat_out x (nlat_in*nlon_in)
        idx = torch.stack([iidx[:, 0], t * torch.ones_like(iidx[:, 0]), iidx[:, 1] * nlon_in + iidx[:, 2]], dim=0)

        # append indices and values to the COO datastructure
181
182
183
184
185
186
        out_idx.append(idx)
        out_vals.append(vals)

    # concatenate the indices and values
    out_idx = torch.cat(out_idx, dim=-1).to(torch.long).contiguous()
    out_vals = torch.cat(out_vals, dim=-1).to(torch.float32).contiguous()
187

Boris Bonev's avatar
Boris Bonev committed
188
189
    if transpose_normalization:
        quad_weights = 2.0 * torch.pi * torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in
Boris Bonev's avatar
Boris Bonev committed
190
    else:
Boris Bonev's avatar
Boris Bonev committed
191
        quad_weights = 2.0 * torch.pi * torch.from_numpy(win).float().reshape(-1, 1) / nlon_in
192
    out_vals = _normalize_convolution_tensor_s2(
193
        out_idx, out_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=transpose_normalization, merge_quadrature=merge_quadrature
194
    )
Boris Bonev's avatar
Boris Bonev committed
195

Boris Bonev's avatar
Boris Bonev committed
196
    return out_idx, out_vals
Boris Bonev's avatar
Boris Bonev committed
197
198
199
200
201
202
203
204
205
206
207
208


class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta):
    """
    Abstract base class for DISCO convolutions
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_shape: Union[int, List[int]],
209
        basis_type: Optional[str] = "piecewise linear",
Boris Bonev's avatar
Boris Bonev committed
210
211
212
213
214
        groups: Optional[int] = 1,
        bias: Optional[bool] = True,
    ):
        super().__init__()

215
        self.kernel_shape = kernel_shape
Boris Bonev's avatar
Boris Bonev committed
216

217
        # get the filter basis functions
218
        self.filter_basis = get_filter_basis(kernel_shape=kernel_shape, basis_type=basis_type)
Boris Bonev's avatar
Boris Bonev committed
219
220
221
222
223
224
225
226
227
228

        # groups
        self.groups = groups

        # weight tensor
        if in_channels % self.groups != 0:
            raise ValueError("Error, the number of input channels has to be an integer multiple of the group size")
        if out_channels % self.groups != 0:
            raise ValueError("Error, the number of output channels has to be an integer multiple of the group size")
        self.groupsize = in_channels // self.groups
Boris Bonev's avatar
Boris Bonev committed
229
        scale = math.sqrt(1.0 / self.groupsize / self.kernel_size)
Boris Bonev's avatar
Boris Bonev committed
230
231
232
233
234
235
236
        self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, self.kernel_size))

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channels))
        else:
            self.bias = None

237
238
239
240
    @property
    def kernel_size(self):
        return self.filter_basis.kernel_size

Boris Bonev's avatar
Boris Bonev committed
241
242
243
244
245
246
    @abc.abstractmethod
    def forward(self, x: torch.Tensor):
        raise NotImplementedError


class DiscreteContinuousConvS2(DiscreteContinuousConv):
247
248
249
250
251
252
253
254
255
256
257
258
259
    """
    Discrete-continuous convolutions (DISCO) on the 2-Sphere as described in [1].

    [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        in_shape: Tuple[int],
        out_shape: Tuple[int],
        kernel_shape: Union[int, List[int]],
260
        basis_type: Optional[str] = "piecewise linear",
261
262
263
264
265
266
        groups: Optional[int] = 1,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        bias: Optional[bool] = True,
        theta_cutoff: Optional[float] = None,
    ):
267
        super().__init__(in_channels, out_channels, kernel_shape, basis_type, groups, bias)
268
269
270
271

        self.nlat_in, self.nlon_in = in_shape
        self.nlat_out, self.nlon_out = out_shape

Boris Bonev's avatar
Boris Bonev committed
272
        # heuristic to compute theta cutoff based on the bandlimit of the input field and overlaps of the basis functions
273
        if theta_cutoff is None:
Boris Bonev's avatar
Boris Bonev committed
274
            theta_cutoff = torch.pi / float(self.nlat_out - 1)
275
276
277
278

        if theta_cutoff <= 0.0:
            raise ValueError("Error, theta_cutoff has to be positive.")

Boris Bonev's avatar
Boris Bonev committed
279
        idx, vals = _precompute_convolution_tensor_s2(
280
            in_shape, out_shape, self.filter_basis, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff, transpose_normalization=False, merge_quadrature=True
Boris Bonev's avatar
Boris Bonev committed
281
282
283
284
285
286
        )

        # sort the values
        ker_idx = idx[0, ...].contiguous()
        row_idx = idx[1, ...].contiguous()
        col_idx = idx[2, ...].contiguous()
287
        vals = vals.contiguous()
Boris Bonev's avatar
Boris Bonev committed
288

289
290
291
292
293
294
        if _cuda_extension_available:
            # preprocessed data-structure for GPU kernel
            roff_idx = preprocess_psi(self.kernel_size, out_shape[0], ker_idx, row_idx, col_idx, vals).contiguous()
            self.register_buffer("psi_roff_idx", roff_idx, persistent=False)

        # save all datastructures
Boris Bonev's avatar
Boris Bonev committed
295
296
297
        self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
        self.register_buffer("psi_row_idx", row_idx, persistent=False)
        self.register_buffer("psi_col_idx", col_idx, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
298
        self.register_buffer("psi_vals", vals, persistent=False)
299

300
301
302
303
    def extra_repr(self):
        r"""
        Pretty print module
        """
304
        return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
305

Boris Bonev's avatar
Boris Bonev committed
306
307
308
309
    @property
    def psi_idx(self):
        return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()

Boris Bonev's avatar
Boris Bonev committed
310
311
312
    def get_psi(self):
        psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in)).coalesce()
        return psi
313

Boris Bonev's avatar
Boris Bonev committed
314
    def forward(self, x: torch.Tensor) -> torch.Tensor:
315

Boris Bonev's avatar
Boris Bonev committed
316
317
318
319
        if x.is_cuda and _cuda_extension_available:
            x = _disco_s2_contraction_cuda(
                x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out
            )
320
        else:
Boris Bonev's avatar
Boris Bonev committed
321
322
323
            if x.is_cuda:
                warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
            psi = self.get_psi()
Boris Bonev's avatar
Boris Bonev committed
324
            x = _disco_s2_contraction_torch(x, psi, self.nlon_out)
325
326
327
328
329
330

        # extract shape
        B, C, K, H, W = x.shape
        x = x.reshape(B, self.groups, self.groupsize, K, H, W)

        # do weight multiplication
Thorsten Kurth's avatar
Thorsten Kurth committed
331
        out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
Boris Bonev's avatar
Boris Bonev committed
332
        out = out.reshape(B, -1, H, W)
333
334
335
336
337
338
339

        if self.bias is not None:
            out = out + self.bias.reshape(1, -1, 1, 1)

        return out


Boris Bonev's avatar
Boris Bonev committed
340
class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
341
342
343
344
345
346
347
348
349
350
351
352
353
    """
    Discrete-continuous transpose convolutions (DISCO) on the 2-Sphere as described in [1].

    [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        in_shape: Tuple[int],
        out_shape: Tuple[int],
        kernel_shape: Union[int, List[int]],
354
        basis_type: Optional[str] = "piecewise linear",
355
356
357
358
359
360
        groups: Optional[int] = 1,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        bias: Optional[bool] = True,
        theta_cutoff: Optional[float] = None,
    ):
361
        super().__init__(in_channels, out_channels, kernel_shape, basis_type, groups, bias)
362
363
364
365
366
367

        self.nlat_in, self.nlon_in = in_shape
        self.nlat_out, self.nlon_out = out_shape

        # bandlimit
        if theta_cutoff is None:
Boris Bonev's avatar
Boris Bonev committed
368
            theta_cutoff = torch.pi / float(self.nlat_in - 1)
369
370
371
372
373

        if theta_cutoff <= 0.0:
            raise ValueError("Error, theta_cutoff has to be positive.")

        # switch in_shape and out_shape since we want transpose conv
Boris Bonev's avatar
Boris Bonev committed
374
        idx, vals = _precompute_convolution_tensor_s2(
375
            out_shape, in_shape, self.filter_basis, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff, transpose_normalization=True, merge_quadrature=True
Boris Bonev's avatar
Boris Bonev committed
376
377
378
379
380
381
        )

        # sort the values
        ker_idx = idx[0, ...].contiguous()
        row_idx = idx[1, ...].contiguous()
        col_idx = idx[2, ...].contiguous()
382
        vals = vals.contiguous()
Boris Bonev's avatar
Boris Bonev committed
383

384
385
386
387
388
389
        if _cuda_extension_available:
            # preprocessed data-structure for GPU kernel
            roff_idx = preprocess_psi(self.kernel_size, in_shape[0], ker_idx, row_idx, col_idx, vals).contiguous()
            self.register_buffer("psi_roff_idx", roff_idx, persistent=False)

        # save all datastructures
Boris Bonev's avatar
Boris Bonev committed
390
391
392
        self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
        self.register_buffer("psi_row_idx", row_idx, persistent=False)
        self.register_buffer("psi_col_idx", col_idx, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
393
        self.register_buffer("psi_vals", vals, persistent=False)
394

395
396
397
398
    def extra_repr(self):
        r"""
        Pretty print module
        """
399
        return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
400

Boris Bonev's avatar
Boris Bonev committed
401
402
403
404
405
406
    @property
    def psi_idx(self):
        return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()

    def get_psi(self, semi_transposed: bool = False):
        if semi_transposed:
407
408
409
410
411
412
            # we do a semi-transposition to faciliate the computation
            tout = self.psi_idx[2] // self.nlon_out
            pout = self.psi_idx[2] % self.nlon_out
            # flip the axis of longitudes
            pout = self.nlon_out - 1 - pout
            tin = self.psi_idx[1]
Boris Bonev's avatar
Boris Bonev committed
413
            idx = torch.stack([self.psi_idx[0], tout, tin * self.nlon_out + pout], dim=0)
414
415
416
            psi = torch.sparse_coo_tensor(idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_out)).coalesce()
        else:
            psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out)).coalesce()
Boris Bonev's avatar
Boris Bonev committed
417

Boris Bonev's avatar
Boris Bonev committed
418
        return psi
419

Boris Bonev's avatar
Boris Bonev committed
420
    def forward(self, x: torch.Tensor) -> torch.Tensor:
421
        # extract shape
Boris Bonev's avatar
Boris Bonev committed
422
        B, C, H, W = x.shape
423
424
425
        x = x.reshape(B, self.groups, self.groupsize, H, W)

        # do weight multiplication
Thorsten Kurth's avatar
Thorsten Kurth committed
426
        x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
Boris Bonev's avatar
Boris Bonev committed
427
        x = x.reshape(B, -1, x.shape[-3], H, W)
428

Boris Bonev's avatar
Boris Bonev committed
429
430
431
432
        if x.is_cuda and _cuda_extension_available:
            out = _disco_s2_transpose_contraction_cuda(
                x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out
            )
433
        else:
Boris Bonev's avatar
Boris Bonev committed
434
435
436
            if x.is_cuda:
                warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
            psi = self.get_psi(semi_transposed=True)
Boris Bonev's avatar
Boris Bonev committed
437
            out = _disco_s2_transpose_contraction_torch(x, psi, self.nlon_out)
438
439
440
441
442

        if self.bias is not None:
            out = out + self.bias.reshape(1, -1, 1, 1)

        return out