convolution.py 18 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

Boris Bonev's avatar
Boris Bonev committed
32
import abc
33
from typing import List, Tuple, Union, Optional
Boris Bonev's avatar
Boris Bonev committed
34
from warnings import warn
35
36
37
38
39
40
41
42

import math

import torch
import torch.nn as nn

from functools import partial

Boris Bonev's avatar
Boris Bonev committed
43
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes
Boris Bonev's avatar
Boris Bonev committed
44
45
from torch_harmonics._disco_convolution import _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch
from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda
46
from torch_harmonics.filter_basis import get_filter_basis
47

48
# import custom C++/CUDA extensions if available
Boris Bonev's avatar
Boris Bonev committed
49
try:
50
    from disco_helpers import preprocess_psi
Boris Bonev's avatar
Boris Bonev committed
51
    import disco_cuda_extension
52

Boris Bonev's avatar
Boris Bonev committed
53
54
55
56
57
58
    _cuda_extension_available = True
except ImportError as err:
    disco_cuda_extension = None
    _cuda_extension_available = False


59
def _normalize_convolution_tensor_s2(psi_idx, psi_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=False, merge_quadrature=False, eps=1e-9):
Boris Bonev's avatar
Boris Bonev committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    """
    Discretely normalizes the convolution tensor.
    """

    nlat_in, nlon_in = in_shape
    nlat_out, nlon_out = out_shape

    # reshape the indices implicitly to be ikernel, lat_out, lat_in, lon_in
    idx = torch.stack([psi_idx[0], psi_idx[1], psi_idx[2] // nlon_in, psi_idx[2] % nlon_in], dim=0)

    if transpose_normalization:
        # pre-compute the quadrature weights
        q = quad_weights[idx[1]].reshape(-1)

        # loop through dimensions which require normalization
        for ik in range(kernel_size):
            for ilat in range(nlat_in):
                # get relevant entries
                iidx = torch.argwhere((idx[0] == ik) & (idx[2] == ilat))
                # normalize, while summing also over the input longitude dimension here as this is not available for the output
                vnorm = torch.sum(psi_vals[iidx] * q[iidx])
81
82
83
84
85
                if merge_quadrature:
                    # the correction factor accounts for the difference in longitudinal grid points when the input vector is upscaled
                    psi_vals[iidx] = psi_vals[iidx] * q[iidx] * nlon_in / nlon_out / (vnorm + eps)
                else:
                    psi_vals[iidx] = psi_vals[iidx] / (vnorm + eps)
Boris Bonev's avatar
Boris Bonev committed
86
87
88
89
90
91
92
93
94
95
96
    else:
        # pre-compute the quadrature weights
        q = quad_weights[idx[2]].reshape(-1)

        # loop through dimensions which require normalization
        for ik in range(kernel_size):
            for ilat in range(nlat_out):
                # get relevant entries
                iidx = torch.argwhere((idx[0] == ik) & (idx[1] == ilat))
                # normalize
                vnorm = torch.sum(psi_vals[iidx] * q[iidx])
97
98
99
100
                if merge_quadrature:
                    psi_vals[iidx] = psi_vals[iidx] * q[iidx] / (vnorm + eps)
                else:
                    psi_vals[iidx] = psi_vals[iidx] / (vnorm + eps)
Boris Bonev's avatar
Boris Bonev committed
101
102
103
104
105

    return psi_vals


def _precompute_convolution_tensor_s2(
106
107
    in_shape,
    out_shape,
108
    filter_basis,
109
110
111
112
113
    grid_in="equiangular",
    grid_out="equiangular",
    theta_cutoff=0.01 * math.pi,
    transpose_normalization=False,
    merge_quadrature=False,
Boris Bonev's avatar
Boris Bonev committed
114
):
115
116
117
    """
    Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
    Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al.
118
119
120
121
122
    The output tensor has shape kernel_shape x nlat_out x (nlat_in * nlon_in).

    The rotation of the Euler angles uses the YZY convention, which applied to the northpole $(0,0,1)^T$ yields
    $$
    Y(\alpha) Z(\beta) Y(\gamma) n =
Boris Bonev's avatar
Boris Bonev committed
123
        {\begin{bmatrix}
124
125
126
127
128
            \cos(\gamma)\sin(\alpha) + \cos(\alpha)\cos(\beta)\sin(\gamma) \\
            \sin(\beta)\sin(\gamma) \\
            \cos(\alpha)\cos(\gamma)-\cos(\beta)\sin(\alpha)\sin(\gamma)
        \end{bmatrix}}
    $$
129
130
131
132
133
    """

    assert len(in_shape) == 2
    assert len(out_shape) == 2

134
    kernel_size = filter_basis.kernel_size
135
136
137
138

    nlat_in, nlon_in = in_shape
    nlat_out, nlon_out = out_shape

Boris Bonev's avatar
Boris Bonev committed
139
    lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in)
140
    lats_in = torch.from_numpy(lats_in).float()
Boris Bonev's avatar
Boris Bonev committed
141
    lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out)
142
143
144
    lats_out = torch.from_numpy(lats_out).float()

    # compute the phi differences
145
    # It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
Boris Bonev's avatar
Boris Bonev committed
146
    lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1]
147

148
149
    out_idx = []
    out_vals = []
150
    for t in range(nlat_out):
151
        # the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis
Boris Bonev's avatar
Boris Bonev committed
152
        alpha = -lats_out[t]
153
154
        beta = lons_in
        gamma = lats_in.reshape(-1, 1)
155
156

        # compute cartesian coordinates of the rotated position
157
158
        # This uses the YZY convention of Euler angles, where the last angle (alpha) is a passive rotation,
        # and therefore applied with a negative sign
Boris Bonev's avatar
Boris Bonev committed
159
        z = -torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(alpha) * torch.cos(gamma)
160
        x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(gamma) * torch.sin(alpha)
161
        y = torch.sin(beta) * torch.sin(gamma)
Boris Bonev's avatar
Boris Bonev committed
162

163
164
        # normalization is emportant to avoid NaNs when arccos and atan are applied
        # this can otherwise lead to spurious artifacts in the solution
Boris Bonev's avatar
Boris Bonev committed
165
        norm = torch.sqrt(x * x + y * y + z * z)
166
167
168
169
        x = x / norm
        y = y / norm
        z = z / norm

Boris Bonev's avatar
Boris Bonev committed
170
        # compute spherical coordinates, where phi needs to fall into the [0, 2pi) range
171
        theta = torch.arccos(z)
Boris Bonev's avatar
Boris Bonev committed
172
        phi = torch.arctan2(y, x) + torch.pi
173
174

        # find the indices where the rotated position falls into the support of the kernel
175
        iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff)
176
177
178
179
180

        # add the output latitude and reshape such that psi has dimensions kernel_shape x nlat_out x (nlat_in*nlon_in)
        idx = torch.stack([iidx[:, 0], t * torch.ones_like(iidx[:, 0]), iidx[:, 1] * nlon_in + iidx[:, 2]], dim=0)

        # append indices and values to the COO datastructure
181
182
183
184
185
186
        out_idx.append(idx)
        out_vals.append(vals)

    # concatenate the indices and values
    out_idx = torch.cat(out_idx, dim=-1).to(torch.long).contiguous()
    out_vals = torch.cat(out_vals, dim=-1).to(torch.float32).contiguous()
187

Boris Bonev's avatar
Boris Bonev committed
188
189
    if transpose_normalization:
        quad_weights = 2.0 * torch.pi * torch.from_numpy(wout).float().reshape(-1, 1) / nlon_in
Boris Bonev's avatar
Boris Bonev committed
190
    else:
Boris Bonev's avatar
Boris Bonev committed
191
        quad_weights = 2.0 * torch.pi * torch.from_numpy(win).float().reshape(-1, 1) / nlon_in
192
    out_vals = _normalize_convolution_tensor_s2(
193
        out_idx, out_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=transpose_normalization, merge_quadrature=merge_quadrature
194
    )
Boris Bonev's avatar
Boris Bonev committed
195

Boris Bonev's avatar
Boris Bonev committed
196
    return out_idx, out_vals
Boris Bonev's avatar
Boris Bonev committed
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213


class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta):
    """
    Abstract base class for DISCO convolutions
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_shape: Union[int, List[int]],
        groups: Optional[int] = 1,
        bias: Optional[bool] = True,
    ):
        super().__init__()

214
        self.kernel_shape = kernel_shape
Boris Bonev's avatar
Boris Bonev committed
215

216
217
        # get the filter basis functions
        self.filter_basis = get_filter_basis(kernel_shape=kernel_shape, basis_type="piecewise linear")
Boris Bonev's avatar
Boris Bonev committed
218
219
220
221
222
223
224
225
226
227

        # groups
        self.groups = groups

        # weight tensor
        if in_channels % self.groups != 0:
            raise ValueError("Error, the number of input channels has to be an integer multiple of the group size")
        if out_channels % self.groups != 0:
            raise ValueError("Error, the number of output channels has to be an integer multiple of the group size")
        self.groupsize = in_channels // self.groups
Boris Bonev's avatar
Boris Bonev committed
228
        scale = math.sqrt(1.0 / self.groupsize / self.kernel_size)
Boris Bonev's avatar
Boris Bonev committed
229
230
231
232
233
234
235
        self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, self.kernel_size))

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channels))
        else:
            self.bias = None

236
237
238
239
    @property
    def kernel_size(self):
        return self.filter_basis.kernel_size

Boris Bonev's avatar
Boris Bonev committed
240
241
242
243
244
245
    @abc.abstractmethod
    def forward(self, x: torch.Tensor):
        raise NotImplementedError


class DiscreteContinuousConvS2(DiscreteContinuousConv):
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
    """
    Discrete-continuous convolutions (DISCO) on the 2-Sphere as described in [1].

    [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        in_shape: Tuple[int],
        out_shape: Tuple[int],
        kernel_shape: Union[int, List[int]],
        groups: Optional[int] = 1,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        bias: Optional[bool] = True,
        theta_cutoff: Optional[float] = None,
    ):
Boris Bonev's avatar
Boris Bonev committed
265
        super().__init__(in_channels, out_channels, kernel_shape, groups, bias)
266
267
268
269

        self.nlat_in, self.nlon_in = in_shape
        self.nlat_out, self.nlon_out = out_shape

Boris Bonev's avatar
Boris Bonev committed
270
        # heuristic to compute theta cutoff based on the bandlimit of the input field and overlaps of the basis functions
271
        if theta_cutoff is None:
Boris Bonev's avatar
Boris Bonev committed
272
            theta_cutoff = torch.pi / float(self.nlat_out - 1)
273
274
275
276

        if theta_cutoff <= 0.0:
            raise ValueError("Error, theta_cutoff has to be positive.")

Boris Bonev's avatar
Boris Bonev committed
277
        idx, vals = _precompute_convolution_tensor_s2(
278
            in_shape, out_shape, self.filter_basis, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff, transpose_normalization=False, merge_quadrature=True
Boris Bonev's avatar
Boris Bonev committed
279
280
281
282
283
284
        )

        # sort the values
        ker_idx = idx[0, ...].contiguous()
        row_idx = idx[1, ...].contiguous()
        col_idx = idx[2, ...].contiguous()
285
        vals = vals.contiguous()
Boris Bonev's avatar
Boris Bonev committed
286

287
288
289
290
291
292
        if _cuda_extension_available:
            # preprocessed data-structure for GPU kernel
            roff_idx = preprocess_psi(self.kernel_size, out_shape[0], ker_idx, row_idx, col_idx, vals).contiguous()
            self.register_buffer("psi_roff_idx", roff_idx, persistent=False)

        # save all datastructures
Boris Bonev's avatar
Boris Bonev committed
293
294
295
        self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
        self.register_buffer("psi_row_idx", row_idx, persistent=False)
        self.register_buffer("psi_col_idx", col_idx, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
296
        self.register_buffer("psi_vals", vals, persistent=False)
297

298
299
300
301
    def extra_repr(self):
        r"""
        Pretty print module
        """
302
        return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
303

Boris Bonev's avatar
Boris Bonev committed
304
305
306
307
    @property
    def psi_idx(self):
        return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()

Boris Bonev's avatar
Boris Bonev committed
308
309
310
    def get_psi(self):
        psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in)).coalesce()
        return psi
311

Boris Bonev's avatar
Boris Bonev committed
312
    def forward(self, x: torch.Tensor) -> torch.Tensor:
313

Boris Bonev's avatar
Boris Bonev committed
314
315
316
317
        if x.is_cuda and _cuda_extension_available:
            x = _disco_s2_contraction_cuda(
                x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out
            )
318
        else:
Boris Bonev's avatar
Boris Bonev committed
319
320
321
            if x.is_cuda:
                warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
            psi = self.get_psi()
Boris Bonev's avatar
Boris Bonev committed
322
            x = _disco_s2_contraction_torch(x, psi, self.nlon_out)
323
324
325
326
327
328

        # extract shape
        B, C, K, H, W = x.shape
        x = x.reshape(B, self.groups, self.groupsize, K, H, W)

        # do weight multiplication
Thorsten Kurth's avatar
Thorsten Kurth committed
329
        out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
Boris Bonev's avatar
Boris Bonev committed
330
        out = out.reshape(B, -1, H, W)
331
332
333
334
335
336
337

        if self.bias is not None:
            out = out + self.bias.reshape(1, -1, 1, 1)

        return out


Boris Bonev's avatar
Boris Bonev committed
338
class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
    """
    Discrete-continuous transpose convolutions (DISCO) on the 2-Sphere as described in [1].

    [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        in_shape: Tuple[int],
        out_shape: Tuple[int],
        kernel_shape: Union[int, List[int]],
        groups: Optional[int] = 1,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        bias: Optional[bool] = True,
        theta_cutoff: Optional[float] = None,
    ):
Boris Bonev's avatar
Boris Bonev committed
358
        super().__init__(in_channels, out_channels, kernel_shape, groups, bias)
359
360
361
362
363
364

        self.nlat_in, self.nlon_in = in_shape
        self.nlat_out, self.nlon_out = out_shape

        # bandlimit
        if theta_cutoff is None:
Boris Bonev's avatar
Boris Bonev committed
365
            theta_cutoff = torch.pi / float(self.nlat_in - 1)
366
367
368
369
370

        if theta_cutoff <= 0.0:
            raise ValueError("Error, theta_cutoff has to be positive.")

        # switch in_shape and out_shape since we want transpose conv
Boris Bonev's avatar
Boris Bonev committed
371
        idx, vals = _precompute_convolution_tensor_s2(
372
            out_shape, in_shape, self.filter_basis, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff, transpose_normalization=True, merge_quadrature=True
Boris Bonev's avatar
Boris Bonev committed
373
374
375
376
377
378
        )

        # sort the values
        ker_idx = idx[0, ...].contiguous()
        row_idx = idx[1, ...].contiguous()
        col_idx = idx[2, ...].contiguous()
379
        vals = vals.contiguous()
Boris Bonev's avatar
Boris Bonev committed
380

381
382
383
384
385
386
        if _cuda_extension_available:
            # preprocessed data-structure for GPU kernel
            roff_idx = preprocess_psi(self.kernel_size, in_shape[0], ker_idx, row_idx, col_idx, vals).contiguous()
            self.register_buffer("psi_roff_idx", roff_idx, persistent=False)

        # save all datastructures
Boris Bonev's avatar
Boris Bonev committed
387
388
389
        self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
        self.register_buffer("psi_row_idx", row_idx, persistent=False)
        self.register_buffer("psi_col_idx", col_idx, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
390
        self.register_buffer("psi_vals", vals, persistent=False)
391

392
393
394
395
    def extra_repr(self):
        r"""
        Pretty print module
        """
396
        return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
397

Boris Bonev's avatar
Boris Bonev committed
398
399
400
401
402
403
    @property
    def psi_idx(self):
        return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()

    def get_psi(self, semi_transposed: bool = False):
        if semi_transposed:
404
405
406
407
408
409
            # we do a semi-transposition to faciliate the computation
            tout = self.psi_idx[2] // self.nlon_out
            pout = self.psi_idx[2] % self.nlon_out
            # flip the axis of longitudes
            pout = self.nlon_out - 1 - pout
            tin = self.psi_idx[1]
Boris Bonev's avatar
Boris Bonev committed
410
            idx = torch.stack([self.psi_idx[0], tout, tin * self.nlon_out + pout], dim=0)
411
412
413
            psi = torch.sparse_coo_tensor(idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_out)).coalesce()
        else:
            psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out)).coalesce()
Boris Bonev's avatar
Boris Bonev committed
414

Boris Bonev's avatar
Boris Bonev committed
415
        return psi
416

Boris Bonev's avatar
Boris Bonev committed
417
    def forward(self, x: torch.Tensor) -> torch.Tensor:
418
        # extract shape
Boris Bonev's avatar
Boris Bonev committed
419
        B, C, H, W = x.shape
420
421
422
        x = x.reshape(B, self.groups, self.groupsize, H, W)

        # do weight multiplication
Thorsten Kurth's avatar
Thorsten Kurth committed
423
        x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
Boris Bonev's avatar
Boris Bonev committed
424
        x = x.reshape(B, -1, x.shape[-3], H, W)
425

Boris Bonev's avatar
Boris Bonev committed
426
427
428
429
        if x.is_cuda and _cuda_extension_available:
            out = _disco_s2_transpose_contraction_cuda(
                x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out
            )
430
        else:
Boris Bonev's avatar
Boris Bonev committed
431
432
433
            if x.is_cuda:
                warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
            psi = self.get_psi(semi_transposed=True)
Boris Bonev's avatar
Boris Bonev committed
434
            out = _disco_s2_transpose_contraction_torch(x, psi, self.nlon_out)
435
436
437
438
439

        if self.bias is not None:
            out = out + self.bias.reshape(1, -1, 1, 1)

        return out