Commit 6f4fe730 authored by Boris Bonev's avatar Boris Bonev
Browse files

v0.4 commit

parents
The code was authored by the following people:
Boris Bonev - NVIDIA Corporation
Christian Hundt - NVIDIA Corporation
Thorsten Kurth - NVIDIA Corporation
# Changelog
## Versioning
### v0.4
* Computation of associated Legendre polynomials
* changed algorithm to compute the associated Legendre polynomials for improved stability
* Improved Readme
### v0.3
* Vector Spherical Harmonic Transforms
* projects vector-valued fields onto the vector Spherical Harmonics
* supports computation of div and curl on the sphere
* New quadrature rules
* Clenshaw-Curtis quadrature rule
* Fejér quadrature rule
* Legendre-Gauss-Lobatto quadrature
* New notebooks
* complete with differentiable Shallow Water Solver
* notebook on quadrature and interpolation
* Unit tests
* Refactor of the API
### v0.2
* Renaming from torch_sht to torch_harmonics
* Adding distributed SHT support
* New logo
### v0.1
* Single GPU forward and backward transform
* Minimal code example and notebook
<!-- ## Detailed logs
### 23-11-2022
* Initialized the library
* Added `getting_started.ipynb` example
* Added simple example to test the SHT
* Logo -->
# SPDX-FileCopyrightText: Copyright (c) 2022 The FourCastNet Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# build after cloning in directoy torch_harmonics via
# docker build . -t torch_harmonics
FROM nvcr.io/nvidia/pytorch:22.08-py3
COPY . /workspace/torch_harmonics
RUN pip install --use-feature=in-tree-build /workspace/torch_harmonics
SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
<p align="center">
<img src="./images/logo/logo.png" width="568">
</p>
<!-- # torch-harmonics: differentiable harmonic transforms -->
<!-- ## What is torch-harmonics? -->
`torch_harmonics` is a differentiable implementation of the Spherical Harmonic transform in PyTorch. It uses quadrature to compute the projection onto the associated Legendre polynomials and FFTs for the projection onto the harmonic basis. This algorithm tends to outperform others with better asymptotic scaling for most practical purposes.
<table border="0" cellspacing="0" cellpadding="0">
<tr>
<td><img src="./images/zonal_jet.gif" width="288"></td>
<td><img src="./images/ginzburg-landau.gif" width="288"></td>
<td><img src="./images/allen-cahn.gif" width="288"></td>
</tr>
<tr>
<td style="text-align:center; border-style : hidden!important;">Shallow Water Eqns.</td>
<td style="text-align:center; border-style : hidden!important;">Ginzburg-Landau Eqn.</td>
<td style="text-align:center; border-style : hidden!important;">Allen-Cahn Eqn.</td>
</tr>
</table>
<!-- <p align="left">
<img src="./images/zonal_jet.gif" width="288">
<img src="./images/allen-cahn.gif" width="288">
</p> -->
## Installation
Build in your environment using the Python package:
```
git clone git@github.com:NVIDIA/torch-harmonics.git
pip install ./torch_harmonics
```
Alternatively, use the Dockerfile to build your custom container after cloning:
```
git clone git@github.com:NVIDIA/torch-harmonics.git
cd torch_harmonics
docker build . -t torch_harmonics
docker run --gpus all -it --rm --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 torch_harmonics
```
## Contributors
- Boris Bonev (bbonev@nvidia.com)
- Christian Hundt (chundt@nvidia.com)
- Thorsten Kurth (tkurth@nvidia.com)
## Implementation
The implementation follows the paper "Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations", N. Schaeffer, G3: Geochemistry, Geophysics, Geosystems.
### Spherical harmonic transform
The truncated series expansion of a function $f$ defined on the surface of a sphere can be written as
$$
f(\theta, \lambda) = \sum_{m=-M}^{M} \exp(im\lambda) \sum_{n=|m|}^{M} F_n^m \bar{P}_n^m (\cos \theta),
$$
where $\theta$ is the colatitude, $\lambda$ the longitude, $\bar{P}_n^m$ the normalized, associated Legendre polynomials and $F_n^m$, the expansion coefficient associated to the mode $(m,n)$.
A direct spherical harmonic transform can be accomplished by a Fourier transform
$$
F^m(\theta) = \frac{1}{2 \pi} \int_{0}^{2\pi} f(\theta, \lambda) \exp(-im\lambda) \mathrm{d}\lambda
$$
in longitude and a Legendre transform
$$
F_n^m = \frac{1}{2} \int_{-1}^1 F^m(\theta) \bar{P}_n^m(\cos \theta) \mathrm{d} \cos \theta
$$
in latitude.
### Discrete Legendre transform
in order to apply the Legendre transfor, we shall use Gauss-Legendre points in the latitudinal direction. The integral
$$
F_n^m = \int_{0}^\pi F^m(\theta) \bar{P}_n^m(\cos \theta) \sin \theta \mathrm{d} \theta
$$
is approximated by the sum
$$
F_n^m = \sum_{j=1}^{N_\theta} F^m(\theta_j) \bar{P}_n^m(\cos \theta_j) w_j
$$
## Usage
### Getting started
The main functionality of `torch_harmonics` is provided in the form of `torch.nn.Modules` for composability. A minimum example is given by:
```python
import torch
import torch_harmonics as harmonics
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
nlat = 512
nlon = 2*nlat
batch_size = 32
signal = torch.randn(batch_size, nlat, nlon)
# transform data on an equiangular grid
sht = harmonics.RealSHT(nlat, nlon, grid="equiangular").to(device).float()
coeffs = sht(signal)
```
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The FourCastNet Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# ignore this (just for development without installation)
import sys
sys.path.append("..")
sys.path.append(".")
import torch
import torch_harmonics as harmonics
try:
from tqdm import tqdm
except:
tqdm = lambda x : x
# everything is awesome on GPUs
device = torch.device("cuda")
# create a batch with one sample and 21 channels
b, c, n_theta, n_lambda = 1, 21, 360, 720
# your layers to play with
forward_transform = harmonics.RealSHT(n_theta, n_lambda).to(device)
inverse_transform = harmonics.InverseRealSHT(n_theta, n_lambda).to(device)
forward_transform_equi = harmonics.RealSHT(n_theta, n_lambda, grid="equiangular").to(device)
inverse_transform_equi = harmonics.InverseRealSHT(n_theta, n_lambda, grid="equiangular").to(device)
signal_leggauss = inverse_transform(torch.randn(b, c, n_theta, n_theta+1, device=device, dtype=torch.complex128))
signal_equi = inverse_transform(torch.randn(b, c, n_theta, n_theta+1, device=device, dtype=torch.complex128))
# let's check the layers
for num_iters in [1, 8, 64, 512]:
base = signal_leggauss
for iteration in tqdm(range(num_iters)):
base = inverse_transform(forward_transform(base))
print("relative l2 error accumulation on the legendre-gauss grid: ",
torch.mean(torch.norm(base-signal_leggauss, p='fro', dim=(-1,-2)) / torch.norm(signal_leggauss, p='fro', dim=(-1,-2)) ).item(),
"after", num_iters, "iterations")
# let's check the equiangular layers
for num_iters in [1, 8, 64, 512]:
base = signal_equi
for iteration in tqdm(range(num_iters)):
base = inverse_transform_equi(forward_transform_equi(base))
print("relative l2 error accumulation with interpolation onto equiangular grid: ",
torch.mean(torch.norm(base-signal_equi, p='fro', dim=(-1,-2)) / torch.norm(signal_equi, p='fro', dim=(-1,-2)) ).item(),
"after", num_iters, "iterations")
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The FourCastNet Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# ignore this (just for development without installation)
import sys
import os
sys.path.append("..")
sys.path.append(".")
import torch
import torch.distributed as dist
import torch_harmonics as harmonics
try:
from tqdm import tqdm
except:
tqdm = lambda x : x
# set up distributed
world_size = int(os.getenv('WORLD_SIZE', 1))
world_rank = int(os.getenv('WORLD_RANK', 0))
port = int(os.getenv('MASTER_PORT', 0))
master_address = os.getenv('MASTER_ADDR', 'localhost')
dist.init_process_group(backend = 'nccl',
init_method = f"tcp://{master_address}:{port}",
rank = world_rank,
world_size = world_size)
local_rank = world_rank % torch.cuda.device_count()
mp_group = dist.new_group(ranks=list(range(world_size)))
my_rank = dist.get_rank(mp_group)
group_size = 1 if not dist.is_initialized() else dist.get_world_size(mp_group)
if my_rank == 0:
print(f"Running distributed test on {group_size} ranks.")
# init distributed SHT:
harmonics.distributed.init(mp_group)
# everything is awesome on GPUs
device = torch.device(f"cuda:{local_rank}")
# create a batch with one sample and 21 channels
b, c, n_theta, n_lambda = 1, 21, 360, 720
# your layers to play with
forward_transform = harmonics.RealSHT(n_theta, n_lambda).to(device)
inverse_transform = harmonics.InverseRealSHT(n_theta, n_lambda).to(device)
forward_transform_equi = harmonics.RealSHT(n_theta, n_lambda, grid="equiangular").to(device)
inverse_transform_equi = harmonics.InverseRealSHT(n_theta, n_lambda, grid="equiangular").to(device)
signal_leggauss = inverse_transform(torch.randn(b, c, n_theta // group_size, n_theta+1, device=device, dtype=torch.complex128))
signal_equi = inverse_transform(torch.randn(b, c, n_theta // group_size, n_theta+1, device=device, dtype=torch.complex128))
# let's check the layers
for num_iters in [1, 8, 64, 512]:
base = signal_leggauss
for iteration in tqdm(range(num_iters), disable=(my_rank!=0)):
base = inverse_transform(forward_transform(base))
# compute error:
numerator = torch.sum(torch.square(torch.abs(base-signal_leggauss)), dim=(-1,-2))
denominator = torch.sum(torch.square(torch.abs(signal_leggauss)), dim=(-1,-2))
if dist.is_initialized():
dist.all_reduce(numerator, group=mp_group)
dist.all_reduce(denominator, group=mp_group)
if my_rank == 0:
print("relative l2 error accumulation on the legendre-gauss grid: ",
torch.mean(torch.sqrt(numerator / denominator)).item(),
"after", num_iters, "iterations")
# let's check the equiangular layers
for num_iters in [1, 8, 64, 512]:
base = signal_equi
for iteration in tqdm(range(num_iters), disable=(my_rank!=0)):
base = inverse_transform_equi(forward_transform_equi(base))
# compute error
numerator = torch.sum(torch.square(torch.abs(base-signal_equi)), dim=(-1,-2))
denominator = torch.sum(torch.square(torch.abs(signal_equi)), dim=(-1,-2))
if dist.is_initialized():
dist.all_reduce(numerator, group=mp_group)
dist.all_reduce(denominator, group=mp_group)
if my_rank == 0:
print("relative l2 error accumulation with interpolation onto equiangular grid: ",
torch.mean(torch.sqrt(numerator / denominator)).item(),
"after", num_iters, "iterations")
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The FourCastNet Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import sys
sys.path.append("..")
sys.path.append(".")
import torch
import torch.nn as nn
import torch_harmonics as harmonics
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
import matplotlib.tri as mtri
try:
import cartopy.crs as ccrs
except ImportError:
ccrs = None
class SphereSolver(nn.Module):
"""
Solver class on the sphere. Can solve the following PDEs:
- Allen-Cahn eq
"""
def __init__(self, nlat, nlon, dt, lmax=None, mmax=None, grid='legendre-gauss', radius=1.0, coeff=0.001):
super().__init__()
# time stepping param
self.dt = dt
# grid parameters
self.nlat = nlat
self.nlon = nlon
self.grid = grid
# physical sonstants
self.register_buffer('radius', torch.as_tensor(radius, dtype=torch.float64))
self.register_buffer('coeff', torch.as_tensor(coeff, dtype=torch.float64))
# SHT
self.sht = harmonics.RealSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False)
self.isht = harmonics.InverseRealSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False)
self.lmax = lmax or self.sht.lmax
self.mmax = lmax or self.sht.mmax
# compute gridpoints
if self.grid == "legendre-gauss":
cost, _ = harmonics.quadrature.legendre_gauss_weights(self.nlat, -1, 1)
elif self.grid == "lobatto":
cost, _ = harmonics.quadrature.lobatto_weights(self.nlat, -1, 1)
elif self.grid == "equiangular":
cost, _ = harmonics.quadrature.clenshaw_curtiss_weights(self.nlat, -1, 1)
# apply cosine transform and flip them
lats = -torch.as_tensor(np.arcsin(cost))
lons = torch.linspace(0, 2*np.pi, self.nlon+1, dtype=torch.float64)[:nlon]
self.lmax = self.sht.lmax
self.mmax = self.sht.mmax
l = torch.arange(0, self.lmax).reshape(self.lmax, 1).cdouble()
l = l.expand(self.lmax, self.mmax)
# the laplace operator acting on the coefficients is given by l (l + 1)
lap = - l * (l + 1) / self.radius**2
invlap = - self.radius**2 / l / (l + 1)
invlap[0] = 0.
# register all
self.register_buffer('lats', lats)
self.register_buffer('lons', lons)
self.register_buffer('l', l)
self.register_buffer('lap', lap)
self.register_buffer('invlap', invlap)
def grid2spec(self, u):
"""spectral coefficients from spatial data"""
return self.sht(u)
def spec2grid(self, uspec):
"""spatial data from spectral coefficients"""
return self.isht(uspec)
def dudtspec(self, uspec, pde='allen-cahn'):
if pde == 'allen-cahn':
ugrid = self.spec2grid(uspec)
u3spec = self.grid2spec(ugrid**3)
dudtspec = self.coeff*self.lap*uspec + uspec - u3spec
elif pde == 'ginzburg-landau':
ugrid = self.spec2grid(uspec)
u3spec = self.grid2spec(ugrid**3)
dudtspec = uspec + (1. + 2.j)*self.coeff*self.lap*uspec - (1. + 2.j)*u3spec
else:
NotImplementedError
return dudtspec
def randspec(self):
"""random data on the sphere"""
rspec = torch.randn_like(self.lap) / 4 / torch.pi
return rspec
def plot_griddata(self, data, fig, cmap='twilight_shifted', vmax=None, vmin=None, projection='3d', title=None, antialiased=False):
"""
plotting routine for data on the grid. Requires cartopy for 3d plots.
"""
lons = self.lons.squeeze() - torch.pi
lats = self.lats.squeeze()
if data.is_cuda:
data = data.cpu()
lons = lons.cpu()
lats = lats.cpu()
Lons, Lats = np.meshgrid(lons, lats)
if projection == 'mollweide':
#ax = plt.gca(projection=projection)
ax = fig.add_subplot(projection=projection)
im = ax.pcolormesh(Lons, Lats, data, cmap=cmap, vmax=vmax, vmin=vmin)
# ax.set_title("Elevation map of mars")
ax.grid(True)
ax.set_xticklabels([])
ax.set_yticklabels([])
plt.colorbar(im, orientation='horizontal')
plt.title(title)
elif projection == '3d':
if ccrs is None:
raise ImportError("Couldn't import Cartopy")
proj = ccrs.Orthographic(central_longitude=0.0, central_latitude=45.0)
#ax = plt.gca(projection=proj, frameon=True)
ax = fig.add_subplot(projection=proj)
Lons = Lons*180/np.pi
Lats = Lats*180/np.pi
# contour data over the map.
im = ax.pcolormesh(Lons, Lats, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=antialiased, vmax=vmax, vmin=vmin)
plt.title(title, y=1.05)
else:
raise NotImplementedError
return im
def plot_specdata(self, data, fig, **kwargs):
return self.plot_griddata(self.isht(data), fig, **kwargs)
\ No newline at end of file
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The FourCastNet Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import sys
sys.path.append("..")
sys.path.append(".")
import torch
import torch.nn as nn
import torch_harmonics as harmonics
from torch_harmonics.quadrature import *
import numpy as np
import matplotlib.pyplot as plt
try:
import cartopy.crs as ccrs
except ImportError:
ccrs = None
class ShallowWaterSolver(nn.Module):
"""
SWE solver class. Interface inspired bu pyspharm and SHTns
"""
def __init__(self, nlat, nlon, dt, lmax=None, mmax=None, grid='legendre-gauss', radius=6.37122E6, \
omega=7.292E-5, gravity=9.80616, havg=10.e3, hamp=120.):
super().__init__()
# time stepping param
self.dt = dt
# grid parameters
self.nlat = nlat
self.nlon = nlon
self.grid = grid
# physical sonstants
self.register_buffer('radius', torch.as_tensor(radius, dtype=torch.float64))
self.register_buffer('omega', torch.as_tensor(omega, dtype=torch.float64))
self.register_buffer('gravity', torch.as_tensor(gravity, dtype=torch.float64))
self.register_buffer('havg', torch.as_tensor(havg, dtype=torch.float64))
self.register_buffer('hamp', torch.as_tensor(hamp, dtype=torch.float64))
# SHT
self.sht = harmonics.RealSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False)
self.isht = harmonics.InverseRealSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False)
self.vsht = harmonics.RealVectorSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False)
self.ivsht = harmonics.InverseRealVectorSHT(nlat, nlon, lmax=lmax, mmax=mmax, grid=grid, csphase=False)
self.lmax = lmax or self.sht.lmax
self.mmax = lmax or self.sht.mmax
# compute gridpoints
if self.grid == "legendre-gauss":
cost, quad_weights = harmonics.quadrature.legendre_gauss_weights(self.nlat, -1, 1)
elif self.grid == "lobatto":
cost, quad_weights = harmonics.quadrature.lobatto_weights(self.nlat, -1, 1)
elif self.grid == "equiangular":
cost, quad_weights = harmonics.quadrature.clenshaw_curtiss_weights(self.nlat, -1, 1)
quad_weights = torch.as_tensor(quad_weights).reshape(-1, 1)
# apply cosine transform and flip them
lats = -torch.as_tensor(np.arcsin(cost))
lons = torch.linspace(0, 2*np.pi, self.nlon+1, dtype=torch.float64)[:nlon]
self.lmax = self.sht.lmax
self.mmax = self.sht.mmax
# compute the laplace and inverse laplace operators
l = torch.arange(0, self.lmax).reshape(self.lmax, 1).double()
l = l.expand(self.lmax, self.mmax)
# the laplace operator acting on the coefficients is given by - l (l + 1)
lap = - l * (l + 1) / self.radius**2
invlap = - self.radius**2 / l / (l + 1)
invlap[0] = 0.
# compute coriolis force
coriolis = 2 * self.omega * torch.sin(lats).reshape(self.nlat, 1)
# hyperdiffusion
hyperdiff = torch.exp(torch.asarray((-self.dt / 2 / 3600.)*(lap / lap[-1, 0])**4))
# register all
self.register_buffer('lats', lats)
self.register_buffer('lons', lons)
self.register_buffer('l', l)
self.register_buffer('lap', lap)
self.register_buffer('invlap', invlap)
self.register_buffer('coriolis', coriolis)
self.register_buffer('hyperdiff', hyperdiff)
self.register_buffer('quad_weights', quad_weights)
def grid2spec(self, ugrid):
"""
spectral coefficients from spatial data
"""
return self.sht(ugrid)
def spec2grid(self, uspec):
"""
spatial data from spectral coefficients
"""
return self.isht(uspec)
def vrtdivspec(self, ugrid):
"""spatial data from spectral coefficients"""
vrtdivspec = self.lap * self.radius * self.vsht(ugrid)
return vrtdivspec
def getuv(self, vrtdivspec):
"""
compute wind vector from spectral coeffs of vorticity and divergence
"""
return self.ivsht( self.invlap * vrtdivspec / self.radius)
def gethuv(self, uspec):
"""
compute wind vector from spectral coeffs of vorticity and divergence
"""
hgrid = self.spec2grid(uspec[:1])
uvgrid = self.getuv(uspec[1:])
return torch.cat((hgrid, uvgrid), dim=-3)
def potential_vorticity(self, uspec):
"""
Compute potential vorticity
"""
ugrid = self.spec2grid(uspec)
pvrt = (0.5 * self.havg * self.gravity / self.omega) * (ugrid[1] + self.coriolis) / ugrid[0]
return pvrt
def dimensionless(self, uspec):
"""
Remove dimensions from variables
"""
uspec[0] = (uspec[0] - self.havg * self.gravity) / self.hamp / self.gravity
# vorticity is measured in 1/s so we normalize using sqrt(g h) / r
uspec[1:] = uspec[1:] * self.radius / torch.sqrt(self.gravity * self.havg)
return uspec
def dudtspec(self, uspec):
"""
Compute time derivatives from solution represented in spectral coefficients
"""
dudtspec = torch.zeros_like(uspec)
# compute the derivatives - this should be incorporated into the solver:
ugrid = self.spec2grid(uspec)
uvgrid = self.getuv(uspec[1:])
# phi = ugrid[0]
# vrtdiv = ugrid[1:]
tmp = uvgrid * (ugrid[1] + self.coriolis)
tmpspec = self.vrtdivspec(tmp)
dudtspec[2] = tmpspec[0]
dudtspec[1] = -1 * tmpspec[1]
tmp = uvgrid * ugrid[0]
tmp = self.vrtdivspec(tmp)
dudtspec[0] = -1 * tmp[1]
tmpspec = self.grid2spec(ugrid[0] + 0.5 * (uvgrid[0]**2 + uvgrid[1]**2))
dudtspec[2] = dudtspec[2] - self.lap * tmpspec
return dudtspec
def galewsky_initial_condition(self):
"""
Initializes non-linear barotropically unstable shallow water test case of Galewsky et al. (2004, Tellus, 56A, 429-440).
[1] Galewsky; An initial-value problem for testing numerical models of the global shallow-water equations;
DOI: 10.1111/j.1600-0870.2004.00071.x; http://www-vortex.mcs.st-and.ac.uk/~rks/reprints/galewsky_etal_tellus_2004.pdf
"""
device = self.lap.device
umax = 80.
phi0 = torch.asarray(torch.pi / 7., device=device)
phi1 = torch.asarray(0.5 * torch.pi - phi0, device=device)
phi2 = 0.25 * torch.pi
en = torch.exp(torch.asarray(-4.0 / (phi1 - phi0)**2, device=device))
alpha = 1. / 3.
beta = 1. / 15.
lats, lons = torch.meshgrid(self.lats, self.lons)
u1 = (umax/en)*torch.exp(1./((lats-phi0)*(lats-phi1)))
ugrid = torch.where(torch.logical_and(lats < phi1, lats > phi0), u1, torch.zeros(self.nlat, self.nlon, device=device))
vgrid = torch.zeros((self.nlat, self.nlon), device=device)
hbump = self.hamp * torch.cos(lats) * torch.exp(-((lons-torch.pi)/alpha)**2) * torch.exp(-(phi2-lats)**2/beta)
# intial velocity field
ugrid = torch.stack((ugrid, vgrid))
# intial vorticity/divergence field
vrtdivspec = self.vrtdivspec(ugrid)
vrtdivgrid = self.spec2grid(vrtdivspec)
# solve balance eqn to get initial zonal geopotential with a localized bump (not balanced).
tmp = ugrid * (vrtdivgrid + self.coriolis)
tmpspec = self.vrtdivspec(tmp)
tmpspec[1] = self.grid2spec(0.5 * torch.sum(ugrid**2, dim=0))
phispec = self.invlap*tmpspec[0] - tmpspec[1] + self.grid2spec(self.gravity*(self.havg + hbump))
# assemble solution
uspec = torch.zeros(3, self.lmax, self.mmax, dtype=vrtdivspec.dtype, device=device)
uspec[0] = phispec
uspec[1:] = vrtdivspec
return torch.tril(uspec)
def random_initial_condition(self, mach=0.1) -> torch.Tensor:
"""
random initial condition on the sphere
"""
device = self.lap.device
ctype = torch.complex128 if self.lap.dtype == torch.float64 else torch.complex64
# mach number relative to wave speed
llimit = mlimit = 20
# hgrid = self.havg + hamp * torch.randn(self.nlat, self.nlon, device=device, dtype=dtype)
# ugrid = uamp * torch.randn(self.nlat, self.nlon, device=device, dtype=dtype)
# vgrid = vamp * torch.randn(self.nlat, self.nlon, device=device, dtype=dtype)
# ugrid = torch.stack((ugrid, vgrid))
# initial geopotential
uspec = torch.zeros(3, self.lmax, self.mmax, dtype=ctype, device=self.lap.device)
uspec[:, :llimit, :mlimit] = torch.sqrt(torch.tensor(4 * torch.pi / llimit / (llimit+1), device=device, dtype=ctype)) * torch.randn_like(uspec[:, :llimit, :mlimit])
uspec[0] = self.gravity * self.hamp * uspec[0]
uspec[0, 0, 0] += torch.sqrt(torch.tensor(4 * torch.pi, device=device, dtype=ctype)) * self.havg * self.gravity
uspec[1:] = mach * uspec[1:] * torch.sqrt(self.gravity * self.havg) / self.radius
# uspec[1:] = self.vrtdivspec(self.spec2grid(uspec[1:]) * torch.cos(self.lats.reshape(-1, 1)))
# # intial velocity field
# ugrid = uamp * self.spec2grid(uspec[1])
# vgrid = vamp * self.spec2grid(uspec[2])
# ugrid = torch.stack((ugrid, vgrid))
# # intial vorticity/divergence field
# vrtdivspec = self.vrtdivspec(ugrid)
# vrtdivgrid = self.spec2grid(vrtdivspec)
# # solve balance eqn to get initial zonal geopotential with a localized bump (not balanced).
# tmp = ugrid * (vrtdivgrid + self.coriolis)
# tmpspec = self.vrtdivspec(tmp)
# tmpspec[1] = self.grid2spec(0.5 * torch.sum(ugrid**2, dim=0))
# phispec = self.invlap*tmpspec[0] - tmpspec[1] + self.grid2spec(self.gravity * hgrid)
# # assemble solution
# uspec = torch.zeros(3, self.lmax, self.mmax, dtype=phispec.dtype, device=device)
# uspec[0] = phispec
# uspec[1:] = vrtdivspec
return torch.tril(uspec)
def timestep(self, uspec: torch.Tensor, nsteps: int) -> torch.Tensor:
"""
Integrate the solution using Adams-Bashforth / forward Euler for nsteps steps.
"""
dudtspec = torch.zeros(3, 3, self.lmax, self.mmax, dtype=uspec.dtype, device=uspec.device)
# pointers to indicate the most current result
inew = 0
inow = 1
iold = 2
for iter in range(nsteps):
dudtspec[inew] = self.dudtspec(uspec)
# update vort,div,phiv with third-order adams-bashforth.
# forward euler, then 2nd-order adams-bashforth time steps to start.
if iter == 0:
dudtspec[inow] = dudtspec[inew]
dudtspec[iold] = dudtspec[inew]
elif iter == 1:
dudtspec[iold] = dudtspec[inew]
uspec = uspec + self.dt*( (23./12.) * dudtspec[inew] - (16./12.) * dudtspec[inow] + (5./12.) * dudtspec[iold] )
# implicit hyperdiffusion for vort and div.
uspec[1:] = self.hyperdiff * uspec[1:]
# cycle through the indices
inew = (inew - 1) % 3
inow = (inow - 1) % 3
iold = (iold - 1) % 3
return uspec
def integrate_grid(self, ugrid, dimensionless=False, polar_opt=0):
dlon = 2 * torch.pi / self.nlon
radius = 1 if dimensionless else self.radius
if polar_opt > 0:
out = torch.sum(ugrid[..., polar_opt:-polar_opt, :] * self.quad_weights[polar_opt:-polar_opt] * dlon * radius**2, dim=(-2, -1))
else:
out = torch.sum(ugrid * self.quad_weights * dlon * radius**2, dim=(-2, -1))
return out
def plot_griddata(self, data, fig, cmap='twilight_shifted', vmax=None, vmin=None, projection='3d', title=None, antialiased=True):
"""
plotting routine for data on the grid. Requires cartopy for 3d plots.
"""
lons = self.lons.squeeze() - torch.pi
lats = self.lats.squeeze()
if data.is_cuda:
data = data.cpu()
lons = lons.cpu()
lats = lats.cpu()
Lons, Lats = np.meshgrid(lons, lats)
if projection == 'mollweide':
#ax = plt.gca(projection=projection)
ax = fig.add_subplot(projection=projection)
im = ax.pcolormesh(Lons, Lats, data, cmap=cmap, vmax=vmax, vmin=vmin)
# ax.set_title("Elevation map of mars")
ax.grid(True)
ax.set_xticklabels([])
ax.set_yticklabels([])
plt.colorbar(im, orientation='horizontal')
plt.title(title)
elif projection == '3d':
if ccrs is None:
raise ImportError("Couldn't import Cartopy")
proj = ccrs.Orthographic(central_longitude=0.0, central_latitude=45.0)
#ax = plt.gca(projection=proj, frameon=True)
ax = fig.add_subplot(projection=proj)
Lons = Lons*180/np.pi
Lats = Lats*180/np.pi
# contour data over the map.
im = ax.pcolormesh(Lons, Lats, data, cmap=cmap, transform=ccrs.PlateCarree(), antialiased=antialiased, vmax=vmax, vmin=vmin)
plt.title(title, y=1.05)
else:
raise NotImplementedError
return im
def plot_specdata(self, data, fig, **kwargs):
return self.plot_griddata(self.isht(data), fig, **kwargs)
images/zonal_jet.gif

132 Bytes

This diff is collapsed.
This source diff could not be displayed because it is too large. You can view the blob instead.
This diff is collapsed.
This source diff could not be displayed because it is too large. You can view the blob instead.
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The FourCastNet Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
from setuptools import setup
setup(
name='torch_harmonics',
version='0.4',
author='Boris Bonev',
author_email='bbonev@nvidia.com',
packages=['torch_harmonics',],
scripts=[],
url='https://github.com/NVIDIA/torch-harmonics',
license='LICENSE.md',
description='a differentiable spherical harmonic transform for PyTorch',
long_description=open('README.md').read(),
install_requires=['torch','numpy']
)
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# ignore this (just for development without installation)
import sys
import os
sys.path.append("..")
sys.path.append(".")
import torch
import torch.distributed as dist
import torch_harmonics as harmonics
from torch_harmonics.distributed.primitives import gather_from_parallel_region, scatter_to_parallel_region
try:
from tqdm import tqdm
except:
tqdm = lambda x : x
# set up distributed
world_size = int(os.getenv('WORLD_SIZE', 1))
world_rank = int(os.getenv('WORLD_RANK', 0))
port = int(os.getenv('MASTER_PORT', 0))
master_address = os.getenv('MASTER_ADDR', 'localhost')
dist.init_process_group(backend = 'nccl',
init_method = f"tcp://{master_address}:{port}",
rank = world_rank,
world_size = world_size)
local_rank = world_rank % torch.cuda.device_count()
mp_group = dist.new_group(ranks=list(range(world_size)))
my_rank = dist.get_rank(mp_group)
group_size = 1 if not dist.is_initialized() else dist.get_world_size(mp_group)
device = torch.device(f"cuda:{local_rank}")
# set seed
torch.manual_seed(333)
torch.cuda.manual_seed(333)
if my_rank == 0:
print(f"Running distributed test on {group_size} ranks.")
# common parameters
b, c, n_theta, n_lambda = 1, 21, 361, 720
# do serial tests first:
#forward_transform = harmonics.RealSHT(n_theta, n_lambda).to(device)
inverse_transform = harmonics.InverseRealSHT(n_theta, n_lambda).to(device)
# set up signal
with torch.no_grad():
signal_leggauss = torch.randn(b, c, inverse_transform.lmax, inverse_transform.mmax, device=device, dtype=torch.complex128)
signal_leggauss_dist = signal_leggauss.clone()
signal_leggauss.requires_grad = True
# do a fwd and bwd pass:
x_local = inverse_transform(signal_leggauss)
loss = torch.sum(x_local)
loss.backward()
local_grad = torch.view_as_real(signal_leggauss.grad.clone())
# now the distributed test
harmonics.distributed.init(mp_group)
inverse_transform_dist = harmonics.InverseRealSHT(n_theta, n_lambda).to(device)
with torch.no_grad():
signal_leggauss_dist = scatter_to_parallel_region(signal_leggauss_dist, dim=2)
signal_leggauss_dist.requires_grad = True
# do distributed sht
x_dist = inverse_transform_dist(signal_leggauss_dist)
loss = torch.sum(x_dist)
loss.backward()
dist_grad = signal_leggauss_dist.grad.clone()
# gather the output
dist_grad = torch.view_as_real(gather_from_parallel_region(dist_grad, dim=2))
if my_rank == 0:
print(f"Local Out: sum={x_local.abs().sum().item()}, max={x_local.max().item()}, min={x_local.min().item()}")
print(f"Dist Out: sum={x_dist.abs().sum().item()}, max={x_dist.max().item()}, min={x_dist.min().item()}")
diff = (x_local-x_dist).abs()
print(f"Out Difference: abs={diff.sum().item()}, rel={diff.sum().item() / (0.5*(x_local.abs().sum() + x_dist.abs().sum()))}, max={diff.max().item()}")
print("")
print(f"Local Grad: sum={local_grad.abs().sum().item()}, max={local_grad.max().item()}, min={local_grad.min().item()}")
print(f"Dist Grad: sum={dist_grad.abs().sum().item()}, max={dist_grad.max().item()}, min={dist_grad.min().item()}")
diff = (local_grad-dist_grad).abs()
print(f"Grad Difference: abs={diff.sum().item()}, rel={diff.sum().item() / (0.5*(local_grad.abs().sum() + dist_grad.abs().sum()))}, max={diff.max().item()}")
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# ignore this (just for development without installation)
import sys
import os
sys.path.append("..")
sys.path.append(".")
import torch
import torch.distributed as dist
import torch_harmonics as harmonics
from torch_harmonics.distributed.primitives import gather_from_parallel_region
try:
from tqdm import tqdm
except:
tqdm = lambda x : x
# set up distributed
world_size = int(os.getenv('WORLD_SIZE', 1))
world_rank = int(os.getenv('WORLD_RANK', 0))
port = int(os.getenv('MASTER_PORT', 0))
master_address = os.getenv('MASTER_ADDR', 'localhost')
dist.init_process_group(backend = 'nccl',
init_method = f"tcp://{master_address}:{port}",
rank = world_rank,
world_size = world_size)
local_rank = world_rank % torch.cuda.device_count()
mp_group = dist.new_group(ranks=list(range(world_size)))
my_rank = dist.get_rank(mp_group)
group_size = 1 if not dist.is_initialized() else dist.get_world_size(mp_group)
device = torch.device(f"cuda:{local_rank}")
# set seed
torch.manual_seed(333)
torch.cuda.manual_seed(333)
if my_rank == 0:
print(f"Running distributed test on {group_size} ranks.")
# common parameters
b, c, n_theta, n_lambda = 1, 21, 361, 720
# do serial tests first:
forward_transform = harmonics.RealSHT(n_theta, n_lambda).to(device)
inverse_transform = harmonics.InverseRealSHT(n_theta, n_lambda).to(device)
# set up signal
with torch.no_grad():
signal_leggauss = inverse_transform(torch.randn(b, c, forward_transform.lmax, forward_transform.mmax, device=device, dtype=torch.complex128))
signal_leggauss_dist = signal_leggauss.clone()
signal_leggauss.requires_grad = True
signal_leggauss_dist.requires_grad = True
# do a fwd and bwd pass:
x_local = forward_transform(signal_leggauss)
loss = torch.sum(torch.view_as_real(x_local))
loss.backward()
x_local = torch.view_as_real(x_local)
local_grad = signal_leggauss.grad.clone()
# now the distributed test
harmonics.distributed.init(mp_group)
forward_transform_dist = harmonics.RealSHT(n_theta, n_lambda).to(device)
inverse_transform_dist = harmonics.InverseRealSHT(n_theta, n_lambda).to(device)
# do distributed sht
x_dist = forward_transform_dist(signal_leggauss_dist)
loss = torch.sum(torch.view_as_real(x_dist))
loss.backward()
x_dist = torch.view_as_real(x_dist)
dist_grad = signal_leggauss_dist.grad.clone()
# gather the output
x_dist = gather_from_parallel_region(x_dist, dim=2)
if my_rank == 0:
print(f"Local Out: sum={x_local.abs().sum().item()}, max={x_local.max().item()}, min={x_local.min().item()}")
print(f"Dist Out: sum={x_dist.abs().sum().item()}, max={x_dist.max().item()}, min={x_dist.min().item()}")
diff = (x_local-x_dist).abs()
print(f"Out Difference: abs={diff.sum().item()}, rel={diff.sum().item() / (0.5*(x_local.abs().sum() + x_dist.abs().sum()))}, max={diff.max().item()}")
print("")
print(f"Local Grad: sum={local_grad.abs().sum().item()}, max={local_grad.max().item()}, min={local_grad.min().item()}")
print(f"Dist Grad: sum={dist_grad.abs().sum().item()}, max={dist_grad.max().item()}, min={dist_grad.min().item()}")
diff = (local_grad-dist_grad).abs()
print(f"Grad Difference: abs={diff.sum().item()}, rel={diff.sum().item() / (0.5*(local_grad.abs().sum() + dist_grad.abs().sum()))}, max={diff.max().item()}")
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT
from . import quadrature
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# we need this in order to enable distributed
from .utils import init, is_initialized
from .primitives import copy_to_parallel_region, scatter_to_parallel_region, reduce_from_parallel_region
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment