distributed_sht.py 24.6 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import os
import torch
import torch.nn as nn
import torch.fft
import torch.nn.functional as F

38
39
from torch_harmonics.quadrature import legendre_gauss_weights, lobatto_weights, clenshaw_curtiss_weights
from torch_harmonics.legendre import _precompute_legpoly, _precompute_dlegpoly
Boris Bonev's avatar
Boris Bonev committed
40
41
from torch_harmonics.distributed import polar_group_size, azimuth_group_size, distributed_transpose_azimuth, distributed_transpose_polar
from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank
42
from torch_harmonics.distributed import compute_split_shapes, split_tensor_along_dim
Boris Bonev's avatar
Boris Bonev committed
43
44
45
46
47
48
49
50


class DistributedRealSHT(nn.Module):
    """
    Defines a module for computing the forward (real-valued) SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
    The SHT is applied to the last two dimensions of the input

apaaris's avatar
apaaris committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    Parameters
    ----------
    nlat: int
        Number of latitude points
    nlon: int
        Number of longitude points
    lmax: int
        Maximum spherical harmonic degree
    mmax: int
        Maximum spherical harmonic order
    grid: str
        Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
    norm: str
        Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
    csphase: bool
        Whether to apply the Condon-Shortley phase factor, by default True

    Returns
    -------
    x: torch.Tensor
        Tensor of shape (..., lmax, mmax)

    References
    ----------
Boris Bonev's avatar
Boris Bonev committed
75
76
77
78
    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """

79
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
80
        
Boris Bonev's avatar
Boris Bonev committed
81
82
83
84
85
86
87
88
89
90
91
92
        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # TODO: include assertions regarding the dimensions

        # compute quadrature points
        if self.grid == "legendre-gauss":
Thorsten Kurth's avatar
Thorsten Kurth committed
93
            cost, weights = legendre_gauss_weights(nlat, -1, 1)
Boris Bonev's avatar
Boris Bonev committed
94
95
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
Thorsten Kurth's avatar
Thorsten Kurth committed
96
            cost, weights = lobatto_weights(nlat, -1, 1)
Boris Bonev's avatar
Boris Bonev committed
97
98
            self.lmax = lmax or self.nlat-1
        elif self.grid == "equiangular":
Thorsten Kurth's avatar
Thorsten Kurth committed
99
            cost, weights = clenshaw_curtiss_weights(nlat, -1, 1)
Boris Bonev's avatar
Boris Bonev committed
100
101
102
103
104
105
106
107
108
109
110
111
            # cost, w = fejer2_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        else:
            raise(ValueError("Unknown quadrature mode"))

        # get the comms grid:
        self.comm_size_polar = polar_group_size()
        self.comm_rank_polar = polar_group_rank()
        self.comm_size_azimuth = azimuth_group_size()
        self.comm_rank_azimuth = azimuth_group_rank()

        # apply cosine transform and flip them
Thorsten Kurth's avatar
Thorsten Kurth committed
112
        tq = torch.flip(torch.arccos(cost), dims=(0,))
Boris Bonev's avatar
Boris Bonev committed
113
114
115
116

        # determine the dimensions
        self.mmax = mmax or self.nlon // 2 + 1

117
118
119
120
121
        # compute splits
        self.lat_shapes = compute_split_shapes(self.nlat, self.comm_size_polar)
        self.lon_shapes = compute_split_shapes(self.nlon, self.comm_size_azimuth)
        self.l_shapes = compute_split_shapes(self.lmax, self.comm_size_polar)
        self.m_shapes = compute_split_shapes(self.mmax, self.comm_size_azimuth)
Boris Bonev's avatar
Boris Bonev committed
122
123

        # combine quadrature weights with the legendre weights
124
        pct = _precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
Boris Bonev's avatar
Boris Bonev committed
125
        weights = torch.einsum('mlk,k->mlk', pct, weights)
126

127
        # split weights
Thorsten Kurth's avatar
Thorsten Kurth committed
128
        weights = split_tensor_along_dim(weights, dim=0, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth].contiguous()
Boris Bonev's avatar
Boris Bonev committed
129
130
131
132
133
134
135
136
137
138
139
140

        # remember quadrature weights
        self.register_buffer('weights', weights, persistent=False)

    def extra_repr(self):
        """
        Pretty print module
        """
        return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'

    def forward(self, x: torch.Tensor):

141
142
143
        if x.dim() < 3:
            raise ValueError(f"Expected tensor with at least 3 dimensions but got {x.dim()} instead")

Boris Bonev's avatar
Boris Bonev committed
144
        # we need to ensure that we can split the channels evenly
145
146
        num_chans = x.shape[-3]

Boris Bonev's avatar
Boris Bonev committed
147
148
        # h and w is split. First we make w local by transposing into channel dim
        if self.comm_size_azimuth > 1:
149
            x = distributed_transpose_azimuth.apply(x, (-3, -1), self.lon_shapes)
Boris Bonev's avatar
Boris Bonev committed
150
151

        # apply real fft in the longitudinal direction: make sure to truncate to nlon
152
        x = 2.0 * torch.pi * torch.fft.rfft(x, n=self.nlon, dim=-1, norm="forward")
Boris Bonev's avatar
Boris Bonev committed
153
154

        # truncate
155
        x = x[..., :self.mmax]
Boris Bonev's avatar
Boris Bonev committed
156
157
158

        # transpose: after this, m is split and c is local
        if self.comm_size_azimuth > 1:
159
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
160
            x = distributed_transpose_azimuth.apply(x, (-1, -3), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
161
162
163

        # transpose: after this, c is split and h is local
        if self.comm_size_polar > 1:
164
            x = distributed_transpose_polar.apply(x, (-3, -2), self.lat_shapes)
Boris Bonev's avatar
Boris Bonev committed
165
166

        # do the Legendre-Gauss quadrature
167
        x = torch.view_as_real(x)
Boris Bonev's avatar
Boris Bonev committed
168
169

        # contraction
170
        xs = torch.einsum('...kmr,mlk->...lmr', x, self.weights.to(x.dtype)).contiguous()
Boris Bonev's avatar
Boris Bonev committed
171

172
173
        # cast to complex
        x = torch.view_as_complex(xs)
Boris Bonev's avatar
Boris Bonev committed
174
175
176

        # transpose: after this, l is split and c is local
        if self.comm_size_polar	> 1:
177
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_polar)
178
179
            x = distributed_transpose_polar.apply(x, (-2, -3), chan_shapes)

180
        return x
Boris Bonev's avatar
Boris Bonev committed
181
182
183
184
185
186
187


class DistributedInverseRealSHT(nn.Module):
    """
    Defines a module for computing the inverse (real-valued) SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.

apaaris's avatar
apaaris committed
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
    Parameters
    ----------
    nlat: int
        Number of latitude points
    nlon: int
        Number of longitude points
    lmax: int
        Maximum spherical harmonic degree
    mmax: int
        Maximum spherical harmonic order
    grid: str
        Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
    norm: str
        Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
    csphase: bool
        Whether to apply the Condon-Shortley phase factor, by default True

    Returns
    -------
    x: torch.Tensor
        Tensor of shape (..., lmax, mmax)

    References
    ----------
Boris Bonev's avatar
Boris Bonev committed
212
213
214
215
    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """

216
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
Boris Bonev's avatar
Boris Bonev committed
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # compute quadrature points
        if self.grid == "legendre-gauss":
            cost, _ = legendre_gauss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
            cost, _ = lobatto_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat-1
        elif self.grid == "equiangular":
            cost, _ = clenshaw_curtiss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        else:
            raise(ValueError("Unknown quadrature mode"))

        # get the comms grid:
        self.comm_size_polar = polar_group_size()
        self.comm_rank_polar = polar_group_rank()
        self.comm_size_azimuth = azimuth_group_size()
        self.comm_rank_azimuth = azimuth_group_rank()

        # apply cosine transform and flip them
Thorsten Kurth's avatar
Thorsten Kurth committed
246
        t = torch.flip(torch.arccos(cost), dims=(0,))
Boris Bonev's avatar
Boris Bonev committed
247
248
249
250

        # determine the dimensions
        self.mmax = mmax or self.nlon // 2 + 1

251
        # compute splits
252
253
254
255
        self.lat_shapes = compute_split_shapes(self.nlat, self.comm_size_polar)
        self.lon_shapes = compute_split_shapes(self.nlon, self.comm_size_azimuth)
        self.l_shapes = compute_split_shapes(self.lmax, self.comm_size_polar)
        self.m_shapes = compute_split_shapes(self.mmax, self.comm_size_azimuth)
Boris Bonev's avatar
Boris Bonev committed
256
257

        # compute legende polynomials
258
        pct = _precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
Boris Bonev's avatar
Boris Bonev committed
259
260

        # split in m
Thorsten Kurth's avatar
Thorsten Kurth committed
261
        pct = split_tensor_along_dim(pct, dim=0, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth].contiguous()
Boris Bonev's avatar
Boris Bonev committed
262
263
264
265
266
267
268
269
270
271
272
273

        # register
        self.register_buffer('pct', pct, persistent=False)

    def extra_repr(self):
        """
        Pretty print module
        """
        return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'

    def forward(self, x: torch.Tensor):

274
275
276
        if x.dim() < 3:
            raise ValueError(f"Expected tensor with at least 3 dimensions but got {x.dim()} instead")

Boris Bonev's avatar
Boris Bonev committed
277
        # we need to ensure that we can split the channels evenly
278
        num_chans = x.shape[-3]
Boris Bonev's avatar
Boris Bonev committed
279
280
281

        # transpose: after that, channels are split, l is local:
        if self.comm_size_polar > 1:
282
            x = distributed_transpose_polar.apply(x, (-3, -2), self.l_shapes)
Boris Bonev's avatar
Boris Bonev committed
283
284

        # Evaluate associated Legendre functions on the output nodes
285
        x = torch.view_as_real(x)
Boris Bonev's avatar
Boris Bonev committed
286
287

        # einsum
288
        xs = torch.einsum('...lmr, mlk->...kmr', x, self.pct.to(x.dtype)).contiguous()
Boris Bonev's avatar
Boris Bonev committed
289

290
291
        # inverse FFT
        x = torch.view_as_complex(xs)
Boris Bonev's avatar
Boris Bonev committed
292
293

        if self.comm_size_polar > 1:
294
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_polar)
295
            x = distributed_transpose_polar.apply(x, (-2, -3), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
296
297
298

        # transpose: after this, channels are split and m is local
        if self.comm_size_azimuth > 1:
299
            x = distributed_transpose_azimuth.apply(x, (-3, -1), self.m_shapes)
Boris Bonev's avatar
Boris Bonev committed
300

301
302
303
304
305
        # set DCT and nyquist frequencies to 0:
        x[..., 0].imag = 0.0
        if (self.nlon % 2 == 0) and (self.nlon // 2 < x.shape[-1]):
            x[..., self.nlon // 2].imag = 0.0
            
Boris Bonev's avatar
Boris Bonev committed
306
        # apply the inverse (real) FFT
307
        x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
Boris Bonev's avatar
Boris Bonev committed
308
309
310

        # transpose: after this, m is split and channels are local
        if self.comm_size_azimuth > 1:
311
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
312
            x = distributed_transpose_azimuth.apply(x, (-1, -3), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
313

314
        return x
Boris Bonev's avatar
Boris Bonev committed
315
316
317
318
319
320
321
322


class DistributedRealVectorSHT(nn.Module):
    """
    Defines a module for computing the forward (real) vector SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
    The SHT is applied to the last three dimensions of the input.

apaaris's avatar
apaaris committed
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
    Parameters
    ----------
    nlat: int
        Number of latitude points
    nlon: int
        Number of longitude points  
    lmax: int
        Maximum spherical harmonic degree
    mmax: int
        Maximum spherical harmonic order
    grid: str
        Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
    norm: str
        Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
    csphase: bool
        Whether to apply the Condon-Shortley phase factor, by default True

    Returns
    -------
    x: torch.Tensor
        Tensor of shape (..., lmax, mmax)

    References
    ----------
Boris Bonev's avatar
Boris Bonev committed
347
348
349
350
    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """

351
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
Boris Bonev's avatar
Boris Bonev committed
352
353
354
355
356
357
358
359
360
361
362

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # compute quadrature points
        if self.grid == "legendre-gauss":
Thorsten Kurth's avatar
Thorsten Kurth committed
363
            cost, weights = legendre_gauss_weights(nlat, -1, 1)
Boris Bonev's avatar
Boris Bonev committed
364
365
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
Thorsten Kurth's avatar
Thorsten Kurth committed
366
            cost, weights = lobatto_weights(nlat, -1, 1)
Boris Bonev's avatar
Boris Bonev committed
367
368
            self.lmax = lmax or self.nlat-1
        elif self.grid == "equiangular":
Thorsten Kurth's avatar
Thorsten Kurth committed
369
            cost, weights = clenshaw_curtiss_weights(nlat, -1, 1)
Boris Bonev's avatar
Boris Bonev committed
370
371
372
373
374
375
376
377
378
379
380
381
            # cost, w = fejer2_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        else:
            raise(ValueError("Unknown quadrature mode"))

        # get the comms grid:
        self.comm_size_polar = polar_group_size()
        self.comm_rank_polar = polar_group_rank()
        self.comm_size_azimuth = azimuth_group_size()
        self.comm_rank_azimuth = azimuth_group_rank()

        # apply cosine transform and flip them
Thorsten Kurth's avatar
Thorsten Kurth committed
382
        tq = torch.flip(torch.arccos(cost), dims=(0,))
Boris Bonev's avatar
Boris Bonev committed
383
384
385
386

        # determine the dimensions
        self.mmax = mmax or self.nlon // 2 + 1

387
388
389
390
391
        # compute splits
        self.lat_shapes = compute_split_shapes(self.nlat, self.comm_size_polar)
        self.lon_shapes = compute_split_shapes(self.nlon, self.comm_size_azimuth)
        self.l_shapes = compute_split_shapes(self.lmax, self.comm_size_polar)
        self.m_shapes = compute_split_shapes(self.mmax, self.comm_size_azimuth)
Boris Bonev's avatar
Boris Bonev committed
392

393
        # compute weights
394
        dpct = _precompute_dlegpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
Boris Bonev's avatar
Boris Bonev committed
395
396
397
398
399
400
401
402
403
404

        # combine integration weights, normalization factor in to one:
        l = torch.arange(0, self.lmax)
        norm_factor = 1. / l / (l+1)
        norm_factor[0] = 1.
        weights = torch.einsum('dmlk,k,l->dmlk', dpct, weights, norm_factor)
        # since the second component is imaginary, we need to take complex conjugation into account
        weights[1] = -1 * weights[1]

        # we need to split in m, pad before:
Thorsten Kurth's avatar
Thorsten Kurth committed
405
        weights = split_tensor_along_dim(weights, dim=1, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth].contiguous()
Boris Bonev's avatar
Boris Bonev committed
406
407
408
409

        # remember quadrature weights
        self.register_buffer('weights', weights, persistent=False)

410

Boris Bonev's avatar
Boris Bonev committed
411
412
413
414
415
416
417
418
    def extra_repr(self):
        """
        Pretty print module
        """
        return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'

    def forward(self, x: torch.Tensor):

419
420
421
        if x.dim() < 4:
            raise ValueError(f"Expected tensor with at least 4 dimensions but got {x.dim()} instead")

422
        # we need to ensure that we can split the channels evenly
423
        num_chans = x.shape[-4]
Boris Bonev's avatar
Boris Bonev committed
424
425
426

        # h and w is split. First we make w local by transposing into channel dim
        if self.comm_size_azimuth > 1:
427
            x = distributed_transpose_azimuth.apply(x, (-4, -1), self.lon_shapes)
Boris Bonev's avatar
Boris Bonev committed
428
429

        # apply real fft in the longitudinal direction: make sure to truncate to nlon
430
        x = 2.0 * torch.pi * torch.fft.rfft(x, n=self.nlon, dim=-1, norm="forward")
Boris Bonev's avatar
Boris Bonev committed
431
432

        # truncate
433
        x = x[..., :self.mmax]
Boris Bonev's avatar
Boris Bonev committed
434
435
436

        # transpose: after this, m is split and c is local
        if self.comm_size_azimuth > 1:
437
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
438
            x = distributed_transpose_azimuth.apply(x, (-1, -4), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
439
440
441

        # transpose: after this, c is split and h is local
        if self.comm_size_polar > 1:
442
            x = distributed_transpose_polar.apply(x, (-4, -2), self.lat_shapes)
Boris Bonev's avatar
Boris Bonev committed
443
444

        # do the Legendre-Gauss quadrature
445
        x = torch.view_as_real(x)
Boris Bonev's avatar
Boris Bonev committed
446
447

        # create output array
448
        xs = torch.zeros_like(x, dtype=x.dtype, device=x.device)
Boris Bonev's avatar
Boris Bonev committed
449
450
451

        # contraction - spheroidal component
        # real component
452
453
454
455
456
        xs[..., 0, :, :, 0] =   torch.einsum('...km,mlk->...lm', x[..., 0, :, :, 0], self.weights[0].to(xs.dtype)) \
                              - torch.einsum('...km,mlk->...lm', x[..., 1, :, :, 1], self.weights[1].to(xs.dtype))
        # imag component
        xs[..., 0, :, :, 1] =   torch.einsum('...km,mlk->...lm', x[..., 0, :, :, 1], self.weights[0].to(xs.dtype)) \
                              + torch.einsum('...km,mlk->...lm', x[..., 1, :, :, 0], self.weights[1].to(xs.dtype))
Boris Bonev's avatar
Boris Bonev committed
457
458
459

        # contraction - toroidal component
        # real component
460
461
        xs[..., 1, :, :, 0] = - torch.einsum('...km,mlk->...lm', x[..., 0, :, :, 1], self.weights[1].to(xs.dtype)) \
                              - torch.einsum('...km,mlk->...lm', x[..., 1, :, :, 0], self.weights[0].to(xs.dtype))
Boris Bonev's avatar
Boris Bonev committed
462
        # imag component
463
464
        xs[..., 1, :, :, 1] =   torch.einsum('...km,mlk->...lm', x[..., 0, :, :, 0], self.weights[1].to(xs.dtype)) \
                              - torch.einsum('...km,mlk->...lm', x[..., 1, :, :, 1], self.weights[0].to(xs.dtype))
Boris Bonev's avatar
Boris Bonev committed
465
466

        # pad if required
467
        x = torch.view_as_complex(xs)
Boris Bonev's avatar
Boris Bonev committed
468
469
470

        # transpose: after this, l is split and c is local
        if self.comm_size_polar > 1:
471
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_polar)
472
            x = distributed_transpose_polar.apply(x, (-2, -4), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
473

474
        return x
Boris Bonev's avatar
Boris Bonev committed
475
476
477
478
479
480
481


class DistributedInverseRealVectorSHT(nn.Module):
    """
    Defines a module for computing the inverse (real-valued) vector SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.

apaaris's avatar
apaaris committed
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
    Parameters
    ----------
    nlat: int
        Number of latitude points
    nlon: int
        Number of longitude points
    lmax: int
        Maximum spherical harmonic degree
    mmax: int
        Maximum spherical harmonic order
    grid: str
        Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
    norm: str
        Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
    csphase: bool
        Whether to apply the Condon-Shortley phase factor, by default True

    Returns
    -------
    x: torch.Tensor
        Tensor of shape (..., lmax, mmax)

    References
    ----------
Boris Bonev's avatar
Boris Bonev committed
506
507
508
    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """
509
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
Boris Bonev's avatar
Boris Bonev committed
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # compute quadrature points
        if self.grid == "legendre-gauss":
            cost, _ = legendre_gauss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
            cost, _ = lobatto_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat-1
        elif self.grid == "equiangular":
            cost, _ = clenshaw_curtiss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        else:
            raise(ValueError("Unknown quadrature mode"))

        self.comm_size_polar = polar_group_size()
        self.comm_rank_polar = polar_group_rank()
        self.comm_size_azimuth = azimuth_group_size()
        self.comm_rank_azimuth = azimuth_group_rank()

        # apply cosine transform and flip them
Thorsten Kurth's avatar
Thorsten Kurth committed
538
        t = torch.flip(torch.arccos(cost), dims=(0,))
Boris Bonev's avatar
Boris Bonev committed
539
540
541
542

        # determine the dimensions
        self.mmax = mmax or self.nlon // 2 + 1

543
        # compute splits
544
545
546
547
        self.lat_shapes = compute_split_shapes(self.nlat, self.comm_size_polar)
        self.lon_shapes = compute_split_shapes(self.nlon, self.comm_size_azimuth)
        self.l_shapes = compute_split_shapes(self.lmax, self.comm_size_polar)
        self.m_shapes = compute_split_shapes(self.mmax, self.comm_size_azimuth)
Boris Bonev's avatar
Boris Bonev committed
548
549

        # compute legende polynomials
550
        dpct = _precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
Boris Bonev's avatar
Boris Bonev committed
551
552

        # split in m
Thorsten Kurth's avatar
Thorsten Kurth committed
553
        dpct = split_tensor_along_dim(dpct, dim=1, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth].contiguous()
Boris Bonev's avatar
Boris Bonev committed
554
555
556
557
558
559
560
561
562
563
564
565

        # register buffer
        self.register_buffer('dpct', dpct, persistent=False)

    def extra_repr(self):
        """
        Pretty print module
        """
        return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'

    def forward(self, x: torch.Tensor):

566
567
568
        if x.dim() < 4:
            raise ValueError(f"Expected tensor with at least 4 dimensions but got {x.dim()} instead")

569
        # store num channels
570
        num_chans = x.shape[-4]
Boris Bonev's avatar
Boris Bonev committed
571
572
573

        # transpose: after that, channels are split, l is local:
        if self.comm_size_polar > 1:
574
            x = distributed_transpose_polar.apply(x, (-4, -2), self.l_shapes)
Boris Bonev's avatar
Boris Bonev committed
575
576

        # Evaluate associated Legendre functions on the output nodes
577
        x = torch.view_as_real(x)
Boris Bonev's avatar
Boris Bonev committed
578
579
580

        # contraction - spheroidal component
        # real component
581
582
        srl =   torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 0], self.dpct[0].to(x.dtype)) \
              - torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 1], self.dpct[1].to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
583
        # imag component
584
585
        sim =   torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 1], self.dpct[0].to(x.dtype)) \
              + torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 0], self.dpct[1].to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
586
587
588

        # contraction - toroidal component
        # real component
589
590
        trl = - torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 1], self.dpct[1].to(x.dtype)) \
              - torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 0], self.dpct[0].to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
591
        # imag component
592
593
        tim =   torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 0], self.dpct[1].to(x.dtype)) \
              - torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 1], self.dpct[0].to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
594
595
596
597
598
599
600
601
602
603

        # reassemble
        s = torch.stack((srl, sim), -1)
        t = torch.stack((trl, tim), -1)
        xs = torch.stack((s, t), -4)

        # convert to complex
        x = torch.view_as_complex(xs)

        if self.comm_size_polar > 1:
604
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_polar)
605
            x = distributed_transpose_polar.apply(x, (-2, -4), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
606
607
608

        # transpose: after this, channels are split and m is local
        if self.comm_size_azimuth > 1:
609
            x = distributed_transpose_azimuth.apply(x, (-4, -1), self.m_shapes)
Boris Bonev's avatar
Boris Bonev committed
610

611
612
613
614
615
        # set DCT and nyquist frequencies to zero
        x[..., 0].imag = 0.0
        if (self.nlon % 2 == 0) and (self.nlon // 2 < x.shape[-1]):
            x[..., self.nlon // 2].imag = 0.0

Boris Bonev's avatar
Boris Bonev committed
616
617
618
619
620
        # apply the inverse (real) FFT
        x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")

        # transpose: after this, m is split and channels are local
        if self.comm_size_azimuth > 1:
621
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
622
            x = distributed_transpose_azimuth.apply(x, (-1, -4), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
623

624
        return x