sht.py 16.5 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
3
4
# coding=utf-8

# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
5
#
Boris Bonev's avatar
Boris Bonev committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import torch
import torch.nn as nn
import torch.fft

36
37
from torch_harmonics.quadrature import legendre_gauss_weights, lobatto_weights, clenshaw_curtiss_weights
from torch_harmonics.legendre import _precompute_legpoly, _precompute_dlegpoly
Boris Bonev's avatar
Boris Bonev committed
38
39
40


class RealSHT(nn.Module):
41
    r"""
Boris Bonev's avatar
Boris Bonev committed
42
43
44
45
46
47
48
49
    Defines a module for computing the forward (real-valued) SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
    The SHT is applied to the last two dimensions of the input

    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """

50
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
51
        r"""
Boris Bonev's avatar
Boris Bonev committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
        Initializes the SHT Layer, precomputing the necessary quadrature weights

        Parameters:
        nlat: input grid resolution in the latitudinal direction
        nlon: input grid resolution in the longitudinal direction
        grid: grid in the latitude direction (for now only tensor product grids are supported)
        """

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # TODO: include assertions regarding the dimensions

70
        # compute quadrature points and lmax based on the exactness of the quadrature
Boris Bonev's avatar
Boris Bonev committed
71
        if self.grid == "legendre-gauss":
Thorsten Kurth's avatar
Thorsten Kurth committed
72
            cost, weights = legendre_gauss_weights(nlat, -1, 1)
73
74
            # maximum polynomial degree for Gauss Legendre is 2 * nlat - 1 >= 2 * lmax
            # and therefore lmax = nlat - 1 (inclusive)
Boris Bonev's avatar
Boris Bonev committed
75
76
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
Thorsten Kurth's avatar
Thorsten Kurth committed
77
            cost, weights = lobatto_weights(nlat, -1, 1)
78
79
80
            # maximum polynomial degree for Gauss Legendre is 2 * nlat - 3 >= 2 * lmax
            # and therefore lmax = nlat - 2 (inclusive)
            self.lmax = lmax or self.nlat - 1
Boris Bonev's avatar
Boris Bonev committed
81
        elif self.grid == "equiangular":
Thorsten Kurth's avatar
Thorsten Kurth committed
82
            cost, weights = clenshaw_curtiss_weights(nlat, -1, 1)
83
84
85
            # in principle, Clenshaw-Curtiss quadrature is only exact up to polynomial degrees of nlat
            # however, we observe that the quadrature is remarkably accurate for higher degress. This is why we do not
            # choose a lower lmax for now.
Boris Bonev's avatar
Boris Bonev committed
86
87
            self.lmax = lmax or self.nlat
        else:
88
            raise (ValueError("Unknown quadrature mode"))
Boris Bonev's avatar
Boris Bonev committed
89
90

        # apply cosine transform and flip them
Thorsten Kurth's avatar
Thorsten Kurth committed
91
        tq = torch.flip(torch.arccos(cost), dims=(0,))
Boris Bonev's avatar
Boris Bonev committed
92

93
        # determine the dimensions
Boris Bonev's avatar
Boris Bonev committed
94
95
96
        self.mmax = mmax or self.nlon // 2 + 1

        # combine quadrature weights with the legendre weights
97
        pct = _precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
Thorsten Kurth's avatar
Thorsten Kurth committed
98
        weights = torch.einsum("mlk,k->mlk", pct, weights).contiguous()
Boris Bonev's avatar
Boris Bonev committed
99
100

        # remember quadrature weights
101
        self.register_buffer("weights", weights, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
102
103

    def extra_repr(self):
104
        r"""
Boris Bonev's avatar
Boris Bonev committed
105
106
        Pretty print module
        """
107
        return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}"
Boris Bonev's avatar
Boris Bonev committed
108
109
110

    def forward(self, x: torch.Tensor):

111
112
113
        if x.dim() < 2:
            raise ValueError(f"Expected tensor with at least 2 dimensions but got {x.dim()} instead")

114
115
        assert x.shape[-2] == self.nlat
        assert x.shape[-1] == self.nlon
Boris Bonev's avatar
Boris Bonev committed
116
117
118

        # apply real fft in the longitudinal direction
        x = 2.0 * torch.pi * torch.fft.rfft(x, dim=-1, norm="forward")
119

Boris Bonev's avatar
Boris Bonev committed
120
121
        # do the Legendre-Gauss quadrature
        x = torch.view_as_real(x)
122

Boris Bonev's avatar
Boris Bonev committed
123
124
        # distributed contraction: fork
        out_shape = list(x.size())
Boris Bonev's avatar
Boris Bonev committed
125
        out_shape[-3] = self.lmax
Boris Bonev's avatar
Boris Bonev committed
126
127
        out_shape[-2] = self.mmax
        xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device)
128

Boris Bonev's avatar
Boris Bonev committed
129
        # contraction
130
131
        xout[..., 0] = torch.einsum("...km,mlk->...lm", x[..., : self.mmax, 0], self.weights.to(x.dtype))
        xout[..., 1] = torch.einsum("...km,mlk->...lm", x[..., : self.mmax, 1], self.weights.to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
132
        x = torch.view_as_complex(xout)
133

Boris Bonev's avatar
Boris Bonev committed
134
135
        return x

136

Boris Bonev's avatar
Boris Bonev committed
137
class InverseRealSHT(nn.Module):
138
    r"""
Boris Bonev's avatar
Boris Bonev committed
139
140
141
142
143
144
145
146
147
    Defines a module for computing the inverse (real-valued) SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
    nlat, nlon: Output dimensions
    lmax, mmax: Input dimensions (spherical coefficients). For convenience, these are inferred from the output dimensions

    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """

148
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
Boris Bonev's avatar
Boris Bonev committed
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # compute quadrature points
        if self.grid == "legendre-gauss":
            cost, _ = legendre_gauss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
            cost, _ = lobatto_weights(nlat, -1, 1)
164
            self.lmax = lmax or self.nlat - 1
Boris Bonev's avatar
Boris Bonev committed
165
166
167
168
        elif self.grid == "equiangular":
            cost, _ = clenshaw_curtiss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        else:
169
            raise (ValueError("Unknown quadrature mode"))
Boris Bonev's avatar
Boris Bonev committed
170
171

        # apply cosine transform and flip them
Thorsten Kurth's avatar
Thorsten Kurth committed
172
        t = torch.flip(torch.arccos(cost), dims=(0,))
Boris Bonev's avatar
Boris Bonev committed
173

174
        # determine the dimensions
Boris Bonev's avatar
Boris Bonev committed
175
176
        self.mmax = mmax or self.nlon // 2 + 1

177
        pct = _precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
Boris Bonev's avatar
Boris Bonev committed
178

Boris Bonev's avatar
Boris Bonev committed
179
        # register buffer
180
        self.register_buffer("pct", pct, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
181
182

    def extra_repr(self):
183
        r"""
Boris Bonev's avatar
Boris Bonev committed
184
185
        Pretty print module
        """
186
        return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}"
Boris Bonev's avatar
Boris Bonev committed
187
188
189

    def forward(self, x: torch.Tensor):

190
191
192
        if len(x.shape) < 2:
            raise ValueError(f"Expected tensor with at least 2 dimensions but got {len(x.shape)} instead")

193
194
        assert x.shape[-2] == self.lmax
        assert x.shape[-1] == self.mmax
195

Boris Bonev's avatar
Boris Bonev committed
196
197
        # Evaluate associated Legendre functions on the output nodes
        x = torch.view_as_real(x)
198
199
200
201

        rl = torch.einsum("...lm, mlk->...km", x[..., 0], self.pct.to(x.dtype))
        im = torch.einsum("...lm, mlk->...km", x[..., 1], self.pct.to(x.dtype))
        xs = torch.stack((rl, im), -1)
Boris Bonev's avatar
Boris Bonev committed
202
203
204
205
206
207
208
209
210

        # apply the inverse (real) FFT
        x = torch.view_as_complex(xs)
        x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")

        return x


class RealVectorSHT(nn.Module):
211
    r"""
Boris Bonev's avatar
Boris Bonev committed
212
213
214
215
216
217
218
219
    Defines a module for computing the forward (real) vector SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
    The SHT is applied to the last three dimensions of the input.

    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """

220
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
221
        r"""
Boris Bonev's avatar
Boris Bonev committed
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
        Initializes the vector SHT Layer, precomputing the necessary quadrature weights

        Parameters:
        nlat: input grid resolution in the latitudinal direction
        nlon: input grid resolution in the longitudinal direction
        grid: type of grid the data lives on
        """

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # compute quadrature points
        if self.grid == "legendre-gauss":
Thorsten Kurth's avatar
Thorsten Kurth committed
240
            cost, weights = legendre_gauss_weights(nlat, -1, 1)
Boris Bonev's avatar
Boris Bonev committed
241
242
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
Thorsten Kurth's avatar
Thorsten Kurth committed
243
            cost, weights = lobatto_weights(nlat, -1, 1)
244
            self.lmax = lmax or self.nlat - 1
Boris Bonev's avatar
Boris Bonev committed
245
        elif self.grid == "equiangular":
Thorsten Kurth's avatar
Thorsten Kurth committed
246
            cost, weights = clenshaw_curtiss_weights(nlat, -1, 1)
Boris Bonev's avatar
Boris Bonev committed
247
248
            self.lmax = lmax or self.nlat
        else:
249
            raise (ValueError("Unknown quadrature mode"))
Boris Bonev's avatar
Boris Bonev committed
250
251

        # apply cosine transform and flip them
Thorsten Kurth's avatar
Thorsten Kurth committed
252
        tq = torch.flip(torch.arccos(cost), dims=(0,))
Boris Bonev's avatar
Boris Bonev committed
253

254
        # determine the dimensions
Boris Bonev's avatar
Boris Bonev committed
255
        self.mmax = mmax or self.nlon // 2 + 1
256
        dpct = _precompute_dlegpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
257

Boris Bonev's avatar
Boris Bonev committed
258
259
        # combine integration weights, normalization factor in to one:
        l = torch.arange(0, self.lmax)
260
261
        norm_factor = 1.0 / l / (l + 1)
        norm_factor[0] = 1.0
Thorsten Kurth's avatar
Thorsten Kurth committed
262
        weights = torch.einsum("dmlk,k,l->dmlk", dpct, weights, norm_factor).contiguous()
Boris Bonev's avatar
Boris Bonev committed
263
264
265
266
        # since the second component is imaginary, we need to take complex conjugation into account
        weights[1] = -1 * weights[1]

        # remember quadrature weights
267
        self.register_buffer("weights", weights, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
268
269

    def extra_repr(self):
270
        r"""
Boris Bonev's avatar
Boris Bonev committed
271
272
        Pretty print module
        """
273
        return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}"
Boris Bonev's avatar
Boris Bonev committed
274
275
276

    def forward(self, x: torch.Tensor):

277
278
279
        if x.dim() < 3:
            raise ValueError(f"Expected tensor with at least 3 dimensions but got {x.dim()} instead")

280
281
        assert x.shape[-2] == self.nlat
        assert x.shape[-1] == self.nlon
Boris Bonev's avatar
Boris Bonev committed
282
283
284

        # apply real fft in the longitudinal direction
        x = 2.0 * torch.pi * torch.fft.rfft(x, dim=-1, norm="forward")
285

Boris Bonev's avatar
Boris Bonev committed
286
287
        # do the Legendre-Gauss quadrature
        x = torch.view_as_real(x)
288

Boris Bonev's avatar
Boris Bonev committed
289
290
        # distributed contraction: fork
        out_shape = list(x.size())
Boris Bonev's avatar
Boris Bonev committed
291
        out_shape[-3] = self.lmax
Boris Bonev's avatar
Boris Bonev committed
292
293
294
295
296
        out_shape[-2] = self.mmax
        xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device)

        # contraction - spheroidal component
        # real component
297
298
299
        xout[..., 0, :, :, 0] = torch.einsum("...km,mlk->...lm", x[..., 0, :, : self.mmax, 0], self.weights[0].to(x.dtype)) - torch.einsum(
            "...km,mlk->...lm", x[..., 1, :, : self.mmax, 1], self.weights[1].to(x.dtype)
        )
Boris Bonev's avatar
Boris Bonev committed
300
301

        # iamg component
302
303
304
        xout[..., 0, :, :, 1] = torch.einsum("...km,mlk->...lm", x[..., 0, :, : self.mmax, 1], self.weights[0].to(x.dtype)) + torch.einsum(
            "...km,mlk->...lm", x[..., 1, :, : self.mmax, 0], self.weights[1].to(x.dtype)
        )
Boris Bonev's avatar
Boris Bonev committed
305
306
307

        # contraction - toroidal component
        # real component
308
309
310
        xout[..., 1, :, :, 0] = -torch.einsum("...km,mlk->...lm", x[..., 0, :, : self.mmax, 1], self.weights[1].to(x.dtype)) - torch.einsum(
            "...km,mlk->...lm", x[..., 1, :, : self.mmax, 0], self.weights[0].to(x.dtype)
        )
Boris Bonev's avatar
Boris Bonev committed
311
        # imag component
312
313
314
        xout[..., 1, :, :, 1] = torch.einsum("...km,mlk->...lm", x[..., 0, :, : self.mmax, 0], self.weights[1].to(x.dtype)) - torch.einsum(
            "...km,mlk->...lm", x[..., 1, :, : self.mmax, 1], self.weights[0].to(x.dtype)
        )
Boris Bonev's avatar
Boris Bonev committed
315
316
317
318
319

        return torch.view_as_complex(xout)


class InverseRealVectorSHT(nn.Module):
320
    r"""
Boris Bonev's avatar
Boris Bonev committed
321
322
    Defines a module for computing the inverse (real-valued) vector SHT.
    Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
323

Boris Bonev's avatar
Boris Bonev committed
324
325
326
    [1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
    [2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
    """
327

328
    def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
Boris Bonev's avatar
Boris Bonev committed
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343

        super().__init__()

        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid
        self.norm = norm
        self.csphase = csphase

        # compute quadrature points
        if self.grid == "legendre-gauss":
            cost, _ = legendre_gauss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        elif self.grid == "lobatto":
            cost, _ = lobatto_weights(nlat, -1, 1)
344
            self.lmax = lmax or self.nlat - 1
Boris Bonev's avatar
Boris Bonev committed
345
346
347
348
        elif self.grid == "equiangular":
            cost, _ = clenshaw_curtiss_weights(nlat, -1, 1)
            self.lmax = lmax or self.nlat
        else:
349
            raise (ValueError("Unknown quadrature mode"))
Boris Bonev's avatar
Boris Bonev committed
350
351

        # apply cosine transform and flip them
Thorsten Kurth's avatar
Thorsten Kurth committed
352
        t = torch.flip(torch.arccos(cost), dims=(0,))
Boris Bonev's avatar
Boris Bonev committed
353

354
        # determine the dimensions
Boris Bonev's avatar
Boris Bonev committed
355
356
        self.mmax = mmax or self.nlon // 2 + 1

357
        dpct = _precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
Boris Bonev's avatar
Boris Bonev committed
358

Boris Bonev's avatar
Boris Bonev committed
359
        # register weights
360
        self.register_buffer("dpct", dpct, persistent=False)
Boris Bonev's avatar
Boris Bonev committed
361
362

    def extra_repr(self):
363
        r"""
Boris Bonev's avatar
Boris Bonev committed
364
365
        Pretty print module
        """
366
        return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}"
Boris Bonev's avatar
Boris Bonev committed
367
368
369

    def forward(self, x: torch.Tensor):

370
371
372
        if x.dim() < 3:
            raise ValueError(f"Expected tensor with at least 3 dimensions but got {x.dim()} instead")

373
374
        assert x.shape[-2] == self.lmax
        assert x.shape[-1] == self.mmax
375

Boris Bonev's avatar
Boris Bonev committed
376
377
378
379
380
        # Evaluate associated Legendre functions on the output nodes
        x = torch.view_as_real(x)

        # contraction - spheroidal component
        # real component
381
        srl = torch.einsum("...lm,mlk->...km", x[..., 0, :, :, 0], self.dpct[0].to(x.dtype)) - torch.einsum("...lm,mlk->...km", x[..., 1, :, :, 1], self.dpct[1].to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
382
        # iamg component
383
        sim = torch.einsum("...lm,mlk->...km", x[..., 0, :, :, 1], self.dpct[0].to(x.dtype)) + torch.einsum("...lm,mlk->...km", x[..., 1, :, :, 0], self.dpct[1].to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
384
385
386

        # contraction - toroidal component
        # real component
387
        trl = -torch.einsum("...lm,mlk->...km", x[..., 0, :, :, 1], self.dpct[1].to(x.dtype)) - torch.einsum("...lm,mlk->...km", x[..., 1, :, :, 0], self.dpct[0].to(x.dtype))
Boris Bonev's avatar
Boris Bonev committed
388
        # imag component
389
        tim = torch.einsum("...lm,mlk->...km", x[..., 0, :, :, 0], self.dpct[1].to(x.dtype)) - torch.einsum("...lm,mlk->...km", x[..., 1, :, :, 1], self.dpct[0].to(x.dtype))
390

Boris Bonev's avatar
Boris Bonev committed
391
392
393
394
395
396
397
398
399
400
        # reassemble
        s = torch.stack((srl, sim), -1)
        t = torch.stack((trl, tim), -1)
        xs = torch.stack((s, t), -4)

        # apply the inverse (real) FFT
        x = torch.view_as_complex(xs)
        x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")

        return x