Commit ef4f6518 authored by Boris Bonev's avatar Boris Bonev
Browse files

v0.5 update

parent 77ac7836
...@@ -3,3 +3,4 @@ The code was authored by the following people: ...@@ -3,3 +3,4 @@ The code was authored by the following people:
Boris Bonev - NVIDIA Corporation Boris Bonev - NVIDIA Corporation
Christian Hundt - NVIDIA Corporation Christian Hundt - NVIDIA Corporation
Thorsten Kurth - NVIDIA Corporation Thorsten Kurth - NVIDIA Corporation
Nikola Kovachki - NVIDIA Corporation
...@@ -2,6 +2,11 @@ ...@@ -2,6 +2,11 @@
## Versioning ## Versioning
### v0.5
* Reworked distributed SHT
* Module for sampling Gaussian Random Fields on the sphere
### v0.4 ### v0.4
* Computation of associated Legendre polynomials * Computation of associated Legendre polynomials
...@@ -33,12 +38,3 @@ ...@@ -33,12 +38,3 @@
* Single GPU forward and backward transform * Single GPU forward and backward transform
* Minimal code example and notebook * Minimal code example and notebook
\ No newline at end of file
<!-- ## Detailed logs
### 23-11-2022
* Initialized the library
* Added `getting_started.ipynb` example
* Added simple example to test the SHT
* Logo -->
# SPDX-FileCopyrightText: Copyright (c) 2022 The FourCastNet Authors. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
# #
# Redistribution and use in source and binary forms, with or without # Redistribution and use in source and binary forms, with or without
......
...@@ -36,7 +36,13 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ...@@ -36,7 +36,13 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
<!-- ## What is torch-harmonics? --> <!-- ## What is torch-harmonics? -->
`torch_harmonics` is a differentiable implementation of the Spherical Harmonic transform in PyTorch. It uses quadrature to compute the projection onto the associated Legendre polynomials and FFTs for the projection onto the harmonic basis. This algorithm tends to outperform others with better asymptotic scaling for most practical purposes. Spherical Harmonic Transforms (SHTs) are the counterpart to Fourier transforms on the sphere. As such they are an invaluable tool for signal-processing on the sphere.
`torch_harmonics` is a differentiable implementation of the SHT in PyTorch. It uses quadrature rules to compute the projection onto the associated Legendre polynomials and FFTs for the projection onto the harmonic basis. This algorithm tends to outperform others with better asymptotic scaling for most practical purposes.
`torch_harmonics` uses PyTorch primitives to implement these operations, making it fully differentiable. Moreover, the quadrature can be distributed onto multiple ranks making it spatially distributed.
`torch_harmonics` has been used to implement a variety of differentiable PDE solvers which generated the animations below.
<table border="0" cellspacing="0" cellpadding="0"> <table border="0" cellspacing="0" cellpadding="0">
...@@ -76,6 +82,7 @@ docker run --gpus all -it --rm --ipc=host --ulimit memlock=-1 --ulimit stack=671 ...@@ -76,6 +82,7 @@ docker run --gpus all -it --rm --ipc=host --ulimit memlock=-1 --ulimit stack=671
- Boris Bonev (bbonev@nvidia.com) - Boris Bonev (bbonev@nvidia.com)
- Christian Hundt (chundt@nvidia.com) - Christian Hundt (chundt@nvidia.com)
- Thorsten Kurth (tkurth@nvidia.com) - Thorsten Kurth (tkurth@nvidia.com)
- Nikola Kovachki (nkovachki@nvidia.com)
## Implementation ## Implementation
The implementation follows the paper "Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations", N. Schaeffer, G3: Geochemistry, Geophysics, Geosystems. The implementation follows the paper "Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations", N. Schaeffer, G3: Geochemistry, Geophysics, Geosystems.
...@@ -126,7 +133,7 @@ The main functionality of `torch_harmonics` is provided in the form of `torch.nn ...@@ -126,7 +133,7 @@ The main functionality of `torch_harmonics` is provided in the form of `torch.nn
```python ```python
import torch import torch
import torch_harmonics as harmonics import torch_harmonics as th
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
...@@ -136,11 +143,13 @@ batch_size = 32 ...@@ -136,11 +143,13 @@ batch_size = 32
signal = torch.randn(batch_size, nlat, nlon) signal = torch.randn(batch_size, nlat, nlon)
# transform data on an equiangular grid # transform data on an equiangular grid
sht = harmonics.RealSHT(nlat, nlon, grid="equiangular").to(device).float() sht = th.RealSHT(nlat, nlon, grid="equiangular").to(device).float()
coeffs = sht(signal) coeffs = sht(signal)
``` ```
`torch_harmonics` also implements a distributed variant of the SHT located in `torch-harmonics.distributed`.
## References ## References
<a id="1">[1]</a> <a id="1">[1]</a>
......
# coding=utf-8 # coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The FourCastNet Authors. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
# #
# Redistribution and use in source and binary forms, with or without # Redistribution and use in source and binary forms, with or without
......
# coding=utf-8 # coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The FourCastNet Authors. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
# #
# Redistribution and use in source and binary forms, with or without # Redistribution and use in source and binary forms, with or without
......
# coding=utf-8 # coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The FourCastNet Authors. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
# #
# Redistribution and use in source and binary forms, with or without # Redistribution and use in source and binary forms, with or without
......
# coding=utf-8 # coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The FourCastNet Authors. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
# #
# Redistribution and use in source and binary forms, with or without # Redistribution and use in source and binary forms, with or without
...@@ -333,7 +333,7 @@ class ShallowWaterSolver(nn.Module): ...@@ -333,7 +333,7 @@ class ShallowWaterSolver(nn.Module):
return out return out
def plot_griddata(self, data, fig, cmap='twilight_shifted', vmax=None, vmin=None, projection='3d', title=None, antialiased=True): def plot_griddata(self, data, fig, cmap='twilight_shifted', vmax=None, vmin=None, projection='3d', title=None, antialiased=False):
""" """
plotting routine for data on the grid. Requires cartopy for 3d plots. plotting routine for data on the grid. Requires cartopy for 3d plots.
""" """
......
This source diff could not be displayed because it is too large. You can view the blob instead.
# coding=utf-8 # coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The FourCastNet Authors. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
# #
# Redistribution and use in source and binary forms, with or without # Redistribution and use in source and binary forms, with or without
...@@ -33,7 +33,7 @@ from setuptools import setup ...@@ -33,7 +33,7 @@ from setuptools import setup
setup( setup(
name='torch_harmonics', name='torch_harmonics',
version='0.4', version='0.5',
author='Boris Bonev', author='Boris Bonev',
author_email='bbonev@nvidia.com', author_email='bbonev@nvidia.com',
packages=['torch_harmonics',], packages=['torch_harmonics',],
......
...@@ -29,86 +29,57 @@ ...@@ -29,86 +29,57 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# #
# ignore this (just for development without installation) # we need this in order to enable distributed
import sys
import os
sys.path.append("..")
sys.path.append(".")
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch_harmonics as harmonics
from torch_harmonics.distributed.primitives import gather_from_parallel_region, scatter_to_parallel_region
try:
from tqdm import tqdm
except:
tqdm = lambda x : x
# set up distributed # those need to be global
world_size = int(os.getenv('WORLD_SIZE', 1)) _POLAR_PARALLEL_GROUP = None
world_rank = int(os.getenv('WORLD_RANK', 0)) _AZIMUTH_PARALLEL_GROUP = None
port = int(os.getenv('MASTER_PORT', 0)) _IS_INITIALIZED = False
master_address = os.getenv('MASTER_ADDR', 'localhost')
dist.init_process_group(backend = 'nccl',
init_method = f"tcp://{master_address}:{port}",
rank = world_rank,
world_size = world_size)
local_rank = world_rank % torch.cuda.device_count()
mp_group = dist.new_group(ranks=list(range(world_size)))
my_rank = dist.get_rank(mp_group)
group_size = 1 if not dist.is_initialized() else dist.get_world_size(mp_group)
device = torch.device(f"cuda:{local_rank}")
# set seed def polar_group():
torch.manual_seed(333) return _POLAR_PARALLEL_GROUP
torch.cuda.manual_seed(333)
if my_rank == 0: def azimuth_group():
print(f"Running distributed test on {group_size} ranks.") return _AZIMUTH_PARALLEL_GROUP
# common parameters def init(polar_process_group, azimuth_process_group):
b, c, n_theta, n_lambda = 1, 21, 361, 720 global _POLAR_PARALLEL_GROUP
global _AZIMUTH_PARALLEL_GROUP
_POLAR_PARALLEL_GROUP = polar_process_group
_AZIMUTH_PARALLEL_GROUP = azimuth_process_group
_IS_INITIALIZED = True
# do serial tests first: def is_initialized() -> bool:
#forward_transform = harmonics.RealSHT(n_theta, n_lambda).to(device) return _IS_INITIALIZED
inverse_transform = harmonics.InverseRealSHT(n_theta, n_lambda).to(device)
# set up signal def is_distributed_polar() -> bool:
with torch.no_grad(): return (_POLAR_PARALLEL_GROUP is not None)
signal_leggauss = torch.randn(b, c, inverse_transform.lmax, inverse_transform.mmax, device=device, dtype=torch.complex128)
signal_leggauss_dist = signal_leggauss.clone()
signal_leggauss.requires_grad = True
# do a fwd and bwd pass: def is_distributed_azimuth() -> bool:
x_local = inverse_transform(signal_leggauss) return (_AZIMUTH_PARALLEL_GROUP is not None)
loss = torch.sum(x_local)
loss.backward()
local_grad = torch.view_as_real(signal_leggauss.grad.clone())
# now the distributed test def polar_group_size() -> int:
harmonics.distributed.init(mp_group) if not is_distributed_polar():
inverse_transform_dist = harmonics.InverseRealSHT(n_theta, n_lambda).to(device) return 1
with torch.no_grad(): else:
signal_leggauss_dist = scatter_to_parallel_region(signal_leggauss_dist, dim=2) return dist.get_world_size(group = _POLAR_PARALLEL_GROUP)
signal_leggauss_dist.requires_grad = True
# do distributed sht def azimuth_group_size() -> int:
x_dist = inverse_transform_dist(signal_leggauss_dist) if not is_distributed_azimuth():
loss = torch.sum(x_dist) return 1
loss.backward() else:
dist_grad = signal_leggauss_dist.grad.clone() return dist.get_world_size(group = _AZIMUTH_PARALLEL_GROUP)
# gather the output def polar_group_rank() -> int:
dist_grad = torch.view_as_real(gather_from_parallel_region(dist_grad, dim=2)) if not is_distributed_polar():
return 0
else:
return dist.get_rank(group = _POLAR_PARALLEL_GROUP)
if my_rank == 0: def azimuth_group_rank() -> int:
print(f"Local Out: sum={x_local.abs().sum().item()}, max={x_local.max().item()}, min={x_local.min().item()}") if not is_distributed_azimuth():
print(f"Dist Out: sum={x_dist.abs().sum().item()}, max={x_dist.max().item()}, min={x_dist.min().item()}") return 0
diff = (x_local-x_dist).abs() else:
print(f"Out Difference: abs={diff.sum().item()}, rel={diff.sum().item() / (0.5*(x_local.abs().sum() + x_dist.abs().sum()))}, max={diff.max().item()}") return dist.get_rank(group = _AZIMUTH_PARALLEL_GROUP)
print("")
print(f"Local Grad: sum={local_grad.abs().sum().item()}, max={local_grad.max().item()}, min={local_grad.min().item()}")
print(f"Dist Grad: sum={dist_grad.abs().sum().item()}, max={dist_grad.max().item()}, min={dist_grad.min().item()}")
diff = (local_grad-dist_grad).abs()
print(f"Grad Difference: abs={diff.sum().item()}, rel={diff.sum().item() / (0.5*(local_grad.abs().sum() + dist_grad.abs().sum()))}, max={diff.max().item()}")
...@@ -36,9 +36,10 @@ sys.path.append("..") ...@@ -36,9 +36,10 @@ sys.path.append("..")
sys.path.append(".") sys.path.append(".")
import torch import torch
import torch.nn.functional as F
import torch.distributed as dist import torch.distributed as dist
import torch_harmonics as harmonics import torch_harmonics as harmonics
from torch_harmonics.distributed.primitives import gather_from_parallel_region import torch_harmonics.distributed as thd
try: try:
from tqdm import tqdm from tqdm import tqdm
...@@ -46,70 +47,169 @@ except: ...@@ -46,70 +47,169 @@ except:
tqdm = lambda x : x tqdm = lambda x : x
# set up distributed # set up distributed
world_size = int(os.getenv('WORLD_SIZE', 1))
world_rank = int(os.getenv('WORLD_RANK', 0)) world_rank = int(os.getenv('WORLD_RANK', 0))
grid_size_h = int(os.getenv('GRID_H', 1))
grid_size_w = int(os.getenv('GRID_W', 1))
port = int(os.getenv('MASTER_PORT', 0)) port = int(os.getenv('MASTER_PORT', 0))
master_address = os.getenv('MASTER_ADDR', 'localhost') master_address = os.getenv('MASTER_ADDR', 'localhost')
world_size = grid_size_h * grid_size_w
dist.init_process_group(backend = 'nccl', dist.init_process_group(backend = 'nccl',
init_method = f"tcp://{master_address}:{port}", init_method = f"tcp://{master_address}:{port}",
rank = world_rank, rank = world_rank,
world_size = world_size) world_size = world_size)
local_rank = world_rank % torch.cuda.device_count() local_rank = world_rank % torch.cuda.device_count()
mp_group = dist.new_group(ranks=list(range(world_size)))
my_rank = dist.get_rank(mp_group)
group_size = 1 if not dist.is_initialized() else dist.get_world_size(mp_group)
device = torch.device(f"cuda:{local_rank}") device = torch.device(f"cuda:{local_rank}")
# compute local ranks in h and w:
# rank = wrank + grid_size_w * hrank
wrank = world_rank % grid_size_w
hrank = world_rank // grid_size_w
w_group = None
h_group = None
# now set up the comm grid:
wgroups = []
for h in range(grid_size_h):
start = h
end = h + grid_size_w
wgroups.append(list(range(start, end)))
print(wgroups)
for grp in wgroups:
if len(grp) == 1:
continue
tmp_group = dist.new_group(ranks=grp)
if wrank in grp:
w_group = tmp_group
# transpose:
hgroups = [sorted(list(i)) for i in zip(*wgroups)]
print(hgroups)
for grp in hgroups:
if len(grp) == 1:
continue
tmp_group = dist.new_group(ranks=grp)
if hrank in grp:
h_group = tmp_group
# set device
torch.cuda.set_device(device.index)
# set seed # set seed
torch.manual_seed(333) torch.manual_seed(333)
torch.cuda.manual_seed(333) torch.cuda.manual_seed(333)
if my_rank == 0: if world_rank == 0:
print(f"Running distributed test on {group_size} ranks.") print(f"Running distributed test on grid H x W = {grid_size_h} x {grid_size_w}")
# initializing sht
thd.init(h_group, w_group)
# common parameters # common parameters
b, c, n_theta, n_lambda = 1, 21, 361, 720 B, C, H, W = 1, 8, 721, 1440
Hloc = (H + grid_size_h - 1) // grid_size_h
Wloc = (W + grid_size_w - 1) // grid_size_w
Hpad = grid_size_h * Hloc - H
Wpad = grid_size_w * Wloc - W
# do serial tests first: # do serial tests first:
forward_transform = harmonics.RealSHT(n_theta, n_lambda).to(device) forward_transform_local = harmonics.RealSHT(nlat=H, nlon=W).to(device)
inverse_transform = harmonics.InverseRealSHT(n_theta, n_lambda).to(device) forward_transform_dist = thd.DistributedRealSHT(nlat=H, nlon=W).to(device)
Lloc = (forward_transform_dist.lpad + forward_transform_dist.lmax) // grid_size_h
Mloc = (forward_transform_dist.mpad + forward_transform_dist.mmax) // grid_size_w
# create tensors
inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=device)
# set up signal # pad
with torch.no_grad(): with torch.no_grad():
signal_leggauss = inverse_transform(torch.randn(b, c, forward_transform.lmax, forward_transform.mmax, device=device, dtype=torch.complex128)) inp_pad = F.pad(inp_full, (0, Wpad, 0, Hpad))
signal_leggauss_dist = signal_leggauss.clone()
signal_leggauss.requires_grad = True # split in W
signal_leggauss_dist.requires_grad = True inp_local = torch.split(inp_pad, split_size_or_sections=Wloc, dim=-1)[wrank]
# do a fwd and bwd pass: # split in H
x_local = forward_transform(signal_leggauss) inp_local = torch.split(inp_local, split_size_or_sections=Hloc, dim=-2)[hrank]
loss = torch.sum(torch.view_as_real(x_local))
loss.backward() # do FWD transform
x_local = torch.view_as_real(x_local) out_full = forward_transform_local(inp_full)
local_grad = signal_leggauss.grad.clone() out_local = forward_transform_dist(inp_local)
# now the distributed test # gather the local data
harmonics.distributed.init(mp_group) # gather in W
forward_transform_dist = harmonics.RealSHT(n_theta, n_lambda).to(device) if grid_size_w > 1:
inverse_transform_dist = harmonics.InverseRealSHT(n_theta, n_lambda).to(device) olist = [torch.empty_like(out_local) for _ in range(grid_size_w)]
olist[wrank] = out_local
# do distributed sht dist.all_gather(olist, out_local, group=w_group)
x_dist = forward_transform_dist(signal_leggauss_dist) out_full_gather = torch.cat(olist, dim=-1)
loss = torch.sum(torch.view_as_real(x_dist)) out_full_gather = out_full_gather[..., :forward_transform_dist.mmax]
loss.backward() else:
x_dist = torch.view_as_real(x_dist) out_full_gather = out_local
dist_grad = signal_leggauss_dist.grad.clone()
# gather in h
# gather the output if grid_size_h > 1:
x_dist = gather_from_parallel_region(x_dist, dim=2) olist = [torch.empty_like(out_full_gather) for _ in range(grid_size_h)]
olist[hrank] = out_full_gather
if my_rank == 0: dist.all_gather(olist, out_full_gather, group=h_group)
print(f"Local Out: sum={x_local.abs().sum().item()}, max={x_local.max().item()}, min={x_local.min().item()}") out_full_gather = torch.cat(olist, dim=-2)
print(f"Dist Out: sum={x_dist.abs().sum().item()}, max={x_dist.max().item()}, min={x_dist.min().item()}") out_full_gather = out_full_gather[..., :forward_transform_dist.lmax, :]
diff = (x_local-x_dist).abs()
print(f"Out Difference: abs={diff.sum().item()}, rel={diff.sum().item() / (0.5*(x_local.abs().sum() + x_dist.abs().sum()))}, max={diff.max().item()}")
if world_rank == 0:
print(f"Local Out: sum={out_full.abs().sum().item()}, max={out_full.abs().max().item()}, min={out_full.abs().min().item()}")
print(f"Dist Out: sum={out_full_gather.abs().sum().item()}, max={out_full_gather.abs().max().item()}, min={out_full_gather.abs().min().item()}")
diff = (out_full-out_full_gather).abs()
print(f"Out Difference: abs={diff.sum().item()}, rel={diff.sum().item() / (0.5*(out_full.abs().sum() + out_full_gather.abs().sum()))}, max={diff.abs().max().item()}")
print("") print("")
print(f"Local Grad: sum={local_grad.abs().sum().item()}, max={local_grad.max().item()}, min={local_grad.min().item()}")
print(f"Dist Grad: sum={dist_grad.abs().sum().item()}, max={dist_grad.max().item()}, min={dist_grad.min().item()}") # create split input grad
diff = (local_grad-dist_grad).abs() with torch.no_grad():
print(f"Grad Difference: abs={diff.sum().item()}, rel={diff.sum().item() / (0.5*(local_grad.abs().sum() + dist_grad.abs().sum()))}, max={diff.max().item()}") # create full grad
ograd_full = torch.randn_like(out_full)
# pad
ograd_pad = F.pad(ograd_full, [0, forward_transform_dist.mpad, 0, forward_transform_dist.lpad])
# split in M
ograd_local = torch.split(ograd_pad, split_size_or_sections=Mloc, dim=-1)[wrank]
# split in H
ograd_local = torch.split(ograd_local, split_size_or_sections=Lloc, dim=-2)[hrank]
# backward pass:
# local
inp_full.requires_grad = True
out_full = forward_transform_local(inp_full)
out_full.backward(ograd_full)
igrad_full = inp_full.grad.clone()
# distributed
inp_local.requires_grad = True
out_local = forward_transform_dist(inp_local)
out_local.backward(ograd_local)
igrad_local = inp_local.grad.clone()
# gather
# gather in W
if grid_size_w > 1:
olist = [torch.empty_like(igrad_local) for _ in range(grid_size_w)]
olist[wrank] = igrad_local
dist.all_gather(olist, igrad_local, group=w_group)
igrad_full_gather = torch.cat(olist, dim=-1)
igrad_full_gather = igrad_full_gather[..., :W]
else:
igrad_full_gather = igrad_local
# gather in h
if grid_size_h > 1:
olist = [torch.empty_like(igrad_full_gather) for _ in range(grid_size_h)]
olist[hrank] = igrad_full_gather
dist.all_gather(olist, igrad_full_gather, group=h_group)
igrad_full_gather = torch.cat(olist, dim=-2)
igrad_full_gather = igrad_full_gather[..., :H, :]
if world_rank == 0:
print(f"Local Grad: sum={igrad_full.abs().sum().item()}, max={igrad_full.abs().max().item()}, min={igrad_full.abs().min().item()}")
print(f"Dist Grad: sum={igrad_full_gather.abs().sum().item()}, max={igrad_full_gather.abs().max().item()}, min={igrad_full_gather.abs().min().item()}")
diff = (igrad_full-igrad_full_gather).abs()
print(f"Grad Difference: abs={diff.sum().item()}, rel={diff.sum().item() / (0.5*(igrad_full.abs().sum() + igrad_full_gather.abs().sum()))}, max={diff.abs().max().item()}")
...@@ -31,3 +31,4 @@ ...@@ -31,3 +31,4 @@
from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT
from . import quadrature from . import quadrature
from . import random_fields
...@@ -30,5 +30,10 @@ ...@@ -30,5 +30,10 @@
# #
# we need this in order to enable distributed # we need this in order to enable distributed
from .utils import init, is_initialized from .utils import init, is_initialized, polar_group, azimuth_group
from .primitives import copy_to_parallel_region, scatter_to_parallel_region, reduce_from_parallel_region from .utils import polar_group_size, azimuth_group_size, polar_group_rank, azimuth_group_rank
from .primitives import distributed_transpose_azimuth, distributed_transpose_polar
# import the sht stuff
from .distributed_sht import DistributedRealSHT, DistributedInverseRealSHT
from .distributed_sht import DistributedRealVectorSHT, DistributedInverseRealVectorSHT
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import os
import numpy as np
import torch
import torch.nn as nn
import torch.fft
import torch.nn.functional as F
from torch_harmonics.quadrature import *
from torch_harmonics.legendre import *
from torch_harmonics.distributed import polar_group_size, azimuth_group_size, distributed_transpose_azimuth, distributed_transpose_polar
from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank
class DistributedRealSHT(nn.Module):
"""
Defines a module for computing the forward (real-valued) SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
The SHT is applied to the last two dimensions of the input
[1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
"""
def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho", csphase=True):
"""
Initializes the SHT Layer, precomputing the necessary quadrature weights
Parameters:
nlat: input grid resolution in the latitudinal direction
nlon: input grid resolution in the longitudinal direction
grid: grid in the latitude direction (for now only tensor product grids are supported)
"""
super().__init__()
self.nlat = nlat
self.nlon = nlon
self.grid = grid
self.norm = norm
self.csphase = csphase
# TODO: include assertions regarding the dimensions
# compute quadrature points
if self.grid == "legendre-gauss":
cost, w = legendre_gauss_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat
elif self.grid == "lobatto":
cost, w = lobatto_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat-1
elif self.grid == "equiangular":
cost, w = clenshaw_curtiss_weights(nlat, -1, 1)
# cost, w = fejer2_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat
else:
raise(ValueError("Unknown quadrature mode"))
# get the comms grid:
self.comm_size_polar = polar_group_size()
self.comm_rank_polar = polar_group_rank()
self.comm_size_azimuth = azimuth_group_size()
self.comm_rank_azimuth = azimuth_group_rank()
# apply cosine transform and flip them
tq = np.flip(np.arccos(cost))
# determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1
# spatial paddings
latdist = (self.nlat + self.comm_size_polar - 1) // self.comm_size_polar
self.nlatpad = latdist * self.comm_size_polar - self.nlat
londist = (self.nlon + self.comm_size_azimuth - 1) // self.comm_size_azimuth
self.nlonpad = londist * self.comm_size_azimuth - self.nlon
# frequency paddings
ldist = (self.lmax + self.comm_size_polar - 1) // self.comm_size_polar
self.lpad = ldist * self.comm_size_polar - self.lmax
mdist = (self.mmax + self.comm_size_azimuth - 1) // self.comm_size_azimuth
self.mpad = mdist * self.comm_size_azimuth - self.mmax
# combine quadrature weights with the legendre weights
weights = torch.from_numpy(w)
pct = precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
weights = torch.einsum('mlk,k->mlk', pct, weights)
# we need to split in m, pad before:
weights = F.pad(weights, [0, 0, 0, 0, 0, self.mpad], mode="constant")
weights = torch.split(weights, (self.mmax+self.mpad) // self.comm_size_azimuth, dim=0)[self.comm_rank_azimuth]
# compute the local pad and size
# spatial
self.nlat_local = min(latdist, self.nlat - self.comm_rank_polar * latdist)
self.nlatpad_local = latdist - self.nlat_local
self.nlon_local = min(londist, self.nlon - self.comm_rank_azimuth * londist)
self.nlonpad_local = londist - self.nlon_local
# frequency
self.lmax_local = min(ldist, self.lmax - self.comm_rank_polar * ldist)
self.lpad_local = ldist - self.lmax_local
self.mmax_local = min(mdist, self.mmax - self.comm_rank_azimuth * mdist)
self.mpad_local = mdist - self.mmax_local
# remember quadrature weights
self.register_buffer('weights', weights, persistent=False)
def extra_repr(self):
"""
Pretty print module
"""
return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'
def forward(self, x: torch.Tensor):
# we need to ensure that we can split the channels evenly
assert(x.shape[1] % self.comm_size_polar == 0)
assert(x.shape[1] % self.comm_size_azimuth == 0)
# h and w is split. First we make w local by transposing into channel dim
if self.comm_size_azimuth > 1:
xt = distributed_transpose_azimuth.apply(x, (1, -1))
else:
xt = x
# apply real fft in the longitudinal direction: make sure to truncate to nlon
xtf = 2.0 * torch.pi * torch.fft.rfft(xt, n=self.nlon, dim=-1, norm="forward")
# truncate
xtft = xtf[..., :self.mmax]
# pad the dim to allow for splitting
xtfp = F.pad(xtft, [0, self.mpad], mode="constant")
# transpose: after this, m is split and c is local
if self.comm_size_azimuth > 1:
y = distributed_transpose_azimuth.apply(xtfp, (-1, 1))
else:
y = xtfp
# transpose: after this, c is split and h is local
if self.comm_size_polar > 1:
yt = distributed_transpose_polar.apply(y, (1, -2))
else:
yt = y
# the input data might be padded, make sure to truncate to nlat:
ytt = yt[..., :self.nlat, :]
# do the Legendre-Gauss quadrature
yttr = torch.view_as_real(ytt)
# contraction
yor = torch.einsum('...kmr,mlk->...lmr', yttr, self.weights.to(yttr.dtype)).contiguous()
# pad if required, truncation is implicit
yopr = F.pad(yor, [0, 0, 0, 0, 0, self.lpad], mode="constant")
yop = torch.view_as_complex(yopr)
# transpose: after this, l is split and c is local
if self.comm_size_polar > 1:
y = distributed_transpose_polar.apply(yop, (-2, 1))
else:
y = yop
return y
class DistributedInverseRealSHT(nn.Module):
"""
Defines a module for computing the inverse (real-valued) SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
nlat, nlon: Output dimensions
lmax, mmax: Input dimensions (spherical coefficients). For convenience, these are inferred from the output dimensions
[1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
"""
def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho", csphase=True):
super().__init__()
self.nlat = nlat
self.nlon = nlon
self.grid = grid
self.norm = norm
self.csphase = csphase
# compute quadrature points
if self.grid == "legendre-gauss":
cost, _ = legendre_gauss_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat
elif self.grid == "lobatto":
cost, _ = lobatto_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat-1
elif self.grid == "equiangular":
cost, _ = clenshaw_curtiss_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat
else:
raise(ValueError("Unknown quadrature mode"))
# get the comms grid:
self.comm_size_polar = polar_group_size()
self.comm_rank_polar = polar_group_rank()
self.comm_size_azimuth = azimuth_group_size()
self.comm_rank_azimuth = azimuth_group_rank()
# apply cosine transform and flip them
t = np.flip(np.arccos(cost))
# determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1
# spatial paddings
latdist = (self.nlat + self.comm_size_polar - 1) // self.comm_size_polar
self.nlatpad = latdist * self.comm_size_polar - self.nlat
londist = (self.nlon + self.comm_size_azimuth - 1) // self.comm_size_azimuth
self.nlonpad = londist * self.comm_size_azimuth - self.nlon
# frequency paddings
ldist = (self.lmax + self.comm_size_polar - 1) // self.comm_size_polar
self.lpad = ldist * self.comm_size_polar - self.lmax
mdist = (self.mmax + self.comm_size_azimuth - 1) // self.comm_size_azimuth
self.mpad = mdist * self.comm_size_azimuth - self.mmax
# compute legende polynomials
pct = precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
# split in m
pct = F.pad(pct, [0, 0, 0, 0, 0, self.mpad], mode="constant")
pct = torch.split(pct, (self.mmax+self.mpad) // self.comm_size_azimuth, dim=0)[self.comm_rank_azimuth]
# compute the local pads and sizes
# spatial
self.nlat_local = min(latdist, self.nlat - self.comm_rank_polar * latdist)
self.nlatpad_local = latdist - self.nlat_local
self.nlon_local = min(londist, self.nlon - self.comm_rank_azimuth * londist)
self.nlonpad_local = londist - self.nlon_local
# frequency
self.lmax_local = min(ldist, self.lmax - self.comm_rank_polar * ldist)
self.lpad_local = ldist - self.lmax_local
self.mmax_local = min(mdist, self.mmax - self.comm_rank_azimuth * mdist)
self.mpad_local = mdist - self.mmax_local
# register
self.register_buffer('pct', pct, persistent=False)
def extra_repr(self):
"""
Pretty print module
"""
return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'
def forward(self, x: torch.Tensor):
# we need to ensure that we can split the channels evenly
assert(x.shape[1] % self.comm_size_polar == 0)
assert(x.shape[1] % self.comm_size_azimuth == 0)
# transpose: after that, channels are split, l is local:
if self.comm_size_polar > 1:
xt = distributed_transpose_polar.apply(x, (1, -2))
else:
xt = x
# remove padding in l:
xtt = xt[..., :self.lmax, :]
# Evaluate associated Legendre functions on the output nodes
xttr = torch.view_as_real(xtt)
# einsum
xs = torch.einsum('...lmr, mlk->...kmr', xttr, self.pct.to(xttr.dtype)).contiguous()
x = torch.view_as_complex(xs)
# transpose: after this, l is split and channels are local
xp = F.pad(x, [0, 0, 0, self.nlatpad])
if self.comm_size_polar > 1:
y = distributed_transpose_polar.apply(xp, (-2, 1))
else:
y = xp
# transpose: after this, channels are split and m is local
if self.comm_size_azimuth > 1:
yt = distributed_transpose_azimuth.apply(y, (1, -1))
else:
yt = y
# truncate
ytt = yt[..., :self.mmax]
# apply the inverse (real) FFT
x = torch.fft.irfft(ytt, n=self.nlon, dim=-1, norm="forward")
# pad before we transpose back
xp = F.pad(x, [0, self.nlonpad])
# transpose: after this, m is split and channels are local
if self.comm_size_azimuth > 1:
out = distributed_transpose_azimuth.apply(xp, (-1, 1))
else:
out = xp
return out
class DistributedRealVectorSHT(nn.Module):
"""
Defines a module for computing the forward (real) vector SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
The SHT is applied to the last three dimensions of the input.
[1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
"""
def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho", csphase=True):
"""
Initializes the vector SHT Layer, precomputing the necessary quadrature weights
Parameters:
nlat: input grid resolution in the latitudinal direction
nlon: input grid resolution in the longitudinal direction
grid: type of grid the data lives on
"""
super().__init__()
self.nlat = nlat
self.nlon = nlon
self.grid = grid
self.norm = norm
self.csphase = csphase
# compute quadrature points
if self.grid == "legendre-gauss":
cost, w = legendre_gauss_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat
elif self.grid == "lobatto":
cost, w = lobatto_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat-1
elif self.grid == "equiangular":
cost, w = clenshaw_curtiss_weights(nlat, -1, 1)
# cost, w = fejer2_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat
else:
raise(ValueError("Unknown quadrature mode"))
# get the comms grid:
self.comm_size_polar = polar_group_size()
self.comm_rank_polar = polar_group_rank()
self.comm_size_azimuth = azimuth_group_size()
self.comm_rank_azimuth = azimuth_group_rank()
# apply cosine transform and flip them
tq = np.flip(np.arccos(cost))
# determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1
# spatial paddings
latdist = (self.nlat + self.comm_size_polar - 1) // self.comm_size_polar
self.nlatpad = latdist * self.comm_size_polar - self.nlat
londist = (self.nlon + self.comm_size_azimuth - 1) // self.comm_size_azimuth
self.nlonpad = londist * self.comm_size_azimuth - self.nlon
# frequency paddings
ldist = (self.lmax + self.comm_size_polar - 1) // self.comm_size_polar
self.lpad = ldist * self.comm_size_polar - self.lmax
mdist = (self.mmax + self.comm_size_azimuth - 1) // self.comm_size_azimuth
self.mpad = mdist * self.comm_size_azimuth - self.mmax
weights = torch.from_numpy(w)
dpct = precompute_dlegpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
# combine integration weights, normalization factor in to one:
l = torch.arange(0, self.lmax)
norm_factor = 1. / l / (l+1)
norm_factor[0] = 1.
weights = torch.einsum('dmlk,k,l->dmlk', dpct, weights, norm_factor)
# since the second component is imaginary, we need to take complex conjugation into account
weights[1] = -1 * weights[1]
# we need to split in m, pad before:
weights = F.pad(weights, [0, 0, 0, 0, 0, self.mpad], mode="constant")
weights = torch.split(weights, (self.mmax+self.mpad) // self.comm_size_azimuth, dim=1)[self.comm_rank_azimuth]
# remember quadrature weights
self.register_buffer('weights', weights, persistent=False)
# compute the local pad and size
# spatial
self.nlat_local = min(latdist, self.nlat - self.comm_rank_polar * latdist)
self.nlatpad_local = latdist - self.nlat_local
self.nlon_local = min(londist, self.nlon - self.comm_rank_azimuth * londist)
self.nlonpad_local = londist - self.nlon_local
# frequency
self.lmax_local = min(ldist, self.lmax - self.comm_rank_polar * ldist)
self.lpad_local = ldist - self.lmax_local
self.mmax_local = min(mdist, self.mmax - self.comm_rank_azimuth * mdist)
self.mpad_local = mdist - self.mmax_local
def extra_repr(self):
"""
Pretty print module
"""
return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'
def forward(self, x: torch.Tensor):
assert(len(x.shape) >= 3)
assert(x.shape[1] % self.comm_size_polar == 0)
assert(x.shape[1] % self.comm_size_azimuth == 0)
# h and w is split. First we make w local by transposing into channel dim
if self.comm_size_azimuth > 1:
xt = distributed_transpose_azimuth.apply(x, (1, -1))
else:
xt = x
# apply real fft in the longitudinal direction: make sure to truncate to nlon
xtf = 2.0 * torch.pi * torch.fft.rfft(xt, n=self.nlon, dim=-1, norm="forward")
# truncate
xtft = xtf[..., :self.mmax]
# pad the dim to allow for splitting
xtfp = F.pad(xtft, [0, self.mpad], mode="constant")
# transpose: after this, m is split and c is local
if self.comm_size_azimuth > 1:
y = distributed_transpose_azimuth.apply(xtfp, (-1, 1))
else:
y = xtfp
# transpose: after this, c is split and h is local
if self.comm_size_polar > 1:
yt = distributed_transpose_polar.apply(y, (1, -2))
else:
yt = y
# the input data might be padded, make sure to truncate to nlat:
ytt = yt[..., :self.nlat, :]
# do the Legendre-Gauss quadrature
yttr = torch.view_as_real(ytt)
# create output array
yor = torch.zeros_like(yttr, dtype=yttr.dtype, device=yttr.device)
# contraction - spheroidal component
# real component
yor[..., 0, :, :, 0] = torch.einsum('...km,mlk->...lm', yttr[..., 0, :, :, 0], self.weights[0].to(yttr.dtype)) \
- torch.einsum('...km,mlk->...lm', yttr[..., 1, :, :, 1], self.weights[1].to(yttr.dtype))
# iamg component
yor[..., 0, :, :, 1] = torch.einsum('...km,mlk->...lm', yttr[..., 0, :, :, 1], self.weights[0].to(yttr.dtype)) \
+ torch.einsum('...km,mlk->...lm', yttr[..., 1, :, :, 0], self.weights[1].to(yttr.dtype))
# contraction - toroidal component
# real component
yor[..., 1, :, :, 0] = - torch.einsum('...km,mlk->...lm', yttr[..., 0, :, :, 1], self.weights[1].to(yttr.dtype)) \
- torch.einsum('...km,mlk->...lm', yttr[..., 1, :, :, 0], self.weights[0].to(yttr.dtype))
# imag component
yor[..., 1, :, :, 1] = torch.einsum('...km,mlk->...lm', yttr[..., 0, :, :, 0], self.weights[1].to(yttr.dtype)) \
- torch.einsum('...km,mlk->...lm', yttr[..., 1, :, :, 1], self.weights[0].to(yttr.dtype))
# pad if required
yopr = F.pad(yor, [0, 0, 0, 0, 0, self.lpad], mode="constant")
yop = torch.view_as_complex(yopr)
# transpose: after this, l is split and c is local
if self.comm_size_polar > 1:
y = distributed_transpose_polar.apply(yop, (-2, 1))
else:
y = yop
return y
class DistributedInverseRealVectorSHT(nn.Module):
"""
Defines a module for computing the inverse (real-valued) vector SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
[1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
"""
def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="lobatto", norm="ortho", csphase=True):
super().__init__()
self.nlat = nlat
self.nlon = nlon
self.grid = grid
self.norm = norm
self.csphase = csphase
# compute quadrature points
if self.grid == "legendre-gauss":
cost, _ = legendre_gauss_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat
elif self.grid == "lobatto":
cost, _ = lobatto_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat-1
elif self.grid == "equiangular":
cost, _ = clenshaw_curtiss_weights(nlat, -1, 1)
self.lmax = lmax or self.nlat
else:
raise(ValueError("Unknown quadrature mode"))
self.comm_size_polar = polar_group_size()
self.comm_rank_polar = polar_group_rank()
self.comm_size_azimuth = azimuth_group_size()
self.comm_rank_azimuth = azimuth_group_rank()
# apply cosine transform and flip them
t = np.flip(np.arccos(cost))
# determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1
# spatial paddings
latdist = (self.nlat + self.comm_size_polar - 1) // self.comm_size_polar
self.nlatpad = latdist * self.comm_size_polar - self.nlat
londist = (self.nlon + self.comm_size_azimuth - 1) // self.comm_size_azimuth
self.nlonpad = londist * self.comm_size_azimuth - self.nlon
# frequency paddings
ldist = (self.lmax + self.comm_size_polar - 1) // self.comm_size_polar
self.lpad = ldist * self.comm_size_polar - self.lmax
mdist = (self.mmax + self.comm_size_azimuth - 1) // self.comm_size_azimuth
self.mpad = mdist * self.comm_size_azimuth - self.mmax
# compute legende polynomials
dpct = precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
# split in m
pct = F.pad(pct, [0, 0, 0, 0, 0, self.mpad], mode="constant")
pct = torch.split(pct, (self.mmax+self.mpad) // self.comm_size_azimuth, dim=0)[self.comm_rank_azimuth]
# register buffer
self.register_buffer('dpct', dpct, persistent=False)
# compute the local pad and size
# spatial
self.nlat_local = min(latdist, self.nlat - self.comm_rank_polar * latdist)
self.nlatpad_local = latdist - self.nlat_local
self.nlon_local = min(londist, self.nlon - self.comm_rank_azimuth * londist)
self.nlonpad_local = londist - self.nlon_local
# frequency
self.lmax_local = min(ldist, self.lmax - self.comm_rank_polar * ldist)
self.lpad_local = ldist - self.lmax_local
self.mmax_local = min(mdist, self.mmax - self.comm_rank_azimuth * mdist)
self.mpad_local = mdist - self.mmax_local
def extra_repr(self):
"""
Pretty print module
"""
return f'nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}'
def forward(self, x: torch.Tensor):
assert(x.shape[1] % self.comm_size_polar == 0)
assert(x.shape[1] % self.comm_size_azimuth == 0)
# transpose: after that, channels are split, l is local:
if self.comm_size_polar > 1:
xt = distributed_transpose_polar.apply(x, (1, -2))
else:
xt = x
# remove padding in l:
xtt = xt[..., :self.lmax, :]
# Evaluate associated Legendre functions on the output nodes
xttr = torch.view_as_real(xtt)
# contraction - spheroidal component
# real component
srl = torch.einsum('...lm,mlk->...km', xttr[..., 0, :, :, 0], self.dpct[0].to(xttr.dtype)) \
- torch.einsum('...lm,mlk->...km', xttr[..., 1, :, :, 1], self.dpct[1].to(xttr.dtype))
# imag component
sim = torch.einsum('...lm,mlk->...km', xttr[..., 0, :, :, 1], self.dpct[0].to(xttr.dtype)) \
+ torch.einsum('...lm,mlk->...km', xttr[..., 1, :, :, 0], self.dpct[1].to(xttr.dtype))
# contraction - toroidal component
# real component
trl = - torch.einsum('...lm,mlk->...km', xttr[..., 0, :, :, 1], self.dpct[1].to(xttr.dtype)) \
- torch.einsum('...lm,mlk->...km', xttr[..., 1, :, :, 0], self.dpct[0].to(xttr.dtype))
# imag component
tim = torch.einsum('...lm,mlk->...km', xttr[..., 0, :, :, 0], self.dpct[1].to(xttr.dtype)) \
- torch.einsum('...lm,mlk->...km', xttr[..., 1, :, :, 1], self.dpct[0].to(xttr.dtype))
# reassemble
s = torch.stack((srl, sim), -1)
t = torch.stack((trl, tim), -1)
xs = torch.stack((s, t), -4)
# convert to complex
x = torch.view_as_complex(xs)
# transpose: after this, l is split and channels are local
xp = F.pad(x, [0, 0, 0, self.nlatpad])
if self.comm_size_polar > 1:
y = distributed_transpose_polar.apply(xp, (-2, 1))
else:
y = xp
# transpose: after this, channels are split and m is local
if self.comm_size_azimuth > 1:
yt = distributed_transpose_azimuth.apply(y, (1, -1))
else:
yt = y
# truncate
ytt = yt[..., :self.mmax]
# apply the inverse (real) FFT
x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
# pad before we transpose back
xp = F.pad(x, [0, self.nlonpad])
# transpose: after this, m is split and channels are local
if self.comm_size_azimuth > 1:
out = distributed_transpose_azimuth.apply(xp, (-1, 1))
else:
out = xp
return out
\ No newline at end of file
...@@ -32,7 +32,7 @@ ...@@ -32,7 +32,7 @@
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from .utils import get_model_parallel_group, is_initialized from .utils import polar_group, azimuth_group, is_initialized
# general helpers # general helpers
def get_memory_format(tensor): def get_memory_format(tensor):
...@@ -50,163 +50,54 @@ def split_tensor_along_dim(tensor, dim, num_chunks): ...@@ -50,163 +50,54 @@ def split_tensor_along_dim(tensor, dim, num_chunks):
return tensor_list return tensor_list
# split def _transpose(tensor, dim0, dim1, group=None, async_op=False):
def _split(input_, dim_, group=None):
"""Split the tensor along its last dimension and keep the corresponding slice."""
# get input format # get input format
input_format = get_memory_format(input_) input_format = get_memory_format(tensor)
# Bypass the function if we are using only 1 GPU. # get comm params
comm_size = dist.get_world_size(group=group) comm_size = dist.get_world_size(group=group)
if comm_size == 1:
return input_
# Split along last dimension. # split and local transposition
input_list = split_tensor_along_dim(input_, dim_, comm_size) split_size = tensor.shape[dim0] // comm_size
x_send = [y.contiguous(memory_format=input_format) for y in torch.split(tensor, split_size, dim=dim0)]
x_recv = [torch.empty_like(x_send[0]).contiguous(memory_format=input_format) for _ in range(comm_size)]
# Note: torch.split does not create contiguous tensors by default. # global transposition
rank = dist.get_rank(group=group) req = dist.all_to_all(x_recv, x_send, group=group, async_op=async_op)
output = input_list[rank].contiguous(memory_format=input_format)
return output return x_recv, req
# those are used by the various helper functions class distributed_transpose_azimuth(torch.autograd.Function):
def _reduce(input_, use_fp32=True, group=None):
"""All-reduce the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
if dist.get_world_size(group=group) == 1:
return input_
# All-reduce.
if use_fp32:
dtype = input_.dtype
inputf_ = input_.float()
dist.all_reduce(inputf_, group=group)
input_ = inputf_.to(dtype)
else:
dist.all_reduce(input_, group=group)
return input_
class _CopyToParallelRegion(torch.autograd.Function):
"""Pass the input to the parallel region."""
@staticmethod
def symbolic(graph, input_):
return input_
@staticmethod
def forward(ctx, input_):
return input_
@staticmethod
def backward(ctx, grad_output):
return _reduce(grad_output, group=get_model_parallel_group())
# write a convenient functional wrapper
def copy_to_parallel_region(input_):
if not is_initialized():
return input_
else:
return _CopyToParallelRegion.apply(input_)
# reduce
class _ReduceFromParallelRegion(torch.autograd.Function):
"""All-reduce the input from the parallel region."""
@staticmethod
def symbolic(graph, input_):
return _reduce(input_, group=get_model_parallel_group())
@staticmethod
def forward(ctx, input_):
return _reduce(input_, group=get_model_parallel_group())
@staticmethod
def backward(ctx, grad_output):
return grad_output
def reduce_from_parallel_region(input_):
if not is_initialized():
return input_
else:
return _ReduceFromParallelRegion.apply(input_)
# gather
def _gather(input_, dim_, group=None):
"""Gather tensors and concatinate along the last dimension."""
# get input format
input_format = get_memory_format(input_)
print(input_format)
comm_size = dist.get_world_size(group=group)
# Bypass the function if we are using only 1 GPU.
if comm_size==1:
return input_
# sanity checks
assert(dim_ < input_.dim()), f"Error, cannot gather along {dim_} for tensor with {input_.dim()} dimensions."
# Size and dimension.
comm_rank = dist.get_rank(group=group)
# input needs to be contiguous
input_ = input_.contiguous(memory_format=input_format)
tensor_list = [torch.empty_like(input_) for _ in range(comm_size)]
tensor_list[comm_rank] = input_
dist.all_gather(tensor_list, input_, group=group)
output = torch.cat(tensor_list, dim=dim_).contiguous(memory_format=input_format)
return output
class _GatherFromParallelRegion(torch.autograd.Function):
"""Gather the input from parallel region and concatinate."""
@staticmethod @staticmethod
def symbolic(graph, input_, dim_): def forward(ctx, x, dim):
return _gather(input_, dim_, group=get_model_parallel_group()) xlist, _ = _transpose(x, dim[0], dim[1], group=azimuth_group())
x = torch.cat(xlist, dim=dim[1])
ctx.dim = dim
return x
@staticmethod @staticmethod
def forward(ctx, input_, dim_): def backward(ctx, go):
ctx.dim = dim_ dim = ctx.dim
return _gather(input_, dim_, group=get_model_parallel_group()) gilist, _ = _transpose(go, dim[1], dim[0], group=azimuth_group())
gi = torch.cat(gilist, dim=dim[0])
@staticmethod return gi, None
def backward(ctx, grad_output):
return _split(grad_output, ctx.dim, group=get_model_parallel_group()), None
def gather_from_parallel_region(input_, dim):
if not is_initialized():
return input_
else:
return _GatherFromParallelRegion.apply(input_, dim)
# scatter
class _ScatterToParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
@staticmethod class distributed_transpose_polar(torch.autograd.Function):
def symbolic(graph, input_, dim_):
return _split(input_, dim_, group=get_model_parallel_group())
@staticmethod @staticmethod
def forward(ctx, input_, dim_): def forward(ctx, x, dim):
ctx.dim = dim_ xlist, _ = _transpose(x, dim[0], dim[1], group=polar_group())
return _split(input_, dim_, group=get_model_parallel_group()) x = torch.cat(xlist, dim=dim[1])
ctx.dim = dim
return x
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, go):
return _gather(grad_output, ctx.dim, group=get_model_parallel_group()), None dim = ctx.dim
gilist, _ = _transpose(go, dim[1], dim[0], group=polar_group())
def scatter_to_parallel_region(input_, dim): gi = torch.cat(gilist, dim=dim[0])
if not is_initialized(): return gi, None
return input_
else:
return _ScatterToParallelRegion.apply(input_, dim)
...@@ -34,15 +34,52 @@ import torch ...@@ -34,15 +34,52 @@ import torch
import torch.distributed as dist import torch.distributed as dist
# those need to be global # those need to be global
_MODEL_PARALLEL_GROUP = None _POLAR_PARALLEL_GROUP = None
_AZIMUTH_PARALLEL_GROUP = None
_IS_INITIALIZED = False
def get_model_parallel_group(): def polar_group():
return _MODEL_PARALLEL_GROUP return _POLAR_PARALLEL_GROUP
def init(process_group): def azimuth_group():
global _MODEL_PARALLEL_GROUP return _AZIMUTH_PARALLEL_GROUP
_MODEL_PARALLEL_GROUP = process_group
def init(polar_process_group, azimuth_process_group):
global _POLAR_PARALLEL_GROUP
global _AZIMUTH_PARALLEL_GROUP
_POLAR_PARALLEL_GROUP = polar_process_group
_AZIMUTH_PARALLEL_GROUP = azimuth_process_group
_IS_INITIALIZED = True
def is_initialized() -> bool: def is_initialized() -> bool:
return _MODEL_PARALLEL_GROUP is not None return _IS_INITIALIZED
def is_distributed_polar() -> bool:
return (_POLAR_PARALLEL_GROUP is not None)
def is_distributed_azimuth() -> bool:
return (_AZIMUTH_PARALLEL_GROUP is not None)
def polar_group_size() -> int:
if not is_distributed_polar():
return 1
else:
return dist.get_world_size(group = _POLAR_PARALLEL_GROUP)
def azimuth_group_size() -> int:
if not is_distributed_azimuth():
return 1
else:
return dist.get_world_size(group = _AZIMUTH_PARALLEL_GROUP)
def polar_group_rank() -> int:
if not is_distributed_polar():
return 0
else:
return dist.get_rank(group = _POLAR_PARALLEL_GROUP)
def azimuth_group_rank() -> int:
if not is_distributed_azimuth():
return 0
else:
return dist.get_rank(group = _AZIMUTH_PARALLEL_GROUP)
...@@ -53,7 +53,8 @@ def precompute_legpoly(mmax, lmax, t, norm="ortho", inverse=False, csphase=True) ...@@ -53,7 +53,8 @@ def precompute_legpoly(mmax, lmax, t, norm="ortho", inverse=False, csphase=True)
""" """
# compute the tensor P^m_n: # compute the tensor P^m_n:
pct = np.zeros((mmax, lmax, len(t)), dtype=np.float64) nmax = max(mmax,lmax)
pct = np.zeros((nmax, nmax, len(t)), dtype=np.float64)
sint = np.sin(t) sint = np.sin(t)
cost = np.cos(t) cost = np.cos(t)
...@@ -65,23 +66,25 @@ def precompute_legpoly(mmax, lmax, t, norm="ortho", inverse=False, csphase=True) ...@@ -65,23 +66,25 @@ def precompute_legpoly(mmax, lmax, t, norm="ortho", inverse=False, csphase=True)
pct[0,0,:] = norm_factor / np.sqrt(4 * np.pi) pct[0,0,:] = norm_factor / np.sqrt(4 * np.pi)
# fill the diagonal and the lower diagonal # fill the diagonal and the lower diagonal
for l in range(1, min(mmax,lmax)): for l in range(1, nmax):
pct[l-1, l, :] = np.sqrt(2*l + 1) * cost * pct[l-1, l-1, :] pct[l-1, l, :] = np.sqrt(2*l + 1) * cost * pct[l-1, l-1, :]
pct[l, l, :] = np.sqrt( (2*l + 1) * (1 + cost) * (1 - cost) / 2 / l ) * pct[l-1, l-1, :] pct[l, l, :] = np.sqrt( (2*l + 1) * (1 + cost) * (1 - cost) / 2 / l ) * pct[l-1, l-1, :]
# fill the remaining values on the upper triangle and multiply b # fill the remaining values on the upper triangle and multiply b
for l in range(2, lmax): for l in range(2, nmax):
for m in range(0, l-1): for m in range(0, l-1):
pct[m, l, :] = cost * np.sqrt((2*l - 1) / (l - m) * (2*l + 1) / (l + m)) * pct[m, l-1, :] \ pct[m, l, :] = cost * np.sqrt((2*l - 1) / (l - m) * (2*l + 1) / (l + m)) * pct[m, l-1, :] \
- np.sqrt((l + m - 1) / (l - m) * (2*l + 1) / (2*l - 3) * (l - m - 1) / (l + m)) * pct[m, l-2, :] - np.sqrt((l + m - 1) / (l - m) * (2*l + 1) / (2*l - 3) * (l - m - 1) / (l + m)) * pct[m, l-2, :]
if norm == "schmidt": if norm == "schmidt":
for l in range(0, lmax): for l in range(0, nmax):
if inverse: if inverse:
pct[:, l, : ] = pct[:, l, : ] * np.sqrt(2*l + 1) pct[:, l, : ] = pct[:, l, : ] * np.sqrt(2*l + 1)
else: else:
pct[:, l, : ] = pct[:, l, : ] / np.sqrt(2*l + 1) pct[:, l, : ] = pct[:, l, : ] / np.sqrt(2*l + 1)
pct = pct[:mmax, :lmax]
if csphase: if csphase:
for m in range(1, mmax, 2): for m in range(1, mmax, 2):
pct[m] *= -1 pct[m] *= -1
......
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import torch
from .sht import InverseRealSHT
class GaussianRandomFieldS2(torch.nn.Module):
def __init__(self, nlat, alpha=2.0, tau=3.0, sigma=None, radius=1.0, grid="equiangular", dtype=torch.float32):
super().__init__()
"""A mean-zero Gaussian Random Field on the sphere with Matern covariance:
C = sigma^2 (-Lap + tau^2 I)^(-alpha).
Lap is the Laplacian on the sphere, I the identity operator,
and sigma, tau, alpha are scalar parameters.
Note: C is trace-class on L^2 if and only if alpha > 1.
Parameters
----------
nlat : int
Number of latitudinal modes;
longitudinal modes are 2*nlat.
alpha : float, default is 2
Regularity parameter. Larger means smoother.
tau : float, default is 3
Lenght-scale parameter. Larger means more scales.
sigma : float, default is None
Scale parameter. Larger means bigger.
If None, sigma = tau**(0.5*(2*alpha - 2.0)).
radius : float, default is 1
Radius of the sphere.
grid : string, default is "equiangular"
Grid type. Currently supports "equiangular" and
"legendre-gauss".
dtype : torch.dtype, default is torch.float32
Numerical type for the calculations.
"""
#Number of latitudinal modes.
self.nlat = nlat
#Default value of sigma if None is given.
if sigma is None:
assert alpha > 1.0, f"Alpha must be greater than one, got {alpha}."
sigma = tau**(0.5*(2*alpha - 2.0))
# Inverse SHT
self.isht = InverseRealSHT(self.nlat, 2*self.nlat, grid=grid, norm='backward').to(dtype=dtype)
#Square root of the eigenvalues of C.
sqrt_eig = torch.tensor([j*(j+1) for j in range(self.nlat)]).view(self.nlat,1).repeat(1, self.nlat+1)
sqrt_eig = torch.tril(sigma*(((sqrt_eig/radius**2) + tau**2)**(-alpha/2.0)))
sqrt_eig[0,0] = 0.0
sqrt_eig = sqrt_eig.unsqueeze(0)
self.register_buffer('sqrt_eig', sqrt_eig)
#Save mean and var of the standard Gaussian.
#Need these to re-initialize distribution on a new device.
mean = torch.tensor([0.0]).to(dtype=dtype)
var = torch.tensor([1.0]).to(dtype=dtype)
self.register_buffer('mean', mean)
self.register_buffer('var', var)
#Standard normal noise sampler.
self.gaussian_noise = torch.distributions.normal.Normal(self.mean, self.var)
def forward(self, N, xi=None):
"""Sample random functions from a spherical GRF.
Parameters
----------
N : int
Number of functions to sample.
xi : torch.Tensor, default is None
Noise is a complex tensor of size (N, nlat, nlat+1).
If None, new Gaussian noise is sampled.
If xi is provided, N is ignored.
Output
-------
u : torch.Tensor
N random samples from the GRF returned as a
tensor of size (N, nlat, 2*nlat) on a equiangular grid.
"""
#Sample Gaussian noise.
if xi is None:
xi = self.gaussian_noise.sample(torch.Size((N, self.nlat, self.nlat + 1, 2))).squeeze()
xi = torch.view_as_complex(xi)
#Karhunen-Loeve expansion.
u = self.isht(xi*self.sqrt_eig)
return u
#Override cuda and to methods so sampler gets initialized with mean
#and variance on the correct device.
def cuda(self, *args, **kwargs):
super().cuda(*args, **kwargs)
self.gaussian_noise = torch.distributions.normal.Normal(self.mean, self.var)
return self
def to(self, *args, **kwargs):
super().to(*args, **kwargs)
self.gaussian_noise = torch.distributions.normal.Normal(self.mean, self.var)
return self
...@@ -36,7 +36,6 @@ import torch.fft ...@@ -36,7 +36,6 @@ import torch.fft
from .quadrature import * from .quadrature import *
from .legendre import * from .legendre import *
from .distributed import copy_to_parallel_region, scatter_to_parallel_region, reduce_from_parallel_region
class RealSHT(nn.Module): class RealSHT(nn.Module):
...@@ -94,10 +93,6 @@ class RealSHT(nn.Module): ...@@ -94,10 +93,6 @@ class RealSHT(nn.Module):
pct = precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase) pct = precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
weights = torch.einsum('mlk,k->mlk', pct, weights) weights = torch.einsum('mlk,k->mlk', pct, weights)
# shard the weights along n, because we want to be distributed in spectral space:
weights = scatter_to_parallel_region(weights, dim=1)
self.lmax_local = weights.shape[1]
# remember quadrature weights # remember quadrature weights
self.register_buffer('weights', weights, persistent=False) self.register_buffer('weights', weights, persistent=False)
...@@ -119,9 +114,8 @@ class RealSHT(nn.Module): ...@@ -119,9 +114,8 @@ class RealSHT(nn.Module):
x = torch.view_as_real(x) x = torch.view_as_real(x)
# distributed contraction: fork # distributed contraction: fork
x = copy_to_parallel_region(x)
out_shape = list(x.size()) out_shape = list(x.size())
out_shape[-3] = self.lmax_local out_shape[-3] = self.lmax
out_shape[-2] = self.mmax out_shape[-2] = self.mmax
xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device) xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device)
...@@ -174,10 +168,7 @@ class InverseRealSHT(nn.Module): ...@@ -174,10 +168,7 @@ class InverseRealSHT(nn.Module):
pct = precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase) pct = precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
# shard the pct along the n dim # register buffer
pct = scatter_to_parallel_region(pct, dim=1)
self.lmax_local = pct.shape[1]
self.register_buffer('pct', pct, persistent=False) self.register_buffer('pct', pct, persistent=False)
def extra_repr(self): def extra_repr(self):
...@@ -188,7 +179,7 @@ class InverseRealSHT(nn.Module): ...@@ -188,7 +179,7 @@ class InverseRealSHT(nn.Module):
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
assert(x.shape[-2] == self.lmax_local) assert(x.shape[-2] == self.lmax)
assert(x.shape[-1] == self.mmax) assert(x.shape[-1] == self.mmax)
# Evaluate associated Legendre functions on the output nodes # Evaluate associated Legendre functions on the output nodes
...@@ -198,9 +189,6 @@ class InverseRealSHT(nn.Module): ...@@ -198,9 +189,6 @@ class InverseRealSHT(nn.Module):
im = torch.einsum('...lm, mlk->...km', x[..., 1], self.pct ) im = torch.einsum('...lm, mlk->...km', x[..., 1], self.pct )
xs = torch.stack((rl, im), -1) xs = torch.stack((rl, im), -1)
# distributed contraction: join
xs = reduce_from_parallel_region(xs)
# apply the inverse (real) FFT # apply the inverse (real) FFT
x = torch.view_as_complex(xs) x = torch.view_as_complex(xs)
x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward") x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
...@@ -267,10 +255,6 @@ class RealVectorSHT(nn.Module): ...@@ -267,10 +255,6 @@ class RealVectorSHT(nn.Module):
# since the second component is imaginary, we need to take complex conjugation into account # since the second component is imaginary, we need to take complex conjugation into account
weights[1] = -1 * weights[1] weights[1] = -1 * weights[1]
# shard the weights along n, because we want to be distributed in spectral space:
weights = scatter_to_parallel_region(weights, dim=2)
self.lmax_local = weights.shape[2]
# remember quadrature weights # remember quadrature weights
self.register_buffer('weights', weights, persistent=False) self.register_buffer('weights', weights, persistent=False)
...@@ -291,9 +275,8 @@ class RealVectorSHT(nn.Module): ...@@ -291,9 +275,8 @@ class RealVectorSHT(nn.Module):
x = torch.view_as_real(x) x = torch.view_as_real(x)
# distributed contraction: fork # distributed contraction: fork
x = copy_to_parallel_region(x)
out_shape = list(x.size()) out_shape = list(x.size())
out_shape[-3] = self.lmax_local out_shape[-3] = self.lmax
out_shape[-2] = self.mmax out_shape[-2] = self.mmax
xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device) xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device)
...@@ -356,10 +339,7 @@ class InverseRealVectorSHT(nn.Module): ...@@ -356,10 +339,7 @@ class InverseRealVectorSHT(nn.Module):
dpct = precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase) dpct = precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
# shard the pct along the n dim # register weights
dpct = scatter_to_parallel_region(dpct, dim=2)
self.lmax_local = dpct.shape[2]
self.register_buffer('dpct', dpct, persistent=False) self.register_buffer('dpct', dpct, persistent=False)
def extra_repr(self): def extra_repr(self):
...@@ -370,7 +350,7 @@ class InverseRealVectorSHT(nn.Module): ...@@ -370,7 +350,7 @@ class InverseRealVectorSHT(nn.Module):
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
assert(x.shape[-2] == self.lmax_local) assert(x.shape[-2] == self.lmax)
assert(x.shape[-1] == self.mmax) assert(x.shape[-1] == self.mmax)
# Evaluate associated Legendre functions on the output nodes # Evaluate associated Legendre functions on the output nodes
...@@ -397,9 +377,6 @@ class InverseRealVectorSHT(nn.Module): ...@@ -397,9 +377,6 @@ class InverseRealVectorSHT(nn.Module):
t = torch.stack((trl, tim), -1) t = torch.stack((trl, tim), -1)
xs = torch.stack((s, t), -4) xs = torch.stack((s, t), -4)
# distributed contraction: join
xs = reduce_from_parallel_region(xs)
# apply the inverse (real) FFT # apply the inverse (real) FFT
x = torch.view_as_complex(xs) x = torch.view_as_complex(xs)
x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward") x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment