sht.py 17 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
5
#
Boris Bonev's avatar
Boris Bonev committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import torch
import torch.nn as nn
import torch.fft

36
37
from torch_harmonics.quadrature import legendre_gauss_weights, lobatto_weights, clenshaw_curtiss_weights
from torch_harmonics.legendre import _precompute_legpoly, _precompute_dlegpoly
Boris Bonev's avatar
Boris Bonev committed
38
39
40


class RealSHT(nn.Module):
41
    r"""
Boris Bonev's avatar
Boris Bonev committed
42
43
44
45
46
47
48
49
    Defines a module for computing the forward (real-valued) SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
    The SHT is applied to the last two dimensions of the input

    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """

50
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
51
        r"""
Boris Bonev's avatar
Boris Bonev committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
        Initializes the SHT Layer, precomputing the necessary quadrature weights

        Parameters:
        nlat: input grid resolution in the latitudinal direction
        nlon: input grid resolution in the longitudinal direction
        grid: grid in the latitude direction (for now only tensor product grids are supported)
        """

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # TODO: include assertions regarding the dimensions

70
        # compute quadrature points and lmax based on the exactness of the quadrature
Boris Bonev's avatar
Boris Bonev committed
71
        if self.grid == "legendre-gauss":
Thorsten Kurth's avatar
Thorsten Kurth committed
72
            cost, weights = legendre_gauss_weights(nlat, -1, 1)
73
74
            # maximum polynomial degree for Gauss Legendre is 2 * nlat - 1 >= 2 * lmax
            # and therefore lmax = nlat - 1 (inclusive)
Boris Bonev's avatar
Boris Bonev committed
75
76
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
Thorsten Kurth's avatar
Thorsten Kurth committed
77
            cost, weights = lobatto_weights(nlat, -1, 1)
78
79
80
            # maximum polynomial degree for Gauss Legendre is 2 * nlat - 3 >= 2 * lmax
            # and therefore lmax = nlat - 2 (inclusive)
            self.lmax = lmax or self.nlat - 1
Boris Bonev's avatar
Boris Bonev committed
81
        elif self.grid == "equiangular":
Thorsten Kurth's avatar
Thorsten Kurth committed
82
            cost, weights = clenshaw_curtiss_weights(nlat, -1, 1)
83
84
85
            # in principle, Clenshaw-Curtiss quadrature is only exact up to polynomial degrees of nlat
            # however, we observe that the quadrature is remarkably accurate for higher degress. This is why we do not
            # choose a lower lmax for now.
Boris Bonev's avatar
Boris Bonev committed
86
87
            self.lmax = lmax or self.nlat
        else:
88
            raise (ValueError("Unknown quadrature mode"))
Boris Bonev's avatar
Boris Bonev committed
89
90

        # apply cosine transform and flip them
Thorsten Kurth's avatar
Thorsten Kurth committed
91
        tq = torch.flip(torch.arccos(cost), dims=(0,))
Boris Bonev's avatar
Boris Bonev committed
92

93
        # determine the dimensions
Boris Bonev's avatar
Boris Bonev committed
94
95
96
        self.mmax = mmax or self.nlon // 2 + 1

        # combine quadrature weights with the legendre weights
97
        pct = _precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
Thorsten Kurth's avatar
Thorsten Kurth committed
98
        weights = torch.einsum("mlk,k->mlk", pct, weights).contiguous()
Boris Bonev's avatar
Boris Bonev committed
99
100

        # remember quadrature weights
101
        self.register_buffer("weights", weights, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
102
103

    def extra_repr(self):
104
        r"""
Boris Bonev's avatar
Boris Bonev committed
105
106
        Pretty print module
        """
107
        return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}"
Boris Bonev's avatar
Boris Bonev committed
108
109
110

    def forward(self, x: torch.Tensor):

111
112
113
        if x.dim() < 2:
            raise ValueError(f"Expected tensor with at least 2 dimensions but got {x.dim()} instead")

114
115
        assert x.shape[-2] == self.nlat
        assert x.shape[-1] == self.nlon
Boris Bonev's avatar
Boris Bonev committed
116
117
118

        # apply real fft in the longitudinal direction
        x = 2.0 * torch.pi * torch.fft.rfft(x, dim=-1, norm="forward")
119

Boris Bonev's avatar
Boris Bonev committed
120
121
        # do the Legendre-Gauss quadrature
        x = torch.view_as_real(x)
122

Boris Bonev's avatar
Boris Bonev committed
123
124
        # distributed contraction: fork
        out_shape = list(x.size())
Boris Bonev's avatar
Boris Bonev committed
125
        out_shape[-3] = self.lmax
Boris Bonev's avatar
Boris Bonev committed
126
127
        out_shape[-2] = self.mmax
        xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device)
128

Boris Bonev's avatar
Boris Bonev committed
129
        # contraction
130
131
        xout[..., 0] = torch.einsum("...km,mlk->...lm", x[..., : self.mmax, 0], self.weights.to(x.dtype))
        xout[..., 1] = torch.einsum("...km,mlk->...lm", x[..., : self.mmax, 1], self.weights.to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
132
        x = torch.view_as_complex(xout)
133

Boris Bonev's avatar
Boris Bonev committed
134
135
        return x

136

Boris Bonev's avatar
Boris Bonev committed
137
class InverseRealSHT(nn.Module):
138
    r"""
Boris Bonev's avatar
Boris Bonev committed
139
140
141
142
143
144
145
146
147
    Defines a module for computing the inverse (real-valued) SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
    nlat, nlon: Output dimensions
    lmax, mmax: Input dimensions (spherical coefficients). For convenience, these are inferred from the output dimensions

    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """

148
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
Boris Bonev's avatar
Boris Bonev committed
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # compute quadrature points
        if self.grid == "legendre-gauss":
            cost, _ = legendre_gauss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
            cost, _ = lobatto_weights(nlat, -1, 1)
164
            self.lmax = lmax or self.nlat - 1
Boris Bonev's avatar
Boris Bonev committed
165
166
167
168
        elif self.grid == "equiangular":
            cost, _ = clenshaw_curtiss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        else:
169
            raise (ValueError("Unknown quadrature mode"))
Boris Bonev's avatar
Boris Bonev committed
170
171

        # apply cosine transform and flip them
Thorsten Kurth's avatar
Thorsten Kurth committed
172
        t = torch.flip(torch.arccos(cost), dims=(0,))
Boris Bonev's avatar
Boris Bonev committed
173

174
        # determine the dimensions
Boris Bonev's avatar
Boris Bonev committed
175
176
        self.mmax = mmax or self.nlon // 2 + 1

177
        pct = _precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
Boris Bonev's avatar
Boris Bonev committed
178

Boris Bonev's avatar
Boris Bonev committed
179
        # register buffer
180
        self.register_buffer("pct", pct, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
181
182

    def extra_repr(self):
183
        r"""
Boris Bonev's avatar
Boris Bonev committed
184
185
        Pretty print module
        """
186
        return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}"
Boris Bonev's avatar
Boris Bonev committed
187
188
189

    def forward(self, x: torch.Tensor):

190
191
192
        if len(x.shape) < 2:
            raise ValueError(f"Expected tensor with at least 2 dimensions but got {len(x.shape)} instead")

193
194
        assert x.shape[-2] == self.lmax
        assert x.shape[-1] == self.mmax
195

Boris Bonev's avatar
Boris Bonev committed
196
197
        # Evaluate associated Legendre functions on the output nodes
        x = torch.view_as_real(x)
198
        xs = torch.einsum("...lmr, mlk->...kmr", x, self.pct.to(x.dtype)).contiguous()
Boris Bonev's avatar
Boris Bonev committed
199
200
201

        # apply the inverse (real) FFT
        x = torch.view_as_complex(xs)
202
203
204
205
206
207
208
209

        # ensure that imaginary part of 0 and nyquist components are zero
        # this is important because not all backend algorithms provided through the
        # irfft interface ensure that
        x[..., 0].imag = 0.0
        if (self.nlon % 2 == 0) and (self.nlon // 2 < self.mmax):
            x[..., self.nlon // 2].imag = 0.0
        
Boris Bonev's avatar
Boris Bonev committed
210
211
212
213
214
215
        x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")

        return x


class RealVectorSHT(nn.Module):
216
    r"""
Boris Bonev's avatar
Boris Bonev committed
217
218
219
220
221
222
223
224
    Defines a module for computing the forward (real) vector SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
    The SHT is applied to the last three dimensions of the input.

    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """

225
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
226
        r"""
Boris Bonev's avatar
Boris Bonev committed
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
        Initializes the vector SHT Layer, precomputing the necessary quadrature weights

        Parameters:
        nlat: input grid resolution in the latitudinal direction
        nlon: input grid resolution in the longitudinal direction
        grid: type of grid the data lives on
        """

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # compute quadrature points
        if self.grid == "legendre-gauss":
Thorsten Kurth's avatar
Thorsten Kurth committed
245
            cost, weights = legendre_gauss_weights(nlat, -1, 1)
Boris Bonev's avatar
Boris Bonev committed
246
247
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
Thorsten Kurth's avatar
Thorsten Kurth committed
248
            cost, weights = lobatto_weights(nlat, -1, 1)
249
            self.lmax = lmax or self.nlat - 1
Boris Bonev's avatar
Boris Bonev committed
250
        elif self.grid == "equiangular":
Thorsten Kurth's avatar
Thorsten Kurth committed
251
            cost, weights = clenshaw_curtiss_weights(nlat, -1, 1)
Boris Bonev's avatar
Boris Bonev committed
252
253
            self.lmax = lmax or self.nlat
        else:
254
            raise (ValueError("Unknown quadrature mode"))
Boris Bonev's avatar
Boris Bonev committed
255
256

        # apply cosine transform and flip them
Thorsten Kurth's avatar
Thorsten Kurth committed
257
        tq = torch.flip(torch.arccos(cost), dims=(0,))
Boris Bonev's avatar
Boris Bonev committed
258

259
        # determine the dimensions
Boris Bonev's avatar
Boris Bonev committed
260
        self.mmax = mmax or self.nlon // 2 + 1
261
        dpct = _precompute_dlegpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
262

Boris Bonev's avatar
Boris Bonev committed
263
264
        # combine integration weights, normalization factor in to one:
        l = torch.arange(0, self.lmax)
265
266
        norm_factor = 1.0 / l / (l + 1)
        norm_factor[0] = 1.0
Thorsten Kurth's avatar
Thorsten Kurth committed
267
        weights = torch.einsum("dmlk,k,l->dmlk", dpct, weights, norm_factor).contiguous()
Boris Bonev's avatar
Boris Bonev committed
268
269
270
271
        # since the second component is imaginary, we need to take complex conjugation into account
        weights[1] = -1 * weights[1]

        # remember quadrature weights
272
        self.register_buffer("weights", weights, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
273
274

    def extra_repr(self):
275
        r"""
Boris Bonev's avatar
Boris Bonev committed
276
277
        Pretty print module
        """
278
        return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}"
Boris Bonev's avatar
Boris Bonev committed
279
280
281

    def forward(self, x: torch.Tensor):

282
283
284
        if x.dim() < 3:
            raise ValueError(f"Expected tensor with at least 3 dimensions but got {x.dim()} instead")

285
286
        assert x.shape[-2] == self.nlat
        assert x.shape[-1] == self.nlon
Boris Bonev's avatar
Boris Bonev committed
287
288
289

        # apply real fft in the longitudinal direction
        x = 2.0 * torch.pi * torch.fft.rfft(x, dim=-1, norm="forward")
290

Boris Bonev's avatar
Boris Bonev committed
291
292
        # do the Legendre-Gauss quadrature
        x = torch.view_as_real(x)
293

Boris Bonev's avatar
Boris Bonev committed
294
295
        # distributed contraction: fork
        out_shape = list(x.size())
Boris Bonev's avatar
Boris Bonev committed
296
        out_shape[-3] = self.lmax
Boris Bonev's avatar
Boris Bonev committed
297
298
299
300
301
        out_shape[-2] = self.mmax
        xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device)

        # contraction - spheroidal component
        # real component
302
303
304
        xout[..., 0, :, :, 0] = torch.einsum("...km,mlk->...lm", x[..., 0, :, : self.mmax, 0], self.weights[0].to(x.dtype)) - torch.einsum(
            "...km,mlk->...lm", x[..., 1, :, : self.mmax, 1], self.weights[1].to(x.dtype)
        )
Boris Bonev's avatar
Boris Bonev committed
305
306

        # iamg component
307
308
309
        xout[..., 0, :, :, 1] = torch.einsum("...km,mlk->...lm", x[..., 0, :, : self.mmax, 1], self.weights[0].to(x.dtype)) + torch.einsum(
            "...km,mlk->...lm", x[..., 1, :, : self.mmax, 0], self.weights[1].to(x.dtype)
        )
Boris Bonev's avatar
Boris Bonev committed
310
311
312

        # contraction - toroidal component
        # real component
313
314
315
        xout[..., 1, :, :, 0] = -torch.einsum("...km,mlk->...lm", x[..., 0, :, : self.mmax, 1], self.weights[1].to(x.dtype)) - torch.einsum(
            "...km,mlk->...lm", x[..., 1, :, : self.mmax, 0], self.weights[0].to(x.dtype)
        )
Boris Bonev's avatar
Boris Bonev committed
316
        # imag component
317
318
319
        xout[..., 1, :, :, 1] = torch.einsum("...km,mlk->...lm", x[..., 0, :, : self.mmax, 0], self.weights[1].to(x.dtype)) - torch.einsum(
            "...km,mlk->...lm", x[..., 1, :, : self.mmax, 1], self.weights[0].to(x.dtype)
        )
Boris Bonev's avatar
Boris Bonev committed
320
321
322
323
324

        return torch.view_as_complex(xout)


class InverseRealVectorSHT(nn.Module):
325
    r"""
Boris Bonev's avatar
Boris Bonev committed
326
327
    Defines a module for computing the inverse (real-valued) vector SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
328

Boris Bonev's avatar
Boris Bonev committed
329
330
331
    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """
332

333
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
Boris Bonev's avatar
Boris Bonev committed
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # compute quadrature points
        if self.grid == "legendre-gauss":
            cost, _ = legendre_gauss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
            cost, _ = lobatto_weights(nlat, -1, 1)
349
            self.lmax = lmax or self.nlat - 1
Boris Bonev's avatar
Boris Bonev committed
350
351
352
353
        elif self.grid == "equiangular":
            cost, _ = clenshaw_curtiss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        else:
354
            raise (ValueError("Unknown quadrature mode"))
Boris Bonev's avatar
Boris Bonev committed
355
356

        # apply cosine transform and flip them
Thorsten Kurth's avatar
Thorsten Kurth committed
357
        t = torch.flip(torch.arccos(cost), dims=(0,))
Boris Bonev's avatar
Boris Bonev committed
358

359
        # determine the dimensions
Boris Bonev's avatar
Boris Bonev committed
360
361
        self.mmax = mmax or self.nlon // 2 + 1

362
        dpct = _precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
Boris Bonev's avatar
Boris Bonev committed
363

Boris Bonev's avatar
Boris Bonev committed
364
        # register weights
365
        self.register_buffer("dpct", dpct, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
366
367

    def extra_repr(self):
368
        r"""
Boris Bonev's avatar
Boris Bonev committed
369
370
        Pretty print module
        """
371
        return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}"
Boris Bonev's avatar
Boris Bonev committed
372
373
374

    def forward(self, x: torch.Tensor):

375
376
377
        if x.dim() < 3:
            raise ValueError(f"Expected tensor with at least 3 dimensions but got {x.dim()} instead")

378
379
        assert x.shape[-2] == self.lmax
        assert x.shape[-1] == self.mmax
380

Boris Bonev's avatar
Boris Bonev committed
381
382
383
384
385
        # Evaluate associated Legendre functions on the output nodes
        x = torch.view_as_real(x)

        # contraction - spheroidal component
        # real component
386
        srl = torch.einsum("...lm,mlk->...km", x[..., 0, :, :, 0], self.dpct[0].to(x.dtype)) - torch.einsum("...lm,mlk->...km", x[..., 1, :, :, 1], self.dpct[1].to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
387
        # iamg component
388
        sim = torch.einsum("...lm,mlk->...km", x[..., 0, :, :, 1], self.dpct[0].to(x.dtype)) + torch.einsum("...lm,mlk->...km", x[..., 1, :, :, 0], self.dpct[1].to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
389
390
391

        # contraction - toroidal component
        # real component
392
        trl = -torch.einsum("...lm,mlk->...km", x[..., 0, :, :, 1], self.dpct[1].to(x.dtype)) - torch.einsum("...lm,mlk->...km", x[..., 1, :, :, 0], self.dpct[0].to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
393
        # imag component
394
        tim = torch.einsum("...lm,mlk->...km", x[..., 0, :, :, 0], self.dpct[1].to(x.dtype)) - torch.einsum("...lm,mlk->...km", x[..., 1, :, :, 1], self.dpct[0].to(x.dtype))
395

Boris Bonev's avatar
Boris Bonev committed
396
397
398
399
400
401
402
        # reassemble
        s = torch.stack((srl, sim), -1)
        t = torch.stack((trl, tim), -1)
        xs = torch.stack((s, t), -4)

        # apply the inverse (real) FFT
        x = torch.view_as_complex(xs)
403
404
405
406
407
408
409
410

        # ensure that imaginary part of 0 and nyquist components are zero
        # this is important because not all backend algorithms provided through the
        # irfft interface ensure that
        x[..., 0].imag = 0.0
        if (self.nlon % 2 == 0) and (self.nlon // 2 < self.mmax):
            x[..., self.nlon // 2].imag = 0.0
        
Boris Bonev's avatar
Boris Bonev committed
411
412
413
        x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")

        return x