"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "40de88af8c8ef6ecd69f99dabeeb07f8362fcf87"
Commit ef4f6518 authored by Boris Bonev's avatar Boris Bonev
Browse files

v0.5 update

parent 77ac7836
...@@ -3,3 +3,4 @@ The code was authored by the following people: ...@@ -3,3 +3,4 @@ The code was authored by the following people:
Boris Bonev - NVIDIA Corporation Boris Bonev - NVIDIA Corporation
Christian Hundt - NVIDIA Corporation Christian Hundt - NVIDIA Corporation
Thorsten Kurth - NVIDIA Corporation Thorsten Kurth - NVIDIA Corporation
Nikola Kovachki - NVIDIA Corporation
...@@ -2,6 +2,11 @@ ...@@ -2,6 +2,11 @@
## Versioning ## Versioning
### v0.5
* Reworked distributed SHT
* Module for sampling Gaussian Random Fields on the sphere
### v0.4 ### v0.4
* Computation of associated Legendre polynomials * Computation of associated Legendre polynomials
...@@ -32,13 +37,4 @@ ...@@ -32,13 +37,4 @@
### v0.1 ### v0.1
* Single GPU forward and backward transform * Single GPU forward and backward transform
* Minimal code example and notebook * Minimal code example and notebook
\ No newline at end of file
<!-- ## Detailed logs
### 23-11-2022
* Initialized the library
* Added `getting_started.ipynb` example
* Added simple example to test the SHT
* Logo -->
# SPDX-FileCopyrightText: Copyright (c) 2022 The FourCastNet Authors. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
# #
# Redistribution and use in source and binary forms, with or without # Redistribution and use in source and binary forms, with or without
......
...@@ -36,7 +36,13 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ...@@ -36,7 +36,13 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
<!-- ## What is torch-harmonics? --> <!-- ## What is torch-harmonics? -->
`torch_harmonics` is a differentiable implementation of the Spherical Harmonic transform in PyTorch. It uses quadrature to compute the projection onto the associated Legendre polynomials and FFTs for the projection onto the harmonic basis. This algorithm tends to outperform others with better asymptotic scaling for most practical purposes. Spherical Harmonic Transforms (SHTs) are the counterpart to Fourier transforms on the sphere. As such they are an invaluable tool for signal-processing on the sphere.
`torch_harmonics` is a differentiable implementation of the SHT in PyTorch. It uses quadrature rules to compute the projection onto the associated Legendre polynomials and FFTs for the projection onto the harmonic basis. This algorithm tends to outperform others with better asymptotic scaling for most practical purposes.
`torch_harmonics` uses PyTorch primitives to implement these operations, making it fully differentiable. Moreover, the quadrature can be distributed onto multiple ranks making it spatially distributed.
`torch_harmonics` has been used to implement a variety of differentiable PDE solvers which generated the animations below.
<table border="0" cellspacing="0" cellpadding="0"> <table border="0" cellspacing="0" cellpadding="0">
...@@ -76,6 +82,7 @@ docker run --gpus all -it --rm --ipc=host --ulimit memlock=-1 --ulimit stack=671 ...@@ -76,6 +82,7 @@ docker run --gpus all -it --rm --ipc=host --ulimit memlock=-1 --ulimit stack=671
- Boris Bonev (bbonev@nvidia.com) - Boris Bonev (bbonev@nvidia.com)
- Christian Hundt (chundt@nvidia.com) - Christian Hundt (chundt@nvidia.com)
- Thorsten Kurth (tkurth@nvidia.com) - Thorsten Kurth (tkurth@nvidia.com)
- Nikola Kovachki (nkovachki@nvidia.com)
## Implementation ## Implementation
The implementation follows the paper "Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations", N. Schaeffer, G3: Geochemistry, Geophysics, Geosystems. The implementation follows the paper "Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations", N. Schaeffer, G3: Geochemistry, Geophysics, Geosystems.
...@@ -126,7 +133,7 @@ The main functionality of `torch_harmonics` is provided in the form of `torch.nn ...@@ -126,7 +133,7 @@ The main functionality of `torch_harmonics` is provided in the form of `torch.nn
```python ```python
import torch import torch
import torch_harmonics as harmonics import torch_harmonics as th
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
...@@ -136,11 +143,13 @@ batch_size = 32 ...@@ -136,11 +143,13 @@ batch_size = 32
signal = torch.randn(batch_size, nlat, nlon) signal = torch.randn(batch_size, nlat, nlon)
# transform data on an equiangular grid # transform data on an equiangular grid
sht = harmonics.RealSHT(nlat, nlon, grid="equiangular").to(device).float() sht = th.RealSHT(nlat, nlon, grid="equiangular").to(device).float()
coeffs = sht(signal) coeffs = sht(signal)
``` ```
`torch_harmonics` also implements a distributed variant of the SHT located in `torch-harmonics.distributed`.
## References ## References
<a id="1">[1]</a> <a id="1">[1]</a>
......
# coding=utf-8 # coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The FourCastNet Authors. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
# #
# Redistribution and use in source and binary forms, with or without # Redistribution and use in source and binary forms, with or without
......
# coding=utf-8 # coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The FourCastNet Authors. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
# #
# Redistribution and use in source and binary forms, with or without # Redistribution and use in source and binary forms, with or without
......
# coding=utf-8 # coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The FourCastNet Authors. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
# #
# Redistribution and use in source and binary forms, with or without # Redistribution and use in source and binary forms, with or without
......
# coding=utf-8 # coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The FourCastNet Authors. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
# #
# Redistribution and use in source and binary forms, with or without # Redistribution and use in source and binary forms, with or without
...@@ -333,7 +333,7 @@ class ShallowWaterSolver(nn.Module): ...@@ -333,7 +333,7 @@ class ShallowWaterSolver(nn.Module):
return out return out
def plot_griddata(self, data, fig, cmap='twilight_shifted', vmax=None, vmin=None, projection='3d', title=None, antialiased=True): def plot_griddata(self, data, fig, cmap='twilight_shifted', vmax=None, vmin=None, projection='3d', title=None, antialiased=False):
""" """
plotting routine for data on the grid. Requires cartopy for 3d plots. plotting routine for data on the grid. Requires cartopy for 3d plots.
""" """
......
This diff is collapsed.
# coding=utf-8 # coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The FourCastNet Authors. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause # SPDX-License-Identifier: BSD-3-Clause
# #
# Redistribution and use in source and binary forms, with or without # Redistribution and use in source and binary forms, with or without
...@@ -33,7 +33,7 @@ from setuptools import setup ...@@ -33,7 +33,7 @@ from setuptools import setup
setup( setup(
name='torch_harmonics', name='torch_harmonics',
version='0.4', version='0.5',
author='Boris Bonev', author='Boris Bonev',
author_email='bbonev@nvidia.com', author_email='bbonev@nvidia.com',
packages=['torch_harmonics',], packages=['torch_harmonics',],
......
...@@ -29,86 +29,57 @@ ...@@ -29,86 +29,57 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# #
# ignore this (just for development without installation) # we need this in order to enable distributed
import sys
import os
sys.path.append("..")
sys.path.append(".")
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch_harmonics as harmonics
from torch_harmonics.distributed.primitives import gather_from_parallel_region, scatter_to_parallel_region
try:
from tqdm import tqdm
except:
tqdm = lambda x : x
# set up distributed # those need to be global
world_size = int(os.getenv('WORLD_SIZE', 1)) _POLAR_PARALLEL_GROUP = None
world_rank = int(os.getenv('WORLD_RANK', 0)) _AZIMUTH_PARALLEL_GROUP = None
port = int(os.getenv('MASTER_PORT', 0)) _IS_INITIALIZED = False
master_address = os.getenv('MASTER_ADDR', 'localhost')
dist.init_process_group(backend = 'nccl',
init_method = f"tcp://{master_address}:{port}",
rank = world_rank,
world_size = world_size)
local_rank = world_rank % torch.cuda.device_count()
mp_group = dist.new_group(ranks=list(range(world_size)))
my_rank = dist.get_rank(mp_group)
group_size = 1 if not dist.is_initialized() else dist.get_world_size(mp_group)
device = torch.device(f"cuda:{local_rank}")
# set seed def polar_group():
torch.manual_seed(333) return _POLAR_PARALLEL_GROUP
torch.cuda.manual_seed(333)
if my_rank == 0: def azimuth_group():
print(f"Running distributed test on {group_size} ranks.") return _AZIMUTH_PARALLEL_GROUP
# common parameters def init(polar_process_group, azimuth_process_group):
b, c, n_theta, n_lambda = 1, 21, 361, 720 global _POLAR_PARALLEL_GROUP
global _AZIMUTH_PARALLEL_GROUP
_POLAR_PARALLEL_GROUP = polar_process_group
_AZIMUTH_PARALLEL_GROUP = azimuth_process_group
_IS_INITIALIZED = True
# do serial tests first: def is_initialized() -> bool:
#forward_transform = harmonics.RealSHT(n_theta, n_lambda).to(device) return _IS_INITIALIZED
inverse_transform = harmonics.InverseRealSHT(n_theta, n_lambda).to(device)
# set up signal def is_distributed_polar() -> bool:
with torch.no_grad(): return (_POLAR_PARALLEL_GROUP is not None)
signal_leggauss = torch.randn(b, c, inverse_transform.lmax, inverse_transform.mmax, device=device, dtype=torch.complex128)
signal_leggauss_dist = signal_leggauss.clone()
signal_leggauss.requires_grad = True
# do a fwd and bwd pass: def is_distributed_azimuth() -> bool:
x_local = inverse_transform(signal_leggauss) return (_AZIMUTH_PARALLEL_GROUP is not None)
loss = torch.sum(x_local)
loss.backward()
local_grad = torch.view_as_real(signal_leggauss.grad.clone())
# now the distributed test def polar_group_size() -> int:
harmonics.distributed.init(mp_group) if not is_distributed_polar():
inverse_transform_dist = harmonics.InverseRealSHT(n_theta, n_lambda).to(device) return 1
with torch.no_grad(): else:
signal_leggauss_dist = scatter_to_parallel_region(signal_leggauss_dist, dim=2) return dist.get_world_size(group = _POLAR_PARALLEL_GROUP)
signal_leggauss_dist.requires_grad = True
# do distributed sht def azimuth_group_size() -> int:
x_dist = inverse_transform_dist(signal_leggauss_dist) if not is_distributed_azimuth():
loss = torch.sum(x_dist) return 1
loss.backward() else:
dist_grad = signal_leggauss_dist.grad.clone() return dist.get_world_size(group = _AZIMUTH_PARALLEL_GROUP)
# gather the output def polar_group_rank() -> int:
dist_grad = torch.view_as_real(gather_from_parallel_region(dist_grad, dim=2)) if not is_distributed_polar():
return 0
else:
return dist.get_rank(group = _POLAR_PARALLEL_GROUP)
if my_rank == 0: def azimuth_group_rank() -> int:
print(f"Local Out: sum={x_local.abs().sum().item()}, max={x_local.max().item()}, min={x_local.min().item()}") if not is_distributed_azimuth():
print(f"Dist Out: sum={x_dist.abs().sum().item()}, max={x_dist.max().item()}, min={x_dist.min().item()}") return 0
diff = (x_local-x_dist).abs() else:
print(f"Out Difference: abs={diff.sum().item()}, rel={diff.sum().item() / (0.5*(x_local.abs().sum() + x_dist.abs().sum()))}, max={diff.max().item()}") return dist.get_rank(group = _AZIMUTH_PARALLEL_GROUP)
print("")
print(f"Local Grad: sum={local_grad.abs().sum().item()}, max={local_grad.max().item()}, min={local_grad.min().item()}")
print(f"Dist Grad: sum={dist_grad.abs().sum().item()}, max={dist_grad.max().item()}, min={dist_grad.min().item()}")
diff = (local_grad-dist_grad).abs()
print(f"Grad Difference: abs={diff.sum().item()}, rel={diff.sum().item() / (0.5*(local_grad.abs().sum() + dist_grad.abs().sum()))}, max={diff.max().item()}")
...@@ -36,9 +36,10 @@ sys.path.append("..") ...@@ -36,9 +36,10 @@ sys.path.append("..")
sys.path.append(".") sys.path.append(".")
import torch import torch
import torch.nn.functional as F
import torch.distributed as dist import torch.distributed as dist
import torch_harmonics as harmonics import torch_harmonics as harmonics
from torch_harmonics.distributed.primitives import gather_from_parallel_region import torch_harmonics.distributed as thd
try: try:
from tqdm import tqdm from tqdm import tqdm
...@@ -46,70 +47,169 @@ except: ...@@ -46,70 +47,169 @@ except:
tqdm = lambda x : x tqdm = lambda x : x
# set up distributed # set up distributed
world_size = int(os.getenv('WORLD_SIZE', 1))
world_rank = int(os.getenv('WORLD_RANK', 0)) world_rank = int(os.getenv('WORLD_RANK', 0))
grid_size_h = int(os.getenv('GRID_H', 1))
grid_size_w = int(os.getenv('GRID_W', 1))
port = int(os.getenv('MASTER_PORT', 0)) port = int(os.getenv('MASTER_PORT', 0))
master_address = os.getenv('MASTER_ADDR', 'localhost') master_address = os.getenv('MASTER_ADDR', 'localhost')
world_size = grid_size_h * grid_size_w
dist.init_process_group(backend = 'nccl', dist.init_process_group(backend = 'nccl',
init_method = f"tcp://{master_address}:{port}", init_method = f"tcp://{master_address}:{port}",
rank = world_rank, rank = world_rank,
world_size = world_size) world_size = world_size)
local_rank = world_rank % torch.cuda.device_count() local_rank = world_rank % torch.cuda.device_count()
mp_group = dist.new_group(ranks=list(range(world_size)))
my_rank = dist.get_rank(mp_group)
group_size = 1 if not dist.is_initialized() else dist.get_world_size(mp_group)
device = torch.device(f"cuda:{local_rank}") device = torch.device(f"cuda:{local_rank}")
# compute local ranks in h and w:
# rank = wrank + grid_size_w * hrank
wrank = world_rank % grid_size_w
hrank = world_rank // grid_size_w
w_group = None
h_group = None
# now set up the comm grid:
wgroups = []
for h in range(grid_size_h):
start = h
end = h + grid_size_w
wgroups.append(list(range(start, end)))
print(wgroups)
for grp in wgroups:
if len(grp) == 1:
continue
tmp_group = dist.new_group(ranks=grp)
if wrank in grp:
w_group = tmp_group
# transpose:
hgroups = [sorted(list(i)) for i in zip(*wgroups)]
print(hgroups)
for grp in hgroups:
if len(grp) == 1:
continue
tmp_group = dist.new_group(ranks=grp)
if hrank in grp:
h_group = tmp_group
# set device
torch.cuda.set_device(device.index)
# set seed # set seed
torch.manual_seed(333) torch.manual_seed(333)
torch.cuda.manual_seed(333) torch.cuda.manual_seed(333)
if my_rank == 0: if world_rank == 0:
print(f"Running distributed test on {group_size} ranks.") print(f"Running distributed test on grid H x W = {grid_size_h} x {grid_size_w}")
# initializing sht
thd.init(h_group, w_group)
# common parameters # common parameters
b, c, n_theta, n_lambda = 1, 21, 361, 720 B, C, H, W = 1, 8, 721, 1440
Hloc = (H + grid_size_h - 1) // grid_size_h
Wloc = (W + grid_size_w - 1) // grid_size_w
Hpad = grid_size_h * Hloc - H
Wpad = grid_size_w * Wloc - W
# do serial tests first: # do serial tests first:
forward_transform = harmonics.RealSHT(n_theta, n_lambda).to(device) forward_transform_local = harmonics.RealSHT(nlat=H, nlon=W).to(device)
inverse_transform = harmonics.InverseRealSHT(n_theta, n_lambda).to(device) forward_transform_dist = thd.DistributedRealSHT(nlat=H, nlon=W).to(device)
Lloc = (forward_transform_dist.lpad + forward_transform_dist.lmax) // grid_size_h
Mloc = (forward_transform_dist.mpad + forward_transform_dist.mmax) // grid_size_w
# create tensors
inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=device)
# set up signal # pad
with torch.no_grad(): with torch.no_grad():
signal_leggauss = inverse_transform(torch.randn(b, c, forward_transform.lmax, forward_transform.mmax, device=device, dtype=torch.complex128)) inp_pad = F.pad(inp_full, (0, Wpad, 0, Hpad))
signal_leggauss_dist = signal_leggauss.clone()
signal_leggauss.requires_grad = True # split in W
signal_leggauss_dist.requires_grad = True inp_local = torch.split(inp_pad, split_size_or_sections=Wloc, dim=-1)[wrank]
# do a fwd and bwd pass: # split in H
x_local = forward_transform(signal_leggauss) inp_local = torch.split(inp_local, split_size_or_sections=Hloc, dim=-2)[hrank]
loss = torch.sum(torch.view_as_real(x_local))
loss.backward() # do FWD transform
x_local = torch.view_as_real(x_local) out_full = forward_transform_local(inp_full)
local_grad = signal_leggauss.grad.clone() out_local = forward_transform_dist(inp_local)
# now the distributed test # gather the local data
harmonics.distributed.init(mp_group) # gather in W
forward_transform_dist = harmonics.RealSHT(n_theta, n_lambda).to(device) if grid_size_w > 1:
inverse_transform_dist = harmonics.InverseRealSHT(n_theta, n_lambda).to(device) olist = [torch.empty_like(out_local) for _ in range(grid_size_w)]
olist[wrank] = out_local
# do distributed sht dist.all_gather(olist, out_local, group=w_group)
x_dist = forward_transform_dist(signal_leggauss_dist) out_full_gather = torch.cat(olist, dim=-1)
loss = torch.sum(torch.view_as_real(x_dist)) out_full_gather = out_full_gather[..., :forward_transform_dist.mmax]
loss.backward() else:
x_dist = torch.view_as_real(x_dist) out_full_gather = out_local
dist_grad = signal_leggauss_dist.grad.clone()
# gather in h
# gather the output if grid_size_h > 1:
x_dist = gather_from_parallel_region(x_dist, dim=2) olist = [torch.empty_like(out_full_gather) for _ in range(grid_size_h)]
olist[hrank] = out_full_gather
if my_rank == 0: dist.all_gather(olist, out_full_gather, group=h_group)
print(f"Local Out: sum={x_local.abs().sum().item()}, max={x_local.max().item()}, min={x_local.min().item()}") out_full_gather = torch.cat(olist, dim=-2)
print(f"Dist Out: sum={x_dist.abs().sum().item()}, max={x_dist.max().item()}, min={x_dist.min().item()}") out_full_gather = out_full_gather[..., :forward_transform_dist.lmax, :]
diff = (x_local-x_dist).abs()
print(f"Out Difference: abs={diff.sum().item()}, rel={diff.sum().item() / (0.5*(x_local.abs().sum() + x_dist.abs().sum()))}, max={diff.max().item()}")
if world_rank == 0:
print(f"Local Out: sum={out_full.abs().sum().item()}, max={out_full.abs().max().item()}, min={out_full.abs().min().item()}")
print(f"Dist Out: sum={out_full_gather.abs().sum().item()}, max={out_full_gather.abs().max().item()}, min={out_full_gather.abs().min().item()}")
diff = (out_full-out_full_gather).abs()
print(f"Out Difference: abs={diff.sum().item()}, rel={diff.sum().item() / (0.5*(out_full.abs().sum() + out_full_gather.abs().sum()))}, max={diff.abs().max().item()}")
print("") print("")
print(f"Local Grad: sum={local_grad.abs().sum().item()}, max={local_grad.max().item()}, min={local_grad.min().item()}")
print(f"Dist Grad: sum={dist_grad.abs().sum().item()}, max={dist_grad.max().item()}, min={dist_grad.min().item()}") # create split input grad
diff = (local_grad-dist_grad).abs() with torch.no_grad():
print(f"Grad Difference: abs={diff.sum().item()}, rel={diff.sum().item() / (0.5*(local_grad.abs().sum() + dist_grad.abs().sum()))}, max={diff.max().item()}") # create full grad
ograd_full = torch.randn_like(out_full)
# pad
ograd_pad = F.pad(ograd_full, [0, forward_transform_dist.mpad, 0, forward_transform_dist.lpad])
# split in M
ograd_local = torch.split(ograd_pad, split_size_or_sections=Mloc, dim=-1)[wrank]
# split in H
ograd_local = torch.split(ograd_local, split_size_or_sections=Lloc, dim=-2)[hrank]
# backward pass:
# local
inp_full.requires_grad = True
out_full = forward_transform_local(inp_full)
out_full.backward(ograd_full)
igrad_full = inp_full.grad.clone()
# distributed
inp_local.requires_grad = True
out_local = forward_transform_dist(inp_local)
out_local.backward(ograd_local)
igrad_local = inp_local.grad.clone()
# gather
# gather in W
if grid_size_w > 1:
olist = [torch.empty_like(igrad_local) for _ in range(grid_size_w)]
olist[wrank] = igrad_local
dist.all_gather(olist, igrad_local, group=w_group)
igrad_full_gather = torch.cat(olist, dim=-1)
igrad_full_gather = igrad_full_gather[..., :W]
else:
igrad_full_gather = igrad_local
# gather in h
if grid_size_h > 1:
olist = [torch.empty_like(igrad_full_gather) for _ in range(grid_size_h)]
olist[hrank] = igrad_full_gather
dist.all_gather(olist, igrad_full_gather, group=h_group)
igrad_full_gather = torch.cat(olist, dim=-2)
igrad_full_gather = igrad_full_gather[..., :H, :]
if world_rank == 0:
print(f"Local Grad: sum={igrad_full.abs().sum().item()}, max={igrad_full.abs().max().item()}, min={igrad_full.abs().min().item()}")
print(f"Dist Grad: sum={igrad_full_gather.abs().sum().item()}, max={igrad_full_gather.abs().max().item()}, min={igrad_full_gather.abs().min().item()}")
diff = (igrad_full-igrad_full_gather).abs()
print(f"Grad Difference: abs={diff.sum().item()}, rel={diff.sum().item() / (0.5*(igrad_full.abs().sum() + igrad_full_gather.abs().sum()))}, max={diff.abs().max().item()}")
...@@ -31,3 +31,4 @@ ...@@ -31,3 +31,4 @@
from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT
from . import quadrature from . import quadrature
from . import random_fields
...@@ -30,5 +30,10 @@ ...@@ -30,5 +30,10 @@
# #
# we need this in order to enable distributed # we need this in order to enable distributed
from .utils import init, is_initialized from .utils import init, is_initialized, polar_group, azimuth_group
from .primitives import copy_to_parallel_region, scatter_to_parallel_region, reduce_from_parallel_region from .utils import polar_group_size, azimuth_group_size, polar_group_rank, azimuth_group_rank
from .primitives import distributed_transpose_azimuth, distributed_transpose_polar
# import the sht stuff
from .distributed_sht import DistributedRealSHT, DistributedInverseRealSHT
from .distributed_sht import DistributedRealVectorSHT, DistributedInverseRealVectorSHT
This diff is collapsed.
...@@ -32,7 +32,7 @@ ...@@ -32,7 +32,7 @@
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from .utils import get_model_parallel_group, is_initialized from .utils import polar_group, azimuth_group, is_initialized
# general helpers # general helpers
def get_memory_format(tensor): def get_memory_format(tensor):
...@@ -50,163 +50,54 @@ def split_tensor_along_dim(tensor, dim, num_chunks): ...@@ -50,163 +50,54 @@ def split_tensor_along_dim(tensor, dim, num_chunks):
return tensor_list return tensor_list
# split def _transpose(tensor, dim0, dim1, group=None, async_op=False):
def _split(input_, dim_, group=None):
"""Split the tensor along its last dimension and keep the corresponding slice."""
# get input format # get input format
input_format = get_memory_format(input_) input_format = get_memory_format(tensor)
# Bypass the function if we are using only 1 GPU. # get comm params
comm_size = dist.get_world_size(group=group) comm_size = dist.get_world_size(group=group)
if comm_size == 1:
return input_
# Split along last dimension. # split and local transposition
input_list = split_tensor_along_dim(input_, dim_, comm_size) split_size = tensor.shape[dim0] // comm_size
x_send = [y.contiguous(memory_format=input_format) for y in torch.split(tensor, split_size, dim=dim0)]
x_recv = [torch.empty_like(x_send[0]).contiguous(memory_format=input_format) for _ in range(comm_size)]
# Note: torch.split does not create contiguous tensors by default. # global transposition
rank = dist.get_rank(group=group) req = dist.all_to_all(x_recv, x_send, group=group, async_op=async_op)
output = input_list[rank].contiguous(memory_format=input_format)
return output return x_recv, req
# those are used by the various helper functions class distributed_transpose_azimuth(torch.autograd.Function):
def _reduce(input_, use_fp32=True, group=None):
"""All-reduce the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
if dist.get_world_size(group=group) == 1:
return input_
# All-reduce.
if use_fp32:
dtype = input_.dtype
inputf_ = input_.float()
dist.all_reduce(inputf_, group=group)
input_ = inputf_.to(dtype)
else:
dist.all_reduce(input_, group=group)
return input_
class _CopyToParallelRegion(torch.autograd.Function):
"""Pass the input to the parallel region."""
@staticmethod
def symbolic(graph, input_):
return input_
@staticmethod @staticmethod
def forward(ctx, input_): def forward(ctx, x, dim):
return input_ xlist, _ = _transpose(x, dim[0], dim[1], group=azimuth_group())
x = torch.cat(xlist, dim=dim[1])
ctx.dim = dim
return x
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, go):
return _reduce(grad_output, group=get_model_parallel_group()) dim = ctx.dim
gilist, _ = _transpose(go, dim[1], dim[0], group=azimuth_group())
# write a convenient functional wrapper gi = torch.cat(gilist, dim=dim[0])
def copy_to_parallel_region(input_): return gi, None
if not is_initialized():
return input_
else:
return _CopyToParallelRegion.apply(input_)
# reduce
class _ReduceFromParallelRegion(torch.autograd.Function):
"""All-reduce the input from the parallel region."""
@staticmethod
def symbolic(graph, input_):
return _reduce(input_, group=get_model_parallel_group())
@staticmethod class distributed_transpose_polar(torch.autograd.Function):
def forward(ctx, input_):
return _reduce(input_, group=get_model_parallel_group())
@staticmethod @staticmethod
def backward(ctx, grad_output): def forward(ctx, x, dim):
return grad_output xlist, _ = _transpose(x, dim[0], dim[1], group=polar_group())
x = torch.cat(xlist, dim=dim[1])
ctx.dim = dim
def reduce_from_parallel_region(input_): return x
if not is_initialized():
return input_
else:
return _ReduceFromParallelRegion.apply(input_)
# gather
def _gather(input_, dim_, group=None):
"""Gather tensors and concatinate along the last dimension."""
# get input format
input_format = get_memory_format(input_)
print(input_format)
comm_size = dist.get_world_size(group=group)
# Bypass the function if we are using only 1 GPU.
if comm_size==1:
return input_
# sanity checks
assert(dim_ < input_.dim()), f"Error, cannot gather along {dim_} for tensor with {input_.dim()} dimensions."
# Size and dimension.
comm_rank = dist.get_rank(group=group)
# input needs to be contiguous
input_ = input_.contiguous(memory_format=input_format)
tensor_list = [torch.empty_like(input_) for _ in range(comm_size)]
tensor_list[comm_rank] = input_
dist.all_gather(tensor_list, input_, group=group)
output = torch.cat(tensor_list, dim=dim_).contiguous(memory_format=input_format)
return output
class _GatherFromParallelRegion(torch.autograd.Function):
"""Gather the input from parallel region and concatinate."""
@staticmethod
def symbolic(graph, input_, dim_):
return _gather(input_, dim_, group=get_model_parallel_group())
@staticmethod @staticmethod
def forward(ctx, input_, dim_): def backward(ctx, go):
ctx.dim = dim_ dim = ctx.dim
return _gather(input_, dim_, group=get_model_parallel_group()) gilist, _ = _transpose(go, dim[1], dim[0], group=polar_group())
gi = torch.cat(gilist, dim=dim[0])
@staticmethod return gi, None
def backward(ctx, grad_output):
return _split(grad_output, ctx.dim, group=get_model_parallel_group()), None
def gather_from_parallel_region(input_, dim):
if not is_initialized():
return input_
else:
return _GatherFromParallelRegion.apply(input_, dim)
# scatter
class _ScatterToParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
@staticmethod
def symbolic(graph, input_, dim_):
return _split(input_, dim_, group=get_model_parallel_group())
@staticmethod
def forward(ctx, input_, dim_):
ctx.dim = dim_
return _split(input_, dim_, group=get_model_parallel_group())
@staticmethod
def backward(ctx, grad_output):
return _gather(grad_output, ctx.dim, group=get_model_parallel_group()), None
def scatter_to_parallel_region(input_, dim):
if not is_initialized():
return input_
else:
return _ScatterToParallelRegion.apply(input_, dim)
...@@ -34,15 +34,52 @@ import torch ...@@ -34,15 +34,52 @@ import torch
import torch.distributed as dist import torch.distributed as dist
# those need to be global # those need to be global
_MODEL_PARALLEL_GROUP = None _POLAR_PARALLEL_GROUP = None
_AZIMUTH_PARALLEL_GROUP = None
_IS_INITIALIZED = False
def get_model_parallel_group(): def polar_group():
return _MODEL_PARALLEL_GROUP return _POLAR_PARALLEL_GROUP
def init(process_group): def azimuth_group():
global _MODEL_PARALLEL_GROUP return _AZIMUTH_PARALLEL_GROUP
_MODEL_PARALLEL_GROUP = process_group
def init(polar_process_group, azimuth_process_group):
global _POLAR_PARALLEL_GROUP
global _AZIMUTH_PARALLEL_GROUP
_POLAR_PARALLEL_GROUP = polar_process_group
_AZIMUTH_PARALLEL_GROUP = azimuth_process_group
_IS_INITIALIZED = True
def is_initialized() -> bool: def is_initialized() -> bool:
return _MODEL_PARALLEL_GROUP is not None return _IS_INITIALIZED
def is_distributed_polar() -> bool:
return (_POLAR_PARALLEL_GROUP is not None)
def is_distributed_azimuth() -> bool:
return (_AZIMUTH_PARALLEL_GROUP is not None)
def polar_group_size() -> int:
if not is_distributed_polar():
return 1
else:
return dist.get_world_size(group = _POLAR_PARALLEL_GROUP)
def azimuth_group_size() -> int:
if not is_distributed_azimuth():
return 1
else:
return dist.get_world_size(group = _AZIMUTH_PARALLEL_GROUP)
def polar_group_rank() -> int:
if not is_distributed_polar():
return 0
else:
return dist.get_rank(group = _POLAR_PARALLEL_GROUP)
def azimuth_group_rank() -> int:
if not is_distributed_azimuth():
return 0
else:
return dist.get_rank(group = _AZIMUTH_PARALLEL_GROUP)
...@@ -53,7 +53,8 @@ def precompute_legpoly(mmax, lmax, t, norm="ortho", inverse=False, csphase=True) ...@@ -53,7 +53,8 @@ def precompute_legpoly(mmax, lmax, t, norm="ortho", inverse=False, csphase=True)
""" """
# compute the tensor P^m_n: # compute the tensor P^m_n:
pct = np.zeros((mmax, lmax, len(t)), dtype=np.float64) nmax = max(mmax,lmax)
pct = np.zeros((nmax, nmax, len(t)), dtype=np.float64)
sint = np.sin(t) sint = np.sin(t)
cost = np.cos(t) cost = np.cos(t)
...@@ -65,23 +66,25 @@ def precompute_legpoly(mmax, lmax, t, norm="ortho", inverse=False, csphase=True) ...@@ -65,23 +66,25 @@ def precompute_legpoly(mmax, lmax, t, norm="ortho", inverse=False, csphase=True)
pct[0,0,:] = norm_factor / np.sqrt(4 * np.pi) pct[0,0,:] = norm_factor / np.sqrt(4 * np.pi)
# fill the diagonal and the lower diagonal # fill the diagonal and the lower diagonal
for l in range(1, min(mmax,lmax)): for l in range(1, nmax):
pct[l-1, l, :] = np.sqrt(2*l + 1) * cost * pct[l-1, l-1, :] pct[l-1, l, :] = np.sqrt(2*l + 1) * cost * pct[l-1, l-1, :]
pct[l, l, :] = np.sqrt( (2*l + 1) * (1 + cost) * (1 - cost) / 2 / l ) * pct[l-1, l-1, :] pct[l, l, :] = np.sqrt( (2*l + 1) * (1 + cost) * (1 - cost) / 2 / l ) * pct[l-1, l-1, :]
# fill the remaining values on the upper triangle and multiply b # fill the remaining values on the upper triangle and multiply b
for l in range(2, lmax): for l in range(2, nmax):
for m in range(0, l-1): for m in range(0, l-1):
pct[m, l, :] = cost * np.sqrt((2*l - 1) / (l - m) * (2*l + 1) / (l + m)) * pct[m, l-1, :] \ pct[m, l, :] = cost * np.sqrt((2*l - 1) / (l - m) * (2*l + 1) / (l + m)) * pct[m, l-1, :] \
- np.sqrt((l + m - 1) / (l - m) * (2*l + 1) / (2*l - 3) * (l - m - 1) / (l + m)) * pct[m, l-2, :] - np.sqrt((l + m - 1) / (l - m) * (2*l + 1) / (2*l - 3) * (l - m - 1) / (l + m)) * pct[m, l-2, :]
if norm == "schmidt": if norm == "schmidt":
for l in range(0, lmax): for l in range(0, nmax):
if inverse: if inverse:
pct[:, l, : ] = pct[:, l, : ] * np.sqrt(2*l + 1) pct[:, l, : ] = pct[:, l, : ] * np.sqrt(2*l + 1)
else: else:
pct[:, l, : ] = pct[:, l, : ] / np.sqrt(2*l + 1) pct[:, l, : ] = pct[:, l, : ] / np.sqrt(2*l + 1)
pct = pct[:mmax, :lmax]
if csphase: if csphase:
for m in range(1, mmax, 2): for m in range(1, mmax, 2):
pct[m] *= -1 pct[m] *= -1
......
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import torch
from .sht import InverseRealSHT
class GaussianRandomFieldS2(torch.nn.Module):
def __init__(self, nlat, alpha=2.0, tau=3.0, sigma=None, radius=1.0, grid="equiangular", dtype=torch.float32):
super().__init__()
"""A mean-zero Gaussian Random Field on the sphere with Matern covariance:
C = sigma^2 (-Lap + tau^2 I)^(-alpha).
Lap is the Laplacian on the sphere, I the identity operator,
and sigma, tau, alpha are scalar parameters.
Note: C is trace-class on L^2 if and only if alpha > 1.
Parameters
----------
nlat : int
Number of latitudinal modes;
longitudinal modes are 2*nlat.
alpha : float, default is 2
Regularity parameter. Larger means smoother.
tau : float, default is 3
Lenght-scale parameter. Larger means more scales.
sigma : float, default is None
Scale parameter. Larger means bigger.
If None, sigma = tau**(0.5*(2*alpha - 2.0)).
radius : float, default is 1
Radius of the sphere.
grid : string, default is "equiangular"
Grid type. Currently supports "equiangular" and
"legendre-gauss".
dtype : torch.dtype, default is torch.float32
Numerical type for the calculations.
"""
#Number of latitudinal modes.
self.nlat = nlat
#Default value of sigma if None is given.
if sigma is None:
assert alpha > 1.0, f"Alpha must be greater than one, got {alpha}."
sigma = tau**(0.5*(2*alpha - 2.0))
# Inverse SHT
self.isht = InverseRealSHT(self.nlat, 2*self.nlat, grid=grid, norm='backward').to(dtype=dtype)
#Square root of the eigenvalues of C.
sqrt_eig = torch.tensor([j*(j+1) for j in range(self.nlat)]).view(self.nlat,1).repeat(1, self.nlat+1)
sqrt_eig = torch.tril(sigma*(((sqrt_eig/radius**2) + tau**2)**(-alpha/2.0)))
sqrt_eig[0,0] = 0.0
sqrt_eig = sqrt_eig.unsqueeze(0)
self.register_buffer('sqrt_eig', sqrt_eig)
#Save mean and var of the standard Gaussian.
#Need these to re-initialize distribution on a new device.
mean = torch.tensor([0.0]).to(dtype=dtype)
var = torch.tensor([1.0]).to(dtype=dtype)
self.register_buffer('mean', mean)
self.register_buffer('var', var)
#Standard normal noise sampler.
self.gaussian_noise = torch.distributions.normal.Normal(self.mean, self.var)
def forward(self, N, xi=None):
"""Sample random functions from a spherical GRF.
Parameters
----------
N : int
Number of functions to sample.
xi : torch.Tensor, default is None
Noise is a complex tensor of size (N, nlat, nlat+1).
If None, new Gaussian noise is sampled.
If xi is provided, N is ignored.
Output
-------
u : torch.Tensor
N random samples from the GRF returned as a
tensor of size (N, nlat, 2*nlat) on a equiangular grid.
"""
#Sample Gaussian noise.
if xi is None:
xi = self.gaussian_noise.sample(torch.Size((N, self.nlat, self.nlat + 1, 2))).squeeze()
xi = torch.view_as_complex(xi)
#Karhunen-Loeve expansion.
u = self.isht(xi*self.sqrt_eig)
return u
#Override cuda and to methods so sampler gets initialized with mean
#and variance on the correct device.
def cuda(self, *args, **kwargs):
super().cuda(*args, **kwargs)
self.gaussian_noise = torch.distributions.normal.Normal(self.mean, self.var)
return self
def to(self, *args, **kwargs):
super().to(*args, **kwargs)
self.gaussian_noise = torch.distributions.normal.Normal(self.mean, self.var)
return self
...@@ -36,7 +36,6 @@ import torch.fft ...@@ -36,7 +36,6 @@ import torch.fft
from .quadrature import * from .quadrature import *
from .legendre import * from .legendre import *
from .distributed import copy_to_parallel_region, scatter_to_parallel_region, reduce_from_parallel_region
class RealSHT(nn.Module): class RealSHT(nn.Module):
...@@ -94,10 +93,6 @@ class RealSHT(nn.Module): ...@@ -94,10 +93,6 @@ class RealSHT(nn.Module):
pct = precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase) pct = precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
weights = torch.einsum('mlk,k->mlk', pct, weights) weights = torch.einsum('mlk,k->mlk', pct, weights)
# shard the weights along n, because we want to be distributed in spectral space:
weights = scatter_to_parallel_region(weights, dim=1)
self.lmax_local = weights.shape[1]
# remember quadrature weights # remember quadrature weights
self.register_buffer('weights', weights, persistent=False) self.register_buffer('weights', weights, persistent=False)
...@@ -119,9 +114,8 @@ class RealSHT(nn.Module): ...@@ -119,9 +114,8 @@ class RealSHT(nn.Module):
x = torch.view_as_real(x) x = torch.view_as_real(x)
# distributed contraction: fork # distributed contraction: fork
x = copy_to_parallel_region(x)
out_shape = list(x.size()) out_shape = list(x.size())
out_shape[-3] = self.lmax_local out_shape[-3] = self.lmax
out_shape[-2] = self.mmax out_shape[-2] = self.mmax
xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device) xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device)
...@@ -174,10 +168,7 @@ class InverseRealSHT(nn.Module): ...@@ -174,10 +168,7 @@ class InverseRealSHT(nn.Module):
pct = precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase) pct = precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
# shard the pct along the n dim # register buffer
pct = scatter_to_parallel_region(pct, dim=1)
self.lmax_local = pct.shape[1]
self.register_buffer('pct', pct, persistent=False) self.register_buffer('pct', pct, persistent=False)
def extra_repr(self): def extra_repr(self):
...@@ -188,7 +179,7 @@ class InverseRealSHT(nn.Module): ...@@ -188,7 +179,7 @@ class InverseRealSHT(nn.Module):
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
assert(x.shape[-2] == self.lmax_local) assert(x.shape[-2] == self.lmax)
assert(x.shape[-1] == self.mmax) assert(x.shape[-1] == self.mmax)
# Evaluate associated Legendre functions on the output nodes # Evaluate associated Legendre functions on the output nodes
...@@ -198,9 +189,6 @@ class InverseRealSHT(nn.Module): ...@@ -198,9 +189,6 @@ class InverseRealSHT(nn.Module):
im = torch.einsum('...lm, mlk->...km', x[..., 1], self.pct ) im = torch.einsum('...lm, mlk->...km', x[..., 1], self.pct )
xs = torch.stack((rl, im), -1) xs = torch.stack((rl, im), -1)
# distributed contraction: join
xs = reduce_from_parallel_region(xs)
# apply the inverse (real) FFT # apply the inverse (real) FFT
x = torch.view_as_complex(xs) x = torch.view_as_complex(xs)
x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward") x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
...@@ -267,10 +255,6 @@ class RealVectorSHT(nn.Module): ...@@ -267,10 +255,6 @@ class RealVectorSHT(nn.Module):
# since the second component is imaginary, we need to take complex conjugation into account # since the second component is imaginary, we need to take complex conjugation into account
weights[1] = -1 * weights[1] weights[1] = -1 * weights[1]
# shard the weights along n, because we want to be distributed in spectral space:
weights = scatter_to_parallel_region(weights, dim=2)
self.lmax_local = weights.shape[2]
# remember quadrature weights # remember quadrature weights
self.register_buffer('weights', weights, persistent=False) self.register_buffer('weights', weights, persistent=False)
...@@ -291,9 +275,8 @@ class RealVectorSHT(nn.Module): ...@@ -291,9 +275,8 @@ class RealVectorSHT(nn.Module):
x = torch.view_as_real(x) x = torch.view_as_real(x)
# distributed contraction: fork # distributed contraction: fork
x = copy_to_parallel_region(x)
out_shape = list(x.size()) out_shape = list(x.size())
out_shape[-3] = self.lmax_local out_shape[-3] = self.lmax
out_shape[-2] = self.mmax out_shape[-2] = self.mmax
xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device) xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device)
...@@ -356,10 +339,7 @@ class InverseRealVectorSHT(nn.Module): ...@@ -356,10 +339,7 @@ class InverseRealVectorSHT(nn.Module):
dpct = precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase) dpct = precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
# shard the pct along the n dim # register weights
dpct = scatter_to_parallel_region(dpct, dim=2)
self.lmax_local = dpct.shape[2]
self.register_buffer('dpct', dpct, persistent=False) self.register_buffer('dpct', dpct, persistent=False)
def extra_repr(self): def extra_repr(self):
...@@ -370,7 +350,7 @@ class InverseRealVectorSHT(nn.Module): ...@@ -370,7 +350,7 @@ class InverseRealVectorSHT(nn.Module):
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
assert(x.shape[-2] == self.lmax_local) assert(x.shape[-2] == self.lmax)
assert(x.shape[-1] == self.mmax) assert(x.shape[-1] == self.mmax)
# Evaluate associated Legendre functions on the output nodes # Evaluate associated Legendre functions on the output nodes
...@@ -397,9 +377,6 @@ class InverseRealVectorSHT(nn.Module): ...@@ -397,9 +377,6 @@ class InverseRealVectorSHT(nn.Module):
t = torch.stack((trl, tim), -1) t = torch.stack((trl, tim), -1)
xs = torch.stack((s, t), -4) xs = torch.stack((s, t), -4)
# distributed contraction: join
xs = reduce_from_parallel_region(xs)
# apply the inverse (real) FFT # apply the inverse (real) FFT
x = torch.view_as_complex(xs) x = torch.view_as_complex(xs)
x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward") x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment