distributed_sht.py 22.9 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import os
import torch
import torch.nn as nn
import torch.fft
import torch.nn.functional as F

38
39
from torch_harmonics.quadrature import legendre_gauss_weights, lobatto_weights, clenshaw_curtiss_weights
from torch_harmonics.legendre import _precompute_legpoly, _precompute_dlegpoly
Boris Bonev's avatar
Boris Bonev committed
40
41
from torch_harmonics.distributed import polar_group_size, azimuth_group_size, distributed_transpose_azimuth, distributed_transpose_polar
from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank
42
from torch_harmonics.distributed import compute_split_shapes, split_tensor_along_dim
Boris Bonev's avatar
Boris Bonev committed
43
44
45
46
47
48
49
50
51
52
53
54


class DistributedRealSHT(nn.Module):
    """
    Defines a module for computing the forward (real-valued) SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
    The SHT is applied to the last two dimensions of the input

    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """

55
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
Boris Bonev's avatar
Boris Bonev committed
56
        """
57
        Distribtued SHT layer. Expects the last 3 dimensions of the input tensor to be channels, latitude, longitude.
Boris Bonev's avatar
Boris Bonev committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76

        Parameters:
        nlat: input grid resolution in the latitudinal direction
        nlon: input grid resolution in the longitudinal direction
        grid: grid in the latitude direction (for now only tensor product grids are supported)
        """

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # TODO: include assertions regarding the dimensions

        # compute quadrature points
        if self.grid == "legendre-gauss":
Thorsten Kurth's avatar
Thorsten Kurth committed
77
            cost, weights = legendre_gauss_weights(nlat, -1, 1)
Boris Bonev's avatar
Boris Bonev committed
78
79
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
Thorsten Kurth's avatar
Thorsten Kurth committed
80
            cost, weights = lobatto_weights(nlat, -1, 1)
Boris Bonev's avatar
Boris Bonev committed
81
82
            self.lmax = lmax or self.nlat-1
        elif self.grid == "equiangular":
Thorsten Kurth's avatar
Thorsten Kurth committed
83
            cost, weights = clenshaw_curtiss_weights(nlat, -1, 1)
Boris Bonev's avatar
Boris Bonev committed
84
85
86
87
88
89
90
91
92
93
94
95
            # cost, w = fejer2_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        else:
            raise(ValueError("Unknown quadrature mode"))

        # get the comms grid:
        self.comm_size_polar = polar_group_size()
        self.comm_rank_polar = polar_group_rank()
        self.comm_size_azimuth = azimuth_group_size()
        self.comm_rank_azimuth = azimuth_group_rank()

        # apply cosine transform and flip them
Thorsten Kurth's avatar
Thorsten Kurth committed
96
        tq = torch.flip(torch.arccos(cost), dims=(0,))
Boris Bonev's avatar
Boris Bonev committed
97
98
99
100

        # determine the dimensions
        self.mmax = mmax or self.nlon // 2 + 1

101
102
103
104
105
        # compute splits
        self.lat_shapes = compute_split_shapes(self.nlat, self.comm_size_polar)
        self.lon_shapes = compute_split_shapes(self.nlon, self.comm_size_azimuth)
        self.l_shapes = compute_split_shapes(self.lmax, self.comm_size_polar)
        self.m_shapes = compute_split_shapes(self.mmax, self.comm_size_azimuth)
Boris Bonev's avatar
Boris Bonev committed
106
107

        # combine quadrature weights with the legendre weights
108
        pct = _precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
Boris Bonev's avatar
Boris Bonev committed
109
        weights = torch.einsum('mlk,k->mlk', pct, weights)
110

111
        # split weights
Thorsten Kurth's avatar
Thorsten Kurth committed
112
        weights = split_tensor_along_dim(weights, dim=0, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth].contiguous()
Boris Bonev's avatar
Boris Bonev committed
113
114
115
116
117
118
119
120
121
122
123
124

        # remember quadrature weights
        self.register_buffer('weights', weights, persistent=False)

    def extra_repr(self):
        """
        Pretty print module
        """
        return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'

    def forward(self, x: torch.Tensor):

125
126
127
        if x.dim() < 3:
            raise ValueError(f"Expected tensor with at least 3 dimensions but got {x.dim()} instead")

Boris Bonev's avatar
Boris Bonev committed
128
        # we need to ensure that we can split the channels evenly
129
130
        num_chans = x.shape[-3]

Boris Bonev's avatar
Boris Bonev committed
131
132
        # h and w is split. First we make w local by transposing into channel dim
        if self.comm_size_azimuth > 1:
133
            x = distributed_transpose_azimuth.apply(x, (-3, -1), self.lon_shapes)
Boris Bonev's avatar
Boris Bonev committed
134
135

        # apply real fft in the longitudinal direction: make sure to truncate to nlon
136
        x = 2.0 * torch.pi * torch.fft.rfft(x, n=self.nlon, dim=-1, norm="forward")
Boris Bonev's avatar
Boris Bonev committed
137
138

        # truncate
139
        x = x[..., :self.mmax]
Boris Bonev's avatar
Boris Bonev committed
140
141
142

        # transpose: after this, m is split and c is local
        if self.comm_size_azimuth > 1:
143
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
144
            x = distributed_transpose_azimuth.apply(x, (-1, -3), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
145
146
147

        # transpose: after this, c is split and h is local
        if self.comm_size_polar > 1:
148
            x = distributed_transpose_polar.apply(x, (-3, -2), self.lat_shapes)
Boris Bonev's avatar
Boris Bonev committed
149
150

        # do the Legendre-Gauss quadrature
151
        x = torch.view_as_real(x)
Boris Bonev's avatar
Boris Bonev committed
152
153

        # contraction
154
        xs = torch.einsum('...kmr,mlk->...lmr', x, self.weights.to(x.dtype)).contiguous()
Boris Bonev's avatar
Boris Bonev committed
155

156
157
        # cast to complex
        x = torch.view_as_complex(xs)
Boris Bonev's avatar
Boris Bonev committed
158
159
160

        # transpose: after this, l is split and c is local
        if self.comm_size_polar	> 1:
161
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_polar)
162
163
            x = distributed_transpose_polar.apply(x, (-2, -3), chan_shapes)

164
        return x
Boris Bonev's avatar
Boris Bonev committed
165
166
167
168
169
170
171
172
173
174
175
176
177


class DistributedInverseRealSHT(nn.Module):
    """
    Defines a module for computing the inverse (real-valued) SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
    nlat, nlon: Output dimensions
    lmax, mmax: Input dimensions (spherical coefficients). For convenience, these are inferred from the output dimensions

    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """

178
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
Boris Bonev's avatar
Boris Bonev committed
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # compute quadrature points
        if self.grid == "legendre-gauss":
            cost, _ = legendre_gauss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
            cost, _ = lobatto_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat-1
        elif self.grid == "equiangular":
            cost, _ = clenshaw_curtiss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        else:
            raise(ValueError("Unknown quadrature mode"))

        # get the comms grid:
        self.comm_size_polar = polar_group_size()
        self.comm_rank_polar = polar_group_rank()
        self.comm_size_azimuth = azimuth_group_size()
        self.comm_rank_azimuth = azimuth_group_rank()

        # apply cosine transform and flip them
Thorsten Kurth's avatar
Thorsten Kurth committed
208
        t = torch.flip(torch.arccos(cost), dims=(0,))
Boris Bonev's avatar
Boris Bonev committed
209
210
211
212

        # determine the dimensions
        self.mmax = mmax or self.nlon // 2 + 1

213
        # compute splits
214
215
216
217
        self.lat_shapes = compute_split_shapes(self.nlat, self.comm_size_polar)
        self.lon_shapes = compute_split_shapes(self.nlon, self.comm_size_azimuth)
        self.l_shapes = compute_split_shapes(self.lmax, self.comm_size_polar)
        self.m_shapes = compute_split_shapes(self.mmax, self.comm_size_azimuth)
Boris Bonev's avatar
Boris Bonev committed
218
219

        # compute legende polynomials
220
        pct = _precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
Boris Bonev's avatar
Boris Bonev committed
221
222

        # split in m
Thorsten Kurth's avatar
Thorsten Kurth committed
223
        pct = split_tensor_along_dim(pct, dim=0, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth].contiguous()
Boris Bonev's avatar
Boris Bonev committed
224
225
226
227
228
229
230
231
232
233
234
235

        # register
        self.register_buffer('pct', pct, persistent=False)

    def extra_repr(self):
        """
        Pretty print module
        """
        return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'

    def forward(self, x: torch.Tensor):

236
237
238
        if x.dim() < 3:
            raise ValueError(f"Expected tensor with at least 3 dimensions but got {x.dim()} instead")

Boris Bonev's avatar
Boris Bonev committed
239
        # we need to ensure that we can split the channels evenly
240
        num_chans = x.shape[-3]
Boris Bonev's avatar
Boris Bonev committed
241
242
243

        # transpose: after that, channels are split, l is local:
        if self.comm_size_polar > 1:
244
            x = distributed_transpose_polar.apply(x, (-3, -2), self.l_shapes)
Boris Bonev's avatar
Boris Bonev committed
245
246

        # Evaluate associated Legendre functions on the output nodes
247
        x = torch.view_as_real(x)
Boris Bonev's avatar
Boris Bonev committed
248
249

        # einsum
250
        xs = torch.einsum('...lmr, mlk->...kmr', x, self.pct.to(x.dtype)).contiguous()
Boris Bonev's avatar
Boris Bonev committed
251

252
253
        # inverse FFT
        x = torch.view_as_complex(xs)
Boris Bonev's avatar
Boris Bonev committed
254
255

        if self.comm_size_polar > 1:
256
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_polar)
257
            x = distributed_transpose_polar.apply(x, (-2, -3), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
258
259
260

        # transpose: after this, channels are split and m is local
        if self.comm_size_azimuth > 1:
261
            x = distributed_transpose_azimuth.apply(x, (-3, -1), self.m_shapes)
Boris Bonev's avatar
Boris Bonev committed
262

263
264
265
266
267
        # set DCT and nyquist frequencies to 0:
        x[..., 0].imag = 0.0
        if (self.nlon % 2 == 0) and (self.nlon // 2 < x.shape[-1]):
            x[..., self.nlon // 2].imag = 0.0
            
Boris Bonev's avatar
Boris Bonev committed
268
        # apply the inverse (real) FFT
269
        x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
Boris Bonev's avatar
Boris Bonev committed
270
271
272

        # transpose: after this, m is split and channels are local
        if self.comm_size_azimuth > 1:
273
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
274
            x = distributed_transpose_azimuth.apply(x, (-1, -3), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
275

276
        return x
Boris Bonev's avatar
Boris Bonev committed
277
278
279
280
281
282
283
284
285
286
287
288


class DistributedRealVectorSHT(nn.Module):
    """
    Defines a module for computing the forward (real) vector SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
    The SHT is applied to the last three dimensions of the input.

    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """

289
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
Boris Bonev's avatar
Boris Bonev committed
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
        """
        Initializes the vector SHT Layer, precomputing the necessary quadrature weights

        Parameters:
        nlat: input grid resolution in the latitudinal direction
        nlon: input grid resolution in the longitudinal direction
        grid: type of grid the data lives on
        """

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # compute quadrature points
        if self.grid == "legendre-gauss":
Thorsten Kurth's avatar
Thorsten Kurth committed
309
            cost, weights = legendre_gauss_weights(nlat, -1, 1)
Boris Bonev's avatar
Boris Bonev committed
310
311
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
Thorsten Kurth's avatar
Thorsten Kurth committed
312
            cost, weights = lobatto_weights(nlat, -1, 1)
Boris Bonev's avatar
Boris Bonev committed
313
314
            self.lmax = lmax or self.nlat-1
        elif self.grid == "equiangular":
Thorsten Kurth's avatar
Thorsten Kurth committed
315
            cost, weights = clenshaw_curtiss_weights(nlat, -1, 1)
Boris Bonev's avatar
Boris Bonev committed
316
317
318
319
320
321
322
323
324
325
326
327
            # cost, w = fejer2_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        else:
            raise(ValueError("Unknown quadrature mode"))

        # get the comms grid:
        self.comm_size_polar = polar_group_size()
        self.comm_rank_polar = polar_group_rank()
        self.comm_size_azimuth = azimuth_group_size()
        self.comm_rank_azimuth = azimuth_group_rank()

        # apply cosine transform and flip them
Thorsten Kurth's avatar
Thorsten Kurth committed
328
        tq = torch.flip(torch.arccos(cost), dims=(0,))
Boris Bonev's avatar
Boris Bonev committed
329
330
331
332

        # determine the dimensions
        self.mmax = mmax or self.nlon // 2 + 1

333
334
335
336
337
        # compute splits
        self.lat_shapes = compute_split_shapes(self.nlat, self.comm_size_polar)
        self.lon_shapes = compute_split_shapes(self.nlon, self.comm_size_azimuth)
        self.l_shapes = compute_split_shapes(self.lmax, self.comm_size_polar)
        self.m_shapes = compute_split_shapes(self.mmax, self.comm_size_azimuth)
Boris Bonev's avatar
Boris Bonev committed
338

339
        # compute weights
340
        dpct = _precompute_dlegpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
Boris Bonev's avatar
Boris Bonev committed
341
342
343
344
345
346
347
348
349
350

        # combine integration weights, normalization factor in to one:
        l = torch.arange(0, self.lmax)
        norm_factor = 1. / l / (l+1)
        norm_factor[0] = 1.
        weights = torch.einsum('dmlk,k,l->dmlk', dpct, weights, norm_factor)
        # since the second component is imaginary, we need to take complex conjugation into account
        weights[1] = -1 * weights[1]

        # we need to split in m, pad before:
Thorsten Kurth's avatar
Thorsten Kurth committed
351
        weights = split_tensor_along_dim(weights, dim=1, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth].contiguous()
Boris Bonev's avatar
Boris Bonev committed
352
353
354
355

        # remember quadrature weights
        self.register_buffer('weights', weights, persistent=False)

356

Boris Bonev's avatar
Boris Bonev committed
357
358
359
360
361
362
363
364
    def extra_repr(self):
        """
        Pretty print module
        """
        return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'

    def forward(self, x: torch.Tensor):

365
366
367
        if x.dim() < 4:
            raise ValueError(f"Expected tensor with at least 4 dimensions but got {x.dim()} instead")

368
        # we need to ensure that we can split the channels evenly
369
        num_chans = x.shape[-4]
Boris Bonev's avatar
Boris Bonev committed
370
371
372

        # h and w is split. First we make w local by transposing into channel dim
        if self.comm_size_azimuth > 1:
373
            x = distributed_transpose_azimuth.apply(x, (-4, -1), self.lon_shapes)
Boris Bonev's avatar
Boris Bonev committed
374
375

        # apply real fft in the longitudinal direction: make sure to truncate to nlon
376
        x = 2.0 * torch.pi * torch.fft.rfft(x, n=self.nlon, dim=-1, norm="forward")
Boris Bonev's avatar
Boris Bonev committed
377
378

        # truncate
379
        x = x[..., :self.mmax]
Boris Bonev's avatar
Boris Bonev committed
380
381
382

        # transpose: after this, m is split and c is local
        if self.comm_size_azimuth > 1:
383
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
384
            x = distributed_transpose_azimuth.apply(x, (-1, -4), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
385
386
387

        # transpose: after this, c is split and h is local
        if self.comm_size_polar > 1:
388
            x = distributed_transpose_polar.apply(x, (-4, -2), self.lat_shapes)
Boris Bonev's avatar
Boris Bonev committed
389
390

        # do the Legendre-Gauss quadrature
391
        x = torch.view_as_real(x)
Boris Bonev's avatar
Boris Bonev committed
392
393

        # create output array
394
        xs = torch.zeros_like(x, dtype=x.dtype, device=x.device)
Boris Bonev's avatar
Boris Bonev committed
395
396
397

        # contraction - spheroidal component
        # real component
398
399
400
401
402
        xs[..., 0, :, :, 0] =   torch.einsum('...km,mlk->...lm', x[..., 0, :, :, 0], self.weights[0].to(xs.dtype)) \
                              - torch.einsum('...km,mlk->...lm', x[..., 1, :, :, 1], self.weights[1].to(xs.dtype))
        # imag component
        xs[..., 0, :, :, 1] =   torch.einsum('...km,mlk->...lm', x[..., 0, :, :, 1], self.weights[0].to(xs.dtype)) \
                              + torch.einsum('...km,mlk->...lm', x[..., 1, :, :, 0], self.weights[1].to(xs.dtype))
Boris Bonev's avatar
Boris Bonev committed
403
404
405

        # contraction - toroidal component
        # real component
406
407
        xs[..., 1, :, :, 0] = - torch.einsum('...km,mlk->...lm', x[..., 0, :, :, 1], self.weights[1].to(xs.dtype)) \
                              - torch.einsum('...km,mlk->...lm', x[..., 1, :, :, 0], self.weights[0].to(xs.dtype))
Boris Bonev's avatar
Boris Bonev committed
408
        # imag component
409
410
        xs[..., 1, :, :, 1] =   torch.einsum('...km,mlk->...lm', x[..., 0, :, :, 0], self.weights[1].to(xs.dtype)) \
                              - torch.einsum('...km,mlk->...lm', x[..., 1, :, :, 1], self.weights[0].to(xs.dtype))
Boris Bonev's avatar
Boris Bonev committed
411
412

        # pad if required
413
        x = torch.view_as_complex(xs)
Boris Bonev's avatar
Boris Bonev committed
414
415
416

        # transpose: after this, l is split and c is local
        if self.comm_size_polar > 1:
417
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_polar)
418
            x = distributed_transpose_polar.apply(x, (-2, -4), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
419

420
        return x
Boris Bonev's avatar
Boris Bonev committed
421
422
423
424
425
426
427
428
429
430


class DistributedInverseRealVectorSHT(nn.Module):
    """
    Defines a module for computing the inverse (real-valued) vector SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.

    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """
431
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
Boris Bonev's avatar
Boris Bonev committed
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # compute quadrature points
        if self.grid == "legendre-gauss":
            cost, _ = legendre_gauss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
            cost, _ = lobatto_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat-1
        elif self.grid == "equiangular":
            cost, _ = clenshaw_curtiss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        else:
            raise(ValueError("Unknown quadrature mode"))

        self.comm_size_polar = polar_group_size()
        self.comm_rank_polar = polar_group_rank()
        self.comm_size_azimuth = azimuth_group_size()
        self.comm_rank_azimuth = azimuth_group_rank()

        # apply cosine transform and flip them
Thorsten Kurth's avatar
Thorsten Kurth committed
460
        t = torch.flip(torch.arccos(cost), dims=(0,))
Boris Bonev's avatar
Boris Bonev committed
461
462
463
464

        # determine the dimensions
        self.mmax = mmax or self.nlon // 2 + 1

465
        # compute splits
466
467
468
469
        self.lat_shapes = compute_split_shapes(self.nlat, self.comm_size_polar)
        self.lon_shapes = compute_split_shapes(self.nlon, self.comm_size_azimuth)
        self.l_shapes = compute_split_shapes(self.lmax, self.comm_size_polar)
        self.m_shapes = compute_split_shapes(self.mmax, self.comm_size_azimuth)
Boris Bonev's avatar
Boris Bonev committed
470
471

        # compute legende polynomials
472
        dpct = _precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
Boris Bonev's avatar
Boris Bonev committed
473
474

        # split in m
Thorsten Kurth's avatar
Thorsten Kurth committed
475
        dpct = split_tensor_along_dim(dpct, dim=1, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth].contiguous()
Boris Bonev's avatar
Boris Bonev committed
476
477
478
479
480
481
482
483
484
485
486
487

        # register buffer
        self.register_buffer('dpct', dpct, persistent=False)

    def extra_repr(self):
        """
        Pretty print module
        """
        return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'

    def forward(self, x: torch.Tensor):

488
489
490
        if x.dim() < 4:
            raise ValueError(f"Expected tensor with at least 4 dimensions but got {x.dim()} instead")

491
        # store num channels
492
        num_chans = x.shape[-4]
Boris Bonev's avatar
Boris Bonev committed
493
494
495

        # transpose: after that, channels are split, l is local:
        if self.comm_size_polar > 1:
496
            x = distributed_transpose_polar.apply(x, (-4, -2), self.l_shapes)
Boris Bonev's avatar
Boris Bonev committed
497
498

        # Evaluate associated Legendre functions on the output nodes
499
        x = torch.view_as_real(x)
Boris Bonev's avatar
Boris Bonev committed
500
501
502

        # contraction - spheroidal component
        # real component
503
504
        srl =   torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 0], self.dpct[0].to(x.dtype)) \
              - torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 1], self.dpct[1].to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
505
        # imag component
506
507
        sim =   torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 1], self.dpct[0].to(x.dtype)) \
              + torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 0], self.dpct[1].to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
508
509
510

        # contraction - toroidal component
        # real component
511
512
        trl = - torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 1], self.dpct[1].to(x.dtype)) \
              - torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 0], self.dpct[0].to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
513
        # imag component
514
515
        tim =   torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 0], self.dpct[1].to(x.dtype)) \
              - torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 1], self.dpct[0].to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
516
517
518
519
520
521
522
523
524
525

        # reassemble
        s = torch.stack((srl, sim), -1)
        t = torch.stack((trl, tim), -1)
        xs = torch.stack((s, t), -4)

        # convert to complex
        x = torch.view_as_complex(xs)

        if self.comm_size_polar > 1:
526
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_polar)
527
            x = distributed_transpose_polar.apply(x, (-2, -4), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
528
529
530

        # transpose: after this, channels are split and m is local
        if self.comm_size_azimuth > 1:
531
            x = distributed_transpose_azimuth.apply(x, (-4, -1), self.m_shapes)
Boris Bonev's avatar
Boris Bonev committed
532

533
534
535
536
537
        # set DCT and nyquist frequencies to zero
        x[..., 0].imag = 0.0
        if (self.nlon % 2 == 0) and (self.nlon // 2 < x.shape[-1]):
            x[..., self.nlon // 2].imag = 0.0

Boris Bonev's avatar
Boris Bonev committed
538
539
540
541
542
        # apply the inverse (real) FFT
        x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")

        # transpose: after this, m is split and channels are local
        if self.comm_size_azimuth > 1:
543
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
544
            x = distributed_transpose_azimuth.apply(x, (-1, -4), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
545

546
        return x