sht.py 16.6 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
5
#
Boris Bonev's avatar
Boris Bonev committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import numpy as np
import torch
import torch.nn as nn
import torch.fft

37
38
from torch_harmonics.quadrature import legendre_gauss_weights, lobatto_weights, clenshaw_curtiss_weights
from torch_harmonics.legendre import _precompute_legpoly, _precompute_dlegpoly
Boris Bonev's avatar
Boris Bonev committed
39
40
41


class RealSHT(nn.Module):
42
    r"""
Boris Bonev's avatar
Boris Bonev committed
43
44
45
46
47
48
49
50
    Defines a module for computing the forward (real-valued) SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
    The SHT is applied to the last two dimensions of the input

    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """

51
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
52
        r"""
Boris Bonev's avatar
Boris Bonev committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
        Initializes the SHT Layer, precomputing the necessary quadrature weights

        Parameters:
        nlat: input grid resolution in the latitudinal direction
        nlon: input grid resolution in the longitudinal direction
        grid: grid in the latitude direction (for now only tensor product grids are supported)
        """

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # TODO: include assertions regarding the dimensions

71
        # compute quadrature points and lmax based on the exactness of the quadrature
Boris Bonev's avatar
Boris Bonev committed
72
73
        if self.grid == "legendre-gauss":
            cost, w = legendre_gauss_weights(nlat, -1, 1)
74
75
            # maximum polynomial degree for Gauss Legendre is 2 * nlat - 1 >= 2 * lmax
            # and therefore lmax = nlat - 1 (inclusive)
Boris Bonev's avatar
Boris Bonev committed
76
77
78
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
            cost, w = lobatto_weights(nlat, -1, 1)
79
80
81
            # maximum polynomial degree for Gauss Legendre is 2 * nlat - 3 >= 2 * lmax
            # and therefore lmax = nlat - 2 (inclusive)
            self.lmax = lmax or self.nlat - 1
Boris Bonev's avatar
Boris Bonev committed
82
83
        elif self.grid == "equiangular":
            cost, w = clenshaw_curtiss_weights(nlat, -1, 1)
84
85
86
            # in principle, Clenshaw-Curtiss quadrature is only exact up to polynomial degrees of nlat
            # however, we observe that the quadrature is remarkably accurate for higher degress. This is why we do not
            # choose a lower lmax for now.
Boris Bonev's avatar
Boris Bonev committed
87
88
            self.lmax = lmax or self.nlat
        else:
89
            raise (ValueError("Unknown quadrature mode"))
Boris Bonev's avatar
Boris Bonev committed
90
91
92
93

        # apply cosine transform and flip them
        tq = np.flip(np.arccos(cost))

94
        # determine the dimensions
Boris Bonev's avatar
Boris Bonev committed
95
96
97
98
        self.mmax = mmax or self.nlon // 2 + 1

        # combine quadrature weights with the legendre weights
        weights = torch.from_numpy(w)
99
100
        pct = _precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
        pct = torch.from_numpy(pct)
101
        weights = torch.einsum("mlk,k->mlk", pct, weights)
Boris Bonev's avatar
Boris Bonev committed
102
103

        # remember quadrature weights
104
        self.register_buffer("weights", weights, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
105
106

    def extra_repr(self):
107
        r"""
Boris Bonev's avatar
Boris Bonev committed
108
109
        Pretty print module
        """
110
        return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}"
Boris Bonev's avatar
Boris Bonev committed
111
112
113

    def forward(self, x: torch.Tensor):

114
115
116
        if x.dim() < 2:
            raise ValueError(f"Expected tensor with at least 2 dimensions but got {x.dim()} instead")

117
118
        assert x.shape[-2] == self.nlat
        assert x.shape[-1] == self.nlon
Boris Bonev's avatar
Boris Bonev committed
119
120
121

        # apply real fft in the longitudinal direction
        x = 2.0 * torch.pi * torch.fft.rfft(x, dim=-1, norm="forward")
122

Boris Bonev's avatar
Boris Bonev committed
123
124
        # do the Legendre-Gauss quadrature
        x = torch.view_as_real(x)
125

Boris Bonev's avatar
Boris Bonev committed
126
127
        # distributed contraction: fork
        out_shape = list(x.size())
Boris Bonev's avatar
Boris Bonev committed
128
        out_shape[-3] = self.lmax
Boris Bonev's avatar
Boris Bonev committed
129
130
        out_shape[-2] = self.mmax
        xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device)
131

Boris Bonev's avatar
Boris Bonev committed
132
        # contraction
133
134
        xout[..., 0] = torch.einsum("...km,mlk->...lm", x[..., : self.mmax, 0], self.weights.to(x.dtype))
        xout[..., 1] = torch.einsum("...km,mlk->...lm", x[..., : self.mmax, 1], self.weights.to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
135
        x = torch.view_as_complex(xout)
136

Boris Bonev's avatar
Boris Bonev committed
137
138
        return x

139

Boris Bonev's avatar
Boris Bonev committed
140
class InverseRealSHT(nn.Module):
141
    r"""
Boris Bonev's avatar
Boris Bonev committed
142
143
144
145
146
147
148
149
150
    Defines a module for computing the inverse (real-valued) SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
    nlat, nlon: Output dimensions
    lmax, mmax: Input dimensions (spherical coefficients). For convenience, these are inferred from the output dimensions

    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """

151
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
Boris Bonev's avatar
Boris Bonev committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # compute quadrature points
        if self.grid == "legendre-gauss":
            cost, _ = legendre_gauss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
            cost, _ = lobatto_weights(nlat, -1, 1)
167
            self.lmax = lmax or self.nlat - 1
Boris Bonev's avatar
Boris Bonev committed
168
169
170
171
        elif self.grid == "equiangular":
            cost, _ = clenshaw_curtiss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        else:
172
            raise (ValueError("Unknown quadrature mode"))
Boris Bonev's avatar
Boris Bonev committed
173
174
175
176

        # apply cosine transform and flip them
        t = np.flip(np.arccos(cost))

177
        # determine the dimensions
Boris Bonev's avatar
Boris Bonev committed
178
179
        self.mmax = mmax or self.nlon // 2 + 1

180
181
        pct = _precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
        pct = torch.from_numpy(pct)
Boris Bonev's avatar
Boris Bonev committed
182

Boris Bonev's avatar
Boris Bonev committed
183
        # register buffer
184
        self.register_buffer("pct", pct, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
185
186

    def extra_repr(self):
187
        r"""
Boris Bonev's avatar
Boris Bonev committed
188
189
        Pretty print module
        """
190
        return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}"
Boris Bonev's avatar
Boris Bonev committed
191
192
193

    def forward(self, x: torch.Tensor):

194
195
196
        if len(x.shape) < 2:
            raise ValueError(f"Expected tensor with at least 2 dimensions but got {len(x.shape)} instead")

197
198
        assert x.shape[-2] == self.lmax
        assert x.shape[-1] == self.mmax
199

Boris Bonev's avatar
Boris Bonev committed
200
201
        # Evaluate associated Legendre functions on the output nodes
        x = torch.view_as_real(x)
202

203
204
        rl = torch.einsum("...lm, mlk->...km", x[..., 0], self.pct.to(x.dtype))
        im = torch.einsum("...lm, mlk->...km", x[..., 1], self.pct.to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
205
206
207
208
209
210
211
212
213
214
        xs = torch.stack((rl, im), -1)

        # apply the inverse (real) FFT
        x = torch.view_as_complex(xs)
        x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")

        return x


class RealVectorSHT(nn.Module):
215
    r"""
Boris Bonev's avatar
Boris Bonev committed
216
217
218
219
220
221
222
223
    Defines a module for computing the forward (real) vector SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
    The SHT is applied to the last three dimensions of the input.

    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """

224
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
225
        r"""
Boris Bonev's avatar
Boris Bonev committed
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
        Initializes the vector SHT Layer, precomputing the necessary quadrature weights

        Parameters:
        nlat: input grid resolution in the latitudinal direction
        nlon: input grid resolution in the longitudinal direction
        grid: type of grid the data lives on
        """

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # compute quadrature points
        if self.grid == "legendre-gauss":
            cost, w = legendre_gauss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
            cost, w = lobatto_weights(nlat, -1, 1)
248
            self.lmax = lmax or self.nlat - 1
Boris Bonev's avatar
Boris Bonev committed
249
250
251
252
        elif self.grid == "equiangular":
            cost, w = clenshaw_curtiss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        else:
253
            raise (ValueError("Unknown quadrature mode"))
Boris Bonev's avatar
Boris Bonev committed
254
255
256
257

        # apply cosine transform and flip them
        tq = np.flip(np.arccos(cost))

258
        # determine the dimensions
Boris Bonev's avatar
Boris Bonev committed
259
260
261
        self.mmax = mmax or self.nlon // 2 + 1

        weights = torch.from_numpy(w)
262
263
        dpct = _precompute_dlegpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
        dpct = torch.from_numpy(dpct)
264

Boris Bonev's avatar
Boris Bonev committed
265
266
        # combine integration weights, normalization factor in to one:
        l = torch.arange(0, self.lmax)
267
268
269
        norm_factor = 1.0 / l / (l + 1)
        norm_factor[0] = 1.0
        weights = torch.einsum("dmlk,k,l->dmlk", dpct, weights, norm_factor)
Boris Bonev's avatar
Boris Bonev committed
270
271
272
273
        # since the second component is imaginary, we need to take complex conjugation into account
        weights[1] = -1 * weights[1]

        # remember quadrature weights
274
        self.register_buffer("weights", weights, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
275
276

    def extra_repr(self):
277
        r"""
Boris Bonev's avatar
Boris Bonev committed
278
279
        Pretty print module
        """
280
        return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}"
Boris Bonev's avatar
Boris Bonev committed
281
282
283

    def forward(self, x: torch.Tensor):

284
285
286
        if x.dim() < 3:
            raise ValueError(f"Expected tensor with at least 3 dimensions but got {x.dim()} instead")

287
288
        assert x.shape[-2] == self.nlat
        assert x.shape[-1] == self.nlon
Boris Bonev's avatar
Boris Bonev committed
289
290
291

        # apply real fft in the longitudinal direction
        x = 2.0 * torch.pi * torch.fft.rfft(x, dim=-1, norm="forward")
292

Boris Bonev's avatar
Boris Bonev committed
293
294
        # do the Legendre-Gauss quadrature
        x = torch.view_as_real(x)
295

Boris Bonev's avatar
Boris Bonev committed
296
297
        # distributed contraction: fork
        out_shape = list(x.size())
Boris Bonev's avatar
Boris Bonev committed
298
        out_shape[-3] = self.lmax
Boris Bonev's avatar
Boris Bonev committed
299
300
301
302
303
        out_shape[-2] = self.mmax
        xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device)

        # contraction - spheroidal component
        # real component
304
305
306
        xout[..., 0, :, :, 0] = torch.einsum("...km,mlk->...lm", x[..., 0, :, : self.mmax, 0], self.weights[0].to(x.dtype)) - torch.einsum(
            "...km,mlk->...lm", x[..., 1, :, : self.mmax, 1], self.weights[1].to(x.dtype)
        )
Boris Bonev's avatar
Boris Bonev committed
307
308

        # iamg component
309
310
311
        xout[..., 0, :, :, 1] = torch.einsum("...km,mlk->...lm", x[..., 0, :, : self.mmax, 1], self.weights[0].to(x.dtype)) + torch.einsum(
            "...km,mlk->...lm", x[..., 1, :, : self.mmax, 0], self.weights[1].to(x.dtype)
        )
Boris Bonev's avatar
Boris Bonev committed
312
313
314

        # contraction - toroidal component
        # real component
315
316
317
        xout[..., 1, :, :, 0] = -torch.einsum("...km,mlk->...lm", x[..., 0, :, : self.mmax, 1], self.weights[1].to(x.dtype)) - torch.einsum(
            "...km,mlk->...lm", x[..., 1, :, : self.mmax, 0], self.weights[0].to(x.dtype)
        )
Boris Bonev's avatar
Boris Bonev committed
318
        # imag component
319
320
321
        xout[..., 1, :, :, 1] = torch.einsum("...km,mlk->...lm", x[..., 0, :, : self.mmax, 0], self.weights[1].to(x.dtype)) - torch.einsum(
            "...km,mlk->...lm", x[..., 1, :, : self.mmax, 1], self.weights[0].to(x.dtype)
        )
Boris Bonev's avatar
Boris Bonev committed
322
323
324
325
326

        return torch.view_as_complex(xout)


class InverseRealVectorSHT(nn.Module):
327
    r"""
Boris Bonev's avatar
Boris Bonev committed
328
329
    Defines a module for computing the inverse (real-valued) vector SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
330

Boris Bonev's avatar
Boris Bonev committed
331
332
333
    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """
334

335
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
Boris Bonev's avatar
Boris Bonev committed
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # compute quadrature points
        if self.grid == "legendre-gauss":
            cost, _ = legendre_gauss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
            cost, _ = lobatto_weights(nlat, -1, 1)
351
            self.lmax = lmax or self.nlat - 1
Boris Bonev's avatar
Boris Bonev committed
352
353
354
355
        elif self.grid == "equiangular":
            cost, _ = clenshaw_curtiss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        else:
356
            raise (ValueError("Unknown quadrature mode"))
Boris Bonev's avatar
Boris Bonev committed
357
358
359
360

        # apply cosine transform and flip them
        t = np.flip(np.arccos(cost))

361
        # determine the dimensions
Boris Bonev's avatar
Boris Bonev committed
362
363
        self.mmax = mmax or self.nlon // 2 + 1

364
365
        dpct = _precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
        dpct = torch.from_numpy(dpct)
Boris Bonev's avatar
Boris Bonev committed
366

Boris Bonev's avatar
Boris Bonev committed
367
        # register weights
368
        self.register_buffer("dpct", dpct, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
369
370

    def extra_repr(self):
371
        r"""
Boris Bonev's avatar
Boris Bonev committed
372
373
        Pretty print module
        """
374
        return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}"
Boris Bonev's avatar
Boris Bonev committed
375
376
377

    def forward(self, x: torch.Tensor):

378
379
380
        if x.dim() < 3:
            raise ValueError(f"Expected tensor with at least 3 dimensions but got {x.dim()} instead")

381
382
        assert x.shape[-2] == self.lmax
        assert x.shape[-1] == self.mmax
383

Boris Bonev's avatar
Boris Bonev committed
384
385
386
387
388
        # Evaluate associated Legendre functions on the output nodes
        x = torch.view_as_real(x)

        # contraction - spheroidal component
        # real component
389
        srl = torch.einsum("...lm,mlk->...km", x[..., 0, :, :, 0], self.dpct[0].to(x.dtype)) - torch.einsum("...lm,mlk->...km", x[..., 1, :, :, 1], self.dpct[1].to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
390
        # iamg component
391
        sim = torch.einsum("...lm,mlk->...km", x[..., 0, :, :, 1], self.dpct[0].to(x.dtype)) + torch.einsum("...lm,mlk->...km", x[..., 1, :, :, 0], self.dpct[1].to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
392
393
394

        # contraction - toroidal component
        # real component
395
        trl = -torch.einsum("...lm,mlk->...km", x[..., 0, :, :, 1], self.dpct[1].to(x.dtype)) - torch.einsum("...lm,mlk->...km", x[..., 1, :, :, 0], self.dpct[0].to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
396
        # imag component
397
        tim = torch.einsum("...lm,mlk->...km", x[..., 0, :, :, 0], self.dpct[1].to(x.dtype)) - torch.einsum("...lm,mlk->...km", x[..., 1, :, :, 1], self.dpct[0].to(x.dtype))
398

Boris Bonev's avatar
Boris Bonev committed
399
400
401
402
403
404
405
406
407
408
        # reassemble
        s = torch.stack((srl, sim), -1)
        t = torch.stack((trl, tim), -1)
        xs = torch.stack((s, t), -4)

        # apply the inverse (real) FFT
        x = torch.view_as_complex(xs)
        x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")

        return x