distributed_convolution.py 20.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import abc
from typing import List, Tuple, Union, Optional
Boris Bonev's avatar
Boris Bonev committed
34
35
from itertools import accumulate
from warnings import warn
36
37
38
39
40
41
42
43
44

import math

import torch
import torch.nn as nn

from functools import partial

from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes
Boris Bonev's avatar
Boris Bonev committed
45
46
from torch_harmonics._disco_convolution import _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch
from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda
47
from torch_harmonics.filter_basis import get_filter_basis
48
from torch_harmonics.convolution import (
Boris Bonev's avatar
Boris Bonev committed
49
    _normalize_convolution_tensor_s2,
50
51
52
    DiscreteContinuousConv,
)

53

54
55
from torch_harmonics.distributed import polar_group_size, azimuth_group_size
from torch_harmonics.distributed import distributed_transpose_azimuth, distributed_transpose_polar
56
from torch_harmonics.distributed import reduce_from_polar_region, scatter_to_polar_region, gather_from_polar_region, copy_to_polar_region
57
58
59
from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank
from torch_harmonics.distributed import compute_split_shapes, split_tensor_along_dim

60
# import custom C++/CUDA extensions if available
Boris Bonev's avatar
Boris Bonev committed
61
try:
62
    from disco_helpers import preprocess_psi
Boris Bonev's avatar
Boris Bonev committed
63
    import disco_cuda_extension
64

Boris Bonev's avatar
Boris Bonev committed
65
66
67
68
69
70
71
    _cuda_extension_available = True
except ImportError as err:
    disco_cuda_extension = None
    _cuda_extension_available = False


def _precompute_distributed_convolution_tensor_s2(
72
73
    in_shape,
    out_shape,
74
    filter_basis,
75
76
77
    grid_in="equiangular",
    grid_out="equiangular",
    theta_cutoff=0.01 * math.pi,
78
    theta_eps = 1e-3,
79
    transpose_normalization=False,
80
    basis_norm_mode="mean",
81
    merge_quadrature=False,
Boris Bonev's avatar
Boris Bonev committed
82
):
83
84
85
86
87
88
89
90
    """
    Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
    Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al.
    The output tensor has shape kernel_shape x nlat_out x (nlat_in * nlon_in).

    The rotation of the Euler angles uses the YZY convention, which applied to the northpole $(0,0,1)^T$ yields
    $$
    Y(\alpha) Z(\beta) Y(\gamma) n =
Boris Bonev's avatar
Boris Bonev committed
91
        {\begin{bmatrix}
92
93
94
95
96
97
98
99
100
101
            \cos(\gamma)\sin(\alpha) + \cos(\alpha)\cos(\beta)\sin(\gamma) \\
            \sin(\beta)\sin(\gamma) \\
            \cos(\alpha)\cos(\gamma)-\cos(\beta)\sin(\alpha)\sin(\gamma)
        \end{bmatrix}}
    $$
    """

    assert len(in_shape) == 2
    assert len(out_shape) == 2

102
    kernel_size = filter_basis.kernel_size
103
104
105
106

    nlat_in, nlon_in = in_shape
    nlat_out, nlon_out = out_shape

107
    # precompute input and output grids
Boris Bonev's avatar
Boris Bonev committed
108
    lats_in, win = _precompute_latitudes(nlat_in, grid=grid_in)
109
    lats_in = torch.from_numpy(lats_in)
Boris Bonev's avatar
Boris Bonev committed
110
    lats_out, wout = _precompute_latitudes(nlat_out, grid=grid_out)
111
    lats_out = torch.from_numpy(lats_out)
112
113
114

    # compute the phi differences
    # It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
115
    lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1, dtype=torch.float64)[:-1]
Boris Bonev's avatar
Boris Bonev committed
116

117
118
    # compute quadrature weights and merge them into the convolution tensor.
    # These quadrature integrate to 1 over the sphere.
119
    if transpose_normalization:
120
        quad_weights = torch.from_numpy(wout).reshape(-1, 1) / nlon_in / 2.0
121
    else:
122
123
124
125
        quad_weights = torch.from_numpy(win).reshape(-1, 1) / nlon_in / 2.0

    # effective theta cutoff if multiplied with a fudge factor to avoid aliasing with grid width (especially near poles)
    theta_cutoff_eff = (1.0 + theta_eps) * theta_cutoff
126

127
128
129
130
131
132
133
134
135
136
137
138
139
    out_idx = []
    out_vals = []
    for t in range(nlat_out):
        # the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis
        alpha = -lats_out[t]
        beta = lons_in
        gamma = lats_in.reshape(-1, 1)

        # compute cartesian coordinates of the rotated position
        # This uses the YZY convention of Euler angles, where the last angle (alpha) is a passive rotation,
        # and therefore applied with a negative sign
        x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(gamma) * torch.sin(alpha)
        y = torch.sin(beta) * torch.sin(gamma)
140
        z = -torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(alpha) * torch.cos(gamma)
141

142
        # normalization is important to avoid NaNs when arccos and atan are applied
143
144
145
146
147
148
149
150
        # this can otherwise lead to spurious artifacts in the solution
        norm = torch.sqrt(x * x + y * y + z * z)
        x = x / norm
        y = y / norm
        z = z / norm

        # compute spherical coordinates, where phi needs to fall into the [0, 2pi) range
        theta = torch.arccos(z)
151
152
        phi = torch.arctan2(y, x)
        phi = torch.where(phi < 0.0, phi + 2 * torch.pi, phi)
153
154

        # find the indices where the rotated position falls into the support of the kernel
155
        iidx, vals = filter_basis.compute_support_vals(theta, phi, r_cutoff=theta_cutoff_eff)
156
157
158
159
160
161
162
163
164

        # add the output latitude and reshape such that psi has dimensions kernel_shape x nlat_out x (nlat_in*nlon_in)
        idx = torch.stack([iidx[:, 0], t * torch.ones_like(iidx[:, 0]), iidx[:, 1] * nlon_in + iidx[:, 2]], dim=0)

        # append indices and values to the COO datastructure
        out_idx.append(idx)
        out_vals.append(vals)

    # concatenate the indices and values
165
166
    out_idx = torch.cat(out_idx, dim=-1)
    out_vals = torch.cat(out_vals, dim=-1)
Boris Bonev's avatar
Boris Bonev committed
167

168
    out_vals = _normalize_convolution_tensor_s2(
169
170
171
172
173
174
175
176
177
        out_idx,
        out_vals,
        in_shape,
        out_shape,
        kernel_size,
        quad_weights,
        transpose_normalization=transpose_normalization,
        basis_norm_mode=basis_norm_mode,
        merge_quadrature=merge_quadrature,
178
    )
Boris Bonev's avatar
Boris Bonev committed
179
180
181
182
183
184
185
186

    # TODO: this part can be split off into it's own function
    # split the latitude indices:
    comm_size_polar = polar_group_size()
    comm_rank_polar = polar_group_rank()
    split_shapes = compute_split_shapes(nlat_in, num_chunks=comm_size_polar)
    offsets = [0] + list(accumulate(split_shapes))
    start_idx = offsets[comm_rank_polar]
187
    end_idx = offsets[comm_rank_polar + 1]
Boris Bonev's avatar
Boris Bonev committed
188
189
190
191
192
193
194

    # once normalization is done we can throw away the entries which correspond to input latitudes we do not care about
    lats = out_idx[2] // nlon_in
    lons = out_idx[2] % nlon_in
    ilats = torch.argwhere((lats < end_idx) & (lats >= start_idx)).squeeze()
    out_vals = out_vals[ilats]
    # for the indices we need to recompute them to refer to local indices of the input tenor
195
    out_idx = torch.stack([out_idx[0, ilats], out_idx[1, ilats], (lats[ilats] - start_idx) * nlon_in + lons[ilats]], dim=0)
196

197
198
199
    out_idx = out_idx.contiguous()
    out_vals = out_vals.to(dtype=torch.float32).contiguous()

200
201
    return out_idx, out_vals

Boris Bonev's avatar
Boris Bonev committed
202

203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
    """
    Distributed version of Discrete-continuous convolutions (DISCO) on the 2-Sphere as described in [1].

    [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603

    We assume the data can be splitted in polar and azimuthal directions.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        in_shape: Tuple[int],
        out_shape: Tuple[int],
        kernel_shape: Union[int, List[int]],
219
        basis_type: Optional[str] = "piecewise linear",
220
        basis_norm_mode: Optional[str] = "mean",
221
222
223
224
225
226
        groups: Optional[int] = 1,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        bias: Optional[bool] = True,
        theta_cutoff: Optional[float] = None,
    ):
227
        super().__init__(in_channels, out_channels, kernel_shape, basis_type, groups, bias)
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245

        self.nlat_in, self.nlon_in = in_shape
        self.nlat_out, self.nlon_out = out_shape

        # get the comms grid:
        self.comm_size_polar = polar_group_size()
        self.comm_rank_polar = polar_group_rank()
        self.comm_size_azimuth = azimuth_group_size()
        self.comm_rank_azimuth = azimuth_group_rank()

        # we need those shapes:
        self.lat_in_shapes = compute_split_shapes(self.nlat_in, self.comm_size_polar)
        self.lon_in_shapes = compute_split_shapes(self.nlon_in, self.comm_size_azimuth)
        self.lat_out_shapes = compute_split_shapes(self.nlat_out, self.comm_size_polar)
        self.lon_out_shapes = compute_split_shapes(self.nlon_out, self.comm_size_azimuth)

        # compute theta cutoff based on the bandlimit of the input field
        if theta_cutoff is None:
Thorsten Kurth's avatar
Thorsten Kurth committed
246
            theta_cutoff = torch.pi / float(self.nlat_out - 1)
247
248
249
250
251
252
253
254
255

        if theta_cutoff <= 0.0:
            raise ValueError("Error, theta_cutoff has to be positive.")

        # Note that the psi matrix is of shape nlat_out x nlat_in * nlon_in. Since the contraction in nlon direction is a convolution,
        # we will keep local to all nodes and split the computation up along nlat. We further split the input dim because this reduces the number
        # of atomic reduction calls inside the actual kernel

        # set local shapes according to distributed mode:
Boris Bonev's avatar
Boris Bonev committed
256
257
        self.nlat_in_local = self.lat_in_shapes[self.comm_rank_polar]
        self.nlat_out_local = self.nlat_out
258

Boris Bonev's avatar
Boris Bonev committed
259
        idx, vals = _precompute_distributed_convolution_tensor_s2(
260
261
262
263
264
265
266
267
268
            in_shape,
            out_shape,
            self.filter_basis,
            grid_in=grid_in,
            grid_out=grid_out,
            theta_cutoff=theta_cutoff,
            transpose_normalization=False,
            basis_norm_mode=basis_norm_mode,
            merge_quadrature=True,
Boris Bonev's avatar
Boris Bonev committed
269
        )
270

Boris Bonev's avatar
Boris Bonev committed
271
272
273
274
        # sort the values
        ker_idx = idx[0, ...].contiguous()
        row_idx = idx[1, ...].contiguous()
        col_idx = idx[2, ...].contiguous()
275
        vals = vals.contiguous()
Boris Bonev's avatar
Boris Bonev committed
276

277
278
279
280
281
        if _cuda_extension_available:
            # preprocessed data-structure for GPU kernel
            roff_idx = preprocess_psi(self.kernel_size, self.nlat_out_local, ker_idx, row_idx, col_idx, vals).contiguous()
            self.register_buffer("psi_roff_idx", roff_idx, persistent=False)

282
        # save all datastructures
Boris Bonev's avatar
Boris Bonev committed
283
284
285
        self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
        self.register_buffer("psi_row_idx", row_idx, persistent=False)
        self.register_buffer("psi_col_idx", col_idx, persistent=False)
286
287
        self.register_buffer("psi_vals", vals, persistent=False)

288
289
290
291
    def extra_repr(self):
        r"""
        Pretty print module
        """
292
        return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
293

Boris Bonev's avatar
Boris Bonev committed
294
295
296
297
    @property
    def psi_idx(self):
        return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()

298
299
300
301
    def get_psi(self):
        psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_out_local, self.nlat_in_local * self.nlon_in)).coalesce()
        return psi

Boris Bonev's avatar
Boris Bonev committed
302
    def forward(self, x: torch.Tensor) -> torch.Tensor:
303
304
305

        # store number of channels
        num_chans = x.shape[1]
306

307
308
309
310
        # h and w is split. First we make w local by transposing into channel dim
        if self.comm_size_azimuth > 1:
            x = distributed_transpose_azimuth.apply(x, (1, -1), self.lon_in_shapes)

Boris Bonev's avatar
Boris Bonev committed
311
312
313
314
315
316
317
        if x.is_cuda and _cuda_extension_available:
            x = _disco_s2_contraction_cuda(
                x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out_local, self.nlon_out
            )
        else:
            if x.is_cuda:
                warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
318

Boris Bonev's avatar
Boris Bonev committed
319
            psi = self.get_psi()
320
321
322

            x = _disco_s2_contraction_torch(x, psi, self.nlon_out)

Thorsten Kurth's avatar
Thorsten Kurth committed
323
        # perform reduce scatter in polar region
324
        x = reduce_from_polar_region(x)
325
        x = scatter_to_polar_region(x, -2)
326
327
328
329
330
331
332
333
334
335
336
337

        # now we can transpose back the result, so that lon is split and channels are local
        if self.comm_size_azimuth > 1:
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
            x = distributed_transpose_azimuth.apply(x, (-1, 1), chan_shapes)

        # extract shape
        B, C, K, H, W = x.shape
        x = x.reshape(B, self.groups, self.groupsize, K, H, W)

        # do weight multiplication
        out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
Boris Bonev's avatar
Boris Bonev committed
338
        out = out.reshape(out.shape[0], -1, H, W)
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359

        if self.bias is not None:
            out = out + self.bias.reshape(1, -1, 1, 1)

        return out


class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
    """
    Discrete-continuous transpose convolutions (DISCO) on the 2-Sphere as described in [1].

    [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        in_shape: Tuple[int],
        out_shape: Tuple[int],
        kernel_shape: Union[int, List[int]],
360
        basis_type: Optional[str] = "piecewise linear",
361
        basis_norm_mode: Optional[str] = "mean",
362
363
364
365
366
367
        groups: Optional[int] = 1,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        bias: Optional[bool] = True,
        theta_cutoff: Optional[float] = None,
    ):
368
        super().__init__(in_channels, out_channels, kernel_shape, basis_type, groups, bias)
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386

        self.nlat_in, self.nlon_in = in_shape
        self.nlat_out, self.nlon_out = out_shape

        # get the comms grid:
        self.comm_size_polar = polar_group_size()
        self.comm_rank_polar = polar_group_rank()
        self.comm_size_azimuth = azimuth_group_size()
        self.comm_rank_azimuth = azimuth_group_rank()

        # we need those shapes:
        self.lat_in_shapes = compute_split_shapes(self.nlat_in, self.comm_size_polar)
        self.lon_in_shapes = compute_split_shapes(self.nlon_in, self.comm_size_azimuth)
        self.lat_out_shapes = compute_split_shapes(self.nlat_out, self.comm_size_polar)
        self.lon_out_shapes = compute_split_shapes(self.nlon_out, self.comm_size_azimuth)

        # bandlimit
        if theta_cutoff is None:
Thorsten Kurth's avatar
Thorsten Kurth committed
387
            theta_cutoff = torch.pi / float(self.nlat_in - 1)
388
389
390
391
392
393

        if theta_cutoff <= 0.0:
            raise ValueError("Error, theta_cutoff has to be positive.")

        # Note that the psi matrix is of shape nlat_out x nlat_in * nlon_in. Since the contraction in nlon direction is a convolution,
        # we will keep local to all nodes and split the computation up along nlat. We further split the input dim because this reduces the number
Boris Bonev's avatar
Boris Bonev committed
394
        # of atomic reduction calls inside the actual kernel
395
396

        # set local shapes according to distributed mode:
Boris Bonev's avatar
Boris Bonev committed
397
398
        self.nlat_in_local = self.nlat_in
        self.nlat_out_local = self.lat_out_shapes[self.comm_rank_polar]
399
400
401

        # switch in_shape and out_shape since we want transpose conv
        # distributed mode here is swapped because of the transpose
Boris Bonev's avatar
Boris Bonev committed
402
        idx, vals = _precompute_distributed_convolution_tensor_s2(
403
404
405
406
407
408
409
410
411
            out_shape,
            in_shape,
            self.filter_basis,
            grid_in=grid_out,
            grid_out=grid_in,
            theta_cutoff=theta_cutoff,
            transpose_normalization=True,
            basis_norm_mode=basis_norm_mode,
            merge_quadrature=True,
Boris Bonev's avatar
Boris Bonev committed
412
        )
413

Boris Bonev's avatar
Boris Bonev committed
414
415
416
417
        # sort the values
        ker_idx = idx[0, ...].contiguous()
        row_idx = idx[1, ...].contiguous()
        col_idx = idx[2, ...].contiguous()
418
        vals = vals.contiguous()
Boris Bonev's avatar
Boris Bonev committed
419

420
421
422
423
424
        if _cuda_extension_available:
            # preprocessed data-structure for GPU kernel
            roff_idx = preprocess_psi(self.kernel_size, self.nlat_in_local, ker_idx, row_idx, col_idx, vals).contiguous()
            self.register_buffer("psi_roff_idx", roff_idx, persistent=False)

425
        # save all datastructures
Boris Bonev's avatar
Boris Bonev committed
426
427
428
        self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
        self.register_buffer("psi_row_idx", row_idx, persistent=False)
        self.register_buffer("psi_col_idx", col_idx, persistent=False)
429
430
        self.register_buffer("psi_vals", vals, persistent=False)

431
432
433
434
    def extra_repr(self):
        r"""
        Pretty print module
        """
435
        return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
436

Boris Bonev's avatar
Boris Bonev committed
437
438
439
440
441
442
    @property
    def psi_idx(self):
        return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()

    def get_psi(self, semi_transposed: bool = False):
        if semi_transposed:
443
444
445
446
447
448
449
            # do partial transpose
            # we do a semi-transposition to faciliate the computation
            tout = self.psi_idx[2] // self.nlon_out
            pout = self.psi_idx[2] % self.nlon_out
            # flip the axis of longitudes
            pout = self.nlon_out - 1 - pout
            tin = self.psi_idx[1]
Boris Bonev's avatar
Boris Bonev committed
450
            idx = torch.stack([self.psi_idx[0], tout, tin * self.nlon_out + pout], dim=0)
451
452
453
454
            psi = torch.sparse_coo_tensor(idx, self.psi_vals, size=(self.kernel_size, self.nlat_out_local, self.nlat_in_local * self.nlon_out)).coalesce()
        else:
            psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_in_local, self.nlat_out_local * self.nlon_out)).coalesce()
        return psi
Boris Bonev's avatar
Boris Bonev committed
455
456
457

    def forward(self, x: torch.Tensor) -> torch.Tensor:

458
459
460
461
462
463
        # extract shape
        B, C, H, W = x.shape
        x = x.reshape(B, self.groups, self.groupsize, H, W)

        # do weight multiplication
        x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
Boris Bonev's avatar
Boris Bonev committed
464
        x = x.reshape(B, -1, x.shape[-3], H, W)
465
        num_chans = x.shape[1]
Boris Bonev's avatar
Boris Bonev committed
466

467
468
469
        # transpose such that lon is local, channels are split
        if self.comm_size_azimuth > 1:
            x = distributed_transpose_azimuth.apply(x, (1, -1), self.lon_in_shapes)
Boris Bonev's avatar
Boris Bonev committed
470

Thorsten Kurth's avatar
Thorsten Kurth committed
471
        # gather input tensor and set up backward reduction hooks
472
473
        x = gather_from_polar_region(x, -2, self.lat_in_shapes)
        x = copy_to_polar_region(x)
Boris Bonev's avatar
Boris Bonev committed
474
475
476
477
478

        if x.is_cuda and _cuda_extension_available:
            out = _disco_s2_transpose_contraction_cuda(
                x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out_local, self.nlon_out
            )
479
        else:
Boris Bonev's avatar
Boris Bonev committed
480
481
482
            if x.is_cuda:
                warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
            psi = self.get_psi(semi_transposed=True)
483
484
485
486
487
488
489
490
491
492
493
            out = _disco_s2_transpose_contraction_torch(x, psi, self.nlon_out)

        # now we can transpose back the result, so that lon is split and channels are local
        if self.comm_size_azimuth > 1:
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
            out = distributed_transpose_azimuth.apply(out, (-1, 1), chan_shapes)

        if self.bias is not None:
            out = out + self.bias.reshape(1, -1, 1, 1)

        return out