distributed_sht.py 25.9 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import os
import torch
import torch.nn as nn
import torch.fft
import torch.nn.functional as F

38
39
from torch_harmonics.quadrature import legendre_gauss_weights, lobatto_weights, clenshaw_curtiss_weights
from torch_harmonics.legendre import _precompute_legpoly, _precompute_dlegpoly
Boris Bonev's avatar
Boris Bonev committed
40
41
from torch_harmonics.distributed import polar_group_size, azimuth_group_size, distributed_transpose_azimuth, distributed_transpose_polar
from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank
42
from torch_harmonics.distributed import compute_split_shapes, split_tensor_along_dim
Boris Bonev's avatar
Boris Bonev committed
43
44
45
46
47
48
49
50


class DistributedRealSHT(nn.Module):
    """
    Defines a module for computing the forward (real-valued) SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
    The SHT is applied to the last two dimensions of the input

apaaris's avatar
apaaris committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    Parameters
    ----------
    nlat: int
        Number of latitude points
    nlon: int
        Number of longitude points
    lmax: int
        Maximum spherical harmonic degree
    mmax: int
        Maximum spherical harmonic order
    grid: str
        Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
    norm: str
        Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
    csphase: bool
        Whether to apply the Condon-Shortley phase factor, by default True

    Returns
    -------
    x: torch.Tensor
        Tensor of shape (..., lmax, mmax)

    References
    ----------
Boris Bonev's avatar
Boris Bonev committed
75
76
77
78
    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """

79
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
Boris Bonev's avatar
Boris Bonev committed
80
        """
81
        Distribtued SHT layer. Expects the last 3 dimensions of the input tensor to be channels, latitude, longitude.
Boris Bonev's avatar
Boris Bonev committed
82

apaaris's avatar
apaaris committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
        Parameters
        ----------
        nlat: int
            Number of latitude points
        nlon: int
            Number of longitude points
        lmax: int
            Maximum spherical harmonic degree
        mmax: int
            Maximum spherical harmonic order
        grid: str
            Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
        norm: str
            Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
        csphase: bool
            Whether to apply the Condon-Shortley phase factor, by default True
Boris Bonev's avatar
Boris Bonev committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
        """

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # TODO: include assertions regarding the dimensions

        # compute quadrature points
        if self.grid == "legendre-gauss":
Thorsten Kurth's avatar
Thorsten Kurth committed
113
            cost, weights = legendre_gauss_weights(nlat, -1, 1)
Boris Bonev's avatar
Boris Bonev committed
114
115
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
Thorsten Kurth's avatar
Thorsten Kurth committed
116
            cost, weights = lobatto_weights(nlat, -1, 1)
Boris Bonev's avatar
Boris Bonev committed
117
118
            self.lmax = lmax or self.nlat-1
        elif self.grid == "equiangular":
Thorsten Kurth's avatar
Thorsten Kurth committed
119
            cost, weights = clenshaw_curtiss_weights(nlat, -1, 1)
Boris Bonev's avatar
Boris Bonev committed
120
121
122
123
124
125
126
127
128
129
130
131
            # cost, w = fejer2_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        else:
            raise(ValueError("Unknown quadrature mode"))

        # get the comms grid:
        self.comm_size_polar = polar_group_size()
        self.comm_rank_polar = polar_group_rank()
        self.comm_size_azimuth = azimuth_group_size()
        self.comm_rank_azimuth = azimuth_group_rank()

        # apply cosine transform and flip them
Thorsten Kurth's avatar
Thorsten Kurth committed
132
        tq = torch.flip(torch.arccos(cost), dims=(0,))
Boris Bonev's avatar
Boris Bonev committed
133
134
135
136

        # determine the dimensions
        self.mmax = mmax or self.nlon // 2 + 1

137
138
139
140
141
        # compute splits
        self.lat_shapes = compute_split_shapes(self.nlat, self.comm_size_polar)
        self.lon_shapes = compute_split_shapes(self.nlon, self.comm_size_azimuth)
        self.l_shapes = compute_split_shapes(self.lmax, self.comm_size_polar)
        self.m_shapes = compute_split_shapes(self.mmax, self.comm_size_azimuth)
Boris Bonev's avatar
Boris Bonev committed
142
143

        # combine quadrature weights with the legendre weights
144
        pct = _precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
Boris Bonev's avatar
Boris Bonev committed
145
        weights = torch.einsum('mlk,k->mlk', pct, weights)
146

147
        # split weights
Thorsten Kurth's avatar
Thorsten Kurth committed
148
        weights = split_tensor_along_dim(weights, dim=0, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth].contiguous()
Boris Bonev's avatar
Boris Bonev committed
149
150
151
152
153
154
155
156
157
158
159
160

        # remember quadrature weights
        self.register_buffer('weights', weights, persistent=False)

    def extra_repr(self):
        """
        Pretty print module
        """
        return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'

    def forward(self, x: torch.Tensor):

161
162
163
        if x.dim() < 3:
            raise ValueError(f"Expected tensor with at least 3 dimensions but got {x.dim()} instead")

Boris Bonev's avatar
Boris Bonev committed
164
        # we need to ensure that we can split the channels evenly
165
166
        num_chans = x.shape[-3]

Boris Bonev's avatar
Boris Bonev committed
167
168
        # h and w is split. First we make w local by transposing into channel dim
        if self.comm_size_azimuth > 1:
169
            x = distributed_transpose_azimuth.apply(x, (-3, -1), self.lon_shapes)
Boris Bonev's avatar
Boris Bonev committed
170
171

        # apply real fft in the longitudinal direction: make sure to truncate to nlon
172
        x = 2.0 * torch.pi * torch.fft.rfft(x, n=self.nlon, dim=-1, norm="forward")
Boris Bonev's avatar
Boris Bonev committed
173
174

        # truncate
175
        x = x[..., :self.mmax]
Boris Bonev's avatar
Boris Bonev committed
176
177
178

        # transpose: after this, m is split and c is local
        if self.comm_size_azimuth > 1:
179
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
180
            x = distributed_transpose_azimuth.apply(x, (-1, -3), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
181
182
183

        # transpose: after this, c is split and h is local
        if self.comm_size_polar > 1:
184
            x = distributed_transpose_polar.apply(x, (-3, -2), self.lat_shapes)
Boris Bonev's avatar
Boris Bonev committed
185
186

        # do the Legendre-Gauss quadrature
187
        x = torch.view_as_real(x)
Boris Bonev's avatar
Boris Bonev committed
188
189

        # contraction
190
        xs = torch.einsum('...kmr,mlk->...lmr', x, self.weights.to(x.dtype)).contiguous()
Boris Bonev's avatar
Boris Bonev committed
191

192
193
        # cast to complex
        x = torch.view_as_complex(xs)
Boris Bonev's avatar
Boris Bonev committed
194
195
196

        # transpose: after this, l is split and c is local
        if self.comm_size_polar	> 1:
197
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_polar)
198
199
            x = distributed_transpose_polar.apply(x, (-2, -3), chan_shapes)

200
        return x
Boris Bonev's avatar
Boris Bonev committed
201
202
203
204
205
206
207


class DistributedInverseRealSHT(nn.Module):
    """
    Defines a module for computing the inverse (real-valued) SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.

apaaris's avatar
apaaris committed
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
    Parameters
    ----------
    nlat: int
        Number of latitude points
    nlon: int
        Number of longitude points
    lmax: int
        Maximum spherical harmonic degree
    mmax: int
        Maximum spherical harmonic order
    grid: str
        Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
    norm: str
        Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
    csphase: bool
        Whether to apply the Condon-Shortley phase factor, by default True

    Returns
    -------
    x: torch.Tensor
        Tensor of shape (..., lmax, mmax)

    References
    ----------
Boris Bonev's avatar
Boris Bonev committed
232
233
234
235
    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """

236
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
Boris Bonev's avatar
Boris Bonev committed
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # compute quadrature points
        if self.grid == "legendre-gauss":
            cost, _ = legendre_gauss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
            cost, _ = lobatto_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat-1
        elif self.grid == "equiangular":
            cost, _ = clenshaw_curtiss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        else:
            raise(ValueError("Unknown quadrature mode"))

        # get the comms grid:
        self.comm_size_polar = polar_group_size()
        self.comm_rank_polar = polar_group_rank()
        self.comm_size_azimuth = azimuth_group_size()
        self.comm_rank_azimuth = azimuth_group_rank()

        # apply cosine transform and flip them
Thorsten Kurth's avatar
Thorsten Kurth committed
266
        t = torch.flip(torch.arccos(cost), dims=(0,))
Boris Bonev's avatar
Boris Bonev committed
267
268
269
270

        # determine the dimensions
        self.mmax = mmax or self.nlon // 2 + 1

271
        # compute splits
272
273
274
275
        self.lat_shapes = compute_split_shapes(self.nlat, self.comm_size_polar)
        self.lon_shapes = compute_split_shapes(self.nlon, self.comm_size_azimuth)
        self.l_shapes = compute_split_shapes(self.lmax, self.comm_size_polar)
        self.m_shapes = compute_split_shapes(self.mmax, self.comm_size_azimuth)
Boris Bonev's avatar
Boris Bonev committed
276
277

        # compute legende polynomials
278
        pct = _precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
Boris Bonev's avatar
Boris Bonev committed
279
280

        # split in m
Thorsten Kurth's avatar
Thorsten Kurth committed
281
        pct = split_tensor_along_dim(pct, dim=0, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth].contiguous()
Boris Bonev's avatar
Boris Bonev committed
282
283
284
285
286
287
288
289
290
291
292
293

        # register
        self.register_buffer('pct', pct, persistent=False)

    def extra_repr(self):
        """
        Pretty print module
        """
        return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'

    def forward(self, x: torch.Tensor):

294
295
296
        if x.dim() < 3:
            raise ValueError(f"Expected tensor with at least 3 dimensions but got {x.dim()} instead")

Boris Bonev's avatar
Boris Bonev committed
297
        # we need to ensure that we can split the channels evenly
298
        num_chans = x.shape[-3]
Boris Bonev's avatar
Boris Bonev committed
299
300
301

        # transpose: after that, channels are split, l is local:
        if self.comm_size_polar > 1:
302
            x = distributed_transpose_polar.apply(x, (-3, -2), self.l_shapes)
Boris Bonev's avatar
Boris Bonev committed
303
304

        # Evaluate associated Legendre functions on the output nodes
305
        x = torch.view_as_real(x)
Boris Bonev's avatar
Boris Bonev committed
306
307

        # einsum
308
        xs = torch.einsum('...lmr, mlk->...kmr', x, self.pct.to(x.dtype)).contiguous()
Boris Bonev's avatar
Boris Bonev committed
309

310
311
        # inverse FFT
        x = torch.view_as_complex(xs)
Boris Bonev's avatar
Boris Bonev committed
312
313

        if self.comm_size_polar > 1:
314
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_polar)
315
            x = distributed_transpose_polar.apply(x, (-2, -3), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
316
317
318

        # transpose: after this, channels are split and m is local
        if self.comm_size_azimuth > 1:
319
            x = distributed_transpose_azimuth.apply(x, (-3, -1), self.m_shapes)
Boris Bonev's avatar
Boris Bonev committed
320

321
322
323
324
325
        # set DCT and nyquist frequencies to 0:
        x[..., 0].imag = 0.0
        if (self.nlon % 2 == 0) and (self.nlon // 2 < x.shape[-1]):
            x[..., self.nlon // 2].imag = 0.0
            
Boris Bonev's avatar
Boris Bonev committed
326
        # apply the inverse (real) FFT
327
        x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
Boris Bonev's avatar
Boris Bonev committed
328
329
330

        # transpose: after this, m is split and channels are local
        if self.comm_size_azimuth > 1:
331
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
332
            x = distributed_transpose_azimuth.apply(x, (-1, -3), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
333

334
        return x
Boris Bonev's avatar
Boris Bonev committed
335
336
337
338
339
340
341
342


class DistributedRealVectorSHT(nn.Module):
    """
    Defines a module for computing the forward (real) vector SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
    The SHT is applied to the last three dimensions of the input.

apaaris's avatar
apaaris committed
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
    Parameters
    ----------
    nlat: int
        Number of latitude points
    nlon: int
        Number of longitude points  
    lmax: int
        Maximum spherical harmonic degree
    mmax: int
        Maximum spherical harmonic order
    grid: str
        Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
    norm: str
        Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
    csphase: bool
        Whether to apply the Condon-Shortley phase factor, by default True

    Returns
    -------
    x: torch.Tensor
        Tensor of shape (..., lmax, mmax)

    References
    ----------
Boris Bonev's avatar
Boris Bonev committed
367
368
369
370
    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """

371
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
Boris Bonev's avatar
Boris Bonev committed
372
373
374
        """
        Initializes the vector SHT Layer, precomputing the necessary quadrature weights

apaaris's avatar
apaaris committed
375
376
377
378
379
380
381
382
383
384
385
386
        Parameters
        ----------
        nlat: int
            Number of latitude points
        nlon: int
            Number of longitude points
        grid: str
            Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
        norm: str
            Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
        csphase: bool
            Whether to apply the Condon-Shortley phase factor, by default True
Boris Bonev's avatar
Boris Bonev committed
387
388
389
390
391
392
393
394
395
396
397
398
        """

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # compute quadrature points
        if self.grid == "legendre-gauss":
Thorsten Kurth's avatar
Thorsten Kurth committed
399
            cost, weights = legendre_gauss_weights(nlat, -1, 1)
Boris Bonev's avatar
Boris Bonev committed
400
401
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
Thorsten Kurth's avatar
Thorsten Kurth committed
402
            cost, weights = lobatto_weights(nlat, -1, 1)
Boris Bonev's avatar
Boris Bonev committed
403
404
            self.lmax = lmax or self.nlat-1
        elif self.grid == "equiangular":
Thorsten Kurth's avatar
Thorsten Kurth committed
405
            cost, weights = clenshaw_curtiss_weights(nlat, -1, 1)
Boris Bonev's avatar
Boris Bonev committed
406
407
408
409
410
411
412
413
414
415
416
417
            # cost, w = fejer2_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        else:
            raise(ValueError("Unknown quadrature mode"))

        # get the comms grid:
        self.comm_size_polar = polar_group_size()
        self.comm_rank_polar = polar_group_rank()
        self.comm_size_azimuth = azimuth_group_size()
        self.comm_rank_azimuth = azimuth_group_rank()

        # apply cosine transform and flip them
Thorsten Kurth's avatar
Thorsten Kurth committed
418
        tq = torch.flip(torch.arccos(cost), dims=(0,))
Boris Bonev's avatar
Boris Bonev committed
419
420
421
422

        # determine the dimensions
        self.mmax = mmax or self.nlon // 2 + 1

423
424
425
426
427
        # compute splits
        self.lat_shapes = compute_split_shapes(self.nlat, self.comm_size_polar)
        self.lon_shapes = compute_split_shapes(self.nlon, self.comm_size_azimuth)
        self.l_shapes = compute_split_shapes(self.lmax, self.comm_size_polar)
        self.m_shapes = compute_split_shapes(self.mmax, self.comm_size_azimuth)
Boris Bonev's avatar
Boris Bonev committed
428

429
        # compute weights
430
        dpct = _precompute_dlegpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
Boris Bonev's avatar
Boris Bonev committed
431
432
433
434
435
436
437
438
439
440

        # combine integration weights, normalization factor in to one:
        l = torch.arange(0, self.lmax)
        norm_factor = 1. / l / (l+1)
        norm_factor[0] = 1.
        weights = torch.einsum('dmlk,k,l->dmlk', dpct, weights, norm_factor)
        # since the second component is imaginary, we need to take complex conjugation into account
        weights[1] = -1 * weights[1]

        # we need to split in m, pad before:
Thorsten Kurth's avatar
Thorsten Kurth committed
441
        weights = split_tensor_along_dim(weights, dim=1, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth].contiguous()
Boris Bonev's avatar
Boris Bonev committed
442
443
444
445

        # remember quadrature weights
        self.register_buffer('weights', weights, persistent=False)

446

Boris Bonev's avatar
Boris Bonev committed
447
448
449
450
451
452
453
454
    def extra_repr(self):
        """
        Pretty print module
        """
        return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'

    def forward(self, x: torch.Tensor):

455
456
457
        if x.dim() < 4:
            raise ValueError(f"Expected tensor with at least 4 dimensions but got {x.dim()} instead")

458
        # we need to ensure that we can split the channels evenly
459
        num_chans = x.shape[-4]
Boris Bonev's avatar
Boris Bonev committed
460
461
462

        # h and w is split. First we make w local by transposing into channel dim
        if self.comm_size_azimuth > 1:
463
            x = distributed_transpose_azimuth.apply(x, (-4, -1), self.lon_shapes)
Boris Bonev's avatar
Boris Bonev committed
464
465

        # apply real fft in the longitudinal direction: make sure to truncate to nlon
466
        x = 2.0 * torch.pi * torch.fft.rfft(x, n=self.nlon, dim=-1, norm="forward")
Boris Bonev's avatar
Boris Bonev committed
467
468

        # truncate
469
        x = x[..., :self.mmax]
Boris Bonev's avatar
Boris Bonev committed
470
471
472

        # transpose: after this, m is split and c is local
        if self.comm_size_azimuth > 1:
473
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
474
            x = distributed_transpose_azimuth.apply(x, (-1, -4), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
475
476
477

        # transpose: after this, c is split and h is local
        if self.comm_size_polar > 1:
478
            x = distributed_transpose_polar.apply(x, (-4, -2), self.lat_shapes)
Boris Bonev's avatar
Boris Bonev committed
479
480

        # do the Legendre-Gauss quadrature
481
        x = torch.view_as_real(x)
Boris Bonev's avatar
Boris Bonev committed
482
483

        # create output array
484
        xs = torch.zeros_like(x, dtype=x.dtype, device=x.device)
Boris Bonev's avatar
Boris Bonev committed
485
486
487

        # contraction - spheroidal component
        # real component
488
489
490
491
492
        xs[..., 0, :, :, 0] =   torch.einsum('...km,mlk->...lm', x[..., 0, :, :, 0], self.weights[0].to(xs.dtype)) \
                              - torch.einsum('...km,mlk->...lm', x[..., 1, :, :, 1], self.weights[1].to(xs.dtype))
        # imag component
        xs[..., 0, :, :, 1] =   torch.einsum('...km,mlk->...lm', x[..., 0, :, :, 1], self.weights[0].to(xs.dtype)) \
                              + torch.einsum('...km,mlk->...lm', x[..., 1, :, :, 0], self.weights[1].to(xs.dtype))
Boris Bonev's avatar
Boris Bonev committed
493
494
495

        # contraction - toroidal component
        # real component
496
497
        xs[..., 1, :, :, 0] = - torch.einsum('...km,mlk->...lm', x[..., 0, :, :, 1], self.weights[1].to(xs.dtype)) \
                              - torch.einsum('...km,mlk->...lm', x[..., 1, :, :, 0], self.weights[0].to(xs.dtype))
Boris Bonev's avatar
Boris Bonev committed
498
        # imag component
499
500
        xs[..., 1, :, :, 1] =   torch.einsum('...km,mlk->...lm', x[..., 0, :, :, 0], self.weights[1].to(xs.dtype)) \
                              - torch.einsum('...km,mlk->...lm', x[..., 1, :, :, 1], self.weights[0].to(xs.dtype))
Boris Bonev's avatar
Boris Bonev committed
501
502

        # pad if required
503
        x = torch.view_as_complex(xs)
Boris Bonev's avatar
Boris Bonev committed
504
505
506

        # transpose: after this, l is split and c is local
        if self.comm_size_polar > 1:
507
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_polar)
508
            x = distributed_transpose_polar.apply(x, (-2, -4), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
509

510
        return x
Boris Bonev's avatar
Boris Bonev committed
511
512
513
514
515
516
517


class DistributedInverseRealVectorSHT(nn.Module):
    """
    Defines a module for computing the inverse (real-valued) vector SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.

apaaris's avatar
apaaris committed
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
    Parameters
    ----------
    nlat: int
        Number of latitude points
    nlon: int
        Number of longitude points
    lmax: int
        Maximum spherical harmonic degree
    mmax: int
        Maximum spherical harmonic order
    grid: str
        Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
    norm: str
        Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
    csphase: bool
        Whether to apply the Condon-Shortley phase factor, by default True

    Returns
    -------
    x: torch.Tensor
        Tensor of shape (..., lmax, mmax)

    References
    ----------
Boris Bonev's avatar
Boris Bonev committed
542
543
544
    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """
545
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
Boris Bonev's avatar
Boris Bonev committed
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # compute quadrature points
        if self.grid == "legendre-gauss":
            cost, _ = legendre_gauss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
            cost, _ = lobatto_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat-1
        elif self.grid == "equiangular":
            cost, _ = clenshaw_curtiss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        else:
            raise(ValueError("Unknown quadrature mode"))

        self.comm_size_polar = polar_group_size()
        self.comm_rank_polar = polar_group_rank()
        self.comm_size_azimuth = azimuth_group_size()
        self.comm_rank_azimuth = azimuth_group_rank()

        # apply cosine transform and flip them
Thorsten Kurth's avatar
Thorsten Kurth committed
574
        t = torch.flip(torch.arccos(cost), dims=(0,))
Boris Bonev's avatar
Boris Bonev committed
575
576
577
578

        # determine the dimensions
        self.mmax = mmax or self.nlon // 2 + 1

579
        # compute splits
580
581
582
583
        self.lat_shapes = compute_split_shapes(self.nlat, self.comm_size_polar)
        self.lon_shapes = compute_split_shapes(self.nlon, self.comm_size_azimuth)
        self.l_shapes = compute_split_shapes(self.lmax, self.comm_size_polar)
        self.m_shapes = compute_split_shapes(self.mmax, self.comm_size_azimuth)
Boris Bonev's avatar
Boris Bonev committed
584
585

        # compute legende polynomials
586
        dpct = _precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
Boris Bonev's avatar
Boris Bonev committed
587
588

        # split in m
Thorsten Kurth's avatar
Thorsten Kurth committed
589
        dpct = split_tensor_along_dim(dpct, dim=1, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth].contiguous()
Boris Bonev's avatar
Boris Bonev committed
590
591
592
593
594
595
596
597
598
599
600
601

        # register buffer
        self.register_buffer('dpct', dpct, persistent=False)

    def extra_repr(self):
        """
        Pretty print module
        """
        return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'

    def forward(self, x: torch.Tensor):

602
603
604
        if x.dim() < 4:
            raise ValueError(f"Expected tensor with at least 4 dimensions but got {x.dim()} instead")

605
        # store num channels
606
        num_chans = x.shape[-4]
Boris Bonev's avatar
Boris Bonev committed
607
608
609

        # transpose: after that, channels are split, l is local:
        if self.comm_size_polar > 1:
610
            x = distributed_transpose_polar.apply(x, (-4, -2), self.l_shapes)
Boris Bonev's avatar
Boris Bonev committed
611
612

        # Evaluate associated Legendre functions on the output nodes
613
        x = torch.view_as_real(x)
Boris Bonev's avatar
Boris Bonev committed
614
615
616

        # contraction - spheroidal component
        # real component
617
618
        srl =   torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 0], self.dpct[0].to(x.dtype)) \
              - torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 1], self.dpct[1].to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
619
        # imag component
620
621
        sim =   torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 1], self.dpct[0].to(x.dtype)) \
              + torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 0], self.dpct[1].to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
622
623
624

        # contraction - toroidal component
        # real component
625
626
        trl = - torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 1], self.dpct[1].to(x.dtype)) \
              - torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 0], self.dpct[0].to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
627
        # imag component
628
629
        tim =   torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 0], self.dpct[1].to(x.dtype)) \
              - torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 1], self.dpct[0].to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
630
631
632
633
634
635
636
637
638
639

        # reassemble
        s = torch.stack((srl, sim), -1)
        t = torch.stack((trl, tim), -1)
        xs = torch.stack((s, t), -4)

        # convert to complex
        x = torch.view_as_complex(xs)

        if self.comm_size_polar > 1:
640
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_polar)
641
            x = distributed_transpose_polar.apply(x, (-2, -4), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
642
643
644

        # transpose: after this, channels are split and m is local
        if self.comm_size_azimuth > 1:
645
            x = distributed_transpose_azimuth.apply(x, (-4, -1), self.m_shapes)
Boris Bonev's avatar
Boris Bonev committed
646

647
648
649
650
651
        # set DCT and nyquist frequencies to zero
        x[..., 0].imag = 0.0
        if (self.nlon % 2 == 0) and (self.nlon // 2 < x.shape[-1]):
            x[..., self.nlon // 2].imag = 0.0

Boris Bonev's avatar
Boris Bonev committed
652
653
654
655
656
        # apply the inverse (real) FFT
        x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")

        # transpose: after this, m is split and channels are local
        if self.comm_size_azimuth > 1:
657
            chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
658
            x = distributed_transpose_azimuth.apply(x, (-1, -4), chan_shapes)
Boris Bonev's avatar
Boris Bonev committed
659

660
        return x