"vscode:/vscode.git/clone" did not exist on "8f8f8ef986913ee13c78bf0f066451a5ac62686a"
Commit ef4f6518 authored by Boris Bonev's avatar Boris Bonev
Browse files

v0.5 update

parent 77ac7836
......@@ -3,3 +3,4 @@ The code was authored by the following people:
Boris Bonev - NVIDIA Corporation
Christian Hundt - NVIDIA Corporation
Thorsten Kurth - NVIDIA Corporation
Nikola Kovachki - NVIDIA Corporation
......@@ -2,6 +2,11 @@
## Versioning
### v0.5
* Reworked distributed SHT
* Module for sampling Gaussian Random Fields on the sphere
### v0.4
* Computation of associated Legendre polynomials
......@@ -32,13 +37,4 @@
### v0.1
* Single GPU forward and backward transform
* Minimal code example and notebook
<!-- ## Detailed logs
### 23-11-2022
* Initialized the library
* Added `getting_started.ipynb` example
* Added simple example to test the SHT
* Logo -->
* Minimal code example and notebook
\ No newline at end of file
# SPDX-FileCopyrightText: Copyright (c) 2022 The FourCastNet Authors. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
......
......@@ -36,7 +36,13 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
<!-- ## What is torch-harmonics? -->
`torch_harmonics` is a differentiable implementation of the Spherical Harmonic transform in PyTorch. It uses quadrature to compute the projection onto the associated Legendre polynomials and FFTs for the projection onto the harmonic basis. This algorithm tends to outperform others with better asymptotic scaling for most practical purposes.
Spherical Harmonic Transforms (SHTs) are the counterpart to Fourier transforms on the sphere. As such they are an invaluable tool for signal-processing on the sphere.
`torch_harmonics` is a differentiable implementation of the SHT in PyTorch. It uses quadrature rules to compute the projection onto the associated Legendre polynomials and FFTs for the projection onto the harmonic basis. This algorithm tends to outperform others with better asymptotic scaling for most practical purposes.
`torch_harmonics` uses PyTorch primitives to implement these operations, making it fully differentiable. Moreover, the quadrature can be distributed onto multiple ranks making it spatially distributed.
`torch_harmonics` has been used to implement a variety of differentiable PDE solvers which generated the animations below.
<table border="0" cellspacing="0" cellpadding="0">
......@@ -76,6 +82,7 @@ docker run --gpus all -it --rm --ipc=host --ulimit memlock=-1 --ulimit stack=671
- Boris Bonev (bbonev@nvidia.com)
- Christian Hundt (chundt@nvidia.com)
- Thorsten Kurth (tkurth@nvidia.com)
- Nikola Kovachki (nkovachki@nvidia.com)
## Implementation
The implementation follows the paper "Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations", N. Schaeffer, G3: Geochemistry, Geophysics, Geosystems.
......@@ -126,7 +133,7 @@ The main functionality of `torch_harmonics` is provided in the form of `torch.nn
```python
import torch
import torch_harmonics as harmonics
import torch_harmonics as th
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
......@@ -136,11 +143,13 @@ batch_size = 32
signal = torch.randn(batch_size, nlat, nlon)
# transform data on an equiangular grid
sht = harmonics.RealSHT(nlat, nlon, grid="equiangular").to(device).float()
sht = th.RealSHT(nlat, nlon, grid="equiangular").to(device).float()
coeffs = sht(signal)
```
`torch_harmonics` also implements a distributed variant of the SHT located in `torch-harmonics.distributed`.
## References
<a id="1">[1]</a>
......
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The FourCastNet Authors. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
......
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The FourCastNet Authors. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
......
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The FourCastNet Authors. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
......
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The FourCastNet Authors. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
......@@ -333,7 +333,7 @@ class ShallowWaterSolver(nn.Module):
return out
def plot_griddata(self, data, fig, cmap='twilight_shifted', vmax=None, vmin=None, projection='3d', title=None, antialiased=True):
def plot_griddata(self, data, fig, cmap='twilight_shifted', vmax=None, vmin=None, projection='3d', title=None, antialiased=False):
"""
plotting routine for data on the grid. Requires cartopy for 3d plots.
"""
......
This source diff could not be displayed because it is too large. You can view the blob instead.
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The FourCastNet Authors. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
......@@ -33,7 +33,7 @@ from setuptools import setup
setup(
name='torch_harmonics',
version='0.4',
version='0.5',
author='Boris Bonev',
author_email='bbonev@nvidia.com',
packages=['torch_harmonics',],
......
......@@ -29,86 +29,57 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# ignore this (just for development without installation)
import sys
import os
sys.path.append("..")
sys.path.append(".")
# we need this in order to enable distributed
import torch
import torch.distributed as dist
import torch_harmonics as harmonics
from torch_harmonics.distributed.primitives import gather_from_parallel_region, scatter_to_parallel_region
try:
from tqdm import tqdm
except:
tqdm = lambda x : x
# set up distributed
world_size = int(os.getenv('WORLD_SIZE', 1))
world_rank = int(os.getenv('WORLD_RANK', 0))
port = int(os.getenv('MASTER_PORT', 0))
master_address = os.getenv('MASTER_ADDR', 'localhost')
dist.init_process_group(backend = 'nccl',
init_method = f"tcp://{master_address}:{port}",
rank = world_rank,
world_size = world_size)
local_rank = world_rank % torch.cuda.device_count()
mp_group = dist.new_group(ranks=list(range(world_size)))
my_rank = dist.get_rank(mp_group)
group_size = 1 if not dist.is_initialized() else dist.get_world_size(mp_group)
device = torch.device(f"cuda:{local_rank}")
# those need to be global
_POLAR_PARALLEL_GROUP = None
_AZIMUTH_PARALLEL_GROUP = None
_IS_INITIALIZED = False
# set seed
torch.manual_seed(333)
torch.cuda.manual_seed(333)
def polar_group():
return _POLAR_PARALLEL_GROUP
if my_rank == 0:
print(f"Running distributed test on {group_size} ranks.")
def azimuth_group():
return _AZIMUTH_PARALLEL_GROUP
# common parameters
b, c, n_theta, n_lambda = 1, 21, 361, 720
def init(polar_process_group, azimuth_process_group):
global _POLAR_PARALLEL_GROUP
global _AZIMUTH_PARALLEL_GROUP
_POLAR_PARALLEL_GROUP = polar_process_group
_AZIMUTH_PARALLEL_GROUP = azimuth_process_group
_IS_INITIALIZED = True
# do serial tests first:
#forward_transform = harmonics.RealSHT(n_theta, n_lambda).to(device)
inverse_transform = harmonics.InverseRealSHT(n_theta, n_lambda).to(device)
def is_initialized() -> bool:
return _IS_INITIALIZED
# set up signal
with torch.no_grad():
signal_leggauss = torch.randn(b, c, inverse_transform.lmax, inverse_transform.mmax, device=device, dtype=torch.complex128)
signal_leggauss_dist = signal_leggauss.clone()
signal_leggauss.requires_grad = True
def is_distributed_polar() -> bool:
return (_POLAR_PARALLEL_GROUP is not None)
# do a fwd and bwd pass:
x_local = inverse_transform(signal_leggauss)
loss = torch.sum(x_local)
loss.backward()
local_grad = torch.view_as_real(signal_leggauss.grad.clone())
def is_distributed_azimuth() -> bool:
return (_AZIMUTH_PARALLEL_GROUP is not None)
# now the distributed test
harmonics.distributed.init(mp_group)
inverse_transform_dist = harmonics.InverseRealSHT(n_theta, n_lambda).to(device)
with torch.no_grad():
signal_leggauss_dist = scatter_to_parallel_region(signal_leggauss_dist, dim=2)
signal_leggauss_dist.requires_grad = True
def polar_group_size() -> int:
if not is_distributed_polar():
return 1
else:
return dist.get_world_size(group = _POLAR_PARALLEL_GROUP)
# do distributed sht
x_dist = inverse_transform_dist(signal_leggauss_dist)
loss = torch.sum(x_dist)
loss.backward()
dist_grad = signal_leggauss_dist.grad.clone()
def azimuth_group_size() -> int:
if not is_distributed_azimuth():
return 1
else:
return dist.get_world_size(group = _AZIMUTH_PARALLEL_GROUP)
# gather the output
dist_grad = torch.view_as_real(gather_from_parallel_region(dist_grad, dim=2))
def polar_group_rank() -> int:
if not is_distributed_polar():
return 0
else:
return dist.get_rank(group = _POLAR_PARALLEL_GROUP)
if my_rank == 0:
print(f"Local Out: sum={x_local.abs().sum().item()}, max={x_local.max().item()}, min={x_local.min().item()}")
print(f"Dist Out: sum={x_dist.abs().sum().item()}, max={x_dist.max().item()}, min={x_dist.min().item()}")
diff = (x_local-x_dist).abs()
print(f"Out Difference: abs={diff.sum().item()}, rel={diff.sum().item() / (0.5*(x_local.abs().sum() + x_dist.abs().sum()))}, max={diff.max().item()}")
print("")
print(f"Local Grad: sum={local_grad.abs().sum().item()}, max={local_grad.max().item()}, min={local_grad.min().item()}")
print(f"Dist Grad: sum={dist_grad.abs().sum().item()}, max={dist_grad.max().item()}, min={dist_grad.min().item()}")
diff = (local_grad-dist_grad).abs()
print(f"Grad Difference: abs={diff.sum().item()}, rel={diff.sum().item() / (0.5*(local_grad.abs().sum() + dist_grad.abs().sum()))}, max={diff.max().item()}")
def azimuth_group_rank() -> int:
if not is_distributed_azimuth():
return 0
else:
return dist.get_rank(group = _AZIMUTH_PARALLEL_GROUP)
......@@ -36,9 +36,10 @@ sys.path.append("..")
sys.path.append(".")
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch_harmonics as harmonics
from torch_harmonics.distributed.primitives import gather_from_parallel_region
import torch_harmonics.distributed as thd
try:
from tqdm import tqdm
......@@ -46,70 +47,169 @@ except:
tqdm = lambda x : x
# set up distributed
world_size = int(os.getenv('WORLD_SIZE', 1))
world_rank = int(os.getenv('WORLD_RANK', 0))
grid_size_h = int(os.getenv('GRID_H', 1))
grid_size_w = int(os.getenv('GRID_W', 1))
port = int(os.getenv('MASTER_PORT', 0))
master_address = os.getenv('MASTER_ADDR', 'localhost')
world_size = grid_size_h * grid_size_w
dist.init_process_group(backend = 'nccl',
init_method = f"tcp://{master_address}:{port}",
rank = world_rank,
world_size = world_size)
local_rank = world_rank % torch.cuda.device_count()
mp_group = dist.new_group(ranks=list(range(world_size)))
my_rank = dist.get_rank(mp_group)
group_size = 1 if not dist.is_initialized() else dist.get_world_size(mp_group)
device = torch.device(f"cuda:{local_rank}")
# compute local ranks in h and w:
# rank = wrank + grid_size_w * hrank
wrank = world_rank % grid_size_w
hrank = world_rank // grid_size_w
w_group = None
h_group = None
# now set up the comm grid:
wgroups = []
for h in range(grid_size_h):
start = h
end = h + grid_size_w
wgroups.append(list(range(start, end)))
print(wgroups)
for grp in wgroups:
if len(grp) == 1:
continue
tmp_group = dist.new_group(ranks=grp)
if wrank in grp:
w_group = tmp_group
# transpose:
hgroups = [sorted(list(i)) for i in zip(*wgroups)]
print(hgroups)
for grp in hgroups:
if len(grp) == 1:
continue
tmp_group = dist.new_group(ranks=grp)
if hrank in grp:
h_group = tmp_group
# set device
torch.cuda.set_device(device.index)
# set seed
torch.manual_seed(333)
torch.cuda.manual_seed(333)
if my_rank == 0:
print(f"Running distributed test on {group_size} ranks.")
if world_rank == 0:
print(f"Running distributed test on grid H x W = {grid_size_h} x {grid_size_w}")
# initializing sht
thd.init(h_group, w_group)
# common parameters
b, c, n_theta, n_lambda = 1, 21, 361, 720
B, C, H, W = 1, 8, 721, 1440
Hloc = (H + grid_size_h - 1) // grid_size_h
Wloc = (W + grid_size_w - 1) // grid_size_w
Hpad = grid_size_h * Hloc - H
Wpad = grid_size_w * Wloc - W
# do serial tests first:
forward_transform = harmonics.RealSHT(n_theta, n_lambda).to(device)
inverse_transform = harmonics.InverseRealSHT(n_theta, n_lambda).to(device)
forward_transform_local = harmonics.RealSHT(nlat=H, nlon=W).to(device)
forward_transform_dist = thd.DistributedRealSHT(nlat=H, nlon=W).to(device)
Lloc = (forward_transform_dist.lpad + forward_transform_dist.lmax) // grid_size_h
Mloc = (forward_transform_dist.mpad + forward_transform_dist.mmax) // grid_size_w
# create tensors
inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=device)
# set up signal
# pad
with torch.no_grad():
signal_leggauss = inverse_transform(torch.randn(b, c, forward_transform.lmax, forward_transform.mmax, device=device, dtype=torch.complex128))
signal_leggauss_dist = signal_leggauss.clone()
signal_leggauss.requires_grad = True
signal_leggauss_dist.requires_grad = True
# do a fwd and bwd pass:
x_local = forward_transform(signal_leggauss)
loss = torch.sum(torch.view_as_real(x_local))
loss.backward()
x_local = torch.view_as_real(x_local)
local_grad = signal_leggauss.grad.clone()
# now the distributed test
harmonics.distributed.init(mp_group)
forward_transform_dist = harmonics.RealSHT(n_theta, n_lambda).to(device)
inverse_transform_dist = harmonics.InverseRealSHT(n_theta, n_lambda).to(device)
# do distributed sht
x_dist = forward_transform_dist(signal_leggauss_dist)
loss = torch.sum(torch.view_as_real(x_dist))
loss.backward()
x_dist = torch.view_as_real(x_dist)
dist_grad = signal_leggauss_dist.grad.clone()
# gather the output
x_dist = gather_from_parallel_region(x_dist, dim=2)
if my_rank == 0:
print(f"Local Out: sum={x_local.abs().sum().item()}, max={x_local.max().item()}, min={x_local.min().item()}")
print(f"Dist Out: sum={x_dist.abs().sum().item()}, max={x_dist.max().item()}, min={x_dist.min().item()}")
diff = (x_local-x_dist).abs()
print(f"Out Difference: abs={diff.sum().item()}, rel={diff.sum().item() / (0.5*(x_local.abs().sum() + x_dist.abs().sum()))}, max={diff.max().item()}")
inp_pad = F.pad(inp_full, (0, Wpad, 0, Hpad))
# split in W
inp_local = torch.split(inp_pad, split_size_or_sections=Wloc, dim=-1)[wrank]
# split in H
inp_local = torch.split(inp_local, split_size_or_sections=Hloc, dim=-2)[hrank]
# do FWD transform
out_full = forward_transform_local(inp_full)
out_local = forward_transform_dist(inp_local)
# gather the local data
# gather in W
if grid_size_w > 1:
olist = [torch.empty_like(out_local) for _ in range(grid_size_w)]
olist[wrank] = out_local
dist.all_gather(olist, out_local, group=w_group)
out_full_gather = torch.cat(olist, dim=-1)
out_full_gather = out_full_gather[..., :forward_transform_dist.mmax]
else:
out_full_gather = out_local
# gather in h
if grid_size_h > 1:
olist = [torch.empty_like(out_full_gather) for _ in range(grid_size_h)]
olist[hrank] = out_full_gather
dist.all_gather(olist, out_full_gather, group=h_group)
out_full_gather = torch.cat(olist, dim=-2)
out_full_gather = out_full_gather[..., :forward_transform_dist.lmax, :]
if world_rank == 0:
print(f"Local Out: sum={out_full.abs().sum().item()}, max={out_full.abs().max().item()}, min={out_full.abs().min().item()}")
print(f"Dist Out: sum={out_full_gather.abs().sum().item()}, max={out_full_gather.abs().max().item()}, min={out_full_gather.abs().min().item()}")
diff = (out_full-out_full_gather).abs()
print(f"Out Difference: abs={diff.sum().item()}, rel={diff.sum().item() / (0.5*(out_full.abs().sum() + out_full_gather.abs().sum()))}, max={diff.abs().max().item()}")
print("")
print(f"Local Grad: sum={local_grad.abs().sum().item()}, max={local_grad.max().item()}, min={local_grad.min().item()}")
print(f"Dist Grad: sum={dist_grad.abs().sum().item()}, max={dist_grad.max().item()}, min={dist_grad.min().item()}")
diff = (local_grad-dist_grad).abs()
print(f"Grad Difference: abs={diff.sum().item()}, rel={diff.sum().item() / (0.5*(local_grad.abs().sum() + dist_grad.abs().sum()))}, max={diff.max().item()}")
# create split input grad
with torch.no_grad():
# create full grad
ograd_full = torch.randn_like(out_full)
# pad
ograd_pad = F.pad(ograd_full, [0, forward_transform_dist.mpad, 0, forward_transform_dist.lpad])
# split in M
ograd_local = torch.split(ograd_pad, split_size_or_sections=Mloc, dim=-1)[wrank]
# split in H
ograd_local = torch.split(ograd_local, split_size_or_sections=Lloc, dim=-2)[hrank]
# backward pass:
# local
inp_full.requires_grad = True
out_full = forward_transform_local(inp_full)
out_full.backward(ograd_full)
igrad_full = inp_full.grad.clone()
# distributed
inp_local.requires_grad = True
out_local = forward_transform_dist(inp_local)
out_local.backward(ograd_local)
igrad_local = inp_local.grad.clone()
# gather
# gather in W
if grid_size_w > 1:
olist = [torch.empty_like(igrad_local) for _ in range(grid_size_w)]
olist[wrank] = igrad_local
dist.all_gather(olist, igrad_local, group=w_group)
igrad_full_gather = torch.cat(olist, dim=-1)
igrad_full_gather = igrad_full_gather[..., :W]
else:
igrad_full_gather = igrad_local
# gather in h
if grid_size_h > 1:
olist = [torch.empty_like(igrad_full_gather) for _ in range(grid_size_h)]
olist[hrank] = igrad_full_gather
dist.all_gather(olist, igrad_full_gather, group=h_group)
igrad_full_gather = torch.cat(olist, dim=-2)
igrad_full_gather = igrad_full_gather[..., :H, :]
if world_rank == 0:
print(f"Local Grad: sum={igrad_full.abs().sum().item()}, max={igrad_full.abs().max().item()}, min={igrad_full.abs().min().item()}")
print(f"Dist Grad: sum={igrad_full_gather.abs().sum().item()}, max={igrad_full_gather.abs().max().item()}, min={igrad_full_gather.abs().min().item()}")
diff = (igrad_full-igrad_full_gather).abs()
print(f"Grad Difference: abs={diff.sum().item()}, rel={diff.sum().item() / (0.5*(igrad_full.abs().sum() + igrad_full_gather.abs().sum()))}, max={diff.abs().max().item()}")
......@@ -31,3 +31,4 @@
from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT
from . import quadrature
from . import random_fields
......@@ -30,5 +30,10 @@
#
# we need this in order to enable distributed
from .utils import init, is_initialized
from .primitives import copy_to_parallel_region, scatter_to_parallel_region, reduce_from_parallel_region
from .utils import init, is_initialized, polar_group, azimuth_group
from .utils import polar_group_size, azimuth_group_size, polar_group_rank, azimuth_group_rank
from .primitives import distributed_transpose_azimuth, distributed_transpose_polar
# import the sht stuff
from .distributed_sht import DistributedRealSHT, DistributedInverseRealSHT
from .distributed_sht import DistributedRealVectorSHT, DistributedInverseRealVectorSHT
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import os
import numpy as np
import torch
import torch.nn as nn
import torch.fft
import torch.nn.functional as F
from torch_harmonics.quadrature import *
from torch_harmonics.legendre import *
from torch_harmonics.distributed import polar_group_size, azimuth_group_size, distributed_transpose_azimuth, distributed_transpose_polar
from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank
class DistributedRealSHT(nn.Module):
"""
Defines a module for computing the forward (real-valued) SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
The SHT is applied to the last two dimensions of the input
[1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
"""
def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho", csphase=True):
"""
Initializes the SHT Layer, precomputing the necessary quadrature weights
Parameters:
nlat: input grid resolution in the latitudinal direction
nlon: input grid resolution in the longitudinal direction
grid: grid in the latitude direction (for now only tensor product grids are supported)
"""
super().__init__()
self.nlat = nlat
self.nlon = nlon
self.grid = grid
self.norm = norm
self.csphase = csphase
# TODO: include assertions regarding the dimensions
# compute quadrature points
if self.grid == "legendre-gauss":
cost, w = legendre_gauss_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat
elif self.grid == "lobatto":
cost, w = lobatto_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat-1
elif self.grid == "equiangular":
cost, w = clenshaw_curtiss_weights(nlat, -1, 1)
# cost, w = fejer2_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat
else:
raise(ValueError("Unknown quadrature mode"))
# get the comms grid:
self.comm_size_polar = polar_group_size()
self.comm_rank_polar = polar_group_rank()
self.comm_size_azimuth = azimuth_group_size()
self.comm_rank_azimuth = azimuth_group_rank()
# apply cosine transform and flip them
tq = np.flip(np.arccos(cost))
# determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1
# spatial paddings
latdist = (self.nlat + self.comm_size_polar - 1) // self.comm_size_polar
self.nlatpad = latdist * self.comm_size_polar - self.nlat
londist = (self.nlon + self.comm_size_azimuth - 1) // self.comm_size_azimuth
self.nlonpad = londist * self.comm_size_azimuth - self.nlon
# frequency paddings
ldist = (self.lmax + self.comm_size_polar - 1) // self.comm_size_polar
self.lpad = ldist * self.comm_size_polar - self.lmax
mdist = (self.mmax + self.comm_size_azimuth - 1) // self.comm_size_azimuth
self.mpad = mdist * self.comm_size_azimuth - self.mmax
# combine quadrature weights with the legendre weights
weights = torch.from_numpy(w)
pct = precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
weights = torch.einsum('mlk,k->mlk', pct, weights)
# we need to split in m, pad before:
weights = F.pad(weights, [0, 0, 0, 0, 0, self.mpad], mode="constant")
weights = torch.split(weights, (self.mmax+self.mpad) // self.comm_size_azimuth, dim=0)[self.comm_rank_azimuth]
# compute the local pad and size
# spatial
self.nlat_local = min(latdist, self.nlat - self.comm_rank_polar * latdist)
self.nlatpad_local = latdist - self.nlat_local
self.nlon_local = min(londist, self.nlon - self.comm_rank_azimuth * londist)
self.nlonpad_local = londist - self.nlon_local
# frequency
self.lmax_local = min(ldist, self.lmax - self.comm_rank_polar * ldist)
self.lpad_local = ldist - self.lmax_local
self.mmax_local = min(mdist, self.mmax - self.comm_rank_azimuth * mdist)
self.mpad_local = mdist - self.mmax_local
# remember quadrature weights
self.register_buffer('weights', weights, persistent=False)
def extra_repr(self):
"""
Pretty print module
"""
return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'
def forward(self, x: torch.Tensor):
# we need to ensure that we can split the channels evenly
assert(x.shape[1] % self.comm_size_polar == 0)
assert(x.shape[1] % self.comm_size_azimuth == 0)
# h and w is split. First we make w local by transposing into channel dim
if self.comm_size_azimuth > 1:
xt = distributed_transpose_azimuth.apply(x, (1, -1))
else:
xt = x
# apply real fft in the longitudinal direction: make sure to truncate to nlon
xtf = 2.0 * torch.pi * torch.fft.rfft(xt, n=self.nlon, dim=-1, norm="forward")
# truncate
xtft = xtf[..., :self.mmax]
# pad the dim to allow for splitting
xtfp = F.pad(xtft, [0, self.mpad], mode="constant")
# transpose: after this, m is split and c is local
if self.comm_size_azimuth > 1:
y = distributed_transpose_azimuth.apply(xtfp, (-1, 1))
else:
y = xtfp
# transpose: after this, c is split and h is local
if self.comm_size_polar > 1:
yt = distributed_transpose_polar.apply(y, (1, -2))
else:
yt = y
# the input data might be padded, make sure to truncate to nlat:
ytt = yt[..., :self.nlat, :]
# do the Legendre-Gauss quadrature
yttr = torch.view_as_real(ytt)
# contraction
yor = torch.einsum('...kmr,mlk->...lmr', yttr, self.weights.to(yttr.dtype)).contiguous()
# pad if required, truncation is implicit
yopr = F.pad(yor, [0, 0, 0, 0, 0, self.lpad], mode="constant")
yop = torch.view_as_complex(yopr)
# transpose: after this, l is split and c is local
if self.comm_size_polar > 1:
y = distributed_transpose_polar.apply(yop, (-2, 1))
else:
y = yop
return y
class DistributedInverseRealSHT(nn.Module):
"""
Defines a module for computing the inverse (real-valued) SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
nlat, nlon: Output dimensions
lmax, mmax: Input dimensions (spherical coefficients). For convenience, these are inferred from the output dimensions
[1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
"""
def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho", csphase=True):
super().__init__()
self.nlat = nlat
self.nlon = nlon
self.grid = grid
self.norm = norm
self.csphase = csphase
# compute quadrature points
if self.grid == "legendre-gauss":
cost, _ = legendre_gauss_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat
elif self.grid == "lobatto":
cost, _ = lobatto_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat-1
elif self.grid == "equiangular":
cost, _ = clenshaw_curtiss_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat
else:
raise(ValueError("Unknown quadrature mode"))
# get the comms grid:
self.comm_size_polar = polar_group_size()
self.comm_rank_polar = polar_group_rank()
self.comm_size_azimuth = azimuth_group_size()
self.comm_rank_azimuth = azimuth_group_rank()
# apply cosine transform and flip them
t = np.flip(np.arccos(cost))
# determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1
# spatial paddings
latdist = (self.nlat + self.comm_size_polar - 1) // self.comm_size_polar
self.nlatpad = latdist * self.comm_size_polar - self.nlat
londist = (self.nlon + self.comm_size_azimuth - 1) // self.comm_size_azimuth
self.nlonpad = londist * self.comm_size_azimuth - self.nlon
# frequency paddings
ldist = (self.lmax + self.comm_size_polar - 1) // self.comm_size_polar
self.lpad = ldist * self.comm_size_polar - self.lmax
mdist = (self.mmax + self.comm_size_azimuth - 1) // self.comm_size_azimuth
self.mpad = mdist * self.comm_size_azimuth - self.mmax
# compute legende polynomials
pct = precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
# split in m
pct = F.pad(pct, [0, 0, 0, 0, 0, self.mpad], mode="constant")
pct = torch.split(pct, (self.mmax+self.mpad) // self.comm_size_azimuth, dim=0)[self.comm_rank_azimuth]
# compute the local pads and sizes
# spatial
self.nlat_local = min(latdist, self.nlat - self.comm_rank_polar * latdist)
self.nlatpad_local = latdist - self.nlat_local
self.nlon_local = min(londist, self.nlon - self.comm_rank_azimuth * londist)
self.nlonpad_local = londist - self.nlon_local
# frequency
self.lmax_local = min(ldist, self.lmax - self.comm_rank_polar * ldist)
self.lpad_local = ldist - self.lmax_local
self.mmax_local = min(mdist, self.mmax - self.comm_rank_azimuth * mdist)
self.mpad_local = mdist - self.mmax_local
# register
self.register_buffer('pct', pct, persistent=False)
def extra_repr(self):
"""
Pretty print module
"""
return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'
def forward(self, x: torch.Tensor):
# we need to ensure that we can split the channels evenly
assert(x.shape[1] % self.comm_size_polar == 0)
assert(x.shape[1] % self.comm_size_azimuth == 0)
# transpose: after that, channels are split, l is local:
if self.comm_size_polar > 1:
xt = distributed_transpose_polar.apply(x, (1, -2))
else:
xt = x
# remove padding in l:
xtt = xt[..., :self.lmax, :]
# Evaluate associated Legendre functions on the output nodes
xttr = torch.view_as_real(xtt)
# einsum
xs = torch.einsum('...lmr, mlk->...kmr', xttr, self.pct.to(xttr.dtype)).contiguous()
x = torch.view_as_complex(xs)
# transpose: after this, l is split and channels are local
xp = F.pad(x, [0, 0, 0, self.nlatpad])
if self.comm_size_polar > 1:
y = distributed_transpose_polar.apply(xp, (-2, 1))
else:
y = xp
# transpose: after this, channels are split and m is local
if self.comm_size_azimuth > 1:
yt = distributed_transpose_azimuth.apply(y, (1, -1))
else:
yt = y
# truncate
ytt = yt[..., :self.mmax]
# apply the inverse (real) FFT
x = torch.fft.irfft(ytt, n=self.nlon, dim=-1, norm="forward")
# pad before we transpose back
xp = F.pad(x, [0, self.nlonpad])
# transpose: after this, m is split and channels are local
if self.comm_size_azimuth > 1:
out = distributed_transpose_azimuth.apply(xp, (-1, 1))
else:
out = xp
return out
class DistributedRealVectorSHT(nn.Module):
"""
Defines a module for computing the forward (real) vector SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
The SHT is applied to the last three dimensions of the input.
[1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
"""
def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho", csphase=True):
"""
Initializes the vector SHT Layer, precomputing the necessary quadrature weights
Parameters:
nlat: input grid resolution in the latitudinal direction
nlon: input grid resolution in the longitudinal direction
grid: type of grid the data lives on
"""
super().__init__()
self.nlat = nlat
self.nlon = nlon
self.grid = grid
self.norm = norm
self.csphase = csphase
# compute quadrature points
if self.grid == "legendre-gauss":
cost, w = legendre_gauss_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat
elif self.grid == "lobatto":
cost, w = lobatto_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat-1
elif self.grid == "equiangular":
cost, w = clenshaw_curtiss_weights(nlat, -1, 1)
# cost, w = fejer2_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat
else:
raise(ValueError("Unknown quadrature mode"))
# get the comms grid:
self.comm_size_polar = polar_group_size()
self.comm_rank_polar = polar_group_rank()
self.comm_size_azimuth = azimuth_group_size()
self.comm_rank_azimuth = azimuth_group_rank()
# apply cosine transform and flip them
tq = np.flip(np.arccos(cost))
# determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1
# spatial paddings
latdist = (self.nlat + self.comm_size_polar - 1) // self.comm_size_polar
self.nlatpad = latdist * self.comm_size_polar - self.nlat
londist = (self.nlon + self.comm_size_azimuth - 1) // self.comm_size_azimuth
self.nlonpad = londist * self.comm_size_azimuth - self.nlon
# frequency paddings
ldist = (self.lmax + self.comm_size_polar - 1) // self.comm_size_polar
self.lpad = ldist * self.comm_size_polar - self.lmax
mdist = (self.mmax + self.comm_size_azimuth - 1) // self.comm_size_azimuth
self.mpad = mdist * self.comm_size_azimuth - self.mmax
weights = torch.from_numpy(w)
dpct = precompute_dlegpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
# combine integration weights, normalization factor in to one:
l = torch.arange(0, self.lmax)
norm_factor = 1. / l / (l+1)
norm_factor[0] = 1.
weights = torch.einsum('dmlk,k,l->dmlk', dpct, weights, norm_factor)
# since the second component is imaginary, we need to take complex conjugation into account
weights[1] = -1 * weights[1]
# we need to split in m, pad before:
weights = F.pad(weights, [0, 0, 0, 0, 0, self.mpad], mode="constant")
weights = torch.split(weights, (self.mmax+self.mpad) // self.comm_size_azimuth, dim=1)[self.comm_rank_azimuth]
# remember quadrature weights
self.register_buffer('weights', weights, persistent=False)
# compute the local pad and size
# spatial
self.nlat_local = min(latdist, self.nlat - self.comm_rank_polar * latdist)
self.nlatpad_local = latdist - self.nlat_local
self.nlon_local = min(londist, self.nlon - self.comm_rank_azimuth * londist)
self.nlonpad_local = londist - self.nlon_local
# frequency
self.lmax_local = min(ldist, self.lmax - self.comm_rank_polar * ldist)
self.lpad_local = ldist - self.lmax_local
self.mmax_local = min(mdist, self.mmax - self.comm_rank_azimuth * mdist)
self.mpad_local = mdist - self.mmax_local
def extra_repr(self):
"""
Pretty print module
"""
return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'
def forward(self, x: torch.Tensor):
assert(len(x.shape) >= 3)
assert(x.shape[1] % self.comm_size_polar == 0)
assert(x.shape[1] % self.comm_size_azimuth == 0)
# h and w is split. First we make w local by transposing into channel dim
if self.comm_size_azimuth > 1:
xt = distributed_transpose_azimuth.apply(x, (1, -1))
else:
xt = x
# apply real fft in the longitudinal direction: make sure to truncate to nlon
xtf = 2.0 * torch.pi * torch.fft.rfft(xt, n=self.nlon, dim=-1, norm="forward")
# truncate
xtft = xtf[..., :self.mmax]
# pad the dim to allow for splitting
xtfp = F.pad(xtft, [0, self.mpad], mode="constant")
# transpose: after this, m is split and c is local
if self.comm_size_azimuth > 1:
y = distributed_transpose_azimuth.apply(xtfp, (-1, 1))
else:
y = xtfp
# transpose: after this, c is split and h is local
if self.comm_size_polar > 1:
yt = distributed_transpose_polar.apply(y, (1, -2))
else:
yt = y
# the input data might be padded, make sure to truncate to nlat:
ytt = yt[..., :self.nlat, :]
# do the Legendre-Gauss quadrature
yttr = torch.view_as_real(ytt)
# create output array
yor = torch.zeros_like(yttr, dtype=yttr.dtype, device=yttr.device)
# contraction - spheroidal component
# real component
yor[..., 0, :, :, 0] = torch.einsum('...km,mlk->...lm', yttr[..., 0, :, :, 0], self.weights[0].to(yttr.dtype)) \
- torch.einsum('...km,mlk->...lm', yttr[..., 1, :, :, 1], self.weights[1].to(yttr.dtype))
# iamg component
yor[..., 0, :, :, 1] = torch.einsum('...km,mlk->...lm', yttr[..., 0, :, :, 1], self.weights[0].to(yttr.dtype)) \
+ torch.einsum('...km,mlk->...lm', yttr[..., 1, :, :, 0], self.weights[1].to(yttr.dtype))
# contraction - toroidal component
# real component
yor[..., 1, :, :, 0] = - torch.einsum('...km,mlk->...lm', yttr[..., 0, :, :, 1], self.weights[1].to(yttr.dtype)) \
- torch.einsum('...km,mlk->...lm', yttr[..., 1, :, :, 0], self.weights[0].to(yttr.dtype))
# imag component
yor[..., 1, :, :, 1] = torch.einsum('...km,mlk->...lm', yttr[..., 0, :, :, 0], self.weights[1].to(yttr.dtype)) \
- torch.einsum('...km,mlk->...lm', yttr[..., 1, :, :, 1], self.weights[0].to(yttr.dtype))
# pad if required
yopr = F.pad(yor, [0, 0, 0, 0, 0, self.lpad], mode="constant")
yop = torch.view_as_complex(yopr)
# transpose: after this, l is split and c is local
if self.comm_size_polar > 1:
y = distributed_transpose_polar.apply(yop, (-2, 1))
else:
y = yop
return y
class DistributedInverseRealVectorSHT(nn.Module):
"""
Defines a module for computing the inverse (real-valued) vector SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
[1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
"""
def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho", csphase=True):
super().__init__()
self.nlat = nlat
self.nlon = nlon
self.grid = grid
self.norm = norm
self.csphase = csphase
# compute quadrature points
if self.grid == "legendre-gauss":
cost, _ = legendre_gauss_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat
elif self.grid == "lobatto":
cost, _ = lobatto_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat-1
elif self.grid == "equiangular":
cost, _ = clenshaw_curtiss_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat
else:
raise(ValueError("Unknown quadrature mode"))
self.comm_size_polar = polar_group_size()
self.comm_rank_polar = polar_group_rank()
self.comm_size_azimuth = azimuth_group_size()
self.comm_rank_azimuth = azimuth_group_rank()
# apply cosine transform and flip them
t = np.flip(np.arccos(cost))
# determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1
# spatial paddings
latdist = (self.nlat + self.comm_size_polar - 1) // self.comm_size_polar
self.nlatpad = latdist * self.comm_size_polar - self.nlat
londist = (self.nlon + self.comm_size_azimuth - 1) // self.comm_size_azimuth
self.nlonpad = londist * self.comm_size_azimuth - self.nlon
# frequency paddings
ldist = (self.lmax + self.comm_size_polar - 1) // self.comm_size_polar
self.lpad = ldist * self.comm_size_polar - self.lmax
mdist = (self.mmax + self.comm_size_azimuth - 1) // self.comm_size_azimuth
self.mpad = mdist * self.comm_size_azimuth - self.mmax
# compute legende polynomials
dpct = precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
# split in m
pct = F.pad(pct, [0, 0, 0, 0, 0, self.mpad], mode="constant")
pct = torch.split(pct, (self.mmax+self.mpad) // self.comm_size_azimuth, dim=0)[self.comm_rank_azimuth]
# register buffer
self.register_buffer('dpct', dpct, persistent=False)
# compute the local pad and size
# spatial
self.nlat_local = min(latdist, self.nlat - self.comm_rank_polar * latdist)
self.nlatpad_local = latdist - self.nlat_local
self.nlon_local = min(londist, self.nlon - self.comm_rank_azimuth * londist)
self.nlonpad_local = londist - self.nlon_local
# frequency
self.lmax_local = min(ldist, self.lmax - self.comm_rank_polar * ldist)
self.lpad_local = ldist - self.lmax_local
self.mmax_local = min(mdist, self.mmax - self.comm_rank_azimuth * mdist)
self.mpad_local = mdist - self.mmax_local
def extra_repr(self):
"""
Pretty print module
"""
return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'
def forward(self, x: torch.Tensor):
assert(x.shape[1] % self.comm_size_polar == 0)
assert(x.shape[1] % self.comm_size_azimuth == 0)
# transpose: after that, channels are split, l is local:
if self.comm_size_polar > 1:
xt = distributed_transpose_polar.apply(x, (1, -2))
else:
xt = x
# remove padding in l:
xtt = xt[..., :self.lmax, :]
# Evaluate associated Legendre functions on the output nodes
xttr = torch.view_as_real(xtt)
# contraction - spheroidal component
# real component
srl = torch.einsum('...lm,mlk->...km', xttr[..., 0, :, :, 0], self.dpct[0].to(xttr.dtype)) \
- torch.einsum('...lm,mlk->...km', xttr[..., 1, :, :, 1], self.dpct[1].to(xttr.dtype))
# imag component
sim = torch.einsum('...lm,mlk->...km', xttr[..., 0, :, :, 1], self.dpct[0].to(xttr.dtype)) \
+ torch.einsum('...lm,mlk->...km', xttr[..., 1, :, :, 0], self.dpct[1].to(xttr.dtype))
# contraction - toroidal component
# real component
trl = - torch.einsum('...lm,mlk->...km', xttr[..., 0, :, :, 1], self.dpct[1].to(xttr.dtype)) \
- torch.einsum('...lm,mlk->...km', xttr[..., 1, :, :, 0], self.dpct[0].to(xttr.dtype))
# imag component
tim = torch.einsum('...lm,mlk->...km', xttr[..., 0, :, :, 0], self.dpct[1].to(xttr.dtype)) \
- torch.einsum('...lm,mlk->...km', xttr[..., 1, :, :, 1], self.dpct[0].to(xttr.dtype))
# reassemble
s = torch.stack((srl, sim), -1)
t = torch.stack((trl, tim), -1)
xs = torch.stack((s, t), -4)
# convert to complex
x = torch.view_as_complex(xs)
# transpose: after this, l is split and channels are local
xp = F.pad(x, [0, 0, 0, self.nlatpad])
if self.comm_size_polar > 1:
y = distributed_transpose_polar.apply(xp, (-2, 1))
else:
y = xp
# transpose: after this, channels are split and m is local
if self.comm_size_azimuth > 1:
yt = distributed_transpose_azimuth.apply(y, (1, -1))
else:
yt = y
# truncate
ytt = yt[..., :self.mmax]
# apply the inverse (real) FFT
x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
# pad before we transpose back
xp = F.pad(x, [0, self.nlonpad])
# transpose: after this, m is split and channels are local
if self.comm_size_azimuth > 1:
out = distributed_transpose_azimuth.apply(xp, (-1, 1))
else:
out = xp
return out
\ No newline at end of file
......@@ -32,7 +32,7 @@
import torch
import torch.distributed as dist
from .utils import get_model_parallel_group, is_initialized
from .utils import polar_group, azimuth_group, is_initialized
# general helpers
def get_memory_format(tensor):
......@@ -50,163 +50,54 @@ def split_tensor_along_dim(tensor, dim, num_chunks):
return tensor_list
# split
def _split(input_, dim_, group=None):
"""Split the tensor along its last dimension and keep the corresponding slice."""
def _transpose(tensor, dim0, dim1, group=None, async_op=False):
# get input format
input_format = get_memory_format(input_)
input_format = get_memory_format(tensor)
# Bypass the function if we are using only 1 GPU.
# get comm params
comm_size = dist.get_world_size(group=group)
if comm_size == 1:
return input_
# Split along last dimension.
input_list = split_tensor_along_dim(input_, dim_, comm_size)
# split and local transposition
split_size = tensor.shape[dim0] // comm_size
x_send = [y.contiguous(memory_format=input_format) for y in torch.split(tensor, split_size, dim=dim0)]
x_recv = [torch.empty_like(x_send[0]).contiguous(memory_format=input_format) for _ in range(comm_size)]
# Note: torch.split does not create contiguous tensors by default.
rank = dist.get_rank(group=group)
output = input_list[rank].contiguous(memory_format=input_format)
# global transposition
req = dist.all_to_all(x_recv, x_send, group=group, async_op=async_op)
return output
return x_recv, req
# those are used by the various helper functions
def _reduce(input_, use_fp32=True, group=None):
"""All-reduce the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
if dist.get_world_size(group=group) == 1:
return input_
# All-reduce.
if use_fp32:
dtype = input_.dtype
inputf_ = input_.float()
dist.all_reduce(inputf_, group=group)
input_ = inputf_.to(dtype)
else:
dist.all_reduce(input_, group=group)
return input_
class _CopyToParallelRegion(torch.autograd.Function):
"""Pass the input to the parallel region."""
@staticmethod
def symbolic(graph, input_):
return input_
class distributed_transpose_azimuth(torch.autograd.Function):
@staticmethod
def forward(ctx, input_):
return input_
def forward(ctx, x, dim):
xlist, _ = _transpose(x, dim[0], dim[1], group=azimuth_group())
x = torch.cat(xlist, dim=dim[1])
ctx.dim = dim
return x
@staticmethod
def backward(ctx, grad_output):
return _reduce(grad_output, group=get_model_parallel_group())
# write a convenient functional wrapper
def copy_to_parallel_region(input_):
if not is_initialized():
return input_
else:
return _CopyToParallelRegion.apply(input_)
def backward(ctx, go):
dim = ctx.dim
gilist, _ = _transpose(go, dim[1], dim[0], group=azimuth_group())
gi = torch.cat(gilist, dim=dim[0])
return gi, None
# reduce
class _ReduceFromParallelRegion(torch.autograd.Function):
"""All-reduce the input from the parallel region."""
@staticmethod
def symbolic(graph, input_):
return _reduce(input_, group=get_model_parallel_group())
@staticmethod
def forward(ctx, input_):
return _reduce(input_, group=get_model_parallel_group())
class distributed_transpose_polar(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
return grad_output
def reduce_from_parallel_region(input_):
if not is_initialized():
return input_
else:
return _ReduceFromParallelRegion.apply(input_)
# gather
def _gather(input_, dim_, group=None):
"""Gather tensors and concatinate along the last dimension."""
# get input format
input_format = get_memory_format(input_)
print(input_format)
comm_size = dist.get_world_size(group=group)
# Bypass the function if we are using only 1 GPU.
if comm_size==1:
return input_
# sanity checks
assert(dim_ < input_.dim()), f"Error, cannot gather along {dim_} for tensor with {input_.dim()} dimensions."
# Size and dimension.
comm_rank = dist.get_rank(group=group)
def forward(ctx, x, dim):
xlist, _ = _transpose(x, dim[0], dim[1], group=polar_group())
x = torch.cat(xlist, dim=dim[1])
ctx.dim = dim
return x
# input needs to be contiguous
input_ = input_.contiguous(memory_format=input_format)
tensor_list = [torch.empty_like(input_) for _ in range(comm_size)]
tensor_list[comm_rank] = input_
dist.all_gather(tensor_list, input_, group=group)
output = torch.cat(tensor_list, dim=dim_).contiguous(memory_format=input_format)
return output
class _GatherFromParallelRegion(torch.autograd.Function):
"""Gather the input from parallel region and concatinate."""
@staticmethod
def symbolic(graph, input_, dim_):
return _gather(input_, dim_, group=get_model_parallel_group())
@staticmethod
def forward(ctx, input_, dim_):
ctx.dim = dim_
return _gather(input_, dim_, group=get_model_parallel_group())
@staticmethod
def backward(ctx, grad_output):
return _split(grad_output, ctx.dim, group=get_model_parallel_group()), None
def gather_from_parallel_region(input_, dim):
if not is_initialized():
return input_
else:
return _GatherFromParallelRegion.apply(input_, dim)
def backward(ctx, go):
dim = ctx.dim
gilist, _ = _transpose(go, dim[1], dim[0], group=polar_group())
gi = torch.cat(gilist, dim=dim[0])
return gi, None
# scatter
class _ScatterToParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
@staticmethod
def symbolic(graph, input_, dim_):
return _split(input_, dim_, group=get_model_parallel_group())
@staticmethod
def forward(ctx, input_, dim_):
ctx.dim = dim_
return _split(input_, dim_, group=get_model_parallel_group())
@staticmethod
def backward(ctx, grad_output):
return _gather(grad_output, ctx.dim, group=get_model_parallel_group()), None
def scatter_to_parallel_region(input_, dim):
if not is_initialized():
return input_
else:
return _ScatterToParallelRegion.apply(input_, dim)
......@@ -34,15 +34,52 @@ import torch
import torch.distributed as dist
# those need to be global
_MODEL_PARALLEL_GROUP = None
_POLAR_PARALLEL_GROUP = None
_AZIMUTH_PARALLEL_GROUP = None
_IS_INITIALIZED = False
def get_model_parallel_group():
return _MODEL_PARALLEL_GROUP
def polar_group():
return _POLAR_PARALLEL_GROUP
def init(process_group):
global _MODEL_PARALLEL_GROUP
_MODEL_PARALLEL_GROUP = process_group
def azimuth_group():
return _AZIMUTH_PARALLEL_GROUP
def init(polar_process_group, azimuth_process_group):
global _POLAR_PARALLEL_GROUP
global _AZIMUTH_PARALLEL_GROUP
_POLAR_PARALLEL_GROUP = polar_process_group
_AZIMUTH_PARALLEL_GROUP = azimuth_process_group
_IS_INITIALIZED = True
def is_initialized() -> bool:
return _MODEL_PARALLEL_GROUP is not None
return _IS_INITIALIZED
def is_distributed_polar() -> bool:
return (_POLAR_PARALLEL_GROUP is not None)
def is_distributed_azimuth() -> bool:
return (_AZIMUTH_PARALLEL_GROUP is not None)
def polar_group_size() -> int:
if not is_distributed_polar():
return 1
else:
return dist.get_world_size(group = _POLAR_PARALLEL_GROUP)
def azimuth_group_size() -> int:
if not is_distributed_azimuth():
return 1
else:
return dist.get_world_size(group = _AZIMUTH_PARALLEL_GROUP)
def polar_group_rank() -> int:
if not is_distributed_polar():
return 0
else:
return dist.get_rank(group = _POLAR_PARALLEL_GROUP)
def azimuth_group_rank() -> int:
if not is_distributed_azimuth():
return 0
else:
return dist.get_rank(group = _AZIMUTH_PARALLEL_GROUP)
......@@ -53,7 +53,8 @@ def precompute_legpoly(mmax, lmax, t, norm="ortho", inverse=False, csphase=True)
"""
# compute the tensor P^m_n:
pct = np.zeros((mmax, lmax, len(t)), dtype=np.float64)
nmax = max(mmax,lmax)
pct = np.zeros((nmax, nmax, len(t)), dtype=np.float64)
sint = np.sin(t)
cost = np.cos(t)
......@@ -65,23 +66,25 @@ def precompute_legpoly(mmax, lmax, t, norm="ortho", inverse=False, csphase=True)
pct[0,0,:] = norm_factor / np.sqrt(4 * np.pi)
# fill the diagonal and the lower diagonal
for l in range(1, min(mmax,lmax)):
for l in range(1, nmax):
pct[l-1, l, :] = np.sqrt(2*l + 1) * cost * pct[l-1, l-1, :]
pct[l, l, :] = np.sqrt( (2*l + 1) * (1 + cost) * (1 - cost) / 2 / l ) * pct[l-1, l-1, :]
# fill the remaining values on the upper triangle and multiply b
for l in range(2, lmax):
for l in range(2, nmax):
for m in range(0, l-1):
pct[m, l, :] = cost * np.sqrt((2*l - 1) / (l - m) * (2*l + 1) / (l + m)) * pct[m, l-1, :] \
- np.sqrt((l + m - 1) / (l - m) * (2*l + 1) / (2*l - 3) * (l - m - 1) / (l + m)) * pct[m, l-2, :]
if norm == "schmidt":
for l in range(0, lmax):
for l in range(0, nmax):
if inverse:
pct[:, l, : ] = pct[:, l, : ] * np.sqrt(2*l + 1)
else:
pct[:, l, : ] = pct[:, l, : ] / np.sqrt(2*l + 1)
pct = pct[:mmax, :lmax]
if csphase:
for m in range(1, mmax, 2):
pct[m] *= -1
......
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import torch
from .sht import InverseRealSHT
class GaussianRandomFieldS2(torch.nn.Module):
def __init__(self, nlat, alpha=2.0, tau=3.0, sigma=None, radius=1.0, grid="equiangular", dtype=torch.float32):
super().__init__()
"""A mean-zero Gaussian Random Field on the sphere with Matern covariance:
C = sigma^2 (-Lap + tau^2 I)^(-alpha).
Lap is the Laplacian on the sphere, I the identity operator,
and sigma, tau, alpha are scalar parameters.
Note: C is trace-class on L^2 if and only if alpha > 1.
Parameters
----------
nlat : int
Number of latitudinal modes;
longitudinal modes are 2*nlat.
alpha : float, default is 2
Regularity parameter. Larger means smoother.
tau : float, default is 3
Lenght-scale parameter. Larger means more scales.
sigma : float, default is None
Scale parameter. Larger means bigger.
If None, sigma = tau**(0.5*(2*alpha - 2.0)).
radius : float, default is 1
Radius of the sphere.
grid : string, default is "equiangular"
Grid type. Currently supports "equiangular" and
"legendre-gauss".
dtype : torch.dtype, default is torch.float32
Numerical type for the calculations.
"""
#Number of latitudinal modes.
self.nlat = nlat
#Default value of sigma if None is given.
if sigma is None:
assert alpha > 1.0, f"Alpha must be greater than one, got {alpha}."
sigma = tau**(0.5*(2*alpha - 2.0))
# Inverse SHT
self.isht = InverseRealSHT(self.nlat, 2*self.nlat, grid=grid, norm='backward').to(dtype=dtype)
#Square root of the eigenvalues of C.
sqrt_eig = torch.tensor([j*(j+1) for j in range(self.nlat)]).view(self.nlat,1).repeat(1, self.nlat+1)
sqrt_eig = torch.tril(sigma*(((sqrt_eig/radius**2) + tau**2)**(-alpha/2.0)))
sqrt_eig[0,0] = 0.0
sqrt_eig = sqrt_eig.unsqueeze(0)
self.register_buffer('sqrt_eig', sqrt_eig)
#Save mean and var of the standard Gaussian.
#Need these to re-initialize distribution on a new device.
mean = torch.tensor([0.0]).to(dtype=dtype)
var = torch.tensor([1.0]).to(dtype=dtype)
self.register_buffer('mean', mean)
self.register_buffer('var', var)
#Standard normal noise sampler.
self.gaussian_noise = torch.distributions.normal.Normal(self.mean, self.var)
def forward(self, N, xi=None):
"""Sample random functions from a spherical GRF.
Parameters
----------
N : int
Number of functions to sample.
xi : torch.Tensor, default is None
Noise is a complex tensor of size (N, nlat, nlat+1).
If None, new Gaussian noise is sampled.
If xi is provided, N is ignored.
Output
-------
u : torch.Tensor
N random samples from the GRF returned as a
tensor of size (N, nlat, 2*nlat) on a equiangular grid.
"""
#Sample Gaussian noise.
if xi is None:
xi = self.gaussian_noise.sample(torch.Size((N, self.nlat, self.nlat + 1, 2))).squeeze()
xi = torch.view_as_complex(xi)
#Karhunen-Loeve expansion.
u = self.isht(xi*self.sqrt_eig)
return u
#Override cuda and to methods so sampler gets initialized with mean
#and variance on the correct device.
def cuda(self, *args, **kwargs):
super().cuda(*args, **kwargs)
self.gaussian_noise = torch.distributions.normal.Normal(self.mean, self.var)
return self
def to(self, *args, **kwargs):
super().to(*args, **kwargs)
self.gaussian_noise = torch.distributions.normal.Normal(self.mean, self.var)
return self
......@@ -36,7 +36,6 @@ import torch.fft
from .quadrature import *
from .legendre import *
from .distributed import copy_to_parallel_region, scatter_to_parallel_region, reduce_from_parallel_region
class RealSHT(nn.Module):
......@@ -94,10 +93,6 @@ class RealSHT(nn.Module):
pct = precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
weights = torch.einsum('mlk,k->mlk', pct, weights)
# shard the weights along n, because we want to be distributed in spectral space:
weights = scatter_to_parallel_region(weights, dim=1)
self.lmax_local = weights.shape[1]
# remember quadrature weights
self.register_buffer('weights', weights, persistent=False)
......@@ -119,9 +114,8 @@ class RealSHT(nn.Module):
x = torch.view_as_real(x)
# distributed contraction: fork
x = copy_to_parallel_region(x)
out_shape = list(x.size())
out_shape[-3] = self.lmax_local
out_shape[-3] = self.lmax
out_shape[-2] = self.mmax
xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device)
......@@ -174,10 +168,7 @@ class InverseRealSHT(nn.Module):
pct = precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
# shard the pct along the n dim
pct = scatter_to_parallel_region(pct, dim=1)
self.lmax_local = pct.shape[1]
# register buffer
self.register_buffer('pct', pct, persistent=False)
def extra_repr(self):
......@@ -188,7 +179,7 @@ class InverseRealSHT(nn.Module):
def forward(self, x: torch.Tensor):
assert(x.shape[-2] == self.lmax_local)
assert(x.shape[-2] == self.lmax)
assert(x.shape[-1] == self.mmax)
# Evaluate associated Legendre functions on the output nodes
......@@ -198,9 +189,6 @@ class InverseRealSHT(nn.Module):
im = torch.einsum('...lm, mlk->...km', x[..., 1], self.pct )
xs = torch.stack((rl, im), -1)
# distributed contraction: join
xs = reduce_from_parallel_region(xs)
# apply the inverse (real) FFT
x = torch.view_as_complex(xs)
x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
......@@ -267,10 +255,6 @@ class RealVectorSHT(nn.Module):
# since the second component is imaginary, we need to take complex conjugation into account
weights[1] = -1 * weights[1]
# shard the weights along n, because we want to be distributed in spectral space:
weights = scatter_to_parallel_region(weights, dim=2)
self.lmax_local = weights.shape[2]
# remember quadrature weights
self.register_buffer('weights', weights, persistent=False)
......@@ -291,9 +275,8 @@ class RealVectorSHT(nn.Module):
x = torch.view_as_real(x)
# distributed contraction: fork
x = copy_to_parallel_region(x)
out_shape = list(x.size())
out_shape[-3] = self.lmax_local
out_shape[-3] = self.lmax
out_shape[-2] = self.mmax
xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device)
......@@ -356,10 +339,7 @@ class InverseRealVectorSHT(nn.Module):
dpct = precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
# shard the pct along the n dim
dpct = scatter_to_parallel_region(dpct, dim=2)
self.lmax_local = dpct.shape[2]
# register weights
self.register_buffer('dpct', dpct, persistent=False)
def extra_repr(self):
......@@ -370,7 +350,7 @@ class InverseRealVectorSHT(nn.Module):
def forward(self, x: torch.Tensor):
assert(x.shape[-2] == self.lmax_local)
assert(x.shape[-2] == self.lmax)
assert(x.shape[-1] == self.mmax)
# Evaluate associated Legendre functions on the output nodes
......@@ -397,9 +377,6 @@ class InverseRealVectorSHT(nn.Module):
t = torch.stack((trl, tim), -1)
xs = torch.stack((s, t), -4)
# distributed contraction: join
xs = reduce_from_parallel_region(xs)
# apply the inverse (real) FFT
x = torch.view_as_complex(xs)
x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment