# coding=utf-8 # SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # # 1. Redistributions of source code must retain the above copyright notice, this # list of conditions and the following disclaimer. # # 2. Redistributions in binary form must reproduce the above copyright notice, # this list of conditions and the following disclaimer in the documentation # and/or other materials provided with the distribution. # # 3. Neither the name of the copyright holder nor the names of its # contributors may be used to endorse or promote products derived from # this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # import torch import torch.nn as nn import torch_harmonics as th from torch_harmonics.quadrature import _precompute_longitudes import math import numpy as np class SphereSolver(nn.Module): """ Solver class on the sphere. Can solve the following PDEs: - Allen-Cahn equation - Ginzburg-Landau equation Parameters ----------- nlat : int Number of latitude points nlon : int Number of longitude points dt : float Time step size lmax : int, optional Maximum l mode for spherical harmonics, by default None mmax : int, optional Maximum m mode for spherical harmonics, by default None grid : str, optional Grid type ("equiangular", "legendre-gauss", "lobatto"), by default "equiangular" radius : float, optional Radius of the sphere, by default 1.0 coeff : float, optional Coefficient for the PDE, by default 0.001 """ def __init__(self, nlat, nlon, dt, lmax=None, mmax=None, grid="equiangular", radius=1.0, coeff=0.001): super().__init__() # time stepping param self.dt = dt # grid parameters self.nlat = nlat self.nlon = nlon self.grid = grid # physical sonstants self.register_buffer('radius', torch.as_tensor(radius, dtype=torch.float64)) self.register_buffer('coeff', torch.as_tensor(coeff, dtype=torch.float64)) # SHT self.sht = th.RealSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False) self.isht = th.InverseRealSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False) self.lmax = lmax or self.sht.lmax self.mmax = lmax or self.sht.mmax # compute gridpoints if self.grid == "legendre-gauss": cost, _ = th.quadrature.legendre_gauss_weights(self.nlat, -1, 1) elif self.grid == "lobatto": cost, _ = th.quadrature.lobatto_weights(self.nlat, -1, 1) elif self.grid == "equiangular": cost, _ = th.quadrature.clenshaw_curtiss_weights(self.nlat, -1, 1) # apply cosine transform and flip them lats = -torch.arcsin(cost) lons = _precompute_longitudes(self.nlon) self.lmax = self.sht.lmax self.mmax = self.sht.mmax l = torch.arange(0, self.lmax).reshape(self.lmax, 1).cdouble() l = l.expand(self.lmax, self.mmax) # the laplace operator acting on the coefficients is given by l (l + 1) lap = - l * (l + 1) / self.radius**2 invlap = - self.radius**2 / l / (l + 1) invlap[0] = 0. # register all self.register_buffer('lats', lats) self.register_buffer('lons', lons) self.register_buffer('l', l) self.register_buffer('lap', lap) self.register_buffer('invlap', invlap) def grid2spec(self, u): return self.sht(u) def spec2grid(self, uspec): """Convert spectral coefficients to spatial data.""" return self.isht(uspec) def dudtspec(self, uspec, pde='allen-cahn'): """Compute the time derivative of spectral coefficients for different PDEs.""" if pde == 'allen-cahn': ugrid = self.spec2grid(uspec) u3spec = self.grid2spec(ugrid**3) dudtspec = self.coeff*self.lap*uspec + uspec - u3spec elif pde == 'ginzburg-landau': ugrid = self.spec2grid(uspec) u3spec = self.grid2spec(ugrid**3) dudtspec = uspec + (1. + 2.j)*self.coeff*self.lap*uspec - (1. + 2.j)*u3spec else: raise NotImplementedError(f"PDE type {pde} not implemented") return dudtspec def randspec(self): """Generate random spectral data on the sphere.""" rspec = torch.randn_like(self.lap) / 4 / torch.pi return rspec def plot_griddata(self, data, fig, cmap='twilight_shifted', vmax=None, vmin=None, projection='3d', title=None, antialiased=False): """ Plot data on the sphere grid. Requires cartopy for 3d plots. Parameters ----------- data : torch.Tensor Data to plot fig : matplotlib.figure.Figure Figure to plot on cmap : str, optional Colormap name, by default 'twilight_shifted' vmax : float, optional Maximum value for color scaling, by default None vmin : float, optional Minimum value for color scaling, by default None projection : str, optional Projection type ("mollweide", "3d"), by default "3d" title : str, optional Plot title, by default None antialiased : bool, optional Whether to use antialiasing, by default False Returns ------- matplotlib.collections.QuadMesh The plotted image object Raises ------ NotImplementedError If projection type is not supported """ import matplotlib.pyplot as plt lons = self.lons.squeeze() - torch.pi lats = self.lats.squeeze() if data.is_cuda: data = data.cpu() lons = lons.cpu() lats = lats.cpu() Lons, Lats = np.meshgrid(lons, lats) if projection == 'mollweide': #ax = plt.gca(projection=projection) ax = fig.add_subplot(projection=projection) im = ax.pcolormesh(Lons, Lats, data, cmap=cmap, vmax=vmax, vmin=vmin) # ax.set_title("Elevation map of mars") ax.grid(True) ax.set_xticklabels([]) ax.set_yticklabels([]) plt.colorbar(im, orientation='horizontal') plt.title(title) elif projection == '3d': import cartopy.crs as ccrs proj = ccrs.Orthographic(central_longitude=0.0, central_latitude=25.0) #ax = plt.gca(projection=proj, frameon=True) ax = fig.add_subplot(projection=proj) Lons = Lons*180/math.pi Lats = Lats*180/math.pi # contour data over the map. im = ax.pcolormesh(Lons, Lats, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=antialiased, vmax=vmax, vmin=vmin) plt.title(title, y=1.05) else: raise NotImplementedError(f"Projection {projection} not implemented") return im def plot_specdata(self, data, fig, **kwargs): """Plot spectral data by converting to spatial data first.""" return self.plot_griddata(self.isht(data), fig, **kwargs)