sht.py 20.5 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
5
#
Boris Bonev's avatar
Boris Bonev committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import torch
import torch.nn as nn
import torch.fft

36
37
from torch_harmonics.quadrature import legendre_gauss_weights, lobatto_weights, clenshaw_curtiss_weights
from torch_harmonics.legendre import _precompute_legpoly, _precompute_dlegpoly
Boris Bonev's avatar
Boris Bonev committed
38
39
40


class RealSHT(nn.Module):
41
    r"""
Boris Bonev's avatar
Boris Bonev committed
42
43
44
45
    Defines a module for computing the forward (real-valued) SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
    The SHT is applied to the last two dimensions of the input

apaaris's avatar
apaaris committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
    Parameters
    -----------
    nlat: int
        Number of latitude points
    nlon: int
        Number of longitude points
    lmax: int
        Maximum spherical harmonic degree
    mmax: int
        Maximum spherical harmonic order
    grid: str
        Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
    norm: str
        Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
    csphase: bool
        Whether to apply the Condon-Shortley phase factor, by default True

    Returns
    -------
    x: torch.Tensor
        Tensor of shape (..., lmax, mmax)

    References
    ----------
Boris Bonev's avatar
Boris Bonev committed
70
71
72
73
    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """

74
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
75
        r"""
Boris Bonev's avatar
Boris Bonev committed
76
77
        Initializes the SHT Layer, precomputing the necessary quadrature weights

apaaris's avatar
apaaris committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
        Parameters
        -----------
        nlat: int
            Number of latitude points
        nlon: int
            Number of longitude points
        lmax: int
            Maximum spherical harmonic degree
        mmax: int
            Maximum spherical harmonic order
        grid: str
            Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
        norm: str
            Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
        csphase: bool
            Whether to apply the Condon-Shortley phase factor, by default True

        Returns
        -------
        x: torch.Tensor
            Tensor of shape (..., lmax, mmax)
Boris Bonev's avatar
Boris Bonev committed
99
100
101
102
103
104
105
106
107
108
109
110
        """

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # TODO: include assertions regarding the dimensions

111
        # compute quadrature points and lmax based on the exactness of the quadrature
Boris Bonev's avatar
Boris Bonev committed
112
        if self.grid == "legendre-gauss":
Thorsten Kurth's avatar
Thorsten Kurth committed
113
            cost, weights = legendre_gauss_weights(nlat, -1, 1)
114
115
            # maximum polynomial degree for Gauss Legendre is 2 * nlat - 1 >= 2 * lmax
            # and therefore lmax = nlat - 1 (inclusive)
Boris Bonev's avatar
Boris Bonev committed
116
117
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
Thorsten Kurth's avatar
Thorsten Kurth committed
118
            cost, weights = lobatto_weights(nlat, -1, 1)
119
120
121
            # maximum polynomial degree for Gauss Legendre is 2 * nlat - 3 >= 2 * lmax
            # and therefore lmax = nlat - 2 (inclusive)
            self.lmax = lmax or self.nlat - 1
Boris Bonev's avatar
Boris Bonev committed
122
        elif self.grid == "equiangular":
Thorsten Kurth's avatar
Thorsten Kurth committed
123
            cost, weights = clenshaw_curtiss_weights(nlat, -1, 1)
124
125
126
            # in principle, Clenshaw-Curtiss quadrature is only exact up to polynomial degrees of nlat
            # however, we observe that the quadrature is remarkably accurate for higher degress. This is why we do not
            # choose a lower lmax for now.
Boris Bonev's avatar
Boris Bonev committed
127
128
            self.lmax = lmax or self.nlat
        else:
129
            raise (ValueError("Unknown quadrature mode"))
Boris Bonev's avatar
Boris Bonev committed
130
131

        # apply cosine transform and flip them
Thorsten Kurth's avatar
Thorsten Kurth committed
132
        tq = torch.flip(torch.arccos(cost), dims=(0,))
Boris Bonev's avatar
Boris Bonev committed
133

134
        # determine the dimensions
Boris Bonev's avatar
Boris Bonev committed
135
136
137
        self.mmax = mmax or self.nlon // 2 + 1

        # combine quadrature weights with the legendre weights
138
        pct = _precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
Thorsten Kurth's avatar
Thorsten Kurth committed
139
        weights = torch.einsum("mlk,k->mlk", pct, weights).contiguous()
Boris Bonev's avatar
Boris Bonev committed
140
141

        # remember quadrature weights
142
        self.register_buffer("weights", weights, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
143
144

    def extra_repr(self):
145
        r"""
Boris Bonev's avatar
Boris Bonev committed
146
147
        Pretty print module
        """
148
        return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}"
Boris Bonev's avatar
Boris Bonev committed
149
150
151

    def forward(self, x: torch.Tensor):

152
153
154
        if x.dim() < 2:
            raise ValueError(f"Expected tensor with at least 2 dimensions but got {x.dim()} instead")

155
156
        assert x.shape[-2] == self.nlat
        assert x.shape[-1] == self.nlon
Boris Bonev's avatar
Boris Bonev committed
157
158
159

        # apply real fft in the longitudinal direction
        x = 2.0 * torch.pi * torch.fft.rfft(x, dim=-1, norm="forward")
160

Boris Bonev's avatar
Boris Bonev committed
161
162
        # do the Legendre-Gauss quadrature
        x = torch.view_as_real(x)
163

Boris Bonev's avatar
Boris Bonev committed
164
165
        # distributed contraction: fork
        out_shape = list(x.size())
Boris Bonev's avatar
Boris Bonev committed
166
        out_shape[-3] = self.lmax
Boris Bonev's avatar
Boris Bonev committed
167
168
        out_shape[-2] = self.mmax
        xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device)
169

Boris Bonev's avatar
Boris Bonev committed
170
        # contraction
171
172
        xout[..., 0] = torch.einsum("...km,mlk->...lm", x[..., : self.mmax, 0], self.weights.to(x.dtype))
        xout[..., 1] = torch.einsum("...km,mlk->...lm", x[..., : self.mmax, 1], self.weights.to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
173
        x = torch.view_as_complex(xout)
174

Boris Bonev's avatar
Boris Bonev committed
175
176
        return x

177

Boris Bonev's avatar
Boris Bonev committed
178
class InverseRealSHT(nn.Module):
179
    r"""
Boris Bonev's avatar
Boris Bonev committed
180
181
182
    Defines a module for computing the inverse (real-valued) SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.

apaaris's avatar
apaaris committed
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
    Parameters
    -----------
    nlat: int
        Number of latitude points
    nlon: int
        Number of longitude points
    lmax: int
        Maximum spherical harmonic degree
    mmax: int
        Maximum spherical harmonic order
    grid: str
        Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
    norm: str
        Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
    csphase: bool
        Whether to apply the Condon-Shortley phase factor, by default True

    Raises
    ------
    ValueError: If the grid type is unknown

    Returns
    -------
    x: torch.Tensor
        Tensor of shape (..., lmax, mmax)

    References
    ----------
Boris Bonev's avatar
Boris Bonev committed
211
212
213
214
    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """

215
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
Boris Bonev's avatar
Boris Bonev committed
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # compute quadrature points
        if self.grid == "legendre-gauss":
            cost, _ = legendre_gauss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
            cost, _ = lobatto_weights(nlat, -1, 1)
231
            self.lmax = lmax or self.nlat - 1
Boris Bonev's avatar
Boris Bonev committed
232
233
234
235
        elif self.grid == "equiangular":
            cost, _ = clenshaw_curtiss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        else:
236
            raise (ValueError("Unknown quadrature mode"))
Boris Bonev's avatar
Boris Bonev committed
237
238

        # apply cosine transform and flip them
Thorsten Kurth's avatar
Thorsten Kurth committed
239
        t = torch.flip(torch.arccos(cost), dims=(0,))
Boris Bonev's avatar
Boris Bonev committed
240

241
        # determine the dimensions
Boris Bonev's avatar
Boris Bonev committed
242
243
        self.mmax = mmax or self.nlon // 2 + 1

244
        pct = _precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
Boris Bonev's avatar
Boris Bonev committed
245

Boris Bonev's avatar
Boris Bonev committed
246
        # register buffer
247
        self.register_buffer("pct", pct, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
248
249

    def extra_repr(self):
250
        r"""
Boris Bonev's avatar
Boris Bonev committed
251
252
        Pretty print module
        """
253
        return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}"
Boris Bonev's avatar
Boris Bonev committed
254
255
256

    def forward(self, x: torch.Tensor):

257
258
259
        if len(x.shape) < 2:
            raise ValueError(f"Expected tensor with at least 2 dimensions but got {len(x.shape)} instead")

260
261
        assert x.shape[-2] == self.lmax
        assert x.shape[-1] == self.mmax
262

Boris Bonev's avatar
Boris Bonev committed
263
264
        # Evaluate associated Legendre functions on the output nodes
        x = torch.view_as_real(x)
265
        xs = torch.einsum("...lmr, mlk->...kmr", x, self.pct.to(x.dtype)).contiguous()
Boris Bonev's avatar
Boris Bonev committed
266
267
268

        # apply the inverse (real) FFT
        x = torch.view_as_complex(xs)
269
270
271
272
273
274
275
276

        # ensure that imaginary part of 0 and nyquist components are zero
        # this is important because not all backend algorithms provided through the
        # irfft interface ensure that
        x[..., 0].imag = 0.0
        if (self.nlon % 2 == 0) and (self.nlon // 2 < self.mmax):
            x[..., self.nlon // 2].imag = 0.0
        
Boris Bonev's avatar
Boris Bonev committed
277
278
279
280
281
282
        x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")

        return x


class RealVectorSHT(nn.Module):
283
    r"""
Boris Bonev's avatar
Boris Bonev committed
284
285
286
287
    Defines a module for computing the forward (real) vector SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
    The SHT is applied to the last three dimensions of the input.

apaaris's avatar
apaaris committed
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
    Parameters
    -----------
    nlat: int
        Number of latitude points
    nlon: int
        Number of longitude points
    lmax: int
        Maximum spherical harmonic degree
    mmax: int
        Maximum spherical harmonic order
    grid: str
        Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
    norm: str
        Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
    csphase: bool
        Whether to apply the Condon-Shortley phase factor, by default True

    Returns
    -------
    x: torch.Tensor
        Tensor of shape (..., lmax, mmax)

    References
    ----------
Boris Bonev's avatar
Boris Bonev committed
312
313
314
315
    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """

316
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
317
        r"""
Boris Bonev's avatar
Boris Bonev committed
318
319
        Initializes the vector SHT Layer, precomputing the necessary quadrature weights

apaaris's avatar
apaaris committed
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
        Parameters
        -----------
        nlat: int
            Number of latitude points
        nlon: int
            Number of longitude points
        lmax: int
            Maximum spherical harmonic degree
        mmax: int
            Maximum spherical harmonic order
        grid: str
            Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
        norm: str
            Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
        csphase: bool
            Whether to apply the Condon-Shortley phase factor, by default True

        Returns
        -------
        x: torch.Tensor
            Tensor of shape (..., lmax, mmax)
Boris Bonev's avatar
Boris Bonev committed
341
342
343
344
345
346
347
348
349
350
351
352
        """

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # compute quadrature points
        if self.grid == "legendre-gauss":
Thorsten Kurth's avatar
Thorsten Kurth committed
353
            cost, weights = legendre_gauss_weights(nlat, -1, 1)
Boris Bonev's avatar
Boris Bonev committed
354
355
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
Thorsten Kurth's avatar
Thorsten Kurth committed
356
            cost, weights = lobatto_weights(nlat, -1, 1)
357
            self.lmax = lmax or self.nlat - 1
Boris Bonev's avatar
Boris Bonev committed
358
        elif self.grid == "equiangular":
Thorsten Kurth's avatar
Thorsten Kurth committed
359
            cost, weights = clenshaw_curtiss_weights(nlat, -1, 1)
Boris Bonev's avatar
Boris Bonev committed
360
361
            self.lmax = lmax or self.nlat
        else:
362
            raise (ValueError("Unknown quadrature mode"))
Boris Bonev's avatar
Boris Bonev committed
363
364

        # apply cosine transform and flip them
Thorsten Kurth's avatar
Thorsten Kurth committed
365
        tq = torch.flip(torch.arccos(cost), dims=(0,))
Boris Bonev's avatar
Boris Bonev committed
366

367
        # determine the dimensions
Boris Bonev's avatar
Boris Bonev committed
368
        self.mmax = mmax or self.nlon // 2 + 1
369
        dpct = _precompute_dlegpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
370

Boris Bonev's avatar
Boris Bonev committed
371
372
        # combine integration weights, normalization factor in to one:
        l = torch.arange(0, self.lmax)
373
374
        norm_factor = 1.0 / l / (l + 1)
        norm_factor[0] = 1.0
Thorsten Kurth's avatar
Thorsten Kurth committed
375
        weights = torch.einsum("dmlk,k,l->dmlk", dpct, weights, norm_factor).contiguous()
Boris Bonev's avatar
Boris Bonev committed
376
377
378
379
        # since the second component is imaginary, we need to take complex conjugation into account
        weights[1] = -1 * weights[1]

        # remember quadrature weights
380
        self.register_buffer("weights", weights, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
381
382

    def extra_repr(self):
383
        r"""
Boris Bonev's avatar
Boris Bonev committed
384
385
        Pretty print module
        """
386
        return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}"
Boris Bonev's avatar
Boris Bonev committed
387
388
389

    def forward(self, x: torch.Tensor):

390
391
392
        if x.dim() < 3:
            raise ValueError(f"Expected tensor with at least 3 dimensions but got {x.dim()} instead")

393
394
        assert x.shape[-2] == self.nlat
        assert x.shape[-1] == self.nlon
Boris Bonev's avatar
Boris Bonev committed
395
396
397

        # apply real fft in the longitudinal direction
        x = 2.0 * torch.pi * torch.fft.rfft(x, dim=-1, norm="forward")
398

Boris Bonev's avatar
Boris Bonev committed
399
400
        # do the Legendre-Gauss quadrature
        x = torch.view_as_real(x)
401

Boris Bonev's avatar
Boris Bonev committed
402
403
        # distributed contraction: fork
        out_shape = list(x.size())
Boris Bonev's avatar
Boris Bonev committed
404
        out_shape[-3] = self.lmax
Boris Bonev's avatar
Boris Bonev committed
405
406
407
408
409
        out_shape[-2] = self.mmax
        xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device)

        # contraction - spheroidal component
        # real component
410
411
412
        xout[..., 0, :, :, 0] = torch.einsum("...km,mlk->...lm", x[..., 0, :, : self.mmax, 0], self.weights[0].to(x.dtype)) - torch.einsum(
            "...km,mlk->...lm", x[..., 1, :, : self.mmax, 1], self.weights[1].to(x.dtype)
        )
Boris Bonev's avatar
Boris Bonev committed
413
414

        # iamg component
415
416
417
        xout[..., 0, :, :, 1] = torch.einsum("...km,mlk->...lm", x[..., 0, :, : self.mmax, 1], self.weights[0].to(x.dtype)) + torch.einsum(
            "...km,mlk->...lm", x[..., 1, :, : self.mmax, 0], self.weights[1].to(x.dtype)
        )
Boris Bonev's avatar
Boris Bonev committed
418
419
420

        # contraction - toroidal component
        # real component
421
422
423
        xout[..., 1, :, :, 0] = -torch.einsum("...km,mlk->...lm", x[..., 0, :, : self.mmax, 1], self.weights[1].to(x.dtype)) - torch.einsum(
            "...km,mlk->...lm", x[..., 1, :, : self.mmax, 0], self.weights[0].to(x.dtype)
        )
Boris Bonev's avatar
Boris Bonev committed
424
        # imag component
425
426
427
        xout[..., 1, :, :, 1] = torch.einsum("...km,mlk->...lm", x[..., 0, :, : self.mmax, 0], self.weights[1].to(x.dtype)) - torch.einsum(
            "...km,mlk->...lm", x[..., 1, :, : self.mmax, 1], self.weights[0].to(x.dtype)
        )
Boris Bonev's avatar
Boris Bonev committed
428
429
430
431
432

        return torch.view_as_complex(xout)


class InverseRealVectorSHT(nn.Module):
433
    r"""
Boris Bonev's avatar
Boris Bonev committed
434
435
    Defines a module for computing the inverse (real-valued) vector SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
436

apaaris's avatar
apaaris committed
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
    Parameters
    -----------
    nlat: int
        Number of latitude points
    nlon: int
        Number of longitude points
    lmax: int
        Maximum spherical harmonic degree
    mmax: int
        Maximum spherical harmonic order
    grid: str
        Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
    norm: str
        Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
    csphase: bool
        Whether to apply the Condon-Shortley phase factor, by default True

    Returns
    -------
    x: torch.Tensor
        Tensor of shape (..., lmax, mmax)

    References
    ----------
Boris Bonev's avatar
Boris Bonev committed
461
462
463
    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """
464

465
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
Boris Bonev's avatar
Boris Bonev committed
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # compute quadrature points
        if self.grid == "legendre-gauss":
            cost, _ = legendre_gauss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
            cost, _ = lobatto_weights(nlat, -1, 1)
481
            self.lmax = lmax or self.nlat - 1
Boris Bonev's avatar
Boris Bonev committed
482
483
484
485
        elif self.grid == "equiangular":
            cost, _ = clenshaw_curtiss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        else:
486
            raise (ValueError("Unknown quadrature mode"))
Boris Bonev's avatar
Boris Bonev committed
487
488

        # apply cosine transform and flip them
Thorsten Kurth's avatar
Thorsten Kurth committed
489
        t = torch.flip(torch.arccos(cost), dims=(0,))
Boris Bonev's avatar
Boris Bonev committed
490

491
        # determine the dimensions
Boris Bonev's avatar
Boris Bonev committed
492
493
        self.mmax = mmax or self.nlon // 2 + 1

494
        dpct = _precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
Boris Bonev's avatar
Boris Bonev committed
495

Boris Bonev's avatar
Boris Bonev committed
496
        # register weights
497
        self.register_buffer("dpct", dpct, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
498
499

    def extra_repr(self):
500
        r"""
Boris Bonev's avatar
Boris Bonev committed
501
502
        Pretty print module
        """
503
        return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}"
Boris Bonev's avatar
Boris Bonev committed
504
505
506

    def forward(self, x: torch.Tensor):

507
508
509
        if x.dim() < 3:
            raise ValueError(f"Expected tensor with at least 3 dimensions but got {x.dim()} instead")

510
511
        assert x.shape[-2] == self.lmax
        assert x.shape[-1] == self.mmax
512

Boris Bonev's avatar
Boris Bonev committed
513
514
515
516
517
        # Evaluate associated Legendre functions on the output nodes
        x = torch.view_as_real(x)

        # contraction - spheroidal component
        # real component
518
        srl = torch.einsum("...lm,mlk->...km", x[..., 0, :, :, 0], self.dpct[0].to(x.dtype)) - torch.einsum("...lm,mlk->...km", x[..., 1, :, :, 1], self.dpct[1].to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
519
        # iamg component
520
        sim = torch.einsum("...lm,mlk->...km", x[..., 0, :, :, 1], self.dpct[0].to(x.dtype)) + torch.einsum("...lm,mlk->...km", x[..., 1, :, :, 0], self.dpct[1].to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
521
522
523

        # contraction - toroidal component
        # real component
524
        trl = -torch.einsum("...lm,mlk->...km", x[..., 0, :, :, 1], self.dpct[1].to(x.dtype)) - torch.einsum("...lm,mlk->...km", x[..., 1, :, :, 0], self.dpct[0].to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
525
        # imag component
526
        tim = torch.einsum("...lm,mlk->...km", x[..., 0, :, :, 0], self.dpct[1].to(x.dtype)) - torch.einsum("...lm,mlk->...km", x[..., 1, :, :, 1], self.dpct[0].to(x.dtype))
527

Boris Bonev's avatar
Boris Bonev committed
528
529
530
531
532
533
534
        # reassemble
        s = torch.stack((srl, sim), -1)
        t = torch.stack((trl, tim), -1)
        xs = torch.stack((s, t), -4)

        # apply the inverse (real) FFT
        x = torch.view_as_complex(xs)
535
536
537
538
539
540
541
542

        # ensure that imaginary part of 0 and nyquist components are zero
        # this is important because not all backend algorithms provided through the
        # irfft interface ensure that
        x[..., 0].imag = 0.0
        if (self.nlon % 2 == 0) and (self.nlon // 2 < self.mmax):
            x[..., self.nlon // 2].imag = 0.0
        
Boris Bonev's avatar
Boris Bonev committed
543
544
545
        x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")

        return x