convolution.py 19.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

Boris Bonev's avatar
Boris Bonev committed
32
import abc
33
from typing import List, Tuple, Union, Optional
Boris Bonev's avatar
Boris Bonev committed
34
from warnings import warn
35
36
37
38
39
40
41
42

import math

import torch
import torch.nn as nn

from functools import partial

Thorsten Kurth's avatar
Thorsten Kurth committed
43
44
from torch_harmonics.cache import lru_cache
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes
45
from torch_harmonics._disco_convolution import _get_psi, _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch
Boris Bonev's avatar
Boris Bonev committed
46
from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda
Thorsten Kurth's avatar
Thorsten Kurth committed
47
from torch_harmonics.filter_basis import FilterBasis, get_filter_basis
48

49
# import custom C++/CUDA extensions if available
Boris Bonev's avatar
Boris Bonev committed
50
try:
51
    from disco_helpers import preprocess_psi
Boris Bonev's avatar
Boris Bonev committed
52
    import disco_cuda_extension
53

Boris Bonev's avatar
Boris Bonev committed
54
55
56
57
58
59
    _cuda_extension_available = True
except ImportError as err:
    disco_cuda_extension = None
    _cuda_extension_available = False


60
def _normalize_convolution_tensor_s2(
61
    psi_idx, psi_vals, in_shape, out_shape, kernel_size, quad_weights, transpose_normalization=False, basis_norm_mode="mean", merge_quadrature=False, eps=1e-9
62
):
Boris Bonev's avatar
Boris Bonev committed
63
    """
64
65
66
67
    Discretely normalizes the convolution tensor and pre-applies quadrature weights. Supports the following three normalization modes:
    - "none": No normalization is applied.
    - "individual": for each output latitude and filter basis function the filter is numerically integrated over the sphere and normalized so that it yields 1.
    - "mean": the norm is computed for each output latitude and then averaged over the output latitudes. Each basis function is then normalized by this mean.
Boris Bonev's avatar
Boris Bonev committed
68
69
    """

70
71
    # reshape the indices implicitly to be ikernel, out_shape[0], in_shape[0], in_shape[1]
    idx = torch.stack([psi_idx[0], psi_idx[1], psi_idx[2] // in_shape[1], psi_idx[2] % in_shape[1]], dim=0)
Boris Bonev's avatar
Boris Bonev committed
72

73
74
    # getting indices for adressing kernels, input and output latitudes
    ikernel = idx[0]
Boris Bonev's avatar
Boris Bonev committed
75
76

    if transpose_normalization:
77
78
79
80
81
        ilat_out = idx[2]
        ilat_in = idx[1]
        # here we are deliberately swapping input and output shapes to handle transpose normalization with the same code
        nlat_out = in_shape[0]
        correction_factor = out_shape[1] / in_shape[1]
Boris Bonev's avatar
Boris Bonev committed
82
    else:
83
84
85
86
87
88
89
90
91
        ilat_out = idx[1]
        ilat_in = idx[2]
        nlat_out = out_shape[0]

    # get the quadrature weights
    q = quad_weights[ilat_in].reshape(-1)

    # buffer to store intermediate values
    vnorm = torch.zeros(kernel_size, nlat_out)
92
    support = torch.zeros(kernel_size, nlat_out)
93
94
95
96
97
98
99
100

    # loop through dimensions to compute the norms
    for ik in range(kernel_size):
        for ilat in range(nlat_out):

            # find indices corresponding to the given output latitude and kernel basis function
            iidx = torch.argwhere((ikernel == ik) & (ilat_out == ilat))

101
102
103
            # compute the 1-norm
            # vnorm[ik, ilat] = torch.sqrt(torch.sum(psi_vals[iidx].abs().pow(2) * q[iidx]))
            vnorm[ik, ilat] = torch.sum(psi_vals[iidx].abs() * q[iidx])
104

105
106
107
108
            # compute the support
            support[ik, ilat] = torch.sum(q[iidx])


109
110
111
112
113
114
115
116
117
118
    # loop over values and renormalize
    for ik in range(kernel_size):
        for ilat in range(nlat_out):

            iidx = torch.argwhere((ikernel == ik) & (ilat_out == ilat))

            if basis_norm_mode == "individual":
                val = vnorm[ik, ilat]
            elif basis_norm_mode == "mean":
                val = vnorm[ik, :].mean()
119
120
            elif basis_norm_mode == "support":
                val = support[ik, ilat]
121
122
123
124
125
            elif basis_norm_mode == "none":
                val = 1.0
            else:
                raise ValueError(f"Unknown basis normalization mode {basis_norm_mode}.")

126
127
            psi_vals[iidx] = psi_vals[iidx] / (val + eps)

128
            if merge_quadrature:
129
                psi_vals[iidx] = psi_vals[iidx] * q[iidx]
130
131
132
133


    if transpose_normalization and merge_quadrature:
        psi_vals = psi_vals / correction_factor
Boris Bonev's avatar
Boris Bonev committed
134
135
136
137

    return psi_vals


Thorsten Kurth's avatar
Thorsten Kurth committed
138
@lru_cache(typed=True, copy=True)
Boris Bonev's avatar
Boris Bonev committed
139
def _precompute_convolution_tensor_s2(
Thorsten Kurth's avatar
Thorsten Kurth committed
140
141
142
143
144
145
146
147
148
149
    in_shape: Tuple[int],
    out_shape: Tuple[int],
    filter_basis: FilterBasis,
    grid_in: Optional[str]="equiangular",
    grid_out: Optional[str]="equiangular",
    theta_cutoff: Optional[float]=0.01 * math.pi,
    theta_eps: Optional[float]=1e-3,
    transpose_normalization: Optional[bool]=False,
    basis_norm_mode: Optional[str]="mean",
    merge_quadrature: Optional[bool]=False,
Boris Bonev's avatar
Boris Bonev committed
150
):
151
152
153
    """
    Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
    Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al.
154
155
156
157
158
    The output tensor has shape kernel_shape x nlat_out x (nlat_in * nlon_in).

    The rotation of the Euler angles uses the YZY convention, which applied to the northpole $(0,0,1)^T$ yields
    $$
    Y(\alpha) Z(\beta) Y(\gamma) n =
Boris Bonev's avatar
Boris Bonev committed
159
        {\begin{bmatrix}
160
161
162
163
164
            \cos(\gamma)\sin(\alpha) + \cos(\alpha)\cos(\beta)\sin(\gamma) \\
            \sin(\beta)\sin(\gamma) \\
            \cos(\alpha)\cos(\gamma)-\cos(\beta)\sin(\alpha)\sin(\gamma)
        \end{bmatrix}}
    $$
165
166
167
168
169
    """

    assert len(in_shape) == 2
    assert len(out_shape) == 2

170
    kernel_size = filter_basis.kernel_size
171
172
173
174

    nlat_in, nlon_in = in_shape
    nlat_out, nlon_out = out_shape

175
    # precompute input and output grids
Boris Bonev's avatar
Boris Bonev committed
176
177
    lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in)
    lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out)
178
179

    # compute the phi differences
180
    # It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
Thorsten Kurth's avatar
Thorsten Kurth committed
181
    lons_in = _precompute_longitudes(nlon_in)
182

183
184
    # compute quadrature weights and merge them into the convolution tensor.
    # These quadrature integrate to 1 over the sphere.
185
    if transpose_normalization:
Thorsten Kurth's avatar
Thorsten Kurth committed
186
        quad_weights = wout.reshape(-1, 1) / nlon_in / 2.0
187
    else:
Thorsten Kurth's avatar
Thorsten Kurth committed
188
        quad_weights = win.reshape(-1, 1) / nlon_in / 2.0
189
190
191

    # effective theta cutoff if multiplied with a fudge factor to avoid aliasing with grid width (especially near poles)
    theta_cutoff_eff = (1.0 + theta_eps) * theta_cutoff
192

193
194
    out_idx = []
    out_vals = []
195
    for t in range(nlat_out):
196
        # the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis
Boris Bonev's avatar
Boris Bonev committed
197
        alpha = -lats_out[t]
198
199
        beta = lons_in
        gamma = lats_in.reshape(-1, 1)
200
201

        # compute cartesian coordinates of the rotated position
202
203
204
        # This uses the YZY convention of Euler angles, where the last angle (alpha) is a passive rotation,
        # and therefore applied with a negative sign
        x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(gamma) * torch.sin(alpha)
205
        y = torch.sin(beta) * torch.sin(gamma)
206
        z = -torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(alpha) * torch.cos(gamma)
Boris Bonev's avatar
Boris Bonev committed
207

208
        # normalization is important to avoid NaNs when arccos and atan are applied
209
        # this can otherwise lead to spurious artifacts in the solution
Boris Bonev's avatar
Boris Bonev committed
210
        norm = torch.sqrt(x * x + y * y + z * z)
211
212
213
214
        x = x / norm
        y = y / norm
        z = z / norm

Boris Bonev's avatar
Boris Bonev committed
215
        # compute spherical coordinates, where phi needs to fall into the [0, 2pi) range
216
        theta = torch.arccos(z)
217
218
        phi = torch.arctan2(y, x)
        phi = torch.where(phi < 0.0, phi + 2 * torch.pi, phi)
219
220

        # find the indices where the rotated position falls into the support of the kernel
221
        iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff_eff)
222
223
224
225
226

        # add the output latitude and reshape such that psi has dimensions kernel_shape x nlat_out x (nlat_in*nlon_in)
        idx = torch.stack([iidx[:, 0], t * torch.ones_like(iidx[:, 0]), iidx[:, 1] * nlon_in + iidx[:, 2]], dim=0)

        # append indices and values to the COO datastructure
227
228
229
230
        out_idx.append(idx)
        out_vals.append(vals)

    # concatenate the indices and values
231
232
    out_idx = torch.cat(out_idx, dim=-1)
    out_vals = torch.cat(out_vals, dim=-1)
233

234
    out_vals = _normalize_convolution_tensor_s2(
235
236
237
238
239
240
241
242
243
        out_idx,
        out_vals,
        in_shape,
        out_shape,
        kernel_size,
        quad_weights,
        transpose_normalization=transpose_normalization,
        basis_norm_mode=basis_norm_mode,
        merge_quadrature=merge_quadrature,
244
    )
Boris Bonev's avatar
Boris Bonev committed
245

246
247
248
    out_idx = out_idx.contiguous()
    out_vals = out_vals.to(dtype=torch.float32).contiguous()

Boris Bonev's avatar
Boris Bonev committed
249
    return out_idx, out_vals
Boris Bonev's avatar
Boris Bonev committed
250
251
252
253


class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta):
    """
254
    Abstract base class for discrete-continuous convolutions
Boris Bonev's avatar
Boris Bonev committed
255
256
257
258
259
260
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
Thorsten Kurth's avatar
Thorsten Kurth committed
261
        kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
262
        basis_type: Optional[str] = "piecewise linear",
Boris Bonev's avatar
Boris Bonev committed
263
264
265
266
267
        groups: Optional[int] = 1,
        bias: Optional[bool] = True,
    ):
        super().__init__()

268
        self.kernel_shape = kernel_shape
Boris Bonev's avatar
Boris Bonev committed
269

270
        # get the filter basis functions
271
        self.filter_basis = get_filter_basis(kernel_shape=kernel_shape, basis_type=basis_type)
Boris Bonev's avatar
Boris Bonev committed
272
273
274
275
276
277
278
279
280
281

        # groups
        self.groups = groups

        # weight tensor
        if in_channels % self.groups != 0:
            raise ValueError("Error, the number of input channels has to be an integer multiple of the group size")
        if out_channels % self.groups != 0:
            raise ValueError("Error, the number of output channels has to be an integer multiple of the group size")
        self.groupsize = in_channels // self.groups
Boris Bonev's avatar
Boris Bonev committed
282
        scale = math.sqrt(1.0 / self.groupsize / self.kernel_size)
Boris Bonev's avatar
Boris Bonev committed
283
284
285
286
287
288
289
        self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, self.kernel_size))

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channels))
        else:
            self.bias = None

290
291
292
293
    @property
    def kernel_size(self):
        return self.filter_basis.kernel_size

Boris Bonev's avatar
Boris Bonev committed
294
295
296
297
298
299
    @abc.abstractmethod
    def forward(self, x: torch.Tensor):
        raise NotImplementedError


class DiscreteContinuousConvS2(DiscreteContinuousConv):
300
    """
301
    Discrete-continuous (DISCO) convolutions on the 2-Sphere as described in [1].
302
303
304
305
306
307
308
309
310
311

    [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        in_shape: Tuple[int],
        out_shape: Tuple[int],
Thorsten Kurth's avatar
Thorsten Kurth committed
312
        kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
313
        basis_type: Optional[str] = "piecewise linear",
314
        basis_norm_mode: Optional[str] = "mean",
315
316
317
318
319
320
        groups: Optional[int] = 1,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        bias: Optional[bool] = True,
        theta_cutoff: Optional[float] = None,
    ):
321
        super().__init__(in_channels, out_channels, kernel_shape, basis_type, groups, bias)
322
323
324
325

        self.nlat_in, self.nlon_in = in_shape
        self.nlat_out, self.nlon_out = out_shape

326
327
328
        # make sure the p-shift works by checking that longitudes are divisible
        assert self.nlon_in % self.nlon_out == 0

Boris Bonev's avatar
Boris Bonev committed
329
        # heuristic to compute theta cutoff based on the bandlimit of the input field and overlaps of the basis functions
330
        if theta_cutoff is None:
Boris Bonev's avatar
Boris Bonev committed
331
            theta_cutoff = torch.pi / float(self.nlat_out - 1)
332
333
334
335

        if theta_cutoff <= 0.0:
            raise ValueError("Error, theta_cutoff has to be positive.")

Boris Bonev's avatar
Boris Bonev committed
336
        idx, vals = _precompute_convolution_tensor_s2(
337
338
339
340
341
342
343
344
345
            in_shape,
            out_shape,
            self.filter_basis,
            grid_in=grid_in,
            grid_out=grid_out,
            theta_cutoff=theta_cutoff,
            transpose_normalization=False,
            basis_norm_mode=basis_norm_mode,
            merge_quadrature=True,
Boris Bonev's avatar
Boris Bonev committed
346
347
348
349
350
351
        )

        # sort the values
        ker_idx = idx[0, ...].contiguous()
        row_idx = idx[1, ...].contiguous()
        col_idx = idx[2, ...].contiguous()
352
        vals = vals.contiguous()
Boris Bonev's avatar
Boris Bonev committed
353

354
355
356
357
358
359
        if _cuda_extension_available:
            # preprocessed data-structure for GPU kernel
            roff_idx = preprocess_psi(self.kernel_size, out_shape[0], ker_idx, row_idx, col_idx, vals).contiguous()
            self.register_buffer("psi_roff_idx", roff_idx, persistent=False)

        # save all datastructures
Boris Bonev's avatar
Boris Bonev committed
360
361
362
        self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
        self.register_buffer("psi_row_idx", row_idx, persistent=False)
        self.register_buffer("psi_col_idx", col_idx, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
363
        self.register_buffer("psi_vals", vals, persistent=False)
364

365
366
367
        # also store psi as COO matrix just in case for torch input
        self.psi = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out)

368
369
370
371
    def extra_repr(self):
        r"""
        Pretty print module
        """
372
        return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
373

Boris Bonev's avatar
Boris Bonev committed
374
375
376
377
378
    @property
    def psi_idx(self):
        return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
379

Boris Bonev's avatar
Boris Bonev committed
380
381
382
383
        if x.is_cuda and _cuda_extension_available:
            x = _disco_s2_contraction_cuda(
                x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out
            )
384
        else:
Boris Bonev's avatar
Boris Bonev committed
385
386
            if x.is_cuda:
                warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
387
            x = _disco_s2_contraction_torch(x, self.psi.to(x.device), self.nlon_out)
388
389
390
391
392
393

        # extract shape
        B, C, K, H, W = x.shape
        x = x.reshape(B, self.groups, self.groupsize, K, H, W)

        # do weight multiplication
Thorsten Kurth's avatar
Thorsten Kurth committed
394
        out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
Boris Bonev's avatar
Boris Bonev committed
395
        out = out.reshape(B, -1, H, W)
396
397
398
399
400
401
402

        if self.bias is not None:
            out = out + self.bias.reshape(1, -1, 1, 1)

        return out


Boris Bonev's avatar
Boris Bonev committed
403
class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
404
    """
405
    Discrete-continuous (DISCO) transpose convolutions on the 2-Sphere as described in [1].
406
407
408
409
410
411
412
413
414
415

    [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        in_shape: Tuple[int],
        out_shape: Tuple[int],
Thorsten Kurth's avatar
Thorsten Kurth committed
416
        kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
417
        basis_type: Optional[str] = "piecewise linear",
418
        basis_norm_mode: Optional[str] = "mean",
419
420
421
422
423
424
        groups: Optional[int] = 1,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        bias: Optional[bool] = True,
        theta_cutoff: Optional[float] = None,
    ):
425
        super().__init__(in_channels, out_channels, kernel_shape, basis_type, groups, bias)
426
427
428
429

        self.nlat_in, self.nlon_in = in_shape
        self.nlat_out, self.nlon_out = out_shape

430
431
432
        # make sure the p-shift works by checking that longitudes are divisible
        assert self.nlon_out % self.nlon_in == 0

433
434
        # bandlimit
        if theta_cutoff is None:
Boris Bonev's avatar
Boris Bonev committed
435
            theta_cutoff = torch.pi / float(self.nlat_in - 1)
436
437
438
439

        if theta_cutoff <= 0.0:
            raise ValueError("Error, theta_cutoff has to be positive.")

440
        # switch in_shape and out_shape since we want the transpose convolution
Boris Bonev's avatar
Boris Bonev committed
441
        idx, vals = _precompute_convolution_tensor_s2(
442
443
444
445
446
447
448
449
450
            out_shape,
            in_shape,
            self.filter_basis,
            grid_in=grid_out,
            grid_out=grid_in,
            theta_cutoff=theta_cutoff,
            transpose_normalization=True,
            basis_norm_mode=basis_norm_mode,
            merge_quadrature=True,
Boris Bonev's avatar
Boris Bonev committed
451
452
453
454
455
456
        )

        # sort the values
        ker_idx = idx[0, ...].contiguous()
        row_idx = idx[1, ...].contiguous()
        col_idx = idx[2, ...].contiguous()
457
        vals = vals.contiguous()
Boris Bonev's avatar
Boris Bonev committed
458

459
460
461
462
463
464
        if _cuda_extension_available:
            # preprocessed data-structure for GPU kernel
            roff_idx = preprocess_psi(self.kernel_size, in_shape[0], ker_idx, row_idx, col_idx, vals).contiguous()
            self.register_buffer("psi_roff_idx", roff_idx, persistent=False)

        # save all datastructures
Boris Bonev's avatar
Boris Bonev committed
465
466
467
        self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
        self.register_buffer("psi_row_idx", row_idx, persistent=False)
        self.register_buffer("psi_col_idx", col_idx, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
468
        self.register_buffer("psi_vals", vals, persistent=False)
469

470
471
472
        # also store psi just in case
        self.psi_st = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, semi_transposed=True)

473
474
475
476
    def extra_repr(self):
        r"""
        Pretty print module
        """
477
        return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
478

Boris Bonev's avatar
Boris Bonev committed
479
480
481
482
483
    @property
    def psi_idx(self):
        return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
484
        # extract shape
Boris Bonev's avatar
Boris Bonev committed
485
        B, C, H, W = x.shape
486
487
488
        x = x.reshape(B, self.groups, self.groupsize, H, W)

        # do weight multiplication
Thorsten Kurth's avatar
Thorsten Kurth committed
489
        x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
Boris Bonev's avatar
Boris Bonev committed
490
        x = x.reshape(B, -1, x.shape[-3], H, W)
491

Boris Bonev's avatar
Boris Bonev committed
492
493
494
495
        if x.is_cuda and _cuda_extension_available:
            out = _disco_s2_transpose_contraction_cuda(
                x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out, self.nlon_out
            )
496
        else:
Boris Bonev's avatar
Boris Bonev committed
497
498
            if x.is_cuda:
                warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
499
            out = _disco_s2_transpose_contraction_torch(x, self.psi_st.to(x.device), self.nlon_out)
500
501
502
503
504

        if self.bias is not None:
            out = out + self.bias.reshape(1, -1, 1, 1)

        return out