shallow_water_equations.py 15.7 KB
Newer Older
Boris Bonev's avatar
Boris Bonev committed
1
2
# coding=utf-8

Boris Bonev's avatar
Boris Bonev committed
3
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
Boris Bonev's avatar
Boris Bonev committed
4
# SPDX-License-Identifier: BSD-3-Clause
5
#
Boris Bonev's avatar
Boris Bonev committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#


import torch
import torch.nn as nn
Boris Bonev's avatar
Boris Bonev committed
35
import torch_harmonics as th
36
from torch_harmonics.quadrature import _precompute_longitudes
Boris Bonev's avatar
Boris Bonev committed
37

Thorsten Kurth's avatar
Thorsten Kurth committed
38
import math
Boris Bonev's avatar
Boris Bonev committed
39
40
41
42
43
import numpy as np


class ShallowWaterSolver(nn.Module):
    """
apaaris's avatar
apaaris committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
    Shallow Water Equations (SWE) solver class for spherical geometry.
    
    Interface inspired by pyspharm and SHTns. Solves the shallow water equations
    on a rotating sphere using spectral methods.
    
    Parameters
    -----------
    nlat : int
        Number of latitude points
    nlon : int
        Number of longitude points
    dt : float
        Time step size
    lmax : int, optional
        Maximum l mode for spherical harmonics, by default None
    mmax : int, optional
        Maximum m mode for spherical harmonics, by default None
    grid : str, optional
        Grid type ("equiangular", "legendre-gauss", "lobatto"), by default "equiangular"
    radius : float, optional
        Radius of the sphere in meters, by default 6.37122E6 (Earth radius)
    omega : float, optional
        Angular velocity of rotation in rad/s, by default 7.292E-5 (Earth)
    gravity : float, optional
        Gravitational acceleration in m/s², by default 9.80616
    havg : float, optional
        Average height in meters, by default 10.e3
    hamp : float, optional
        Height amplitude in meters, by default 120.
Boris Bonev's avatar
Boris Bonev committed
73
74
    """

75
    def __init__(self, nlat, nlon, dt, lmax=None, mmax=None, grid="equiangular", radius=6.37122E6, \
Boris Bonev's avatar
Boris Bonev committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
                 omega=7.292E-5, gravity=9.80616, havg=10.e3, hamp=120.):
        super().__init__()

        # time stepping param
        self.dt = dt

        # grid parameters
        self.nlat = nlat
        self.nlon = nlon
        self.grid = grid

        # physical sonstants
        self.register_buffer('radius', torch.as_tensor(radius, dtype=torch.float64))
        self.register_buffer('omega', torch.as_tensor(omega, dtype=torch.float64))
        self.register_buffer('gravity', torch.as_tensor(gravity, dtype=torch.float64))
        self.register_buffer('havg', torch.as_tensor(havg, dtype=torch.float64))
        self.register_buffer('hamp', torch.as_tensor(hamp, dtype=torch.float64))

        # SHT
Boris Bonev's avatar
Boris Bonev committed
95
96
97
98
        self.sht = th.RealSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False)
        self.isht = th.InverseRealSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False)
        self.vsht = th.RealVectorSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False)
        self.ivsht = th.InverseRealVectorSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False)
Boris Bonev's avatar
Boris Bonev committed
99
100
101
102
103
104

        self.lmax = lmax or self.sht.lmax
        self.mmax = lmax or self.sht.mmax

        # compute gridpoints
        if self.grid == "legendre-gauss":
Boris Bonev's avatar
Boris Bonev committed
105
            cost, quad_weights = th.quadrature.legendre_gauss_weights(self.nlat, -1, 1)
Boris Bonev's avatar
Boris Bonev committed
106
        elif self.grid == "lobatto":
Boris Bonev's avatar
Boris Bonev committed
107
            cost, quad_weights = th.quadrature.lobatto_weights(self.nlat, -1, 1)
Boris Bonev's avatar
Boris Bonev committed
108
        elif self.grid == "equiangular":
Boris Bonev's avatar
Boris Bonev committed
109
            cost, quad_weights = th.quadrature.clenshaw_curtiss_weights(self.nlat, -1, 1)
Boris Bonev's avatar
Boris Bonev committed
110

Thorsten Kurth's avatar
Thorsten Kurth committed
111
        quad_weights = quad_weights.reshape(-1, 1)
Boris Bonev's avatar
Boris Bonev committed
112
113

        # apply cosine transform and flip them
Thorsten Kurth's avatar
Thorsten Kurth committed
114
115
        lats = -torch.arcsin(cost)
        lons = _precompute_longitudes(self.nlon)
Boris Bonev's avatar
Boris Bonev committed
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144

        self.lmax = self.sht.lmax
        self.mmax = self.sht.mmax

        # compute the laplace and inverse laplace operators
        l = torch.arange(0, self.lmax).reshape(self.lmax, 1).double()
        l = l.expand(self.lmax, self.mmax)
        # the laplace operator acting on the coefficients is given by - l (l + 1)
        lap = - l * (l + 1) / self.radius**2
        invlap = - self.radius**2 / l / (l + 1)
        invlap[0] = 0.

        # compute coriolis force
        coriolis = 2 * self.omega * torch.sin(lats).reshape(self.nlat, 1)

        # hyperdiffusion
        hyperdiff = torch.exp(torch.asarray((-self.dt / 2 / 3600.)*(lap / lap[-1, 0])**4))

        # register all
        self.register_buffer('lats', lats)
        self.register_buffer('lons', lons)
        self.register_buffer('l', l)
        self.register_buffer('lap', lap)
        self.register_buffer('invlap', invlap)
        self.register_buffer('coriolis', coriolis)
        self.register_buffer('hyperdiff', hyperdiff)
        self.register_buffer('quad_weights', quad_weights)

    def grid2spec(self, ugrid):
Andrea Paris's avatar
Andrea Paris committed
145
        """Convert spatial data to spectral coefficients."""
Boris Bonev's avatar
Boris Bonev committed
146
147
148
        return self.sht(ugrid)

    def spec2grid(self, uspec):
Andrea Paris's avatar
Andrea Paris committed
149
        """Convert spectral coefficients to spatial data."""
Boris Bonev's avatar
Boris Bonev committed
150
151
152
        return self.isht(uspec)

    def vrtdivspec(self, ugrid):
Andrea Paris's avatar
Andrea Paris committed
153
        """Compute vorticity and divergence from velocity field."""
Boris Bonev's avatar
Boris Bonev committed
154
155
156
157
        vrtdivspec = self.lap * self.radius * self.vsht(ugrid)
        return vrtdivspec

    def getuv(self, vrtdivspec):
Andrea Paris's avatar
Andrea Paris committed
158
        """Compute wind vector from spectral coefficients of vorticity and divergence."""
Boris Bonev's avatar
Boris Bonev committed
159
160
161
        return self.ivsht( self.invlap * vrtdivspec / self.radius)

    def gethuv(self, uspec):
Andrea Paris's avatar
Andrea Paris committed
162
        """Compute height and wind vector from spectral coefficients."""
Boris Bonev's avatar
Boris Bonev committed
163
164
165
166
167
        hgrid = self.spec2grid(uspec[:1])
        uvgrid = self.getuv(uspec[1:])
        return torch.cat((hgrid, uvgrid), dim=-3)

    def potential_vorticity(self, uspec):
Andrea Paris's avatar
Andrea Paris committed
168
        """Compute potential vorticity from spectral coefficients."""
Boris Bonev's avatar
Boris Bonev committed
169
170
171
172
173
        ugrid = self.spec2grid(uspec)
        pvrt = (0.5 * self.havg * self.gravity / self.omega) * (ugrid[1] + self.coriolis) / ugrid[0]
        return pvrt

    def dimensionless(self, uspec):
Andrea Paris's avatar
Andrea Paris committed
174
        """Remove dimensions from variables for dimensionless analysis."""
Boris Bonev's avatar
Boris Bonev committed
175
176
177
178
179
180
        uspec[0] = (uspec[0] - self.havg * self.gravity) / self.hamp / self.gravity
        # vorticity is measured in 1/s so we normalize using sqrt(g h) / r
        uspec[1:] = uspec[1:] * self.radius / torch.sqrt(self.gravity * self.havg)
        return uspec

    def dudtspec(self, uspec):
Andrea Paris's avatar
Andrea Paris committed
181
        """Compute time derivatives from solution represented in spectral coefficients."""
Boris Bonev's avatar
Boris Bonev committed
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
        dudtspec = torch.zeros_like(uspec)

        # compute the derivatives - this should be incorporated into the solver:
        ugrid = self.spec2grid(uspec)
        uvgrid = self.getuv(uspec[1:])

        # phi = ugrid[0]
        # vrtdiv = ugrid[1:]

        tmp = uvgrid * (ugrid[1] + self.coriolis)
        tmpspec = self.vrtdivspec(tmp)
        dudtspec[2] = tmpspec[0]
        dudtspec[1] = -1 * tmpspec[1]

        tmp = uvgrid * ugrid[0]
        tmp = self.vrtdivspec(tmp)
        dudtspec[0] = -1 * tmp[1]

        tmpspec = self.grid2spec(ugrid[0] + 0.5 * (uvgrid[0]**2 + uvgrid[1]**2))
        dudtspec[2] = dudtspec[2] - self.lap * tmpspec

        return dudtspec

    def galewsky_initial_condition(self):
Andrea Paris's avatar
Andrea Paris committed
206
        """Initialize non-linear barotropically unstable shallow water test case."""
Boris Bonev's avatar
Boris Bonev committed
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
        device = self.lap.device

        umax = 80.
        phi0 = torch.asarray(torch.pi / 7., device=device)
        phi1 = torch.asarray(0.5 * torch.pi - phi0, device=device)
        phi2 = 0.25 * torch.pi
        en = torch.exp(torch.asarray(-4.0 / (phi1 - phi0)**2, device=device))
        alpha = 1. / 3.
        beta = 1. / 15.

        lats, lons = torch.meshgrid(self.lats, self.lons)

        u1 = (umax/en)*torch.exp(1./((lats-phi0)*(lats-phi1)))
        ugrid = torch.where(torch.logical_and(lats < phi1, lats > phi0), u1, torch.zeros(self.nlat, self.nlon, device=device))
        vgrid = torch.zeros((self.nlat, self.nlon), device=device)
        hbump = self.hamp * torch.cos(lats) * torch.exp(-((lons-torch.pi)/alpha)**2) * torch.exp(-(phi2-lats)**2/beta)

        # intial velocity field
        ugrid = torch.stack((ugrid, vgrid))
        # intial vorticity/divergence field
        vrtdivspec = self.vrtdivspec(ugrid)
        vrtdivgrid = self.spec2grid(vrtdivspec)

        # solve balance eqn to get initial zonal geopotential with a localized bump (not balanced).
        tmp = ugrid * (vrtdivgrid + self.coriolis)
        tmpspec = self.vrtdivspec(tmp)
        tmpspec[1] = self.grid2spec(0.5 * torch.sum(ugrid**2, dim=0))
        phispec = self.invlap*tmpspec[0] - tmpspec[1] + self.grid2spec(self.gravity*(self.havg + hbump))

        # assemble solution
        uspec = torch.zeros(3, self.lmax, self.mmax, dtype=vrtdivspec.dtype, device=device)
        uspec[0] = phispec
        uspec[1:] = vrtdivspec

        return torch.tril(uspec)

    def random_initial_condition(self, mach=0.1) -> torch.Tensor:
Andrea Paris's avatar
Andrea Paris committed
244
        """Generate random initial condition on the sphere."""
Boris Bonev's avatar
Boris Bonev committed
245
246
247
248
        device = self.lap.device
        ctype = torch.complex128 if self.lap.dtype == torch.float64 else torch.complex64

        # mach number relative to wave speed
249
        llimit = mlimit = 120
Boris Bonev's avatar
Boris Bonev committed
250
251
252
253
254
255
256
257
258

        # hgrid = self.havg + hamp * torch.randn(self.nlat, self.nlon, device=device, dtype=dtype)
        # ugrid = uamp * torch.randn(self.nlat, self.nlon, device=device, dtype=dtype)
        # vgrid = vamp * torch.randn(self.nlat, self.nlon, device=device, dtype=dtype)
        # ugrid = torch.stack((ugrid, vgrid))

        # initial geopotential
        uspec = torch.zeros(3, self.lmax, self.mmax, dtype=ctype, device=self.lap.device)
        uspec[:, :llimit, :mlimit] = torch.sqrt(torch.tensor(4 * torch.pi / llimit / (llimit+1), device=device, dtype=ctype)) * torch.randn_like(uspec[:, :llimit, :mlimit])
259

Boris Bonev's avatar
Boris Bonev committed
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
        uspec[0] = self.gravity * self.hamp * uspec[0]
        uspec[0, 0, 0] += torch.sqrt(torch.tensor(4 * torch.pi, device=device, dtype=ctype)) * self.havg * self.gravity
        uspec[1:] = mach * uspec[1:] * torch.sqrt(self.gravity * self.havg) / self.radius
        # uspec[1:] = self.vrtdivspec(self.spec2grid(uspec[1:]) * torch.cos(self.lats.reshape(-1, 1)))

        # # intial velocity field
        # ugrid = uamp * self.spec2grid(uspec[1])
        # vgrid = vamp * self.spec2grid(uspec[2])
        # ugrid = torch.stack((ugrid, vgrid))



        # # intial vorticity/divergence field
        # vrtdivspec = self.vrtdivspec(ugrid)
        # vrtdivgrid = self.spec2grid(vrtdivspec)

        # # solve balance eqn to get initial zonal geopotential with a localized bump (not balanced).
        # tmp = ugrid * (vrtdivgrid + self.coriolis)
        # tmpspec = self.vrtdivspec(tmp)
        # tmpspec[1] = self.grid2spec(0.5 * torch.sum(ugrid**2, dim=0))
        # phispec = self.invlap*tmpspec[0] - tmpspec[1] + self.grid2spec(self.gravity * hgrid)

        # # assemble solution
        # uspec = torch.zeros(3, self.lmax, self.mmax, dtype=phispec.dtype, device=device)
        # uspec[0] = phispec
        # uspec[1:] = vrtdivspec
286

Boris Bonev's avatar
Boris Bonev committed
287
288
289
        return torch.tril(uspec)

    def timestep(self, uspec: torch.Tensor, nsteps: int) -> torch.Tensor:
Andrea Paris's avatar
Andrea Paris committed
290
        """Integrate the solution using Adams-Bashforth / forward Euler for nsteps steps."""
Boris Bonev's avatar
Boris Bonev committed
291
292
293
294
295
296
297
298
299
        dudtspec = torch.zeros(3, 3, self.lmax, self.mmax, dtype=uspec.dtype, device=uspec.device)

        # pointers to indicate the most current result
        inew = 0
        inow = 1
        iold = 2

        for iter in range(nsteps):
            dudtspec[inew] = self.dudtspec(uspec)
300

Boris Bonev's avatar
Boris Bonev committed
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
            # update vort,div,phiv with third-order adams-bashforth.
            # forward euler, then 2nd-order adams-bashforth time steps to start.
            if iter == 0:
                dudtspec[inow] = dudtspec[inew]
                dudtspec[iold] = dudtspec[inew]
            elif iter == 1:
                dudtspec[iold] = dudtspec[inew]

            uspec = uspec + self.dt*( (23./12.) * dudtspec[inew] - (16./12.) * dudtspec[inow] + (5./12.) * dudtspec[iold] )

            # implicit hyperdiffusion for vort and div.
            uspec[1:] = self.hyperdiff * uspec[1:]

            # cycle through the indices
            inew = (inew - 1) % 3
            inow = (inow - 1) % 3
            iold = (iold - 1) % 3
318

Boris Bonev's avatar
Boris Bonev committed
319
320
321
        return uspec

    def integrate_grid(self, ugrid, dimensionless=False, polar_opt=0):
Andrea Paris's avatar
Andrea Paris committed
322
        """Integrate the solution on the grid."""
323
        dlon = 2 * torch.pi / self.nlon
Boris Bonev's avatar
Boris Bonev committed
324
325
326
327
328
329
330
331
        radius = 1 if dimensionless else self.radius
        if polar_opt > 0:
            out = torch.sum(ugrid[..., polar_opt:-polar_opt, :] * self.quad_weights[polar_opt:-polar_opt] * dlon * radius**2, dim=(-2, -1))
        else:
            out = torch.sum(ugrid * self.quad_weights * dlon * radius**2, dim=(-2, -1))
        return out


Boris Bonev's avatar
Boris Bonev committed
332
    def plot_griddata(self, data, fig, cmap='twilight_shifted', vmax=None, vmin=None, projection='3d', title=None, antialiased=False):
Andrea Paris's avatar
Andrea Paris committed
333
        """Plotting routine for data on the grid. Requires cartopy for 3d plots."""
Boris Bonev's avatar
Boris Bonev committed
334
        import matplotlib.pyplot as plt
Boris Bonev's avatar
Boris Bonev committed
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359

        lons = self.lons.squeeze() - torch.pi
        lats = self.lats.squeeze()

        if data.is_cuda:
            data = data.cpu()
            lons = lons.cpu()
            lats = lats.cpu()

        Lons, Lats = np.meshgrid(lons, lats)

        if projection == 'mollweide':

            #ax = plt.gca(projection=projection)
            ax = fig.add_subplot(projection=projection)
            im = ax.pcolormesh(Lons, Lats, data, cmap=cmap, vmax=vmax, vmin=vmin)
            # ax.set_title("Elevation map of mars")
            ax.grid(True)
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            plt.colorbar(im, orientation='horizontal')
            plt.title(title)

        elif projection == '3d':

Boris Bonev's avatar
Boris Bonev committed
360
            import cartopy.crs as ccrs
Boris Bonev's avatar
Boris Bonev committed
361

362
            proj = ccrs.Orthographic(central_longitude=0.0, central_latitude=25.0)
Boris Bonev's avatar
Boris Bonev committed
363
364
365

            #ax = plt.gca(projection=proj, frameon=True)
            ax = fig.add_subplot(projection=proj)
Thorsten Kurth's avatar
Thorsten Kurth committed
366
367
            Lons = Lons*180/math.pi
            Lats = Lats*180/math.pi
Boris Bonev's avatar
Boris Bonev committed
368
369
370
371
372

            # contour data over the map.
            im = ax.pcolormesh(Lons, Lats, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=antialiased, vmax=vmax, vmin=vmin)
            plt.title(title, y=1.05)

373
374
375
376
377
378
379
380
        elif projection == 'robinson':

            import cartopy.crs as ccrs

            proj = ccrs.Robinson(central_longitude=0.0)

            #ax = plt.gca(projection=proj, frameon=True)
            ax = fig.add_subplot(projection=proj)
Thorsten Kurth's avatar
Thorsten Kurth committed
381
382
            Lons = Lons*180/math.pi
            Lats = Lats*180/math.pi
383
384
385
386
387

            # contour data over the map.
            im = ax.pcolormesh(Lons, Lats, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=antialiased, vmax=vmax, vmin=vmin)
            plt.title(title, y=1.05)

Boris Bonev's avatar
Boris Bonev committed
388
389
390
391
392
393
394
        else:
            raise NotImplementedError

        return im

    def plot_specdata(self, data, fig, **kwargs):
        return self.plot_griddata(self.isht(data), fig, **kwargs)