sht.py 18.7 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
5
#
Boris Bonev's avatar
Boris Bonev committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import torch
import torch.nn as nn
import torch.fft

36
37
from torch_harmonics.quadrature import legendre_gauss_weights, lobatto_weights, clenshaw_curtiss_weights
from torch_harmonics.legendre import _precompute_legpoly, _precompute_dlegpoly
Boris Bonev's avatar
Boris Bonev committed
38
39
40


class RealSHT(nn.Module):
41
    r"""
Boris Bonev's avatar
Boris Bonev committed
42
43
44
45
    Defines a module for computing the forward (real-valued) SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
    The SHT is applied to the last two dimensions of the input

apaaris's avatar
apaaris committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
    Parameters
    -----------
    nlat: int
        Number of latitude points
    nlon: int
        Number of longitude points
    lmax: int
        Maximum spherical harmonic degree
    mmax: int
        Maximum spherical harmonic order
    grid: str
        Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
    norm: str
        Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
    csphase: bool
        Whether to apply the Condon-Shortley phase factor, by default True

    Returns
    -------
    x: torch.Tensor
        Tensor of shape (..., lmax, mmax)

    References
    ----------
Boris Bonev's avatar
Boris Bonev committed
70
71
72
73
    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """

74
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
75
        
Boris Bonev's avatar
Boris Bonev committed
76
77
78
79
80
81
82
83
84
85
86

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # TODO: include assertions regarding the dimensions

87
        # compute quadrature points and lmax based on the exactness of the quadrature
Boris Bonev's avatar
Boris Bonev committed
88
        if self.grid == "legendre-gauss":
Thorsten Kurth's avatar
Thorsten Kurth committed
89
            cost, weights = legendre_gauss_weights(nlat, -1, 1)
90
91
            # maximum polynomial degree for Gauss Legendre is 2 * nlat - 1 >= 2 * lmax
            # and therefore lmax = nlat - 1 (inclusive)
Boris Bonev's avatar
Boris Bonev committed
92
93
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
Thorsten Kurth's avatar
Thorsten Kurth committed
94
            cost, weights = lobatto_weights(nlat, -1, 1)
95
96
97
            # maximum polynomial degree for Gauss Legendre is 2 * nlat - 3 >= 2 * lmax
            # and therefore lmax = nlat - 2 (inclusive)
            self.lmax = lmax or self.nlat - 1
Boris Bonev's avatar
Boris Bonev committed
98
        elif self.grid == "equiangular":
Thorsten Kurth's avatar
Thorsten Kurth committed
99
            cost, weights = clenshaw_curtiss_weights(nlat, -1, 1)
100
101
102
            # in principle, Clenshaw-Curtiss quadrature is only exact up to polynomial degrees of nlat
            # however, we observe that the quadrature is remarkably accurate for higher degress. This is why we do not
            # choose a lower lmax for now.
Boris Bonev's avatar
Boris Bonev committed
103
104
            self.lmax = lmax or self.nlat
        else:
105
            raise (ValueError("Unknown quadrature mode"))
Boris Bonev's avatar
Boris Bonev committed
106
107

        # apply cosine transform and flip them
Thorsten Kurth's avatar
Thorsten Kurth committed
108
        tq = torch.flip(torch.arccos(cost), dims=(0,))
Boris Bonev's avatar
Boris Bonev committed
109

110
        # determine the dimensions
Boris Bonev's avatar
Boris Bonev committed
111
112
113
        self.mmax = mmax or self.nlon // 2 + 1

        # combine quadrature weights with the legendre weights
114
        pct = _precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
Thorsten Kurth's avatar
Thorsten Kurth committed
115
        weights = torch.einsum("mlk,k->mlk", pct, weights).contiguous()
Boris Bonev's avatar
Boris Bonev committed
116
117

        # remember quadrature weights
118
        self.register_buffer("weights", weights, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
119
120

    def extra_repr(self):
121
        return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}"
Boris Bonev's avatar
Boris Bonev committed
122
123
124

    def forward(self, x: torch.Tensor):

125
126
127
        if x.dim() < 2:
            raise ValueError(f"Expected tensor with at least 2 dimensions but got {x.dim()} instead")

128
129
        assert x.shape[-2] == self.nlat
        assert x.shape[-1] == self.nlon
Boris Bonev's avatar
Boris Bonev committed
130
131
132

        # apply real fft in the longitudinal direction
        x = 2.0 * torch.pi * torch.fft.rfft(x, dim=-1, norm="forward")
133

Boris Bonev's avatar
Boris Bonev committed
134
135
        # do the Legendre-Gauss quadrature
        x = torch.view_as_real(x)
136

Boris Bonev's avatar
Boris Bonev committed
137
138
        # distributed contraction: fork
        out_shape = list(x.size())
Boris Bonev's avatar
Boris Bonev committed
139
        out_shape[-3] = self.lmax
Boris Bonev's avatar
Boris Bonev committed
140
141
        out_shape[-2] = self.mmax
        xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device)
142

Boris Bonev's avatar
Boris Bonev committed
143
        # contraction
144
145
        xout[..., 0] = torch.einsum("...km,mlk->...lm", x[..., : self.mmax, 0], self.weights.to(x.dtype))
        xout[..., 1] = torch.einsum("...km,mlk->...lm", x[..., : self.mmax, 1], self.weights.to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
146
        x = torch.view_as_complex(xout)
147

Boris Bonev's avatar
Boris Bonev committed
148
149
        return x

150

Boris Bonev's avatar
Boris Bonev committed
151
class InverseRealSHT(nn.Module):
152
    r"""
Boris Bonev's avatar
Boris Bonev committed
153
154
155
    Defines a module for computing the inverse (real-valued) SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.

apaaris's avatar
apaaris committed
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
    Parameters
    -----------
    nlat: int
        Number of latitude points
    nlon: int
        Number of longitude points
    lmax: int
        Maximum spherical harmonic degree
    mmax: int
        Maximum spherical harmonic order
    grid: str
        Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
    norm: str
        Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
    csphase: bool
        Whether to apply the Condon-Shortley phase factor, by default True

    Raises
    ------
    ValueError: If the grid type is unknown

    Returns
    -------
    x: torch.Tensor
        Tensor of shape (..., lmax, mmax)

    References
    ----------
Boris Bonev's avatar
Boris Bonev committed
184
185
186
187
    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """

188
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
Boris Bonev's avatar
Boris Bonev committed
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # compute quadrature points
        if self.grid == "legendre-gauss":
            cost, _ = legendre_gauss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
            cost, _ = lobatto_weights(nlat, -1, 1)
204
            self.lmax = lmax or self.nlat - 1
Boris Bonev's avatar
Boris Bonev committed
205
206
207
208
        elif self.grid == "equiangular":
            cost, _ = clenshaw_curtiss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        else:
209
            raise (ValueError("Unknown quadrature mode"))
Boris Bonev's avatar
Boris Bonev committed
210
211

        # apply cosine transform and flip them
Thorsten Kurth's avatar
Thorsten Kurth committed
212
        t = torch.flip(torch.arccos(cost), dims=(0,))
Boris Bonev's avatar
Boris Bonev committed
213

214
        # determine the dimensions
Boris Bonev's avatar
Boris Bonev committed
215
216
        self.mmax = mmax or self.nlon // 2 + 1

217
        pct = _precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
Boris Bonev's avatar
Boris Bonev committed
218

Boris Bonev's avatar
Boris Bonev committed
219
        # register buffer
220
        self.register_buffer("pct", pct, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
221
222

    def extra_repr(self):
223
        return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}"
Boris Bonev's avatar
Boris Bonev committed
224
225
226

    def forward(self, x: torch.Tensor):

227
228
229
        if len(x.shape) < 2:
            raise ValueError(f"Expected tensor with at least 2 dimensions but got {len(x.shape)} instead")

230
231
        assert x.shape[-2] == self.lmax
        assert x.shape[-1] == self.mmax
232

Boris Bonev's avatar
Boris Bonev committed
233
234
        # Evaluate associated Legendre functions on the output nodes
        x = torch.view_as_real(x)
235
        xs = torch.einsum("...lmr, mlk->...kmr", x, self.pct.to(x.dtype)).contiguous()
Boris Bonev's avatar
Boris Bonev committed
236
237
238

        # apply the inverse (real) FFT
        x = torch.view_as_complex(xs)
239
240
241
242
243
244
245
246

        # ensure that imaginary part of 0 and nyquist components are zero
        # this is important because not all backend algorithms provided through the
        # irfft interface ensure that
        x[..., 0].imag = 0.0
        if (self.nlon % 2 == 0) and (self.nlon // 2 < self.mmax):
            x[..., self.nlon // 2].imag = 0.0
        
Boris Bonev's avatar
Boris Bonev committed
247
248
249
250
251
252
        x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")

        return x


class RealVectorSHT(nn.Module):
253
    r"""
Boris Bonev's avatar
Boris Bonev committed
254
255
256
257
    Defines a module for computing the forward (real) vector SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
    The SHT is applied to the last three dimensions of the input.

apaaris's avatar
apaaris committed
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
    Parameters
    -----------
    nlat: int
        Number of latitude points
    nlon: int
        Number of longitude points
    lmax: int
        Maximum spherical harmonic degree
    mmax: int
        Maximum spherical harmonic order
    grid: str
        Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
    norm: str
        Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
    csphase: bool
        Whether to apply the Condon-Shortley phase factor, by default True

    Returns
    -------
    x: torch.Tensor
        Tensor of shape (..., lmax, mmax)

    References
    ----------
Boris Bonev's avatar
Boris Bonev committed
282
283
284
285
    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """

286
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
287
        
Boris Bonev's avatar
Boris Bonev committed
288
289
290
291
292
293
294
295
296
297
298

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # compute quadrature points
        if self.grid == "legendre-gauss":
Thorsten Kurth's avatar
Thorsten Kurth committed
299
            cost, weights = legendre_gauss_weights(nlat, -1, 1)
Boris Bonev's avatar
Boris Bonev committed
300
301
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
Thorsten Kurth's avatar
Thorsten Kurth committed
302
            cost, weights = lobatto_weights(nlat, -1, 1)
303
            self.lmax = lmax or self.nlat - 1
Boris Bonev's avatar
Boris Bonev committed
304
        elif self.grid == "equiangular":
Thorsten Kurth's avatar
Thorsten Kurth committed
305
            cost, weights = clenshaw_curtiss_weights(nlat, -1, 1)
Boris Bonev's avatar
Boris Bonev committed
306
307
            self.lmax = lmax or self.nlat
        else:
308
            raise (ValueError("Unknown quadrature mode"))
Boris Bonev's avatar
Boris Bonev committed
309
310

        # apply cosine transform and flip them
Thorsten Kurth's avatar
Thorsten Kurth committed
311
        tq = torch.flip(torch.arccos(cost), dims=(0,))
Boris Bonev's avatar
Boris Bonev committed
312

313
        # determine the dimensions
Boris Bonev's avatar
Boris Bonev committed
314
        self.mmax = mmax or self.nlon // 2 + 1
315
        dpct = _precompute_dlegpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
316

Boris Bonev's avatar
Boris Bonev committed
317
318
        # combine integration weights, normalization factor in to one:
        l = torch.arange(0, self.lmax)
319
320
        norm_factor = 1.0 / l / (l + 1)
        norm_factor[0] = 1.0
Thorsten Kurth's avatar
Thorsten Kurth committed
321
        weights = torch.einsum("dmlk,k,l->dmlk", dpct, weights, norm_factor).contiguous()
Boris Bonev's avatar
Boris Bonev committed
322
323
324
325
        # since the second component is imaginary, we need to take complex conjugation into account
        weights[1] = -1 * weights[1]

        # remember quadrature weights
326
        self.register_buffer("weights", weights, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
327
328

    def extra_repr(self):
329
        return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}"
Boris Bonev's avatar
Boris Bonev committed
330
331
332

    def forward(self, x: torch.Tensor):

333
334
335
        if x.dim() < 3:
            raise ValueError(f"Expected tensor with at least 3 dimensions but got {x.dim()} instead")

336
337
        assert x.shape[-2] == self.nlat
        assert x.shape[-1] == self.nlon
Boris Bonev's avatar
Boris Bonev committed
338
339
340

        # apply real fft in the longitudinal direction
        x = 2.0 * torch.pi * torch.fft.rfft(x, dim=-1, norm="forward")
341

Boris Bonev's avatar
Boris Bonev committed
342
343
        # do the Legendre-Gauss quadrature
        x = torch.view_as_real(x)
344

Boris Bonev's avatar
Boris Bonev committed
345
346
        # distributed contraction: fork
        out_shape = list(x.size())
Boris Bonev's avatar
Boris Bonev committed
347
        out_shape[-3] = self.lmax
Boris Bonev's avatar
Boris Bonev committed
348
349
350
351
352
        out_shape[-2] = self.mmax
        xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device)

        # contraction - spheroidal component
        # real component
353
354
355
        xout[..., 0, :, :, 0] = torch.einsum("...km,mlk->...lm", x[..., 0, :, : self.mmax, 0], self.weights[0].to(x.dtype)) - torch.einsum(
            "...km,mlk->...lm", x[..., 1, :, : self.mmax, 1], self.weights[1].to(x.dtype)
        )
Boris Bonev's avatar
Boris Bonev committed
356
357

        # iamg component
358
359
360
        xout[..., 0, :, :, 1] = torch.einsum("...km,mlk->...lm", x[..., 0, :, : self.mmax, 1], self.weights[0].to(x.dtype)) + torch.einsum(
            "...km,mlk->...lm", x[..., 1, :, : self.mmax, 0], self.weights[1].to(x.dtype)
        )
Boris Bonev's avatar
Boris Bonev committed
361
362
363

        # contraction - toroidal component
        # real component
364
365
366
        xout[..., 1, :, :, 0] = -torch.einsum("...km,mlk->...lm", x[..., 0, :, : self.mmax, 1], self.weights[1].to(x.dtype)) - torch.einsum(
            "...km,mlk->...lm", x[..., 1, :, : self.mmax, 0], self.weights[0].to(x.dtype)
        )
Boris Bonev's avatar
Boris Bonev committed
367
        # imag component
368
369
370
        xout[..., 1, :, :, 1] = torch.einsum("...km,mlk->...lm", x[..., 0, :, : self.mmax, 0], self.weights[1].to(x.dtype)) - torch.einsum(
            "...km,mlk->...lm", x[..., 1, :, : self.mmax, 1], self.weights[0].to(x.dtype)
        )
Boris Bonev's avatar
Boris Bonev committed
371
372
373
374
375

        return torch.view_as_complex(xout)


class InverseRealVectorSHT(nn.Module):
376
    r"""
Boris Bonev's avatar
Boris Bonev committed
377
378
    Defines a module for computing the inverse (real-valued) vector SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
379

apaaris's avatar
apaaris committed
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
    Parameters
    -----------
    nlat: int
        Number of latitude points
    nlon: int
        Number of longitude points
    lmax: int
        Maximum spherical harmonic degree
    mmax: int
        Maximum spherical harmonic order
    grid: str
        Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
    norm: str
        Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
    csphase: bool
        Whether to apply the Condon-Shortley phase factor, by default True

    Returns
    -------
    x: torch.Tensor
        Tensor of shape (..., lmax, mmax)

    References
    ----------
Boris Bonev's avatar
Boris Bonev committed
404
405
406
    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """
407

408
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
Boris Bonev's avatar
Boris Bonev committed
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # compute quadrature points
        if self.grid == "legendre-gauss":
            cost, _ = legendre_gauss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
            cost, _ = lobatto_weights(nlat, -1, 1)
424
            self.lmax = lmax or self.nlat - 1
Boris Bonev's avatar
Boris Bonev committed
425
426
427
428
        elif self.grid == "equiangular":
            cost, _ = clenshaw_curtiss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        else:
429
            raise (ValueError("Unknown quadrature mode"))
Boris Bonev's avatar
Boris Bonev committed
430
431

        # apply cosine transform and flip them
Thorsten Kurth's avatar
Thorsten Kurth committed
432
        t = torch.flip(torch.arccos(cost), dims=(0,))
Boris Bonev's avatar
Boris Bonev committed
433

434
        # determine the dimensions
Boris Bonev's avatar
Boris Bonev committed
435
436
        self.mmax = mmax or self.nlon // 2 + 1

437
        dpct = _precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
Boris Bonev's avatar
Boris Bonev committed
438

Boris Bonev's avatar
Boris Bonev committed
439
        # register weights
440
        self.register_buffer("dpct", dpct, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
441
442

    def extra_repr(self):
443
        return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}"
Boris Bonev's avatar
Boris Bonev committed
444
445
446

    def forward(self, x: torch.Tensor):

447
448
449
        if x.dim() < 3:
            raise ValueError(f"Expected tensor with at least 3 dimensions but got {x.dim()} instead")

450
451
        assert x.shape[-2] == self.lmax
        assert x.shape[-1] == self.mmax
452

Boris Bonev's avatar
Boris Bonev committed
453
454
455
456
457
        # Evaluate associated Legendre functions on the output nodes
        x = torch.view_as_real(x)

        # contraction - spheroidal component
        # real component
458
        srl = torch.einsum("...lm,mlk->...km", x[..., 0, :, :, 0], self.dpct[0].to(x.dtype)) - torch.einsum("...lm,mlk->...km", x[..., 1, :, :, 1], self.dpct[1].to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
459
        # iamg component
460
        sim = torch.einsum("...lm,mlk->...km", x[..., 0, :, :, 1], self.dpct[0].to(x.dtype)) + torch.einsum("...lm,mlk->...km", x[..., 1, :, :, 0], self.dpct[1].to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
461
462
463

        # contraction - toroidal component
        # real component
464
        trl = -torch.einsum("...lm,mlk->...km", x[..., 0, :, :, 1], self.dpct[1].to(x.dtype)) - torch.einsum("...lm,mlk->...km", x[..., 1, :, :, 0], self.dpct[0].to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
465
        # imag component
466
        tim = torch.einsum("...lm,mlk->...km", x[..., 0, :, :, 0], self.dpct[1].to(x.dtype)) - torch.einsum("...lm,mlk->...km", x[..., 1, :, :, 1], self.dpct[0].to(x.dtype))
467

Boris Bonev's avatar
Boris Bonev committed
468
469
470
471
472
473
474
        # reassemble
        s = torch.stack((srl, sim), -1)
        t = torch.stack((trl, tim), -1)
        xs = torch.stack((s, t), -4)

        # apply the inverse (real) FFT
        x = torch.view_as_complex(xs)
475
476
477
478
479
480
481
482

        # ensure that imaginary part of 0 and nyquist components are zero
        # this is important because not all backend algorithms provided through the
        # irfft interface ensure that
        x[..., 0].imag = 0.0
        if (self.nlon % 2 == 0) and (self.nlon // 2 < self.mmax):
            x[..., self.nlon // 2].imag = 0.0
        
Boris Bonev's avatar
Boris Bonev committed
483
484
485
        x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")

        return x