distributed_sht.py 22.3 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import os
import numpy as np
import torch
import torch.nn as nn
import torch.fft
import torch.nn.functional as F

39
40
from torch_harmonics.quadrature import legendre_gauss_weights, lobatto_weights, clenshaw_curtiss_weights
from torch_harmonics.legendre import _precompute_legpoly, _precompute_dlegpoly
Boris Bonev's avatar
Boris Bonev committed
41
42
from torch_harmonics.distributed import polar_group_size, azimuth_group_size, distributed_transpose_azimuth, distributed_transpose_polar
from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank
43
from torch_harmonics.distributed import compute_split_shapes, split_tensor_along_dim
Boris Bonev's avatar
Boris Bonev committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101


class DistributedRealSHT(nn.Module):
    """
    Defines a module for computing the forward (real-valued) SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
    The SHT is applied to the last two dimensions of the input

    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """

    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho", csphase=True):
        """
        Initializes the SHT Layer, precomputing the necessary quadrature weights

        Parameters:
        nlat: input grid resolution in the latitudinal direction
        nlon: input grid resolution in the longitudinal direction
        grid: grid in the latitude direction (for now only tensor product grids are supported)
        """

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # TODO: include assertions regarding the dimensions

        # compute quadrature points
        if self.grid == "legendre-gauss":
            cost, w = legendre_gauss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
            cost, w = lobatto_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat-1
        elif self.grid == "equiangular":
            cost, w = clenshaw_curtiss_weights(nlat, -1, 1)
            # cost, w = fejer2_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        else:
            raise(ValueError("Unknown quadrature mode"))

        # get the comms grid:
        self.comm_size_polar = polar_group_size()
        self.comm_rank_polar = polar_group_rank()
        self.comm_size_azimuth = azimuth_group_size()
        self.comm_rank_azimuth = azimuth_group_rank()

        # apply cosine transform and flip them
        tq = np.flip(np.arccos(cost))

        # determine the dimensions
        self.mmax = mmax or self.nlon // 2 + 1

102
103
104
105
106
        # compute splits
        self.lat_shapes = compute_split_shapes(self.nlat, self.comm_size_polar)
        self.lon_shapes = compute_split_shapes(self.nlon, self.comm_size_azimuth)
        self.l_shapes = compute_split_shapes(self.lmax, self.comm_size_polar)
        self.m_shapes = compute_split_shapes(self.mmax, self.comm_size_azimuth)
Boris Bonev's avatar
Boris Bonev committed
107
108
109

        # combine quadrature weights with the legendre weights
        weights = torch.from_numpy(w)
110
        pct = _precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
111
        pct = torch.from_numpy(pct)        
Boris Bonev's avatar
Boris Bonev committed
112
        weights = torch.einsum('mlk,k->mlk', pct, weights)
113
114
115
        
        # split weights
        weights = split_tensor_along_dim(weights, dim=0, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth]
Boris Bonev's avatar
Boris Bonev committed
116
117
118
119
120
121
122
123
124
125
126
127
128

        # remember quadrature weights
        self.register_buffer('weights', weights, persistent=False)

    def extra_repr(self):
        """
        Pretty print module
        """
        return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'

    def forward(self, x: torch.Tensor):

        # we need to ensure that we can split the channels evenly
129
130
        num_chans = x.shape[1]
        
Boris Bonev's avatar
Boris Bonev committed
131
132
        # h and w is split. First we make w local by transposing into channel dim
        if self.comm_size_azimuth > 1:
133
            x = distributed_transpose_azimuth.apply(x, (1, -1), self.lon_shapes)
Boris Bonev's avatar
Boris Bonev committed
134
135

        # apply real fft in the longitudinal direction: make sure to truncate to nlon
136
        x = 2.0 * torch.pi * torch.fft.rfft(x, n=self.nlon, dim=-1, norm="forward")
Boris Bonev's avatar
Boris Bonev committed
137
138

        # truncate
139
        x = x[..., :self.mmax]
Boris Bonev's avatar
Boris Bonev committed
140
141
142

        # transpose: after this, m is split and c is local
        if self.comm_size_azimuth > 1:
143
144
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
            x = distributed_transpose_azimuth.apply(x, (-1, 1), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
145
146
147

        # transpose: after this, c is split and h is local
        if self.comm_size_polar > 1:
148
            x = distributed_transpose_polar.apply(x, (1, -2), self.lat_shapes)
Boris Bonev's avatar
Boris Bonev committed
149
150

        # do the Legendre-Gauss quadrature
151
        x = torch.view_as_real(x)
Boris Bonev's avatar
Boris Bonev committed
152
153

        # contraction
154
        xs = torch.einsum('...kmr,mlk->...lmr', x, self.weights.to(x.dtype)).contiguous()
Boris Bonev's avatar
Boris Bonev committed
155

156
157
        # cast to complex
        x = torch.view_as_complex(xs)
Boris Bonev's avatar
Boris Bonev committed
158
159
160

        # transpose: after this, l is split and c is local
        if self.comm_size_polar	> 1:
161
162
163
164
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_polar)
            x = distributed_transpose_polar.apply(x, (-2, 1), chan_shapes)
            
        return x
Boris Bonev's avatar
Boris Bonev committed
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212


class DistributedInverseRealSHT(nn.Module):
    """
    Defines a module for computing the inverse (real-valued) SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
    nlat, nlon: Output dimensions
    lmax, mmax: Input dimensions (spherical coefficients). For convenience, these are inferred from the output dimensions

    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """

    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho", csphase=True):

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # compute quadrature points
        if self.grid == "legendre-gauss":
            cost, _ = legendre_gauss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
            cost, _ = lobatto_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat-1
        elif self.grid == "equiangular":
            cost, _ = clenshaw_curtiss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        else:
            raise(ValueError("Unknown quadrature mode"))

        # get the comms grid:
        self.comm_size_polar = polar_group_size()
        self.comm_rank_polar = polar_group_rank()
        self.comm_size_azimuth = azimuth_group_size()
        self.comm_rank_azimuth = azimuth_group_rank()

        # apply cosine transform and flip them
        t = np.flip(np.arccos(cost))

        # determine the dimensions
        self.mmax = mmax or self.nlon // 2 + 1

213
214
215
216
217
        # compute splits 
        self.lat_shapes = compute_split_shapes(self.nlat, self.comm_size_polar)
        self.lon_shapes = compute_split_shapes(self.nlon, self.comm_size_azimuth)
        self.l_shapes = compute_split_shapes(self.lmax, self.comm_size_polar)
        self.m_shapes = compute_split_shapes(self.mmax, self.comm_size_azimuth)
Boris Bonev's avatar
Boris Bonev committed
218
219

        # compute legende polynomials
220
221
        pct = _precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
        pct = torch.from_numpy(pct)
Boris Bonev's avatar
Boris Bonev committed
222
223

        # split in m
224
        pct = split_tensor_along_dim(pct, dim=0, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth]
Boris Bonev's avatar
Boris Bonev committed
225
226
227
228
229
230
231
232
233
234
235
236
237

        # register
        self.register_buffer('pct', pct, persistent=False)

    def extra_repr(self):
        """
        Pretty print module
        """
        return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'

    def forward(self, x: torch.Tensor):

        # we need to ensure that we can split the channels evenly
238
        num_chans = x.shape[1]
Boris Bonev's avatar
Boris Bonev committed
239
240
241

        # transpose: after that, channels are split, l is local:
        if self.comm_size_polar > 1:
242
            x = distributed_transpose_polar.apply(x, (1, -2), self.l_shapes)
Boris Bonev's avatar
Boris Bonev committed
243
244

        # Evaluate associated Legendre functions on the output nodes
245
        x = torch.view_as_real(x)
Boris Bonev's avatar
Boris Bonev committed
246
247

        # einsum
248
249
250
251
        xs = torch.einsum('...lmr, mlk->...kmr', x, self.pct.to(x.dtype)).contiguous()
        #rl = torch.einsum('...lm, mlk->...km', x[..., 0], self.pct.to(x.dtype) )
        #im = torch.einsum('...lm, mlk->...km', x[..., 1], self.pct.to(x.dtype) )
        #xs = torch.stack((rl, im), -1).contiguous()
Boris Bonev's avatar
Boris Bonev committed
252

253
254
        # inverse FFT
        x = torch.view_as_complex(xs)
Boris Bonev's avatar
Boris Bonev committed
255
256

        if self.comm_size_polar > 1:
257
258
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_polar)
            x = distributed_transpose_polar.apply(x, (-2, 1), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
259
260
261

        # transpose: after this, channels are split and m is local
        if self.comm_size_azimuth > 1:
262
            x = distributed_transpose_azimuth.apply(x, (1, -1), self.m_shapes)
Boris Bonev's avatar
Boris Bonev committed
263
264

        # apply the inverse (real) FFT
265
        x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
Boris Bonev's avatar
Boris Bonev committed
266
267
268

        # transpose: after this, m is split and channels are local
        if self.comm_size_azimuth > 1:
269
270
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
            x = distributed_transpose_azimuth.apply(x, (-1, 1), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
271

272
        return x
Boris Bonev's avatar
Boris Bonev committed
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328


class DistributedRealVectorSHT(nn.Module):
    """
    Defines a module for computing the forward (real) vector SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
    The SHT is applied to the last three dimensions of the input.

    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """

    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho", csphase=True):
        """
        Initializes the vector SHT Layer, precomputing the necessary quadrature weights

        Parameters:
        nlat: input grid resolution in the latitudinal direction
        nlon: input grid resolution in the longitudinal direction
        grid: type of grid the data lives on
        """

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # compute quadrature points
        if self.grid == "legendre-gauss":
            cost, w = legendre_gauss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
            cost, w = lobatto_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat-1
        elif self.grid == "equiangular":
            cost, w = clenshaw_curtiss_weights(nlat, -1, 1)
            # cost, w = fejer2_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        else:
            raise(ValueError("Unknown quadrature mode"))

        # get the comms grid:
        self.comm_size_polar = polar_group_size()
        self.comm_rank_polar = polar_group_rank()
        self.comm_size_azimuth = azimuth_group_size()
        self.comm_rank_azimuth = azimuth_group_rank()

        # apply cosine transform and flip them
        tq = np.flip(np.arccos(cost))

        # determine the dimensions
        self.mmax = mmax or self.nlon // 2 + 1

329
330
331
332
333
        # compute splits
        self.lat_shapes = compute_split_shapes(self.nlat, self.comm_size_polar)
        self.lon_shapes = compute_split_shapes(self.nlon, self.comm_size_azimuth)
        self.l_shapes = compute_split_shapes(self.lmax, self.comm_size_polar)
        self.m_shapes = compute_split_shapes(self.mmax, self.comm_size_azimuth)
Boris Bonev's avatar
Boris Bonev committed
334

335
        # compute weights
Boris Bonev's avatar
Boris Bonev committed
336
        weights = torch.from_numpy(w)
337
338
        dpct = _precompute_dlegpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
        dpct = torch.from_numpy(dpct)
Boris Bonev's avatar
Boris Bonev committed
339
340
341
342
343
344
345
346
347
348

        # combine integration weights, normalization factor in to one:
        l = torch.arange(0, self.lmax)
        norm_factor = 1. / l / (l+1)
        norm_factor[0] = 1.
        weights = torch.einsum('dmlk,k,l->dmlk', dpct, weights, norm_factor)
        # since the second component is imaginary, we need to take complex conjugation into account
        weights[1] = -1 * weights[1]

        # we need to split in m, pad before:
349
        weights = split_tensor_along_dim(weights, dim=1, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth]
Boris Bonev's avatar
Boris Bonev committed
350
351
352
353

        # remember quadrature weights
        self.register_buffer('weights', weights, persistent=False)

354
        
Boris Bonev's avatar
Boris Bonev committed
355
356
357
358
359
360
361
362
363
    def extra_repr(self):
        """
        Pretty print module
        """
        return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'

    def forward(self, x: torch.Tensor):

        assert(len(x.shape) >= 3)
364
365
366
        
        # we need to ensure that we can split the channels evenly
        num_chans = x.shape[1]
Boris Bonev's avatar
Boris Bonev committed
367
368
369

        # h and w is split. First we make w local by transposing into channel dim
        if self.comm_size_azimuth > 1:
370
            x = distributed_transpose_azimuth.apply(x, (1, -1), self.lon_shapes)
Boris Bonev's avatar
Boris Bonev committed
371
372

        # apply real fft in the longitudinal direction: make sure to truncate to nlon
373
        x = 2.0 * torch.pi * torch.fft.rfft(x, n=self.nlon, dim=-1, norm="forward")
Boris Bonev's avatar
Boris Bonev committed
374
375

        # truncate
376
        x = x[..., :self.mmax]
Boris Bonev's avatar
Boris Bonev committed
377
378
379

        # transpose: after this, m is split and c is local
        if self.comm_size_azimuth > 1:
380
381
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
            x = distributed_transpose_azimuth.apply(x, (-1, 1), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
382
383
384

        # transpose: after this, c is split and h is local
        if self.comm_size_polar > 1:
385
            x = distributed_transpose_polar.apply(x, (1, -2), self.lat_shapes)
Boris Bonev's avatar
Boris Bonev committed
386
387

        # do the Legendre-Gauss quadrature
388
        x = torch.view_as_real(x)
Boris Bonev's avatar
Boris Bonev committed
389
390

        # create output array
391
        xs = torch.zeros_like(x, dtype=x.dtype, device=x.device)
Boris Bonev's avatar
Boris Bonev committed
392
393
394

        # contraction - spheroidal component
        # real component
395
396
397
398
399
        xs[..., 0, :, :, 0] =   torch.einsum('...km,mlk->...lm', x[..., 0, :, :, 0], self.weights[0].to(xs.dtype)) \
                              - torch.einsum('...km,mlk->...lm', x[..., 1, :, :, 1], self.weights[1].to(xs.dtype))
        # imag component
        xs[..., 0, :, :, 1] =   torch.einsum('...km,mlk->...lm', x[..., 0, :, :, 1], self.weights[0].to(xs.dtype)) \
                              + torch.einsum('...km,mlk->...lm', x[..., 1, :, :, 0], self.weights[1].to(xs.dtype))
Boris Bonev's avatar
Boris Bonev committed
400
401
402

        # contraction - toroidal component
        # real component
403
404
        xs[..., 1, :, :, 0] = - torch.einsum('...km,mlk->...lm', x[..., 0, :, :, 1], self.weights[1].to(xs.dtype)) \
                              - torch.einsum('...km,mlk->...lm', x[..., 1, :, :, 0], self.weights[0].to(xs.dtype))
Boris Bonev's avatar
Boris Bonev committed
405
        # imag component
406
407
        xs[..., 1, :, :, 1] =   torch.einsum('...km,mlk->...lm', x[..., 0, :, :, 0], self.weights[1].to(xs.dtype)) \
                              - torch.einsum('...km,mlk->...lm', x[..., 1, :, :, 1], self.weights[0].to(xs.dtype))
Boris Bonev's avatar
Boris Bonev committed
408
409

        # pad if required
410
        x = torch.view_as_complex(xs)
Boris Bonev's avatar
Boris Bonev committed
411
412
413

        # transpose: after this, l is split and c is local
        if self.comm_size_polar > 1:
414
415
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_polar)
            x = distributed_transpose_polar.apply(x, (-2, 1), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
416

417
        return x
Boris Bonev's avatar
Boris Bonev committed
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461


class DistributedInverseRealVectorSHT(nn.Module):
    """
    Defines a module for computing the inverse (real-valued) vector SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.

    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho", csphase=True):

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # compute quadrature points
        if self.grid == "legendre-gauss":
            cost, _ = legendre_gauss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
            cost, _ = lobatto_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat-1
        elif self.grid == "equiangular":
            cost, _ = clenshaw_curtiss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        else:
            raise(ValueError("Unknown quadrature mode"))

        self.comm_size_polar = polar_group_size()
        self.comm_rank_polar = polar_group_rank()
        self.comm_size_azimuth = azimuth_group_size()
        self.comm_rank_azimuth = azimuth_group_rank()

        # apply cosine transform and flip them
        t = np.flip(np.arccos(cost))

        # determine the dimensions
        self.mmax = mmax or self.nlon // 2 + 1

462
463
464
465
466
        # compute splits 
        self.lat_shapes = compute_split_shapes(self.nlat, self.comm_size_polar)
        self.lon_shapes = compute_split_shapes(self.nlon, self.comm_size_azimuth)
        self.l_shapes = compute_split_shapes(self.lmax, self.comm_size_polar)
        self.m_shapes = compute_split_shapes(self.mmax, self.comm_size_azimuth)
Boris Bonev's avatar
Boris Bonev committed
467
468

        # compute legende polynomials
469
470
        dpct = _precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
        dpct = torch.from_numpy(dpct)
Boris Bonev's avatar
Boris Bonev committed
471
472

        # split in m
473
        dpct = split_tensor_along_dim(dpct, dim=1, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth]
Boris Bonev's avatar
Boris Bonev committed
474
475
476
477
478
479
480
481
482
483
484
485

        # register buffer
        self.register_buffer('dpct', dpct, persistent=False)

    def extra_repr(self):
        """
        Pretty print module
        """
        return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'

    def forward(self, x: torch.Tensor):

486
487
        # store num channels
        num_chans = x.shape[1]
Boris Bonev's avatar
Boris Bonev committed
488
489
490

        # transpose: after that, channels are split, l is local:
        if self.comm_size_polar > 1:
491
            x = distributed_transpose_polar.apply(x, (1, -2), self.l_shapes)
Boris Bonev's avatar
Boris Bonev committed
492
493

        # Evaluate associated Legendre functions on the output nodes
494
        x = torch.view_as_real(x)
Boris Bonev's avatar
Boris Bonev committed
495
496
497

        # contraction - spheroidal component
        # real component
498
499
        srl =   torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 0], self.dpct[0].to(x.dtype)) \
              - torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 1], self.dpct[1].to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
500
        # imag component
501
502
        sim =   torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 1], self.dpct[0].to(x.dtype)) \
              + torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 0], self.dpct[1].to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
503
504
505

        # contraction - toroidal component
        # real component
506
507
        trl = - torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 1], self.dpct[1].to(x.dtype)) \
              - torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 0], self.dpct[0].to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
508
        # imag component
509
510
        tim =   torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 0], self.dpct[1].to(x.dtype)) \
              - torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 1], self.dpct[0].to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
511
512
513
514
515
516
517
518
519
520

        # reassemble
        s = torch.stack((srl, sim), -1)
        t = torch.stack((trl, tim), -1)
        xs = torch.stack((s, t), -4)

        # convert to complex
        x = torch.view_as_complex(xs)

        if self.comm_size_polar > 1:
521
522
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_polar)
            x = distributed_transpose_polar.apply(x, (-2, 1), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
523
524
525

        # transpose: after this, channels are split and m is local
        if self.comm_size_azimuth > 1:
526
            x = distributed_transpose_azimuth.apply(x, (1, -1), self.m_shapes)
Boris Bonev's avatar
Boris Bonev committed
527
528
529
530
531
532

        # apply the inverse (real) FFT
        x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")

        # transpose: after this, m is split and channels are local
        if self.comm_size_azimuth > 1:
533
534
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
            x = distributed_transpose_azimuth.apply(x, (-1, 1), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
535

536
        return x