Commit ef4f6518 authored by Boris Bonev's avatar Boris Bonev
Browse files

v0.5 update

parent 77ac7836
......@@ -3,3 +3,4 @@ The code was authored by the following people:
Boris Bonev - NVIDIA Corporation
Christian Hundt - NVIDIA Corporation
Thorsten Kurth - NVIDIA Corporation
Nikola Kovachki - NVIDIA Corporation
......@@ -2,6 +2,11 @@
## Versioning
### v0.5
* Reworked distributed SHT
* Module for sampling Gaussian Random Fields on the sphere
### v0.4
* Computation of associated Legendre polynomials
......@@ -32,13 +37,4 @@
### v0.1
* Single GPU forward and backward transform
* Minimal code example and notebook
<!-- ## Detailed logs
### 23-11-2022
* Initialized the library
* Added `getting_started.ipynb` example
* Added simple example to test the SHT
* Logo -->
* Minimal code example and notebook
\ No newline at end of file
# SPDX-FileCopyrightText: Copyright (c) 2022 The FourCastNet Authors. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
......
......@@ -36,7 +36,13 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
<!-- ## What is torch-harmonics? -->
`torch_harmonics` is a differentiable implementation of the Spherical Harmonic transform in PyTorch. It uses quadrature to compute the projection onto the associated Legendre polynomials and FFTs for the projection onto the harmonic basis. This algorithm tends to outperform others with better asymptotic scaling for most practical purposes.
Spherical Harmonic Transforms (SHTs) are the counterpart to Fourier transforms on the sphere. As such they are an invaluable tool for signal-processing on the sphere.
`torch_harmonics` is a differentiable implementation of the SHT in PyTorch. It uses quadrature rules to compute the projection onto the associated Legendre polynomials and FFTs for the projection onto the harmonic basis. This algorithm tends to outperform others with better asymptotic scaling for most practical purposes.
`torch_harmonics` uses PyTorch primitives to implement these operations, making it fully differentiable. Moreover, the quadrature can be distributed onto multiple ranks making it spatially distributed.
`torch_harmonics` has been used to implement a variety of differentiable PDE solvers which generated the animations below.
<table border="0" cellspacing="0" cellpadding="0">
......@@ -76,6 +82,7 @@ docker run --gpus all -it --rm --ipc=host --ulimit memlock=-1 --ulimit stack=671
- Boris Bonev (bbonev@nvidia.com)
- Christian Hundt (chundt@nvidia.com)
- Thorsten Kurth (tkurth@nvidia.com)
- Nikola Kovachki (nkovachki@nvidia.com)
## Implementation
The implementation follows the paper "Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations", N. Schaeffer, G3: Geochemistry, Geophysics, Geosystems.
......@@ -126,7 +133,7 @@ The main functionality of `torch_harmonics` is provided in the form of `torch.nn
```python
import torch
import torch_harmonics as harmonics
import torch_harmonics as th
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
......@@ -136,11 +143,13 @@ batch_size = 32
signal = torch.randn(batch_size, nlat, nlon)
# transform data on an equiangular grid
sht = harmonics.RealSHT(nlat, nlon, grid="equiangular").to(device).float()
sht = th.RealSHT(nlat, nlon, grid="equiangular").to(device).float()
coeffs = sht(signal)
```
`torch_harmonics` also implements a distributed variant of the SHT located in `torch-harmonics.distributed`.
## References
<a id="1">[1]</a>
......
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The FourCastNet Authors. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
......
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The FourCastNet Authors. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
......
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The FourCastNet Authors. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
......
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The FourCastNet Authors. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
......@@ -333,7 +333,7 @@ class ShallowWaterSolver(nn.Module):
return out
def plot_griddata(self, data, fig, cmap='twilight_shifted', vmax=None, vmin=None, projection='3d', title=None, antialiased=True):
def plot_griddata(self, data, fig, cmap='twilight_shifted', vmax=None, vmin=None, projection='3d', title=None, antialiased=False):
"""
plotting routine for data on the grid. Requires cartopy for 3d plots.
"""
......
This diff is collapsed.
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The FourCastNet Authors. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
......@@ -33,7 +33,7 @@ from setuptools import setup
setup(
name='torch_harmonics',
version='0.4',
version='0.5',
author='Boris Bonev',
author_email='bbonev@nvidia.com',
packages=['torch_harmonics',],
......
......@@ -29,86 +29,57 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# ignore this (just for development without installation)
import sys
import os
sys.path.append("..")
sys.path.append(".")
# we need this in order to enable distributed
import torch
import torch.distributed as dist
import torch_harmonics as harmonics
from torch_harmonics.distributed.primitives import gather_from_parallel_region, scatter_to_parallel_region
try:
from tqdm import tqdm
except:
tqdm = lambda x : x
# set up distributed
world_size = int(os.getenv('WORLD_SIZE', 1))
world_rank = int(os.getenv('WORLD_RANK', 0))
port = int(os.getenv('MASTER_PORT', 0))
master_address = os.getenv('MASTER_ADDR', 'localhost')
dist.init_process_group(backend = 'nccl',
init_method = f"tcp://{master_address}:{port}",
rank = world_rank,
world_size = world_size)
local_rank = world_rank % torch.cuda.device_count()
mp_group = dist.new_group(ranks=list(range(world_size)))
my_rank = dist.get_rank(mp_group)
group_size = 1 if not dist.is_initialized() else dist.get_world_size(mp_group)
device = torch.device(f"cuda:{local_rank}")
# those need to be global
_POLAR_PARALLEL_GROUP = None
_AZIMUTH_PARALLEL_GROUP = None
_IS_INITIALIZED = False
# set seed
torch.manual_seed(333)
torch.cuda.manual_seed(333)
def polar_group():
return _POLAR_PARALLEL_GROUP
if my_rank == 0:
print(f"Running distributed test on {group_size} ranks.")
def azimuth_group():
return _AZIMUTH_PARALLEL_GROUP
# common parameters
b, c, n_theta, n_lambda = 1, 21, 361, 720
def init(polar_process_group, azimuth_process_group):
global _POLAR_PARALLEL_GROUP
global _AZIMUTH_PARALLEL_GROUP
_POLAR_PARALLEL_GROUP = polar_process_group
_AZIMUTH_PARALLEL_GROUP = azimuth_process_group
_IS_INITIALIZED = True
# do serial tests first:
#forward_transform = harmonics.RealSHT(n_theta, n_lambda).to(device)
inverse_transform = harmonics.InverseRealSHT(n_theta, n_lambda).to(device)
def is_initialized() -> bool:
return _IS_INITIALIZED
# set up signal
with torch.no_grad():
signal_leggauss = torch.randn(b, c, inverse_transform.lmax, inverse_transform.mmax, device=device, dtype=torch.complex128)
signal_leggauss_dist = signal_leggauss.clone()
signal_leggauss.requires_grad = True
def is_distributed_polar() -> bool:
return (_POLAR_PARALLEL_GROUP is not None)
# do a fwd and bwd pass:
x_local = inverse_transform(signal_leggauss)
loss = torch.sum(x_local)
loss.backward()
local_grad = torch.view_as_real(signal_leggauss.grad.clone())
def is_distributed_azimuth() -> bool:
return (_AZIMUTH_PARALLEL_GROUP is not None)
# now the distributed test
harmonics.distributed.init(mp_group)
inverse_transform_dist = harmonics.InverseRealSHT(n_theta, n_lambda).to(device)
with torch.no_grad():
signal_leggauss_dist = scatter_to_parallel_region(signal_leggauss_dist, dim=2)
signal_leggauss_dist.requires_grad = True
def polar_group_size() -> int:
if not is_distributed_polar():
return 1
else:
return dist.get_world_size(group = _POLAR_PARALLEL_GROUP)
# do distributed sht
x_dist = inverse_transform_dist(signal_leggauss_dist)
loss = torch.sum(x_dist)
loss.backward()
dist_grad = signal_leggauss_dist.grad.clone()
def azimuth_group_size() -> int:
if not is_distributed_azimuth():
return 1
else:
return dist.get_world_size(group = _AZIMUTH_PARALLEL_GROUP)
# gather the output
dist_grad = torch.view_as_real(gather_from_parallel_region(dist_grad, dim=2))
def polar_group_rank() -> int:
if not is_distributed_polar():
return 0
else:
return dist.get_rank(group = _POLAR_PARALLEL_GROUP)
if my_rank == 0:
print(f"Local Out: sum={x_local.abs().sum().item()}, max={x_local.max().item()}, min={x_local.min().item()}")
print(f"Dist Out: sum={x_dist.abs().sum().item()}, max={x_dist.max().item()}, min={x_dist.min().item()}")
diff = (x_local-x_dist).abs()
print(f"Out Difference: abs={diff.sum().item()}, rel={diff.sum().item() / (0.5*(x_local.abs().sum() + x_dist.abs().sum()))}, max={diff.max().item()}")
print("")
print(f"Local Grad: sum={local_grad.abs().sum().item()}, max={local_grad.max().item()}, min={local_grad.min().item()}")
print(f"Dist Grad: sum={dist_grad.abs().sum().item()}, max={dist_grad.max().item()}, min={dist_grad.min().item()}")
diff = (local_grad-dist_grad).abs()
print(f"Grad Difference: abs={diff.sum().item()}, rel={diff.sum().item() / (0.5*(local_grad.abs().sum() + dist_grad.abs().sum()))}, max={diff.max().item()}")
def azimuth_group_rank() -> int:
if not is_distributed_azimuth():
return 0
else:
return dist.get_rank(group = _AZIMUTH_PARALLEL_GROUP)
......@@ -36,9 +36,10 @@ sys.path.append("..")
sys.path.append(".")
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch_harmonics as harmonics
from torch_harmonics.distributed.primitives import gather_from_parallel_region
import torch_harmonics.distributed as thd
try:
from tqdm import tqdm
......@@ -46,70 +47,169 @@ except:
tqdm = lambda x : x
# set up distributed
world_size = int(os.getenv('WORLD_SIZE', 1))
world_rank = int(os.getenv('WORLD_RANK', 0))
grid_size_h = int(os.getenv('GRID_H', 1))
grid_size_w = int(os.getenv('GRID_W', 1))
port = int(os.getenv('MASTER_PORT', 0))
master_address = os.getenv('MASTER_ADDR', 'localhost')
world_size = grid_size_h * grid_size_w
dist.init_process_group(backend = 'nccl',
init_method = f"tcp://{master_address}:{port}",
rank = world_rank,
world_size = world_size)
local_rank = world_rank % torch.cuda.device_count()
mp_group = dist.new_group(ranks=list(range(world_size)))
my_rank = dist.get_rank(mp_group)
group_size = 1 if not dist.is_initialized() else dist.get_world_size(mp_group)
device = torch.device(f"cuda:{local_rank}")
# compute local ranks in h and w:
# rank = wrank + grid_size_w * hrank
wrank = world_rank % grid_size_w
hrank = world_rank // grid_size_w
w_group = None
h_group = None
# now set up the comm grid:
wgroups = []
for h in range(grid_size_h):
start = h
end = h + grid_size_w
wgroups.append(list(range(start, end)))
print(wgroups)
for grp in wgroups:
if len(grp) == 1:
continue
tmp_group = dist.new_group(ranks=grp)
if wrank in grp:
w_group = tmp_group
# transpose:
hgroups = [sorted(list(i)) for i in zip(*wgroups)]
print(hgroups)
for grp in hgroups:
if len(grp) == 1:
continue
tmp_group = dist.new_group(ranks=grp)
if hrank in grp:
h_group = tmp_group
# set device
torch.cuda.set_device(device.index)
# set seed
torch.manual_seed(333)
torch.cuda.manual_seed(333)
if my_rank == 0:
print(f"Running distributed test on {group_size} ranks.")
if world_rank == 0:
print(f"Running distributed test on grid H x W = {grid_size_h} x {grid_size_w}")
# initializing sht
thd.init(h_group, w_group)
# common parameters
b, c, n_theta, n_lambda = 1, 21, 361, 720
B, C, H, W = 1, 8, 721, 1440
Hloc = (H + grid_size_h - 1) // grid_size_h
Wloc = (W + grid_size_w - 1) // grid_size_w
Hpad = grid_size_h * Hloc - H
Wpad = grid_size_w * Wloc - W
# do serial tests first:
forward_transform = harmonics.RealSHT(n_theta, n_lambda).to(device)
inverse_transform = harmonics.InverseRealSHT(n_theta, n_lambda).to(device)
forward_transform_local = harmonics.RealSHT(nlat=H, nlon=W).to(device)
forward_transform_dist = thd.DistributedRealSHT(nlat=H, nlon=W).to(device)
Lloc = (forward_transform_dist.lpad + forward_transform_dist.lmax) // grid_size_h
Mloc = (forward_transform_dist.mpad + forward_transform_dist.mmax) // grid_size_w
# create tensors
inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=device)
# set up signal
# pad
with torch.no_grad():
signal_leggauss = inverse_transform(torch.randn(b, c, forward_transform.lmax, forward_transform.mmax, device=device, dtype=torch.complex128))
signal_leggauss_dist = signal_leggauss.clone()
signal_leggauss.requires_grad = True
signal_leggauss_dist.requires_grad = True
# do a fwd and bwd pass:
x_local = forward_transform(signal_leggauss)
loss = torch.sum(torch.view_as_real(x_local))
loss.backward()
x_local = torch.view_as_real(x_local)
local_grad = signal_leggauss.grad.clone()
# now the distributed test
harmonics.distributed.init(mp_group)
forward_transform_dist = harmonics.RealSHT(n_theta, n_lambda).to(device)
inverse_transform_dist = harmonics.InverseRealSHT(n_theta, n_lambda).to(device)
# do distributed sht
x_dist = forward_transform_dist(signal_leggauss_dist)
loss = torch.sum(torch.view_as_real(x_dist))
loss.backward()
x_dist = torch.view_as_real(x_dist)
dist_grad = signal_leggauss_dist.grad.clone()
# gather the output
x_dist = gather_from_parallel_region(x_dist, dim=2)
if my_rank == 0:
print(f"Local Out: sum={x_local.abs().sum().item()}, max={x_local.max().item()}, min={x_local.min().item()}")
print(f"Dist Out: sum={x_dist.abs().sum().item()}, max={x_dist.max().item()}, min={x_dist.min().item()}")
diff = (x_local-x_dist).abs()
print(f"Out Difference: abs={diff.sum().item()}, rel={diff.sum().item() / (0.5*(x_local.abs().sum() + x_dist.abs().sum()))}, max={diff.max().item()}")
inp_pad = F.pad(inp_full, (0, Wpad, 0, Hpad))
# split in W
inp_local = torch.split(inp_pad, split_size_or_sections=Wloc, dim=-1)[wrank]
# split in H
inp_local = torch.split(inp_local, split_size_or_sections=Hloc, dim=-2)[hrank]
# do FWD transform
out_full = forward_transform_local(inp_full)
out_local = forward_transform_dist(inp_local)
# gather the local data
# gather in W
if grid_size_w > 1:
olist = [torch.empty_like(out_local) for _ in range(grid_size_w)]
olist[wrank] = out_local
dist.all_gather(olist, out_local, group=w_group)
out_full_gather = torch.cat(olist, dim=-1)
out_full_gather = out_full_gather[..., :forward_transform_dist.mmax]
else:
out_full_gather = out_local
# gather in h
if grid_size_h > 1:
olist = [torch.empty_like(out_full_gather) for _ in range(grid_size_h)]
olist[hrank] = out_full_gather
dist.all_gather(olist, out_full_gather, group=h_group)
out_full_gather = torch.cat(olist, dim=-2)
out_full_gather = out_full_gather[..., :forward_transform_dist.lmax, :]
if world_rank == 0:
print(f"Local Out: sum={out_full.abs().sum().item()}, max={out_full.abs().max().item()}, min={out_full.abs().min().item()}")
print(f"Dist Out: sum={out_full_gather.abs().sum().item()}, max={out_full_gather.abs().max().item()}, min={out_full_gather.abs().min().item()}")
diff = (out_full-out_full_gather).abs()
print(f"Out Difference: abs={diff.sum().item()}, rel={diff.sum().item() / (0.5*(out_full.abs().sum() + out_full_gather.abs().sum()))}, max={diff.abs().max().item()}")
print("")
print(f"Local Grad: sum={local_grad.abs().sum().item()}, max={local_grad.max().item()}, min={local_grad.min().item()}")
print(f"Dist Grad: sum={dist_grad.abs().sum().item()}, max={dist_grad.max().item()}, min={dist_grad.min().item()}")
diff = (local_grad-dist_grad).abs()
print(f"Grad Difference: abs={diff.sum().item()}, rel={diff.sum().item() / (0.5*(local_grad.abs().sum() + dist_grad.abs().sum()))}, max={diff.max().item()}")
# create split input grad
with torch.no_grad():
# create full grad
ograd_full = torch.randn_like(out_full)
# pad
ograd_pad = F.pad(ograd_full, [0, forward_transform_dist.mpad, 0, forward_transform_dist.lpad])
# split in M
ograd_local = torch.split(ograd_pad, split_size_or_sections=Mloc, dim=-1)[wrank]
# split in H
ograd_local = torch.split(ograd_local, split_size_or_sections=Lloc, dim=-2)[hrank]
# backward pass:
# local
inp_full.requires_grad = True
out_full = forward_transform_local(inp_full)
out_full.backward(ograd_full)
igrad_full = inp_full.grad.clone()
# distributed
inp_local.requires_grad = True
out_local = forward_transform_dist(inp_local)
out_local.backward(ograd_local)
igrad_local = inp_local.grad.clone()
# gather
# gather in W
if grid_size_w > 1:
olist = [torch.empty_like(igrad_local) for _ in range(grid_size_w)]
olist[wrank] = igrad_local
dist.all_gather(olist, igrad_local, group=w_group)
igrad_full_gather = torch.cat(olist, dim=-1)
igrad_full_gather = igrad_full_gather[..., :W]
else:
igrad_full_gather = igrad_local
# gather in h
if grid_size_h > 1:
olist = [torch.empty_like(igrad_full_gather) for _ in range(grid_size_h)]
olist[hrank] = igrad_full_gather
dist.all_gather(olist, igrad_full_gather, group=h_group)
igrad_full_gather = torch.cat(olist, dim=-2)
igrad_full_gather = igrad_full_gather[..., :H, :]
if world_rank == 0:
print(f"Local Grad: sum={igrad_full.abs().sum().item()}, max={igrad_full.abs().max().item()}, min={igrad_full.abs().min().item()}")
print(f"Dist Grad: sum={igrad_full_gather.abs().sum().item()}, max={igrad_full_gather.abs().max().item()}, min={igrad_full_gather.abs().min().item()}")
diff = (igrad_full-igrad_full_gather).abs()
print(f"Grad Difference: abs={diff.sum().item()}, rel={diff.sum().item() / (0.5*(igrad_full.abs().sum() + igrad_full_gather.abs().sum()))}, max={diff.abs().max().item()}")
......@@ -31,3 +31,4 @@
from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT
from . import quadrature
from . import random_fields
......@@ -30,5 +30,10 @@
#
# we need this in order to enable distributed
from .utils import init, is_initialized
from .primitives import copy_to_parallel_region, scatter_to_parallel_region, reduce_from_parallel_region
from .utils import init, is_initialized, polar_group, azimuth_group
from .utils import polar_group_size, azimuth_group_size, polar_group_rank, azimuth_group_rank
from .primitives import distributed_transpose_azimuth, distributed_transpose_polar
# import the sht stuff
from .distributed_sht import DistributedRealSHT, DistributedInverseRealSHT
from .distributed_sht import DistributedRealVectorSHT, DistributedInverseRealVectorSHT
This diff is collapsed.
......@@ -32,7 +32,7 @@
import torch
import torch.distributed as dist
from .utils import get_model_parallel_group, is_initialized
from .utils import polar_group, azimuth_group, is_initialized
# general helpers
def get_memory_format(tensor):
......@@ -50,163 +50,54 @@ def split_tensor_along_dim(tensor, dim, num_chunks):
return tensor_list
# split
def _split(input_, dim_, group=None):
"""Split the tensor along its last dimension and keep the corresponding slice."""
def _transpose(tensor, dim0, dim1, group=None, async_op=False):
# get input format
input_format = get_memory_format(input_)
input_format = get_memory_format(tensor)
# Bypass the function if we are using only 1 GPU.
# get comm params
comm_size = dist.get_world_size(group=group)
if comm_size == 1:
return input_
# Split along last dimension.
input_list = split_tensor_along_dim(input_, dim_, comm_size)
# split and local transposition
split_size = tensor.shape[dim0] // comm_size
x_send = [y.contiguous(memory_format=input_format) for y in torch.split(tensor, split_size, dim=dim0)]
x_recv = [torch.empty_like(x_send[0]).contiguous(memory_format=input_format) for _ in range(comm_size)]
# Note: torch.split does not create contiguous tensors by default.
rank = dist.get_rank(group=group)
output = input_list[rank].contiguous(memory_format=input_format)
# global transposition
req = dist.all_to_all(x_recv, x_send, group=group, async_op=async_op)
return output
return x_recv, req
# those are used by the various helper functions
def _reduce(input_, use_fp32=True, group=None):
"""All-reduce the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
if dist.get_world_size(group=group) == 1:
return input_
# All-reduce.
if use_fp32:
dtype = input_.dtype
inputf_ = input_.float()
dist.all_reduce(inputf_, group=group)
input_ = inputf_.to(dtype)
else:
dist.all_reduce(input_, group=group)
return input_
class _CopyToParallelRegion(torch.autograd.Function):
"""Pass the input to the parallel region."""
@staticmethod
def symbolic(graph, input_):
return input_
class distributed_transpose_azimuth(torch.autograd.Function):
@staticmethod
def forward(ctx, input_):
return input_
def forward(ctx, x, dim):
xlist, _ = _transpose(x, dim[0], dim[1], group=azimuth_group())
x = torch.cat(xlist, dim=dim[1])
ctx.dim = dim
return x
@staticmethod
def backward(ctx, grad_output):
return _reduce(grad_output, group=get_model_parallel_group())
# write a convenient functional wrapper
def copy_to_parallel_region(input_):
if not is_initialized():
return input_
else:
return _CopyToParallelRegion.apply(input_)
def backward(ctx, go):
dim = ctx.dim
gilist, _ = _transpose(go, dim[1], dim[0], group=azimuth_group())
gi = torch.cat(gilist, dim=dim[0])
return gi, None
# reduce
class _ReduceFromParallelRegion(torch.autograd.Function):
"""All-reduce the input from the parallel region."""
@staticmethod
def symbolic(graph, input_):
return _reduce(input_, group=get_model_parallel_group())
@staticmethod
def forward(ctx, input_):
return _reduce(input_, group=get_model_parallel_group())
class distributed_transpose_polar(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
return grad_output
def reduce_from_parallel_region(input_):
if not is_initialized():
return input_
else:
return _ReduceFromParallelRegion.apply(input_)
# gather
def _gather(input_, dim_, group=None):
"""Gather tensors and concatinate along the last dimension."""
# get input format
input_format = get_memory_format(input_)
print(input_format)
comm_size = dist.get_world_size(group=group)
# Bypass the function if we are using only 1 GPU.
if comm_size==1:
return input_
# sanity checks
assert(dim_ < input_.dim()), f"Error, cannot gather along {dim_} for tensor with {input_.dim()} dimensions."
# Size and dimension.
comm_rank = dist.get_rank(group=group)
def forward(ctx, x, dim):
xlist, _ = _transpose(x, dim[0], dim[1], group=polar_group())
x = torch.cat(xlist, dim=dim[1])
ctx.dim = dim
return x
# input needs to be contiguous
input_ = input_.contiguous(memory_format=input_format)
tensor_list = [torch.empty_like(input_) for _ in range(comm_size)]
tensor_list[comm_rank] = input_
dist.all_gather(tensor_list, input_, group=group)
output = torch.cat(tensor_list, dim=dim_).contiguous(memory_format=input_format)
return output
class _GatherFromParallelRegion(torch.autograd.Function):
"""Gather the input from parallel region and concatinate."""
@staticmethod
def symbolic(graph, input_, dim_):
return _gather(input_, dim_, group=get_model_parallel_group())
@staticmethod
def forward(ctx, input_, dim_):
ctx.dim = dim_
return _gather(input_, dim_, group=get_model_parallel_group())
@staticmethod
def backward(ctx, grad_output):
return _split(grad_output, ctx.dim, group=get_model_parallel_group()), None
def gather_from_parallel_region(input_, dim):
if not is_initialized():
return input_
else:
return _GatherFromParallelRegion.apply(input_, dim)
def backward(ctx, go):
dim = ctx.dim
gilist, _ = _transpose(go, dim[1], dim[0], group=polar_group())
gi = torch.cat(gilist, dim=dim[0])
return gi, None
# scatter
class _ScatterToParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
@staticmethod
def symbolic(graph, input_, dim_):
return _split(input_, dim_, group=get_model_parallel_group())
@staticmethod
def forward(ctx, input_, dim_):
ctx.dim = dim_
return _split(input_, dim_, group=get_model_parallel_group())
@staticmethod
def backward(ctx, grad_output):
return _gather(grad_output, ctx.dim, group=get_model_parallel_group()), None
def scatter_to_parallel_region(input_, dim):
if not is_initialized():
return input_
else:
return _ScatterToParallelRegion.apply(input_, dim)
......@@ -34,15 +34,52 @@ import torch
import torch.distributed as dist
# those need to be global
_MODEL_PARALLEL_GROUP = None
_POLAR_PARALLEL_GROUP = None
_AZIMUTH_PARALLEL_GROUP = None
_IS_INITIALIZED = False
def get_model_parallel_group():
return _MODEL_PARALLEL_GROUP
def polar_group():
return _POLAR_PARALLEL_GROUP
def init(process_group):
global _MODEL_PARALLEL_GROUP
_MODEL_PARALLEL_GROUP = process_group
def azimuth_group():
return _AZIMUTH_PARALLEL_GROUP
def init(polar_process_group, azimuth_process_group):
global _POLAR_PARALLEL_GROUP
global _AZIMUTH_PARALLEL_GROUP
_POLAR_PARALLEL_GROUP = polar_process_group
_AZIMUTH_PARALLEL_GROUP = azimuth_process_group
_IS_INITIALIZED = True
def is_initialized() -> bool:
return _MODEL_PARALLEL_GROUP is not None
return _IS_INITIALIZED
def is_distributed_polar() -> bool:
return (_POLAR_PARALLEL_GROUP is not None)
def is_distributed_azimuth() -> bool:
return (_AZIMUTH_PARALLEL_GROUP is not None)
def polar_group_size() -> int:
if not is_distributed_polar():
return 1
else:
return dist.get_world_size(group = _POLAR_PARALLEL_GROUP)
def azimuth_group_size() -> int:
if not is_distributed_azimuth():
return 1
else:
return dist.get_world_size(group = _AZIMUTH_PARALLEL_GROUP)
def polar_group_rank() -> int:
if not is_distributed_polar():
return 0
else:
return dist.get_rank(group = _POLAR_PARALLEL_GROUP)
def azimuth_group_rank() -> int:
if not is_distributed_azimuth():
return 0
else:
return dist.get_rank(group = _AZIMUTH_PARALLEL_GROUP)
......@@ -53,7 +53,8 @@ def precompute_legpoly(mmax, lmax, t, norm="ortho", inverse=False, csphase=True)
"""
# compute the tensor P^m_n:
pct = np.zeros((mmax, lmax, len(t)), dtype=np.float64)
nmax = max(mmax,lmax)
pct = np.zeros((nmax, nmax, len(t)), dtype=np.float64)
sint = np.sin(t)
cost = np.cos(t)
......@@ -65,23 +66,25 @@ def precompute_legpoly(mmax, lmax, t, norm="ortho", inverse=False, csphase=True)
pct[0,0,:] = norm_factor / np.sqrt(4 * np.pi)
# fill the diagonal and the lower diagonal
for l in range(1, min(mmax,lmax)):
for l in range(1, nmax):
pct[l-1, l, :] = np.sqrt(2*l + 1) * cost * pct[l-1, l-1, :]
pct[l, l, :] = np.sqrt( (2*l + 1) * (1 + cost) * (1 - cost) / 2 / l ) * pct[l-1, l-1, :]
# fill the remaining values on the upper triangle and multiply b
for l in range(2, lmax):
for l in range(2, nmax):
for m in range(0, l-1):
pct[m, l, :] = cost * np.sqrt((2*l - 1) / (l - m) * (2*l + 1) / (l + m)) * pct[m, l-1, :] \
- np.sqrt((l + m - 1) / (l - m) * (2*l + 1) / (2*l - 3) * (l - m - 1) / (l + m)) * pct[m, l-2, :]
if norm == "schmidt":
for l in range(0, lmax):
for l in range(0, nmax):
if inverse:
pct[:, l, : ] = pct[:, l, : ] * np.sqrt(2*l + 1)
else:
pct[:, l, : ] = pct[:, l, : ] / np.sqrt(2*l + 1)
pct = pct[:mmax, :lmax]
if csphase:
for m in range(1, mmax, 2):
pct[m] *= -1
......
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import torch
from .sht import InverseRealSHT
class GaussianRandomFieldS2(torch.nn.Module):
def __init__(self, nlat, alpha=2.0, tau=3.0, sigma=None, radius=1.0, grid="equiangular", dtype=torch.float32):
super().__init__()
"""A mean-zero Gaussian Random Field on the sphere with Matern covariance:
C = sigma^2 (-Lap + tau^2 I)^(-alpha).
Lap is the Laplacian on the sphere, I the identity operator,
and sigma, tau, alpha are scalar parameters.
Note: C is trace-class on L^2 if and only if alpha > 1.
Parameters
----------
nlat : int
Number of latitudinal modes;
longitudinal modes are 2*nlat.
alpha : float, default is 2
Regularity parameter. Larger means smoother.
tau : float, default is 3
Lenght-scale parameter. Larger means more scales.
sigma : float, default is None
Scale parameter. Larger means bigger.
If None, sigma = tau**(0.5*(2*alpha - 2.0)).
radius : float, default is 1
Radius of the sphere.
grid : string, default is "equiangular"
Grid type. Currently supports "equiangular" and
"legendre-gauss".
dtype : torch.dtype, default is torch.float32
Numerical type for the calculations.
"""
#Number of latitudinal modes.
self.nlat = nlat
#Default value of sigma if None is given.
if sigma is None:
assert alpha > 1.0, f"Alpha must be greater than one, got {alpha}."
sigma = tau**(0.5*(2*alpha - 2.0))
# Inverse SHT
self.isht = InverseRealSHT(self.nlat, 2*self.nlat, grid=grid, norm='backward').to(dtype=dtype)
#Square root of the eigenvalues of C.
sqrt_eig = torch.tensor([j*(j+1) for j in range(self.nlat)]).view(self.nlat,1).repeat(1, self.nlat+1)
sqrt_eig = torch.tril(sigma*(((sqrt_eig/radius**2) + tau**2)**(-alpha/2.0)))
sqrt_eig[0,0] = 0.0
sqrt_eig = sqrt_eig.unsqueeze(0)
self.register_buffer('sqrt_eig', sqrt_eig)
#Save mean and var of the standard Gaussian.
#Need these to re-initialize distribution on a new device.
mean = torch.tensor([0.0]).to(dtype=dtype)
var = torch.tensor([1.0]).to(dtype=dtype)
self.register_buffer('mean', mean)
self.register_buffer('var', var)
#Standard normal noise sampler.
self.gaussian_noise = torch.distributions.normal.Normal(self.mean, self.var)
def forward(self, N, xi=None):
"""Sample random functions from a spherical GRF.
Parameters
----------
N : int
Number of functions to sample.
xi : torch.Tensor, default is None
Noise is a complex tensor of size (N, nlat, nlat+1).
If None, new Gaussian noise is sampled.
If xi is provided, N is ignored.
Output
-------
u : torch.Tensor
N random samples from the GRF returned as a
tensor of size (N, nlat, 2*nlat) on a equiangular grid.
"""
#Sample Gaussian noise.
if xi is None:
xi = self.gaussian_noise.sample(torch.Size((N, self.nlat, self.nlat + 1, 2))).squeeze()
xi = torch.view_as_complex(xi)
#Karhunen-Loeve expansion.
u = self.isht(xi*self.sqrt_eig)
return u
#Override cuda and to methods so sampler gets initialized with mean
#and variance on the correct device.
def cuda(self, *args, **kwargs):
super().cuda(*args, **kwargs)
self.gaussian_noise = torch.distributions.normal.Normal(self.mean, self.var)
return self
def to(self, *args, **kwargs):
super().to(*args, **kwargs)
self.gaussian_noise = torch.distributions.normal.Normal(self.mean, self.var)
return self
......@@ -36,7 +36,6 @@ import torch.fft
from .quadrature import *
from .legendre import *
from .distributed import copy_to_parallel_region, scatter_to_parallel_region, reduce_from_parallel_region
class RealSHT(nn.Module):
......@@ -94,10 +93,6 @@ class RealSHT(nn.Module):
pct = precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
weights = torch.einsum('mlk,k->mlk', pct, weights)
# shard the weights along n, because we want to be distributed in spectral space:
weights = scatter_to_parallel_region(weights, dim=1)
self.lmax_local = weights.shape[1]
# remember quadrature weights
self.register_buffer('weights', weights, persistent=False)
......@@ -119,9 +114,8 @@ class RealSHT(nn.Module):
x = torch.view_as_real(x)
# distributed contraction: fork
x = copy_to_parallel_region(x)
out_shape = list(x.size())
out_shape[-3] = self.lmax_local
out_shape[-3] = self.lmax
out_shape[-2] = self.mmax
xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device)
......@@ -174,10 +168,7 @@ class InverseRealSHT(nn.Module):
pct = precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
# shard the pct along the n dim
pct = scatter_to_parallel_region(pct, dim=1)
self.lmax_local = pct.shape[1]
# register buffer
self.register_buffer('pct', pct, persistent=False)
def extra_repr(self):
......@@ -188,7 +179,7 @@ class InverseRealSHT(nn.Module):
def forward(self, x: torch.Tensor):
assert(x.shape[-2] == self.lmax_local)
assert(x.shape[-2] == self.lmax)
assert(x.shape[-1] == self.mmax)
# Evaluate associated Legendre functions on the output nodes
......@@ -198,9 +189,6 @@ class InverseRealSHT(nn.Module):
im = torch.einsum('...lm, mlk->...km', x[..., 1], self.pct )
xs = torch.stack((rl, im), -1)
# distributed contraction: join
xs = reduce_from_parallel_region(xs)
# apply the inverse (real) FFT
x = torch.view_as_complex(xs)
x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
......@@ -267,10 +255,6 @@ class RealVectorSHT(nn.Module):
# since the second component is imaginary, we need to take complex conjugation into account
weights[1] = -1 * weights[1]
# shard the weights along n, because we want to be distributed in spectral space:
weights = scatter_to_parallel_region(weights, dim=2)
self.lmax_local = weights.shape[2]
# remember quadrature weights
self.register_buffer('weights', weights, persistent=False)
......@@ -291,9 +275,8 @@ class RealVectorSHT(nn.Module):
x = torch.view_as_real(x)
# distributed contraction: fork
x = copy_to_parallel_region(x)
out_shape = list(x.size())
out_shape[-3] = self.lmax_local
out_shape[-3] = self.lmax
out_shape[-2] = self.mmax
xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device)
......@@ -356,10 +339,7 @@ class InverseRealVectorSHT(nn.Module):
dpct = precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
# shard the pct along the n dim
dpct = scatter_to_parallel_region(dpct, dim=2)
self.lmax_local = dpct.shape[2]
# register weights
self.register_buffer('dpct', dpct, persistent=False)
def extra_repr(self):
......@@ -370,7 +350,7 @@ class InverseRealVectorSHT(nn.Module):
def forward(self, x: torch.Tensor):
assert(x.shape[-2] == self.lmax_local)
assert(x.shape[-2] == self.lmax)
assert(x.shape[-1] == self.mmax)
# Evaluate associated Legendre functions on the output nodes
......@@ -397,9 +377,6 @@ class InverseRealVectorSHT(nn.Module):
t = torch.stack((trl, tim), -1)
xs = torch.stack((s, t), -4)
# distributed contraction: join
xs = reduce_from_parallel_region(xs)
# apply the inverse (real) FFT
x = torch.view_as_complex(xs)
x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment