distributed_sht.py 22.7 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import os
import torch
import torch.nn as nn
import torch.fft
import torch.nn.functional as F

38
39
from torch_harmonics.quadrature import legendre_gauss_weights, lobatto_weights, clenshaw_curtiss_weights
from torch_harmonics.legendre import _precompute_legpoly, _precompute_dlegpoly
Boris Bonev's avatar
Boris Bonev committed
40
41
from torch_harmonics.distributed import polar_group_size, azimuth_group_size, distributed_transpose_azimuth, distributed_transpose_polar
from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank
42
from torch_harmonics.distributed import compute_split_shapes, split_tensor_along_dim
Boris Bonev's avatar
Boris Bonev committed
43
44
45
46
47
48
49
50
51
52
53
54


class DistributedRealSHT(nn.Module):
    """
    Defines a module for computing the forward (real-valued) SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
    The SHT is applied to the last two dimensions of the input

    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """

55
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
Boris Bonev's avatar
Boris Bonev committed
56
        """
57
        Distribtued SHT layer. Expects the last 3 dimensions of the input tensor to be channels, latitude, longitude.
Boris Bonev's avatar
Boris Bonev committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76

        Parameters:
        nlat: input grid resolution in the latitudinal direction
        nlon: input grid resolution in the longitudinal direction
        grid: grid in the latitude direction (for now only tensor product grids are supported)
        """

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # TODO: include assertions regarding the dimensions

        # compute quadrature points
        if self.grid == "legendre-gauss":
Thorsten Kurth's avatar
Thorsten Kurth committed
77
            cost, weights = legendre_gauss_weights(nlat, -1, 1)
Boris Bonev's avatar
Boris Bonev committed
78
79
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
Thorsten Kurth's avatar
Thorsten Kurth committed
80
            cost, weights = lobatto_weights(nlat, -1, 1)
Boris Bonev's avatar
Boris Bonev committed
81
82
            self.lmax = lmax or self.nlat-1
        elif self.grid == "equiangular":
Thorsten Kurth's avatar
Thorsten Kurth committed
83
            cost, weights = clenshaw_curtiss_weights(nlat, -1, 1)
Boris Bonev's avatar
Boris Bonev committed
84
85
86
87
88
89
90
91
92
93
94
95
            # cost, w = fejer2_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        else:
            raise(ValueError("Unknown quadrature mode"))

        # get the comms grid:
        self.comm_size_polar = polar_group_size()
        self.comm_rank_polar = polar_group_rank()
        self.comm_size_azimuth = azimuth_group_size()
        self.comm_rank_azimuth = azimuth_group_rank()

        # apply cosine transform and flip them
Thorsten Kurth's avatar
Thorsten Kurth committed
96
        tq = torch.flip(torch.arccos(cost), dims=(0,))
Boris Bonev's avatar
Boris Bonev committed
97
98
99
100

        # determine the dimensions
        self.mmax = mmax or self.nlon // 2 + 1

101
102
103
104
105
        # compute splits
        self.lat_shapes = compute_split_shapes(self.nlat, self.comm_size_polar)
        self.lon_shapes = compute_split_shapes(self.nlon, self.comm_size_azimuth)
        self.l_shapes = compute_split_shapes(self.lmax, self.comm_size_polar)
        self.m_shapes = compute_split_shapes(self.mmax, self.comm_size_azimuth)
Boris Bonev's avatar
Boris Bonev committed
106
107

        # combine quadrature weights with the legendre weights
108
        pct = _precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
Boris Bonev's avatar
Boris Bonev committed
109
        weights = torch.einsum('mlk,k->mlk', pct, weights)
110

111
        # split weights
Thorsten Kurth's avatar
Thorsten Kurth committed
112
        weights = split_tensor_along_dim(weights, dim=0, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth].contiguous()
Boris Bonev's avatar
Boris Bonev committed
113
114
115
116
117
118
119
120
121
122
123
124

        # remember quadrature weights
        self.register_buffer('weights', weights, persistent=False)

    def extra_repr(self):
        """
        Pretty print module
        """
        return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'

    def forward(self, x: torch.Tensor):

125
126
127
        if x.dim() < 3:
            raise ValueError(f"Expected tensor with at least 3 dimensions but got {x.dim()} instead")

Boris Bonev's avatar
Boris Bonev committed
128
        # we need to ensure that we can split the channels evenly
129
130
        num_chans = x.shape[-3]

Boris Bonev's avatar
Boris Bonev committed
131
132
        # h and w is split. First we make w local by transposing into channel dim
        if self.comm_size_azimuth > 1:
133
            x = distributed_transpose_azimuth.apply(x, (-3, -1), self.lon_shapes)
Boris Bonev's avatar
Boris Bonev committed
134
135

        # apply real fft in the longitudinal direction: make sure to truncate to nlon
136
        x = 2.0 * torch.pi * torch.fft.rfft(x, n=self.nlon, dim=-1, norm="forward")
Boris Bonev's avatar
Boris Bonev committed
137
138

        # truncate
139
        x = x[..., :self.mmax]
Boris Bonev's avatar
Boris Bonev committed
140
141
142

        # transpose: after this, m is split and c is local
        if self.comm_size_azimuth > 1:
143
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
144
            x = distributed_transpose_azimuth.apply(x, (-1, -3), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
145
146
147

        # transpose: after this, c is split and h is local
        if self.comm_size_polar > 1:
148
            x = distributed_transpose_polar.apply(x, (-3, -2), self.lat_shapes)
Boris Bonev's avatar
Boris Bonev committed
149
150

        # do the Legendre-Gauss quadrature
151
        x = torch.view_as_real(x)
Boris Bonev's avatar
Boris Bonev committed
152
153

        # contraction
154
        xs = torch.einsum('...kmr,mlk->...lmr', x, self.weights.to(x.dtype)).contiguous()
Boris Bonev's avatar
Boris Bonev committed
155

156
157
        # cast to complex
        x = torch.view_as_complex(xs)
Boris Bonev's avatar
Boris Bonev committed
158
159
160

        # transpose: after this, l is split and c is local
        if self.comm_size_polar	> 1:
161
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_polar)
162
163
            x = distributed_transpose_polar.apply(x, (-2, -3), chan_shapes)

164
        return x
Boris Bonev's avatar
Boris Bonev committed
165
166
167
168
169
170
171
172
173
174
175
176
177


class DistributedInverseRealSHT(nn.Module):
    """
    Defines a module for computing the inverse (real-valued) SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
    nlat, nlon: Output dimensions
    lmax, mmax: Input dimensions (spherical coefficients). For convenience, these are inferred from the output dimensions

    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """

178
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
Boris Bonev's avatar
Boris Bonev committed
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # compute quadrature points
        if self.grid == "legendre-gauss":
            cost, _ = legendre_gauss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
            cost, _ = lobatto_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat-1
        elif self.grid == "equiangular":
            cost, _ = clenshaw_curtiss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        else:
            raise(ValueError("Unknown quadrature mode"))

        # get the comms grid:
        self.comm_size_polar = polar_group_size()
        self.comm_rank_polar = polar_group_rank()
        self.comm_size_azimuth = azimuth_group_size()
        self.comm_rank_azimuth = azimuth_group_rank()

        # apply cosine transform and flip them
Thorsten Kurth's avatar
Thorsten Kurth committed
208
        t = torch.flip(torch.arccos(cost), dims=(0,))
Boris Bonev's avatar
Boris Bonev committed
209
210
211
212

        # determine the dimensions
        self.mmax = mmax or self.nlon // 2 + 1

213
        # compute splits
214
215
216
217
        self.lat_shapes = compute_split_shapes(self.nlat, self.comm_size_polar)
        self.lon_shapes = compute_split_shapes(self.nlon, self.comm_size_azimuth)
        self.l_shapes = compute_split_shapes(self.lmax, self.comm_size_polar)
        self.m_shapes = compute_split_shapes(self.mmax, self.comm_size_azimuth)
Boris Bonev's avatar
Boris Bonev committed
218
219

        # compute legende polynomials
220
        pct = _precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
Boris Bonev's avatar
Boris Bonev committed
221
222

        # split in m
Thorsten Kurth's avatar
Thorsten Kurth committed
223
        pct = split_tensor_along_dim(pct, dim=0, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth].contiguous()
Boris Bonev's avatar
Boris Bonev committed
224
225
226
227
228
229
230
231
232
233
234
235

        # register
        self.register_buffer('pct', pct, persistent=False)

    def extra_repr(self):
        """
        Pretty print module
        """
        return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'

    def forward(self, x: torch.Tensor):

236
237
238
        if x.dim() < 3:
            raise ValueError(f"Expected tensor with at least 3 dimensions but got {x.dim()} instead")

Boris Bonev's avatar
Boris Bonev committed
239
        # we need to ensure that we can split the channels evenly
240
        num_chans = x.shape[-3]
Boris Bonev's avatar
Boris Bonev committed
241
242
243

        # transpose: after that, channels are split, l is local:
        if self.comm_size_polar > 1:
244
            x = distributed_transpose_polar.apply(x, (-3, -2), self.l_shapes)
Boris Bonev's avatar
Boris Bonev committed
245
246

        # Evaluate associated Legendre functions on the output nodes
247
        x = torch.view_as_real(x)
Boris Bonev's avatar
Boris Bonev committed
248
249

        # einsum
250
251
252
253
        xs = torch.einsum('...lmr, mlk->...kmr', x, self.pct.to(x.dtype)).contiguous()
        #rl = torch.einsum('...lm, mlk->...km', x[..., 0], self.pct.to(x.dtype) )
        #im = torch.einsum('...lm, mlk->...km', x[..., 1], self.pct.to(x.dtype) )
        #xs = torch.stack((rl, im), -1).contiguous()
Boris Bonev's avatar
Boris Bonev committed
254

255
256
        # inverse FFT
        x = torch.view_as_complex(xs)
Boris Bonev's avatar
Boris Bonev committed
257
258

        if self.comm_size_polar > 1:
259
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_polar)
260
            x = distributed_transpose_polar.apply(x, (-2, -3), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
261
262
263

        # transpose: after this, channels are split and m is local
        if self.comm_size_azimuth > 1:
264
            x = distributed_transpose_azimuth.apply(x, (-3, -1), self.m_shapes)
Boris Bonev's avatar
Boris Bonev committed
265
266

        # apply the inverse (real) FFT
267
        x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
Boris Bonev's avatar
Boris Bonev committed
268
269
270

        # transpose: after this, m is split and channels are local
        if self.comm_size_azimuth > 1:
271
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
272
            x = distributed_transpose_azimuth.apply(x, (-1, -3), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
273

274
        return x
Boris Bonev's avatar
Boris Bonev committed
275
276
277
278
279
280
281
282
283
284
285
286


class DistributedRealVectorSHT(nn.Module):
    """
    Defines a module for computing the forward (real) vector SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
    The SHT is applied to the last three dimensions of the input.

    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """

287
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
Boris Bonev's avatar
Boris Bonev committed
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
        """
        Initializes the vector SHT Layer, precomputing the necessary quadrature weights

        Parameters:
        nlat: input grid resolution in the latitudinal direction
        nlon: input grid resolution in the longitudinal direction
        grid: type of grid the data lives on
        """

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # compute quadrature points
        if self.grid == "legendre-gauss":
Thorsten Kurth's avatar
Thorsten Kurth committed
307
            cost, weights = legendre_gauss_weights(nlat, -1, 1)
Boris Bonev's avatar
Boris Bonev committed
308
309
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
Thorsten Kurth's avatar
Thorsten Kurth committed
310
            cost, weights = lobatto_weights(nlat, -1, 1)
Boris Bonev's avatar
Boris Bonev committed
311
312
            self.lmax = lmax or self.nlat-1
        elif self.grid == "equiangular":
Thorsten Kurth's avatar
Thorsten Kurth committed
313
            cost, weights = clenshaw_curtiss_weights(nlat, -1, 1)
Boris Bonev's avatar
Boris Bonev committed
314
315
316
317
318
319
320
321
322
323
324
325
            # cost, w = fejer2_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        else:
            raise(ValueError("Unknown quadrature mode"))

        # get the comms grid:
        self.comm_size_polar = polar_group_size()
        self.comm_rank_polar = polar_group_rank()
        self.comm_size_azimuth = azimuth_group_size()
        self.comm_rank_azimuth = azimuth_group_rank()

        # apply cosine transform and flip them
Thorsten Kurth's avatar
Thorsten Kurth committed
326
        tq = torch.flip(torch.arccos(cost), dims=(0,))
Boris Bonev's avatar
Boris Bonev committed
327
328
329
330

        # determine the dimensions
        self.mmax = mmax or self.nlon // 2 + 1

331
332
333
334
335
        # compute splits
        self.lat_shapes = compute_split_shapes(self.nlat, self.comm_size_polar)
        self.lon_shapes = compute_split_shapes(self.nlon, self.comm_size_azimuth)
        self.l_shapes = compute_split_shapes(self.lmax, self.comm_size_polar)
        self.m_shapes = compute_split_shapes(self.mmax, self.comm_size_azimuth)
Boris Bonev's avatar
Boris Bonev committed
336

337
        # compute weights
338
        dpct = _precompute_dlegpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
Boris Bonev's avatar
Boris Bonev committed
339
340
341
342
343
344
345
346
347
348

        # combine integration weights, normalization factor in to one:
        l = torch.arange(0, self.lmax)
        norm_factor = 1. / l / (l+1)
        norm_factor[0] = 1.
        weights = torch.einsum('dmlk,k,l->dmlk', dpct, weights, norm_factor)
        # since the second component is imaginary, we need to take complex conjugation into account
        weights[1] = -1 * weights[1]

        # we need to split in m, pad before:
Thorsten Kurth's avatar
Thorsten Kurth committed
349
        weights = split_tensor_along_dim(weights, dim=1, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth].contiguous()
Boris Bonev's avatar
Boris Bonev committed
350
351
352
353

        # remember quadrature weights
        self.register_buffer('weights', weights, persistent=False)

354

Boris Bonev's avatar
Boris Bonev committed
355
356
357
358
359
360
361
362
    def extra_repr(self):
        """
        Pretty print module
        """
        return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'

    def forward(self, x: torch.Tensor):

363
364
365
        if x.dim() < 4:
            raise ValueError(f"Expected tensor with at least 4 dimensions but got {x.dim()} instead")

366
        # we need to ensure that we can split the channels evenly
367
        num_chans = x.shape[-4]
Boris Bonev's avatar
Boris Bonev committed
368
369
370

        # h and w is split. First we make w local by transposing into channel dim
        if self.comm_size_azimuth > 1:
371
            x = distributed_transpose_azimuth.apply(x, (-4, -1), self.lon_shapes)
Boris Bonev's avatar
Boris Bonev committed
372
373

        # apply real fft in the longitudinal direction: make sure to truncate to nlon
374
        x = 2.0 * torch.pi * torch.fft.rfft(x, n=self.nlon, dim=-1, norm="forward")
Boris Bonev's avatar
Boris Bonev committed
375
376

        # truncate
377
        x = x[..., :self.mmax]
Boris Bonev's avatar
Boris Bonev committed
378
379
380

        # transpose: after this, m is split and c is local
        if self.comm_size_azimuth > 1:
381
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
382
            x = distributed_transpose_azimuth.apply(x, (-1, -4), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
383
384
385

        # transpose: after this, c is split and h is local
        if self.comm_size_polar > 1:
386
            x = distributed_transpose_polar.apply(x, (-4, -2), self.lat_shapes)
Boris Bonev's avatar
Boris Bonev committed
387
388

        # do the Legendre-Gauss quadrature
389
        x = torch.view_as_real(x)
Boris Bonev's avatar
Boris Bonev committed
390
391

        # create output array
392
        xs = torch.zeros_like(x, dtype=x.dtype, device=x.device)
Boris Bonev's avatar
Boris Bonev committed
393
394
395

        # contraction - spheroidal component
        # real component
396
397
398
399
400
        xs[..., 0, :, :, 0] =   torch.einsum('...km,mlk->...lm', x[..., 0, :, :, 0], self.weights[0].to(xs.dtype)) \
                              - torch.einsum('...km,mlk->...lm', x[..., 1, :, :, 1], self.weights[1].to(xs.dtype))
        # imag component
        xs[..., 0, :, :, 1] =   torch.einsum('...km,mlk->...lm', x[..., 0, :, :, 1], self.weights[0].to(xs.dtype)) \
                              + torch.einsum('...km,mlk->...lm', x[..., 1, :, :, 0], self.weights[1].to(xs.dtype))
Boris Bonev's avatar
Boris Bonev committed
401
402
403

        # contraction - toroidal component
        # real component
404
405
        xs[..., 1, :, :, 0] = - torch.einsum('...km,mlk->...lm', x[..., 0, :, :, 1], self.weights[1].to(xs.dtype)) \
                              - torch.einsum('...km,mlk->...lm', x[..., 1, :, :, 0], self.weights[0].to(xs.dtype))
Boris Bonev's avatar
Boris Bonev committed
406
        # imag component
407
408
        xs[..., 1, :, :, 1] =   torch.einsum('...km,mlk->...lm', x[..., 0, :, :, 0], self.weights[1].to(xs.dtype)) \
                              - torch.einsum('...km,mlk->...lm', x[..., 1, :, :, 1], self.weights[0].to(xs.dtype))
Boris Bonev's avatar
Boris Bonev committed
409
410

        # pad if required
411
        x = torch.view_as_complex(xs)
Boris Bonev's avatar
Boris Bonev committed
412
413
414

        # transpose: after this, l is split and c is local
        if self.comm_size_polar > 1:
415
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_polar)
416
            x = distributed_transpose_polar.apply(x, (-2, -4), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
417

418
        return x
Boris Bonev's avatar
Boris Bonev committed
419
420
421
422
423
424
425
426
427
428


class DistributedInverseRealVectorSHT(nn.Module):
    """
    Defines a module for computing the inverse (real-valued) vector SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.

    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """
429
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
Boris Bonev's avatar
Boris Bonev committed
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # compute quadrature points
        if self.grid == "legendre-gauss":
            cost, _ = legendre_gauss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
            cost, _ = lobatto_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat-1
        elif self.grid == "equiangular":
            cost, _ = clenshaw_curtiss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        else:
            raise(ValueError("Unknown quadrature mode"))

        self.comm_size_polar = polar_group_size()
        self.comm_rank_polar = polar_group_rank()
        self.comm_size_azimuth = azimuth_group_size()
        self.comm_rank_azimuth = azimuth_group_rank()

        # apply cosine transform and flip them
Thorsten Kurth's avatar
Thorsten Kurth committed
458
        t = torch.flip(torch.arccos(cost), dims=(0,))
Boris Bonev's avatar
Boris Bonev committed
459
460
461
462

        # determine the dimensions
        self.mmax = mmax or self.nlon // 2 + 1

463
        # compute splits
464
465
466
467
        self.lat_shapes = compute_split_shapes(self.nlat, self.comm_size_polar)
        self.lon_shapes = compute_split_shapes(self.nlon, self.comm_size_azimuth)
        self.l_shapes = compute_split_shapes(self.lmax, self.comm_size_polar)
        self.m_shapes = compute_split_shapes(self.mmax, self.comm_size_azimuth)
Boris Bonev's avatar
Boris Bonev committed
468
469

        # compute legende polynomials
470
        dpct = _precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
Boris Bonev's avatar
Boris Bonev committed
471
472

        # split in m
Thorsten Kurth's avatar
Thorsten Kurth committed
473
        dpct = split_tensor_along_dim(dpct, dim=1, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth].contiguous()
Boris Bonev's avatar
Boris Bonev committed
474
475
476
477
478
479
480
481
482
483
484
485

        # register buffer
        self.register_buffer('dpct', dpct, persistent=False)

    def extra_repr(self):
        """
        Pretty print module
        """
        return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'

    def forward(self, x: torch.Tensor):

486
487
488
        if x.dim() < 4:
            raise ValueError(f"Expected tensor with at least 4 dimensions but got {x.dim()} instead")

489
        # store num channels
490
        num_chans = x.shape[-4]
Boris Bonev's avatar
Boris Bonev committed
491
492
493

        # transpose: after that, channels are split, l is local:
        if self.comm_size_polar > 1:
494
            x = distributed_transpose_polar.apply(x, (-4, -2), self.l_shapes)
Boris Bonev's avatar
Boris Bonev committed
495
496

        # Evaluate associated Legendre functions on the output nodes
497
        x = torch.view_as_real(x)
Boris Bonev's avatar
Boris Bonev committed
498
499
500

        # contraction - spheroidal component
        # real component
501
502
        srl =   torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 0], self.dpct[0].to(x.dtype)) \
              - torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 1], self.dpct[1].to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
503
        # imag component
504
505
        sim =   torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 1], self.dpct[0].to(x.dtype)) \
              + torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 0], self.dpct[1].to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
506
507
508

        # contraction - toroidal component
        # real component
509
510
        trl = - torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 1], self.dpct[1].to(x.dtype)) \
              - torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 0], self.dpct[0].to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
511
        # imag component
512
513
        tim =   torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 0], self.dpct[1].to(x.dtype)) \
              - torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 1], self.dpct[0].to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
514
515
516
517
518
519
520
521
522
523

        # reassemble
        s = torch.stack((srl, sim), -1)
        t = torch.stack((trl, tim), -1)
        xs = torch.stack((s, t), -4)

        # convert to complex
        x = torch.view_as_complex(xs)

        if self.comm_size_polar > 1:
524
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_polar)
525
            x = distributed_transpose_polar.apply(x, (-2, -4), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
526
527
528

        # transpose: after this, channels are split and m is local
        if self.comm_size_azimuth > 1:
529
            x = distributed_transpose_azimuth.apply(x, (-4, -1), self.m_shapes)
Boris Bonev's avatar
Boris Bonev committed
530
531
532
533
534
535

        # apply the inverse (real) FFT
        x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")

        # transpose: after this, m is split and channels are local
        if self.comm_size_azimuth > 1:
536
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
537
            x = distributed_transpose_azimuth.apply(x, (-1, -4), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
538

539
        return x