distributed_sht.py 22.8 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import os
import numpy as np
import torch
import torch.nn as nn
import torch.fft
import torch.nn.functional as F

39
40
from torch_harmonics.quadrature import legendre_gauss_weights, lobatto_weights, clenshaw_curtiss_weights
from torch_harmonics.legendre import _precompute_legpoly, _precompute_dlegpoly
Boris Bonev's avatar
Boris Bonev committed
41
42
from torch_harmonics.distributed import polar_group_size, azimuth_group_size, distributed_transpose_azimuth, distributed_transpose_polar
from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank
43
from torch_harmonics.distributed import compute_split_shapes, split_tensor_along_dim
Boris Bonev's avatar
Boris Bonev committed
44
45
46
47
48
49
50
51
52
53
54
55


class DistributedRealSHT(nn.Module):
    """
    Defines a module for computing the forward (real-valued) SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
    The SHT is applied to the last two dimensions of the input

    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """

56
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
Boris Bonev's avatar
Boris Bonev committed
57
        """
58
        Distribtued SHT layer. Expects the last 3 dimensions of the input tensor to be channels, latitude, longitude.
Boris Bonev's avatar
Boris Bonev committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101

        Parameters:
        nlat: input grid resolution in the latitudinal direction
        nlon: input grid resolution in the longitudinal direction
        grid: grid in the latitude direction (for now only tensor product grids are supported)
        """

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # TODO: include assertions regarding the dimensions

        # compute quadrature points
        if self.grid == "legendre-gauss":
            cost, w = legendre_gauss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
            cost, w = lobatto_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat-1
        elif self.grid == "equiangular":
            cost, w = clenshaw_curtiss_weights(nlat, -1, 1)
            # cost, w = fejer2_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        else:
            raise(ValueError("Unknown quadrature mode"))

        # get the comms grid:
        self.comm_size_polar = polar_group_size()
        self.comm_rank_polar = polar_group_rank()
        self.comm_size_azimuth = azimuth_group_size()
        self.comm_rank_azimuth = azimuth_group_rank()

        # apply cosine transform and flip them
        tq = np.flip(np.arccos(cost))

        # determine the dimensions
        self.mmax = mmax or self.nlon // 2 + 1

102
103
104
105
106
        # compute splits
        self.lat_shapes = compute_split_shapes(self.nlat, self.comm_size_polar)
        self.lon_shapes = compute_split_shapes(self.nlon, self.comm_size_azimuth)
        self.l_shapes = compute_split_shapes(self.lmax, self.comm_size_polar)
        self.m_shapes = compute_split_shapes(self.mmax, self.comm_size_azimuth)
Boris Bonev's avatar
Boris Bonev committed
107
108
109

        # combine quadrature weights with the legendre weights
        weights = torch.from_numpy(w)
110
        pct = _precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
111
        pct = torch.from_numpy(pct)
Boris Bonev's avatar
Boris Bonev committed
112
        weights = torch.einsum('mlk,k->mlk', pct, weights)
113

114
115
        # split weights
        weights = split_tensor_along_dim(weights, dim=0, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth]
Boris Bonev's avatar
Boris Bonev committed
116
117
118
119
120
121
122
123
124
125
126
127

        # remember quadrature weights
        self.register_buffer('weights', weights, persistent=False)

    def extra_repr(self):
        """
        Pretty print module
        """
        return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'

    def forward(self, x: torch.Tensor):

128
129
130
        if x.dim() < 3:
            raise ValueError(f"Expected tensor with at least 3 dimensions but got {x.dim()} instead")

Boris Bonev's avatar
Boris Bonev committed
131
        # we need to ensure that we can split the channels evenly
132
133
        num_chans = x.shape[-3]

Boris Bonev's avatar
Boris Bonev committed
134
135
        # h and w is split. First we make w local by transposing into channel dim
        if self.comm_size_azimuth > 1:
136
            x = distributed_transpose_azimuth.apply(x, (-3, -1), self.lon_shapes)
Boris Bonev's avatar
Boris Bonev committed
137
138

        # apply real fft in the longitudinal direction: make sure to truncate to nlon
139
        x = 2.0 * torch.pi * torch.fft.rfft(x, n=self.nlon, dim=-1, norm="forward")
Boris Bonev's avatar
Boris Bonev committed
140
141

        # truncate
142
        x = x[..., :self.mmax]
Boris Bonev's avatar
Boris Bonev committed
143
144
145

        # transpose: after this, m is split and c is local
        if self.comm_size_azimuth > 1:
146
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
147
            x = distributed_transpose_azimuth.apply(x, (-1, -3), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
148
149
150

        # transpose: after this, c is split and h is local
        if self.comm_size_polar > 1:
151
            x = distributed_transpose_polar.apply(x, (-3, -2), self.lat_shapes)
Boris Bonev's avatar
Boris Bonev committed
152
153

        # do the Legendre-Gauss quadrature
154
        x = torch.view_as_real(x)
Boris Bonev's avatar
Boris Bonev committed
155
156

        # contraction
157
        xs = torch.einsum('...kmr,mlk->...lmr', x, self.weights.to(x.dtype)).contiguous()
Boris Bonev's avatar
Boris Bonev committed
158

159
160
        # cast to complex
        x = torch.view_as_complex(xs)
Boris Bonev's avatar
Boris Bonev committed
161
162
163

        # transpose: after this, l is split and c is local
        if self.comm_size_polar	> 1:
164
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_polar)
165
166
            x = distributed_transpose_polar.apply(x, (-2, -3), chan_shapes)

167
        return x
Boris Bonev's avatar
Boris Bonev committed
168
169
170
171
172
173
174
175
176
177
178
179
180


class DistributedInverseRealSHT(nn.Module):
    """
    Defines a module for computing the inverse (real-valued) SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
    nlat, nlon: Output dimensions
    lmax, mmax: Input dimensions (spherical coefficients). For convenience, these are inferred from the output dimensions

    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """

181
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
Boris Bonev's avatar
Boris Bonev committed
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # compute quadrature points
        if self.grid == "legendre-gauss":
            cost, _ = legendre_gauss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
            cost, _ = lobatto_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat-1
        elif self.grid == "equiangular":
            cost, _ = clenshaw_curtiss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        else:
            raise(ValueError("Unknown quadrature mode"))

        # get the comms grid:
        self.comm_size_polar = polar_group_size()
        self.comm_rank_polar = polar_group_rank()
        self.comm_size_azimuth = azimuth_group_size()
        self.comm_rank_azimuth = azimuth_group_rank()

        # apply cosine transform and flip them
        t = np.flip(np.arccos(cost))

        # determine the dimensions
        self.mmax = mmax or self.nlon // 2 + 1

216
        # compute splits
217
218
219
220
        self.lat_shapes = compute_split_shapes(self.nlat, self.comm_size_polar)
        self.lon_shapes = compute_split_shapes(self.nlon, self.comm_size_azimuth)
        self.l_shapes = compute_split_shapes(self.lmax, self.comm_size_polar)
        self.m_shapes = compute_split_shapes(self.mmax, self.comm_size_azimuth)
Boris Bonev's avatar
Boris Bonev committed
221
222

        # compute legende polynomials
223
224
        pct = _precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
        pct = torch.from_numpy(pct)
Boris Bonev's avatar
Boris Bonev committed
225
226

        # split in m
227
        pct = split_tensor_along_dim(pct, dim=0, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth]
Boris Bonev's avatar
Boris Bonev committed
228
229
230
231
232
233
234
235
236
237
238
239

        # register
        self.register_buffer('pct', pct, persistent=False)

    def extra_repr(self):
        """
        Pretty print module
        """
        return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'

    def forward(self, x: torch.Tensor):

240
241
242
        if x.dim() < 3:
            raise ValueError(f"Expected tensor with at least 3 dimensions but got {x.dim()} instead")

Boris Bonev's avatar
Boris Bonev committed
243
        # we need to ensure that we can split the channels evenly
244
        num_chans = x.shape[-3]
Boris Bonev's avatar
Boris Bonev committed
245
246
247

        # transpose: after that, channels are split, l is local:
        if self.comm_size_polar > 1:
248
            x = distributed_transpose_polar.apply(x, (-3, -2), self.l_shapes)
Boris Bonev's avatar
Boris Bonev committed
249
250

        # Evaluate associated Legendre functions on the output nodes
251
        x = torch.view_as_real(x)
Boris Bonev's avatar
Boris Bonev committed
252
253

        # einsum
254
255
256
257
        xs = torch.einsum('...lmr, mlk->...kmr', x, self.pct.to(x.dtype)).contiguous()
        #rl = torch.einsum('...lm, mlk->...km', x[..., 0], self.pct.to(x.dtype) )
        #im = torch.einsum('...lm, mlk->...km', x[..., 1], self.pct.to(x.dtype) )
        #xs = torch.stack((rl, im), -1).contiguous()
Boris Bonev's avatar
Boris Bonev committed
258

259
260
        # inverse FFT
        x = torch.view_as_complex(xs)
Boris Bonev's avatar
Boris Bonev committed
261
262

        if self.comm_size_polar > 1:
263
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_polar)
264
            x = distributed_transpose_polar.apply(x, (-2, -3), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
265
266
267

        # transpose: after this, channels are split and m is local
        if self.comm_size_azimuth > 1:
268
            x = distributed_transpose_azimuth.apply(x, (-3, -1), self.m_shapes)
Boris Bonev's avatar
Boris Bonev committed
269
270

        # apply the inverse (real) FFT
271
        x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
Boris Bonev's avatar
Boris Bonev committed
272
273
274

        # transpose: after this, m is split and channels are local
        if self.comm_size_azimuth > 1:
275
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
276
            x = distributed_transpose_azimuth.apply(x, (-1, -3), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
277

278
        return x
Boris Bonev's avatar
Boris Bonev committed
279
280
281
282
283
284
285
286
287
288
289
290


class DistributedRealVectorSHT(nn.Module):
    """
    Defines a module for computing the forward (real) vector SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
    The SHT is applied to the last three dimensions of the input.

    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """

291
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
Boris Bonev's avatar
Boris Bonev committed
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
        """
        Initializes the vector SHT Layer, precomputing the necessary quadrature weights

        Parameters:
        nlat: input grid resolution in the latitudinal direction
        nlon: input grid resolution in the longitudinal direction
        grid: type of grid the data lives on
        """

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # compute quadrature points
        if self.grid == "legendre-gauss":
            cost, w = legendre_gauss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
            cost, w = lobatto_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat-1
        elif self.grid == "equiangular":
            cost, w = clenshaw_curtiss_weights(nlat, -1, 1)
            # cost, w = fejer2_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        else:
            raise(ValueError("Unknown quadrature mode"))

        # get the comms grid:
        self.comm_size_polar = polar_group_size()
        self.comm_rank_polar = polar_group_rank()
        self.comm_size_azimuth = azimuth_group_size()
        self.comm_rank_azimuth = azimuth_group_rank()

        # apply cosine transform and flip them
        tq = np.flip(np.arccos(cost))

        # determine the dimensions
        self.mmax = mmax or self.nlon // 2 + 1

335
336
337
338
339
        # compute splits
        self.lat_shapes = compute_split_shapes(self.nlat, self.comm_size_polar)
        self.lon_shapes = compute_split_shapes(self.nlon, self.comm_size_azimuth)
        self.l_shapes = compute_split_shapes(self.lmax, self.comm_size_polar)
        self.m_shapes = compute_split_shapes(self.mmax, self.comm_size_azimuth)
Boris Bonev's avatar
Boris Bonev committed
340

341
        # compute weights
Boris Bonev's avatar
Boris Bonev committed
342
        weights = torch.from_numpy(w)
343
344
        dpct = _precompute_dlegpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
        dpct = torch.from_numpy(dpct)
Boris Bonev's avatar
Boris Bonev committed
345
346
347
348
349
350
351
352
353
354

        # combine integration weights, normalization factor in to one:
        l = torch.arange(0, self.lmax)
        norm_factor = 1. / l / (l+1)
        norm_factor[0] = 1.
        weights = torch.einsum('dmlk,k,l->dmlk', dpct, weights, norm_factor)
        # since the second component is imaginary, we need to take complex conjugation into account
        weights[1] = -1 * weights[1]

        # we need to split in m, pad before:
355
        weights = split_tensor_along_dim(weights, dim=1, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth]
Boris Bonev's avatar
Boris Bonev committed
356
357
358
359

        # remember quadrature weights
        self.register_buffer('weights', weights, persistent=False)

360

Boris Bonev's avatar
Boris Bonev committed
361
362
363
364
365
366
367
368
    def extra_repr(self):
        """
        Pretty print module
        """
        return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'

    def forward(self, x: torch.Tensor):

369
370
371
        if x.dim() < 4:
            raise ValueError(f"Expected tensor with at least 4 dimensions but got {x.dim()} instead")

372
        # we need to ensure that we can split the channels evenly
373
        num_chans = x.shape[-4]
Boris Bonev's avatar
Boris Bonev committed
374
375
376

        # h and w is split. First we make w local by transposing into channel dim
        if self.comm_size_azimuth > 1:
377
            x = distributed_transpose_azimuth.apply(x, (-4, -1), self.lon_shapes)
Boris Bonev's avatar
Boris Bonev committed
378
379

        # apply real fft in the longitudinal direction: make sure to truncate to nlon
380
        x = 2.0 * torch.pi * torch.fft.rfft(x, n=self.nlon, dim=-1, norm="forward")
Boris Bonev's avatar
Boris Bonev committed
381
382

        # truncate
383
        x = x[..., :self.mmax]
Boris Bonev's avatar
Boris Bonev committed
384
385
386

        # transpose: after this, m is split and c is local
        if self.comm_size_azimuth > 1:
387
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
388
            x = distributed_transpose_azimuth.apply(x, (-1, -4), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
389
390
391

        # transpose: after this, c is split and h is local
        if self.comm_size_polar > 1:
392
            x = distributed_transpose_polar.apply(x, (-4, -2), self.lat_shapes)
Boris Bonev's avatar
Boris Bonev committed
393
394

        # do the Legendre-Gauss quadrature
395
        x = torch.view_as_real(x)
Boris Bonev's avatar
Boris Bonev committed
396
397

        # create output array
398
        xs = torch.zeros_like(x, dtype=x.dtype, device=x.device)
Boris Bonev's avatar
Boris Bonev committed
399
400
401

        # contraction - spheroidal component
        # real component
402
403
404
405
406
        xs[..., 0, :, :, 0] =   torch.einsum('...km,mlk->...lm', x[..., 0, :, :, 0], self.weights[0].to(xs.dtype)) \
                              - torch.einsum('...km,mlk->...lm', x[..., 1, :, :, 1], self.weights[1].to(xs.dtype))
        # imag component
        xs[..., 0, :, :, 1] =   torch.einsum('...km,mlk->...lm', x[..., 0, :, :, 1], self.weights[0].to(xs.dtype)) \
                              + torch.einsum('...km,mlk->...lm', x[..., 1, :, :, 0], self.weights[1].to(xs.dtype))
Boris Bonev's avatar
Boris Bonev committed
407
408
409

        # contraction - toroidal component
        # real component
410
411
        xs[..., 1, :, :, 0] = - torch.einsum('...km,mlk->...lm', x[..., 0, :, :, 1], self.weights[1].to(xs.dtype)) \
                              - torch.einsum('...km,mlk->...lm', x[..., 1, :, :, 0], self.weights[0].to(xs.dtype))
Boris Bonev's avatar
Boris Bonev committed
412
        # imag component
413
414
        xs[..., 1, :, :, 1] =   torch.einsum('...km,mlk->...lm', x[..., 0, :, :, 0], self.weights[1].to(xs.dtype)) \
                              - torch.einsum('...km,mlk->...lm', x[..., 1, :, :, 1], self.weights[0].to(xs.dtype))
Boris Bonev's avatar
Boris Bonev committed
415
416

        # pad if required
417
        x = torch.view_as_complex(xs)
Boris Bonev's avatar
Boris Bonev committed
418
419
420

        # transpose: after this, l is split and c is local
        if self.comm_size_polar > 1:
421
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_polar)
422
            x = distributed_transpose_polar.apply(x, (-2, -4), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
423

424
        return x
Boris Bonev's avatar
Boris Bonev committed
425
426
427
428
429
430
431
432
433
434


class DistributedInverseRealVectorSHT(nn.Module):
    """
    Defines a module for computing the inverse (real-valued) vector SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.

    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """
435
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
Boris Bonev's avatar
Boris Bonev committed
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # compute quadrature points
        if self.grid == "legendre-gauss":
            cost, _ = legendre_gauss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
            cost, _ = lobatto_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat-1
        elif self.grid == "equiangular":
            cost, _ = clenshaw_curtiss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        else:
            raise(ValueError("Unknown quadrature mode"))

        self.comm_size_polar = polar_group_size()
        self.comm_rank_polar = polar_group_rank()
        self.comm_size_azimuth = azimuth_group_size()
        self.comm_rank_azimuth = azimuth_group_rank()

        # apply cosine transform and flip them
        t = np.flip(np.arccos(cost))

        # determine the dimensions
        self.mmax = mmax or self.nlon // 2 + 1

469
        # compute splits
470
471
472
473
        self.lat_shapes = compute_split_shapes(self.nlat, self.comm_size_polar)
        self.lon_shapes = compute_split_shapes(self.nlon, self.comm_size_azimuth)
        self.l_shapes = compute_split_shapes(self.lmax, self.comm_size_polar)
        self.m_shapes = compute_split_shapes(self.mmax, self.comm_size_azimuth)
Boris Bonev's avatar
Boris Bonev committed
474
475

        # compute legende polynomials
476
477
        dpct = _precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
        dpct = torch.from_numpy(dpct)
Boris Bonev's avatar
Boris Bonev committed
478
479

        # split in m
480
        dpct = split_tensor_along_dim(dpct, dim=1, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth]
Boris Bonev's avatar
Boris Bonev committed
481
482
483
484
485
486
487
488
489
490
491
492

        # register buffer
        self.register_buffer('dpct', dpct, persistent=False)

    def extra_repr(self):
        """
        Pretty print module
        """
        return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'

    def forward(self, x: torch.Tensor):

493
494
495
        if x.dim() < 4:
            raise ValueError(f"Expected tensor with at least 4 dimensions but got {x.dim()} instead")

496
        # store num channels
497
        num_chans = x.shape[-4]
Boris Bonev's avatar
Boris Bonev committed
498
499
500

        # transpose: after that, channels are split, l is local:
        if self.comm_size_polar > 1:
501
            x = distributed_transpose_polar.apply(x, (-4, -2), self.l_shapes)
Boris Bonev's avatar
Boris Bonev committed
502
503

        # Evaluate associated Legendre functions on the output nodes
504
        x = torch.view_as_real(x)
Boris Bonev's avatar
Boris Bonev committed
505
506
507

        # contraction - spheroidal component
        # real component
508
509
        srl =   torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 0], self.dpct[0].to(x.dtype)) \
              - torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 1], self.dpct[1].to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
510
        # imag component
511
512
        sim =   torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 1], self.dpct[0].to(x.dtype)) \
              + torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 0], self.dpct[1].to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
513
514
515

        # contraction - toroidal component
        # real component
516
517
        trl = - torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 1], self.dpct[1].to(x.dtype)) \
              - torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 0], self.dpct[0].to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
518
        # imag component
519
520
        tim =   torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 0], self.dpct[1].to(x.dtype)) \
              - torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 1], self.dpct[0].to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
521
522
523
524
525
526
527
528
529
530

        # reassemble
        s = torch.stack((srl, sim), -1)
        t = torch.stack((trl, tim), -1)
        xs = torch.stack((s, t), -4)

        # convert to complex
        x = torch.view_as_complex(xs)

        if self.comm_size_polar > 1:
531
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_polar)
532
            x = distributed_transpose_polar.apply(x, (-2, -4), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
533
534
535

        # transpose: after this, channels are split and m is local
        if self.comm_size_azimuth > 1:
536
            x = distributed_transpose_azimuth.apply(x, (-4, -1), self.m_shapes)
Boris Bonev's avatar
Boris Bonev committed
537
538
539
540
541
542

        # apply the inverse (real) FFT
        x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")

        # transpose: after this, m is split and channels are local
        if self.comm_size_azimuth > 1:
543
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
544
            x = distributed_transpose_azimuth.apply(x, (-1, -4), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
545

546
        return x