distributed_convolution.py 20 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import abc
from typing import List, Tuple, Union, Optional
Boris Bonev's avatar
Boris Bonev committed
34
35
from itertools import accumulate
from warnings import warn
36
37
38
39
40
41
42
43

import math

import torch
import torch.nn as nn

from functools import partial

Thorsten Kurth's avatar
Thorsten Kurth committed
44
from torch_harmonics.quadrature import _precompute_grid, _precompute_latitudes, _precompute_longitudes
45
from torch_harmonics._disco_convolution import _get_psi, _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch
Boris Bonev's avatar
Boris Bonev committed
46
from torch_harmonics._disco_convolution import _disco_s2_contraction_cuda, _disco_s2_transpose_contraction_cuda
47
from torch_harmonics.filter_basis import get_filter_basis
48
from torch_harmonics.convolution import (
Thorsten Kurth's avatar
Thorsten Kurth committed
49
    _precompute_convolution_tensor_s2,
50
51
52
    DiscreteContinuousConv,
)

53

54
55
from torch_harmonics.distributed import polar_group_size, azimuth_group_size
from torch_harmonics.distributed import distributed_transpose_azimuth, distributed_transpose_polar
56
from torch_harmonics.distributed import reduce_from_polar_region, scatter_to_polar_region, gather_from_polar_region, copy_to_polar_region
57
58
59
from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank
from torch_harmonics.distributed import compute_split_shapes, split_tensor_along_dim

60
# import custom C++/CUDA extensions if available
Boris Bonev's avatar
Boris Bonev committed
61
try:
62
    from disco_helpers import preprocess_psi
Boris Bonev's avatar
Boris Bonev committed
63
    import disco_cuda_extension
64

Boris Bonev's avatar
Boris Bonev committed
65
66
67
68
69
70
    _cuda_extension_available = True
except ImportError as err:
    disco_cuda_extension = None
    _cuda_extension_available = False


Thorsten Kurth's avatar
Thorsten Kurth committed
71
72
73
74
75
def _split_distributed_convolution_tensor_s2(
    idx: torch.Tensor,
    vals: torch.Tensor,
    in_shape: Tuple[int],
    out_shape: Tuple[int],
Boris Bonev's avatar
Boris Bonev committed
76
):
apaaris's avatar
apaaris committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
    """
    Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i \nu = Y(-\theta_j)Z(\phi_i - \phi_j)Y(\theta_j)\nu$.
    Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al.
    The output tensor has shape kernel_shape x nlat_out x (nlat_in * nlon_in).

    The rotation of the Euler angles uses the YZY convention, which applied to the northpole $(0,0,1)^T$ yields
    $$
    Y(\alpha) Z(\beta) Y(\gamma) n =
        {\begin{bmatrix}
            \cos(\gamma)\sin(\alpha) + \cos(\alpha)\cos(\beta)\sin(\gamma) \\
            \sin(\beta)\sin(\gamma) \\
            \cos(\alpha)\cos(\gamma)-\cos(\beta)\sin(\alpha)\sin(\gamma)
        \end{bmatrix}}
    $$

    Parameters
    ----------
    in_shape: Tuple[int]
        Shape of the input tensor
    out_shape: Tuple[int]
        Shape of the output tensor
    filter_basis: FilterBasis
        Filter basis to use
    grid_in: str
        Grid type for the input tensor
    grid_out: str
        Grid type for the output tensor
    theta_cutoff: float
        Theta cutoff for the filter basis
    theta_eps: float
        Epsilon for the theta cutoff
    transpose_normalization: bool
        Whether to transpose the normalization
    basis_norm_mode: str
        Normalization mode for the filter basis
    merge_quadrature: bool
        Whether to merge the quadrature weights

    Returns
    -------
    out_idx: torch.Tensor
        Indices of the output tensor
    out_vals: torch.Tensor
        Values of the output tensor
    """

    assert len(in_shape) == 2
    assert len(out_shape) == 2

    kernel_size = filter_basis.kernel_size

128
129
130
    nlat_in, nlon_in = in_shape
    nlat_out, nlon_out = out_shape

Boris Bonev's avatar
Boris Bonev committed
131
132
133
134
135
    comm_size_polar = polar_group_size()
    comm_rank_polar = polar_group_rank()
    split_shapes = compute_split_shapes(nlat_in, num_chunks=comm_size_polar)
    offsets = [0] + list(accumulate(split_shapes))
    start_idx = offsets[comm_rank_polar]
136
    end_idx = offsets[comm_rank_polar + 1]
Boris Bonev's avatar
Boris Bonev committed
137
138

    # once normalization is done we can throw away the entries which correspond to input latitudes we do not care about
Thorsten Kurth's avatar
Thorsten Kurth committed
139
140
    lats = idx[2] // nlon_in
    lons = idx[2] % nlon_in
Boris Bonev's avatar
Boris Bonev committed
141
    ilats = torch.argwhere((lats < end_idx) & (lats >= start_idx)).squeeze()
Thorsten Kurth's avatar
Thorsten Kurth committed
142
    vals = vals[ilats]
Boris Bonev's avatar
Boris Bonev committed
143
    # for the indices we need to recompute them to refer to local indices of the input tenor
Thorsten Kurth's avatar
Thorsten Kurth committed
144
    idx = torch.stack([idx[0, ilats], idx[1, ilats], (lats[ilats] - start_idx) * nlon_in + lons[ilats]], dim=0)
145

Thorsten Kurth's avatar
Thorsten Kurth committed
146
147
148
    # make results contiguous
    idx = idx.contiguous()
    vals = vals.to(dtype=torch.float32).contiguous()
149

Thorsten Kurth's avatar
Thorsten Kurth committed
150
    return idx, vals
151

Boris Bonev's avatar
Boris Bonev committed
152

153
154
155
class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
    """
    Distributed version of Discrete-continuous convolutions (DISCO) on the 2-Sphere as described in [1].
apaaris's avatar
apaaris committed
156
    We assume the data can be splitted in polar and azimuthal directions.
157

apaaris's avatar
apaaris committed
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
    Parameters
    ----------
    in_channels: int
        Number of input channels
    out_channels: int
        Number of output channels
    in_shape: Tuple[int]
        Shape of the input tensor
    out_shape: Tuple[int]
        Shape of the output tensor
    kernel_shape: Union[int, Tuple[int], Tuple[int, int]]
        Shape of the kernel
    basis_type: Optional[str]
        Type of basis to use
    basis_norm_mode: Optional[str]
        Normalization mode for the filter basis
    groups: Optional[int]
        Number of groups
    grid_in: Optional[str]
        Grid type for the input tensor  
    grid_out: Optional[str]
        Grid type for the output tensor
    bias: Optional[bool]
        Whether to use bias
    theta_cutoff: Optional[float]
        Theta cutoff for the filter basis

    Returns
    -------
    out: torch.Tensor
        Output tensor

    References
    ----------
192
193
194
195
196
197
198
199
200
    [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        in_shape: Tuple[int],
        out_shape: Tuple[int],
Thorsten Kurth's avatar
Thorsten Kurth committed
201
        kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
202
        basis_type: Optional[str] = "piecewise linear",
203
        basis_norm_mode: Optional[str] = "mean",
204
205
206
207
208
209
        groups: Optional[int] = 1,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        bias: Optional[bool] = True,
        theta_cutoff: Optional[float] = None,
    ):
210
        super().__init__(in_channels, out_channels, kernel_shape, basis_type, groups, bias)
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228

        self.nlat_in, self.nlon_in = in_shape
        self.nlat_out, self.nlon_out = out_shape

        # get the comms grid:
        self.comm_size_polar = polar_group_size()
        self.comm_rank_polar = polar_group_rank()
        self.comm_size_azimuth = azimuth_group_size()
        self.comm_rank_azimuth = azimuth_group_rank()

        # we need those shapes:
        self.lat_in_shapes = compute_split_shapes(self.nlat_in, self.comm_size_polar)
        self.lon_in_shapes = compute_split_shapes(self.nlon_in, self.comm_size_azimuth)
        self.lat_out_shapes = compute_split_shapes(self.nlat_out, self.comm_size_polar)
        self.lon_out_shapes = compute_split_shapes(self.nlon_out, self.comm_size_azimuth)

        # compute theta cutoff based on the bandlimit of the input field
        if theta_cutoff is None:
Thorsten Kurth's avatar
Thorsten Kurth committed
229
            theta_cutoff = torch.pi / float(self.nlat_out - 1)
230
231
232
233
234
235
236
237
238

        if theta_cutoff <= 0.0:
            raise ValueError("Error, theta_cutoff has to be positive.")

        # Note that the psi matrix is of shape nlat_out x nlat_in * nlon_in. Since the contraction in nlon direction is a convolution,
        # we will keep local to all nodes and split the computation up along nlat. We further split the input dim because this reduces the number
        # of atomic reduction calls inside the actual kernel

        # set local shapes according to distributed mode:
Boris Bonev's avatar
Boris Bonev committed
239
240
        self.nlat_in_local = self.lat_in_shapes[self.comm_rank_polar]
        self.nlat_out_local = self.nlat_out
241

Thorsten Kurth's avatar
Thorsten Kurth committed
242
243
        # compute global convolution tensor
        idx, vals, _ = _precompute_convolution_tensor_s2(
244
245
246
247
248
249
250
251
252
            in_shape,
            out_shape,
            self.filter_basis,
            grid_in=grid_in,
            grid_out=grid_out,
            theta_cutoff=theta_cutoff,
            transpose_normalization=False,
            basis_norm_mode=basis_norm_mode,
            merge_quadrature=True,
Boris Bonev's avatar
Boris Bonev committed
253
        )
254

Thorsten Kurth's avatar
Thorsten Kurth committed
255
256
257
        # split the convolution tensor along latitude
        idx, vals = _split_distributed_convolution_tensor_s2(idx, vals, in_shape, out_shape)

Boris Bonev's avatar
Boris Bonev committed
258
259
260
261
        # sort the values
        ker_idx = idx[0, ...].contiguous()
        row_idx = idx[1, ...].contiguous()
        col_idx = idx[2, ...].contiguous()
262
        vals = vals.contiguous()
Boris Bonev's avatar
Boris Bonev committed
263

264
265
266
267
268
        if _cuda_extension_available:
            # preprocessed data-structure for GPU kernel
            roff_idx = preprocess_psi(self.kernel_size, self.nlat_out_local, ker_idx, row_idx, col_idx, vals).contiguous()
            self.register_buffer("psi_roff_idx", roff_idx, persistent=False)

269
        # save all datastructures
Boris Bonev's avatar
Boris Bonev committed
270
271
272
        self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
        self.register_buffer("psi_row_idx", row_idx, persistent=False)
        self.register_buffer("psi_col_idx", col_idx, persistent=False)
273
274
        self.register_buffer("psi_vals", vals, persistent=False)

275
276
277
        # store psi jic:
        self.psi = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, self.nlat_in_local, self.nlat_out_local)

278
279
280
281
    def extra_repr(self):
        r"""
        Pretty print module
        """
282
        return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
283

Boris Bonev's avatar
Boris Bonev committed
284
285
286
287
288
    @property
    def psi_idx(self):
        return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
289
290
291

        # store number of channels
        num_chans = x.shape[1]
292

293
294
295
296
        # h and w is split. First we make w local by transposing into channel dim
        if self.comm_size_azimuth > 1:
            x = distributed_transpose_azimuth.apply(x, (1, -1), self.lon_in_shapes)

Boris Bonev's avatar
Boris Bonev committed
297
298
299
300
301
302
303
        if x.is_cuda and _cuda_extension_available:
            x = _disco_s2_contraction_cuda(
                x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out_local, self.nlon_out
            )
        else:
            if x.is_cuda:
                warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
304

305
            x = _disco_s2_contraction_torch(x, self.psi.to(x.device), self.nlon_out)
306

Thorsten Kurth's avatar
Thorsten Kurth committed
307
        # perform reduce scatter in polar region
308
        x = reduce_from_polar_region(x)
309
        x = scatter_to_polar_region(x, -2)
310
311
312
313
314
315
316
317
318
319
320
321

        # now we can transpose back the result, so that lon is split and channels are local
        if self.comm_size_azimuth > 1:
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
            x = distributed_transpose_azimuth.apply(x, (-1, 1), chan_shapes)

        # extract shape
        B, C, K, H, W = x.shape
        x = x.reshape(B, self.groups, self.groupsize, K, H, W)

        # do weight multiplication
        out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
Boris Bonev's avatar
Boris Bonev committed
322
        out = out.reshape(out.shape[0], -1, H, W)
323
324
325
326
327
328
329
330
331
332
333

        if self.bias is not None:
            out = out + self.bias.reshape(1, -1, 1, 1)

        return out


class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
    """
    Discrete-continuous transpose convolutions (DISCO) on the 2-Sphere as described in [1].

apaaris's avatar
apaaris committed
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
    Parameters
    ----------
    in_channels: int
        Number of input channels
    out_channels: int
        Number of output channels
    in_shape: Tuple[int]
        Shape of the input tensor
    out_shape: Tuple[int]
        Shape of the output tensor
    kernel_shape: Union[int, Tuple[int], Tuple[int, int]]
        Shape of the kernel
    basis_type: Optional[str]
        Type of basis to use
    basis_norm_mode: Optional[str]
        Normalization mode for the filter basis
    groups: Optional[int]
        Number of groups
    grid_in: Optional[str]
        Grid type for the input tensor  
    grid_out: Optional[str]
        Grid type for the output tensor
    bias: Optional[bool]
        Whether to use bias
    theta_cutoff: Optional[float]
        Theta cutoff for the filter basis

    Returns
    -------
    out: torch.Tensor
        Output tensor

    References
    ----------
368
    [1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
Thorsten Kurth's avatar
Thorsten Kurth committed
369
370

    We assume the data can be splitted in polar and azimuthal directions.
371
372
373
374
375
376
377
378
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        in_shape: Tuple[int],
        out_shape: Tuple[int],
Thorsten Kurth's avatar
Thorsten Kurth committed
379
        kernel_shape: Union[int, Tuple[int], Tuple[int, int]],
380
        basis_type: Optional[str] = "piecewise linear",
381
        basis_norm_mode: Optional[str] = "mean",
382
383
384
385
386
387
        groups: Optional[int] = 1,
        grid_in: Optional[str] = "equiangular",
        grid_out: Optional[str] = "equiangular",
        bias: Optional[bool] = True,
        theta_cutoff: Optional[float] = None,
    ):
388
        super().__init__(in_channels, out_channels, kernel_shape, basis_type, groups, bias)
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406

        self.nlat_in, self.nlon_in = in_shape
        self.nlat_out, self.nlon_out = out_shape

        # get the comms grid:
        self.comm_size_polar = polar_group_size()
        self.comm_rank_polar = polar_group_rank()
        self.comm_size_azimuth = azimuth_group_size()
        self.comm_rank_azimuth = azimuth_group_rank()

        # we need those shapes:
        self.lat_in_shapes = compute_split_shapes(self.nlat_in, self.comm_size_polar)
        self.lon_in_shapes = compute_split_shapes(self.nlon_in, self.comm_size_azimuth)
        self.lat_out_shapes = compute_split_shapes(self.nlat_out, self.comm_size_polar)
        self.lon_out_shapes = compute_split_shapes(self.nlon_out, self.comm_size_azimuth)

        # bandlimit
        if theta_cutoff is None:
Thorsten Kurth's avatar
Thorsten Kurth committed
407
            theta_cutoff = torch.pi / float(self.nlat_in - 1)
408
409
410
411
412
413

        if theta_cutoff <= 0.0:
            raise ValueError("Error, theta_cutoff has to be positive.")

        # Note that the psi matrix is of shape nlat_out x nlat_in * nlon_in. Since the contraction in nlon direction is a convolution,
        # we will keep local to all nodes and split the computation up along nlat. We further split the input dim because this reduces the number
Boris Bonev's avatar
Boris Bonev committed
414
        # of atomic reduction calls inside the actual kernel
415
416

        # set local shapes according to distributed mode:
Boris Bonev's avatar
Boris Bonev committed
417
418
        self.nlat_in_local = self.nlat_in
        self.nlat_out_local = self.lat_out_shapes[self.comm_rank_polar]
419

Thorsten Kurth's avatar
Thorsten Kurth committed
420
        # compute global convolution tensor
421
422
        # switch in_shape and out_shape since we want transpose conv
        # distributed mode here is swapped because of the transpose
Thorsten Kurth's avatar
Thorsten Kurth committed
423
        idx, vals, _ = _precompute_convolution_tensor_s2(
424
425
426
427
428
429
430
431
432
            out_shape,
            in_shape,
            self.filter_basis,
            grid_in=grid_out,
            grid_out=grid_in,
            theta_cutoff=theta_cutoff,
            transpose_normalization=True,
            basis_norm_mode=basis_norm_mode,
            merge_quadrature=True,
Boris Bonev's avatar
Boris Bonev committed
433
        )
434

Thorsten Kurth's avatar
Thorsten Kurth committed
435
436
437
438
        # split the convolution tensor along latitude, again, we need to swap the meaning
        # of in_shape and out_shape
        idx, vals = _split_distributed_convolution_tensor_s2(idx, vals, out_shape, in_shape)

Boris Bonev's avatar
Boris Bonev committed
439
440
441
442
        # sort the values
        ker_idx = idx[0, ...].contiguous()
        row_idx = idx[1, ...].contiguous()
        col_idx = idx[2, ...].contiguous()
443
        vals = vals.contiguous()
Boris Bonev's avatar
Boris Bonev committed
444

445
446
447
448
449
        if _cuda_extension_available:
            # preprocessed data-structure for GPU kernel
            roff_idx = preprocess_psi(self.kernel_size, self.nlat_in_local, ker_idx, row_idx, col_idx, vals).contiguous()
            self.register_buffer("psi_roff_idx", roff_idx, persistent=False)

450
        # save all datastructures
Boris Bonev's avatar
Boris Bonev committed
451
452
453
        self.register_buffer("psi_ker_idx", ker_idx, persistent=False)
        self.register_buffer("psi_row_idx", row_idx, persistent=False)
        self.register_buffer("psi_col_idx", col_idx, persistent=False)
454
455
        self.register_buffer("psi_vals", vals, persistent=False)

456
457
458
        # store psi as COO
        self.psi_st = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, self.nlat_in_local, self.nlat_out_local, semi_transposed=True)

459
460
461
462
    def extra_repr(self):
        r"""
        Pretty print module
        """
463
        return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
464

Boris Bonev's avatar
Boris Bonev committed
465
466
467
468
469
470
    @property
    def psi_idx(self):
        return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()

    def forward(self, x: torch.Tensor) -> torch.Tensor:

471
472
473
474
475
476
        # extract shape
        B, C, H, W = x.shape
        x = x.reshape(B, self.groups, self.groupsize, H, W)

        # do weight multiplication
        x = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, -1, self.weight.shape[1], self.weight.shape[2])).contiguous()
Boris Bonev's avatar
Boris Bonev committed
477
        x = x.reshape(B, -1, x.shape[-3], H, W)
478
        num_chans = x.shape[1]
Boris Bonev's avatar
Boris Bonev committed
479

480
481
482
        # transpose such that lon is local, channels are split
        if self.comm_size_azimuth > 1:
            x = distributed_transpose_azimuth.apply(x, (1, -1), self.lon_in_shapes)
Boris Bonev's avatar
Boris Bonev committed
483

Thorsten Kurth's avatar
Thorsten Kurth committed
484
        # gather input tensor and set up backward reduction hooks
485
486
        x = gather_from_polar_region(x, -2, self.lat_in_shapes)
        x = copy_to_polar_region(x)
Boris Bonev's avatar
Boris Bonev committed
487
488
489
490
491

        if x.is_cuda and _cuda_extension_available:
            out = _disco_s2_transpose_contraction_cuda(
                x, self.psi_roff_idx, self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx, self.psi_vals, self.kernel_size, self.nlat_out_local, self.nlon_out
            )
492
        else:
Boris Bonev's avatar
Boris Bonev committed
493
494
            if x.is_cuda:
                warn("couldn't find CUDA extension, falling back to slow PyTorch implementation")
495
            out = _disco_s2_transpose_contraction_torch(x, self.psi_st.to(x.device), self.nlon_out)
496
497
498
499
500
501
502
503
504
505

        # now we can transpose back the result, so that lon is split and channels are local
        if self.comm_size_azimuth > 1:
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
            out = distributed_transpose_azimuth.apply(out, (-1, 1), chan_shapes)

        if self.bias is not None:
            out = out + self.bias.reshape(1, -1, 1, 1)

        return out