Unverified Commit eab72f04 authored by Thorsten Kurth's avatar Thorsten Kurth Committed by GitHub
Browse files

Tkurth/flexible sharding (#22)

* working distributed fwd SHT with flexible sharding and without padding

* fixing distributed sht with new logic

* fixed distributed SHT and ISHT with flexible padding

* working unittest for distributed fwd SHT

* distributed tests converted to unit tests

* bumping version number up to 0.6.4

* fixing small splitting bug in th distributed

* updated changeloc, removed tests folder
parent 31a33579
......@@ -2,6 +2,11 @@
## Versioning
### v0.6.4
* reworking distributed to allow for uneven split tensors, effectively removing the necessity of padding the transformed tensors
* distributed SHT tests are now using unittest. Test extended to vector SHT versions. Tests are defined in `torch_harmonics/distributed/distributed_tests.py`
* base pytorch container version bumped up to 23.11 in Dockerfile
### v0.6.3
* Adding gradient check in unit tests
......@@ -62,4 +67,4 @@
### v0.1.0
* Single GPU forward and backward transform
* Minimal code example and notebook
\ No newline at end of file
* Minimal code example and notebook
......@@ -30,9 +30,10 @@
# build after cloning in directoy torch_harmonics via
# docker build . -t torch_harmonics
FROM nvcr.io/nvidia/pytorch:22.08-py3
FROM nvcr.io/nvidia/pytorch:23.11-py3
COPY . /workspace/torch_harmonics
RUN pip install --use-feature=in-tree-build /workspace/torch_harmonics
RUN pip install parameterized
RUN pip install /workspace/torch_harmonics
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# ignore this (just for development without installation)
import sys
import os
sys.path.append("..")
sys.path.append(".")
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch_harmonics as harmonics
import torch_harmonics.distributed as thd
try:
from tqdm import tqdm
except:
tqdm = lambda x : x
# set up distributed
world_rank = int(os.getenv('WORLD_RANK', 0))
grid_size_h = int(os.getenv('GRID_H', 1))
grid_size_w = int(os.getenv('GRID_W', 1))
port = int(os.getenv('MASTER_PORT', 0))
master_address = os.getenv('MASTER_ADDR', 'localhost')
world_size = grid_size_h * grid_size_w
dist.init_process_group(backend = 'nccl',
init_method = f"tcp://{master_address}:{port}",
rank = world_rank,
world_size = world_size)
local_rank = world_rank % torch.cuda.device_count()
device = torch.device(f"cuda:{local_rank}")
# compute local ranks in h and w:
# rank = wrank + grid_size_w * hrank
wrank = world_rank % grid_size_w
hrank = world_rank // grid_size_w
w_group = None
h_group = None
# now set up the comm grid:
wgroups = []
for h in range(grid_size_h):
start = h
end = h + grid_size_w
wgroups.append(list(range(start, end)))
print(wgroups)
for grp in wgroups:
if len(grp) == 1:
continue
tmp_group = dist.new_group(ranks=grp)
if wrank in grp:
w_group = tmp_group
# transpose:
hgroups = [sorted(list(i)) for i in zip(*wgroups)]
print(hgroups)
for grp in hgroups:
if len(grp) == 1:
continue
tmp_group = dist.new_group(ranks=grp)
if hrank in grp:
h_group = tmp_group
# set device
torch.cuda.set_device(device.index)
# set seed
torch.manual_seed(333)
torch.cuda.manual_seed(333)
if world_rank == 0:
print(f"Running distributed test on grid H x W = {grid_size_h} x {grid_size_w}")
# initializing sht
thd.init(h_group, w_group)
# common parameters
B, C, H, W = 1, 8, 721, 1440
Hloc = (H + grid_size_h - 1) // grid_size_h
Wloc = (W + grid_size_w - 1) // grid_size_w
Hpad = grid_size_h * Hloc - H
Wpad = grid_size_w * Wloc - W
# do serial tests first:
forward_transform_local = harmonics.RealSHT(nlat=H, nlon=W).to(device)
backward_transform_local = harmonics.InverseRealSHT(nlat=H, nlon=W).to(device)
backward_transform_dist = thd.DistributedInverseRealSHT(nlat=H, nlon=W).to(device)
Lpad = backward_transform_dist.lpad
Mpad = backward_transform_dist.mpad
Lloc = (Lpad + backward_transform_dist.lmax) // grid_size_h
Mloc = (Mpad + backward_transform_dist.mmax) // grid_size_w
# create tensors
dummy_full = torch.randn((B, C, H, W), dtype=torch.float32, device=device)
inp_full = forward_transform_local(dummy_full)
# pad
with torch.no_grad():
inp_pad = F.pad(inp_full, (0, Mpad, 0, Lpad))
# split in W
inp_local = torch.split(inp_pad, split_size_or_sections=Mloc, dim=-1)[wrank]
# split in H
inp_local = torch.split(inp_local, split_size_or_sections=Lloc, dim=-2)[hrank]
# do FWD transform
out_full = backward_transform_local(inp_full)
out_local = backward_transform_dist(inp_local)
# gather the local data
# gather in W
if grid_size_w > 1:
olist = [torch.empty_like(out_local) for _ in range(grid_size_w)]
olist[wrank] = out_local
dist.all_gather(olist, out_local, group=w_group)
out_full_gather = torch.cat(olist, dim=-1)
out_full_gather = out_full_gather[..., :W]
else:
out_full_gather = out_local
# gather in h
if grid_size_h > 1:
olist = [torch.empty_like(out_full_gather) for _ in range(grid_size_h)]
olist[hrank] = out_full_gather
dist.all_gather(olist, out_full_gather, group=h_group)
out_full_gather = torch.cat(olist, dim=-2)
out_full_gather = out_full_gather[..., :H, :]
if world_rank == 0:
print(f"Local Out: sum={out_full.abs().sum().item()}, max={out_full.abs().max().item()}, min={out_full.abs().min().item()}")
print(f"Dist Out: sum={out_full_gather.abs().sum().item()}, max={out_full_gather.abs().max().item()}, min={out_full_gather.abs().min().item()}")
diff = (out_full-out_full_gather).abs()
print(f"Out Difference: abs={diff.sum().item()}, rel={diff.sum().item() / (0.5*(out_full.abs().sum() + out_full_gather.abs().sum()))}, max={diff.abs().max().item()}")
print("")
# create split input grad
with torch.no_grad():
# create full grad
ograd_full = torch.randn_like(out_full)
# pad
ograd_pad = F.pad(ograd_full, [0, Wpad, 0, Hpad])
# split in W
ograd_local = torch.split(ograd_pad, split_size_or_sections=Wloc, dim=-1)[wrank]
# split in H
ograd_local = torch.split(ograd_local, split_size_or_sections=Hloc, dim=-2)[hrank]
# backward pass:
# local
inp_full.requires_grad = True
out_full = backward_transform_local(inp_full)
out_full.backward(ograd_full)
igrad_full = inp_full.grad.clone()
# distributed
inp_local.requires_grad = True
out_local = backward_transform_dist(inp_local)
out_local.backward(ograd_local)
igrad_local = inp_local.grad.clone()
# gather
# gather in W
if grid_size_w > 1:
olist = [torch.empty_like(igrad_local) for _ in range(grid_size_w)]
olist[wrank] = igrad_local
dist.all_gather(olist, igrad_local, group=w_group)
igrad_full_gather = torch.cat(olist, dim=-1)
igrad_full_gather = igrad_full_gather[..., :backward_transform_dist.mmax]
else:
igrad_full_gather = igrad_local
# gather in h
if grid_size_h > 1:
olist = [torch.empty_like(igrad_full_gather) for _ in range(grid_size_h)]
olist[hrank] = igrad_full_gather
dist.all_gather(olist, igrad_full_gather, group=h_group)
igrad_full_gather = torch.cat(olist, dim=-2)
igrad_full_gather = igrad_full_gather[..., :backward_transform_dist.lmax, :]
if world_rank == 0:
print(f"Local Grad: sum={igrad_full.abs().sum().item()}, max={igrad_full.abs().max().item()}, min={igrad_full.abs().min().item()}")
print(f"Dist Grad: sum={igrad_full_gather.abs().sum().item()}, max={igrad_full_gather.abs().max().item()}, min={igrad_full_gather.abs().min().item()}")
diff = (igrad_full-igrad_full_gather).abs()
print(f"Grad Difference: abs={diff.sum().item()}, rel={diff.sum().item() / (0.5*(igrad_full.abs().sum() + igrad_full_gather.abs().sum()))}, max={diff.abs().max().item()}")
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# ignore this (just for development without installation)
import sys
import os
sys.path.append("..")
sys.path.append(".")
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch_harmonics as harmonics
import torch_harmonics.distributed as thd
try:
from tqdm import tqdm
except:
tqdm = lambda x : x
# set up distributed
world_rank = int(os.getenv('WORLD_RANK', 0))
grid_size_h = int(os.getenv('GRID_H', 1))
grid_size_w = int(os.getenv('GRID_W', 1))
port = int(os.getenv('MASTER_PORT', 0))
master_address = os.getenv('MASTER_ADDR', 'localhost')
world_size = grid_size_h * grid_size_w
dist.init_process_group(backend = 'nccl',
init_method = f"tcp://{master_address}:{port}",
rank = world_rank,
world_size = world_size)
local_rank = world_rank % torch.cuda.device_count()
device = torch.device(f"cuda:{local_rank}")
# compute local ranks in h and w:
# rank = wrank + grid_size_w * hrank
wrank = world_rank % grid_size_w
hrank = world_rank // grid_size_w
w_group = None
h_group = None
# now set up the comm grid:
wgroups = []
for h in range(grid_size_h):
start = h
end = h + grid_size_w
wgroups.append(list(range(start, end)))
print(wgroups)
for grp in wgroups:
if len(grp) == 1:
continue
tmp_group = dist.new_group(ranks=grp)
if wrank in grp:
w_group = tmp_group
# transpose:
hgroups = [sorted(list(i)) for i in zip(*wgroups)]
print(hgroups)
for grp in hgroups:
if len(grp) == 1:
continue
tmp_group = dist.new_group(ranks=grp)
if hrank in grp:
h_group = tmp_group
# set device
torch.cuda.set_device(device.index)
# set seed
torch.manual_seed(333)
torch.cuda.manual_seed(333)
if world_rank == 0:
print(f"Running distributed test on grid H x W = {grid_size_h} x {grid_size_w}")
# initializing sht
thd.init(h_group, w_group)
# common parameters
B, C, H, W = 1, 8, 721, 1440
Hloc = (H + grid_size_h - 1) // grid_size_h
Wloc = (W + grid_size_w - 1) // grid_size_w
Hpad = grid_size_h * Hloc - H
Wpad = grid_size_w * Wloc - W
# do serial tests first:
forward_transform_local = harmonics.RealSHT(nlat=H, nlon=W).to(device)
forward_transform_dist = thd.DistributedRealSHT(nlat=H, nlon=W).to(device)
Lloc = (forward_transform_dist.lpad + forward_transform_dist.lmax) // grid_size_h
Mloc = (forward_transform_dist.mpad + forward_transform_dist.mmax) // grid_size_w
# create tensors
inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=device)
# pad
with torch.no_grad():
inp_pad = F.pad(inp_full, (0, Wpad, 0, Hpad))
# split in W
inp_local = torch.split(inp_pad, split_size_or_sections=Wloc, dim=-1)[wrank]
# split in H
inp_local = torch.split(inp_local, split_size_or_sections=Hloc, dim=-2)[hrank]
# do FWD transform
out_full = forward_transform_local(inp_full)
out_local = forward_transform_dist(inp_local)
# gather the local data
# gather in W
if grid_size_w > 1:
olist = [torch.empty_like(out_local) for _ in range(grid_size_w)]
olist[wrank] = out_local
dist.all_gather(olist, out_local, group=w_group)
out_full_gather = torch.cat(olist, dim=-1)
out_full_gather = out_full_gather[..., :forward_transform_dist.mmax]
else:
out_full_gather = out_local
# gather in h
if grid_size_h > 1:
olist = [torch.empty_like(out_full_gather) for _ in range(grid_size_h)]
olist[hrank] = out_full_gather
dist.all_gather(olist, out_full_gather, group=h_group)
out_full_gather = torch.cat(olist, dim=-2)
out_full_gather = out_full_gather[..., :forward_transform_dist.lmax, :]
if world_rank == 0:
print(f"Local Out: sum={out_full.abs().sum().item()}, max={out_full.abs().max().item()}, min={out_full.abs().min().item()}")
print(f"Dist Out: sum={out_full_gather.abs().sum().item()}, max={out_full_gather.abs().max().item()}, min={out_full_gather.abs().min().item()}")
diff = (out_full-out_full_gather).abs()
print(f"Out Difference: abs={diff.sum().item()}, rel={diff.sum().item() / (0.5*(out_full.abs().sum() + out_full_gather.abs().sum()))}, max={diff.abs().max().item()}")
print("")
# create split input grad
with torch.no_grad():
# create full grad
ograd_full = torch.randn_like(out_full)
# pad
ograd_pad = F.pad(ograd_full, [0, forward_transform_dist.mpad, 0, forward_transform_dist.lpad])
# split in M
ograd_local = torch.split(ograd_pad, split_size_or_sections=Mloc, dim=-1)[wrank]
# split in H
ograd_local = torch.split(ograd_local, split_size_or_sections=Lloc, dim=-2)[hrank]
# backward pass:
# local
inp_full.requires_grad = True
out_full = forward_transform_local(inp_full)
out_full.backward(ograd_full)
igrad_full = inp_full.grad.clone()
# distributed
inp_local.requires_grad = True
out_local = forward_transform_dist(inp_local)
out_local.backward(ograd_local)
igrad_local = inp_local.grad.clone()
# gather
# gather in W
if grid_size_w > 1:
olist = [torch.empty_like(igrad_local) for _ in range(grid_size_w)]
olist[wrank] = igrad_local
dist.all_gather(olist, igrad_local, group=w_group)
igrad_full_gather = torch.cat(olist, dim=-1)
igrad_full_gather = igrad_full_gather[..., :W]
else:
igrad_full_gather = igrad_local
# gather in h
if grid_size_h > 1:
olist = [torch.empty_like(igrad_full_gather) for _ in range(grid_size_h)]
olist[hrank] = igrad_full_gather
dist.all_gather(olist, igrad_full_gather, group=h_group)
igrad_full_gather = torch.cat(olist, dim=-2)
igrad_full_gather = igrad_full_gather[..., :H, :]
if world_rank == 0:
print(f"Local Grad: sum={igrad_full.abs().sum().item()}, max={igrad_full.abs().max().item()}, min={igrad_full.abs().min().item()}")
print(f"Dist Grad: sum={igrad_full_gather.abs().sum().item()}, max={igrad_full_gather.abs().max().item()}, min={igrad_full_gather.abs().min().item()}")
diff = (igrad_full-igrad_full_gather).abs()
print(f"Grad Difference: abs={diff.sum().item()}, rel={diff.sum().item() / (0.5*(igrad_full.abs().sum() + igrad_full_gather.abs().sum()))}, max={diff.abs().max().item()}")
......@@ -29,7 +29,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
__version__ = '0.6.3'
__version__ = '0.6.4'
from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT
from . import quadrature
......
......@@ -32,7 +32,7 @@
# we need this in order to enable distributed
from .utils import init, is_initialized, polar_group, azimuth_group
from .utils import polar_group_size, azimuth_group_size, polar_group_rank, azimuth_group_rank
from .primitives import distributed_transpose_azimuth, distributed_transpose_polar
from .primitives import compute_split_shapes, split_tensor_along_dim, distributed_transpose_azimuth, distributed_transpose_polar
# import the sht stuff
from .distributed_sht import DistributedRealSHT, DistributedInverseRealSHT
......
......@@ -40,6 +40,7 @@ from torch_harmonics.quadrature import legendre_gauss_weights, lobatto_weights,
from torch_harmonics.legendre import _precompute_legpoly, _precompute_dlegpoly
from torch_harmonics.distributed import polar_group_size, azimuth_group_size, distributed_transpose_azimuth, distributed_transpose_polar
from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank
from torch_harmonics.distributed import compute_split_shapes, split_tensor_along_dim
class DistributedRealSHT(nn.Module):
......@@ -98,40 +99,20 @@ class DistributedRealSHT(nn.Module):
# determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1
# spatial paddings
latdist = (self.nlat + self.comm_size_polar - 1) // self.comm_size_polar
self.nlatpad = latdist * self.comm_size_polar - self.nlat
londist = (self.nlon + self.comm_size_azimuth - 1) // self.comm_size_azimuth
self.nlonpad = londist * self.comm_size_azimuth - self.nlon
# frequency paddings
ldist = (self.lmax + self.comm_size_polar - 1) // self.comm_size_polar
self.lpad = ldist * self.comm_size_polar - self.lmax
mdist = (self.mmax + self.comm_size_azimuth - 1) // self.comm_size_azimuth
self.mpad = mdist * self.comm_size_azimuth - self.mmax
# compute splits
self.lat_shapes = compute_split_shapes(self.nlat, self.comm_size_polar)
self.lon_shapes = compute_split_shapes(self.nlon, self.comm_size_azimuth)
self.l_shapes = compute_split_shapes(self.lmax, self.comm_size_polar)
self.m_shapes = compute_split_shapes(self.mmax, self.comm_size_azimuth)
# combine quadrature weights with the legendre weights
weights = torch.from_numpy(w)
pct = _precompute_legpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
pct = torch.from_numpy(pct)
pct = torch.from_numpy(pct)
weights = torch.einsum('mlk,k->mlk', pct, weights)
# we need to split in m, pad before:
weights = F.pad(weights, [0, 0, 0, 0, 0, self.mpad], mode="constant")
weights = torch.split(weights, (self.mmax+self.mpad) // self.comm_size_azimuth, dim=0)[self.comm_rank_azimuth]
# compute the local pad and size
# spatial
self.nlat_local = min(latdist, self.nlat - self.comm_rank_polar * latdist)
self.nlatpad_local = latdist - self.nlat_local
self.nlon_local = min(londist, self.nlon - self.comm_rank_azimuth * londist)
self.nlonpad_local = londist - self.nlon_local
# frequency
self.lmax_local = min(ldist, self.lmax - self.comm_rank_polar * ldist)
self.lpad_local = ldist - self.lmax_local
self.mmax_local = min(mdist, self.mmax - self.comm_rank_azimuth * mdist)
self.mpad_local = mdist - self.mmax_local
# split weights
weights = split_tensor_along_dim(weights, dim=0, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth]
# remember quadrature weights
self.register_buffer('weights', weights, persistent=False)
......@@ -145,56 +126,42 @@ class DistributedRealSHT(nn.Module):
def forward(self, x: torch.Tensor):
# we need to ensure that we can split the channels evenly
assert(x.shape[1] % self.comm_size_polar == 0)
assert(x.shape[1] % self.comm_size_azimuth == 0)
num_chans = x.shape[1]
# h and w is split. First we make w local by transposing into channel dim
if self.comm_size_azimuth > 1:
xt = distributed_transpose_azimuth.apply(x, (1, -1))
else:
xt = x
x = distributed_transpose_azimuth.apply(x, (1, -1), self.lon_shapes)
# apply real fft in the longitudinal direction: make sure to truncate to nlon
xtf = 2.0 * torch.pi * torch.fft.rfft(xt, n=self.nlon, dim=-1, norm="forward")
x = 2.0 * torch.pi * torch.fft.rfft(x, n=self.nlon, dim=-1, norm="forward")
# truncate
xtft = xtf[..., :self.mmax]
# pad the dim to allow for splitting
xtfp = F.pad(xtft, [0, self.mpad], mode="constant")
x = x[..., :self.mmax]
# transpose: after this, m is split and c is local
if self.comm_size_azimuth > 1:
y = distributed_transpose_azimuth.apply(xtfp, (-1, 1))
else:
y = xtfp
chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
x = distributed_transpose_azimuth.apply(x, (-1, 1), chan_shapes)
# transpose: after this, c is split and h is local
if self.comm_size_polar > 1:
yt = distributed_transpose_polar.apply(y, (1, -2))
else:
yt = y
# the input data might be padded, make sure to truncate to nlat:
ytt = yt[..., :self.nlat, :]
x = distributed_transpose_polar.apply(x, (1, -2), self.lat_shapes)
# do the Legendre-Gauss quadrature
yttr = torch.view_as_real(ytt)
x = torch.view_as_real(x)
# contraction
yor = torch.einsum('...kmr,mlk->...lmr', yttr, self.weights.to(yttr.dtype)).contiguous()
xs = torch.einsum('...kmr,mlk->...lmr', x, self.weights.to(x.dtype)).contiguous()
# pad if required, truncation is implicit
yopr = F.pad(yor, [0, 0, 0, 0, 0, self.lpad], mode="constant")
yop = torch.view_as_complex(yopr)
# cast to complex
x = torch.view_as_complex(xs)
# transpose: after this, l is split and c is local
if self.comm_size_polar > 1:
y = distributed_transpose_polar.apply(yop, (-2, 1))
else:
y = yop
return y
chan_shapes = compute_split_shapes(num_chans, self.comm_size_polar)
x = distributed_transpose_polar.apply(x, (-2, 1), chan_shapes)
return x
class DistributedInverseRealSHT(nn.Module):
......@@ -243,38 +210,18 @@ class DistributedInverseRealSHT(nn.Module):
# determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1
# spatial paddings
latdist = (self.nlat + self.comm_size_polar - 1) // self.comm_size_polar
self.nlatpad = latdist * self.comm_size_polar - self.nlat
londist = (self.nlon + self.comm_size_azimuth - 1) // self.comm_size_azimuth
self.nlonpad = londist * self.comm_size_azimuth - self.nlon
# frequency paddings
ldist = (self.lmax + self.comm_size_polar - 1) // self.comm_size_polar
self.lpad = ldist * self.comm_size_polar - self.lmax
mdist = (self.mmax + self.comm_size_azimuth - 1) // self.comm_size_azimuth
self.mpad = mdist * self.comm_size_azimuth - self.mmax
# compute splits
self.lat_shapes = compute_split_shapes(self.nlat, self.comm_size_polar)
self.lon_shapes = compute_split_shapes(self.nlon, self.comm_size_azimuth)
self.l_shapes = compute_split_shapes(self.lmax, self.comm_size_polar)
self.m_shapes = compute_split_shapes(self.mmax, self.comm_size_azimuth)
# compute legende polynomials
pct = _precompute_legpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
pct = torch.from_numpy(pct)
# split in m
pct = F.pad(pct, [0, 0, 0, 0, 0, self.mpad], mode="constant")
pct = torch.split(pct, (self.mmax+self.mpad) // self.comm_size_azimuth, dim=0)[self.comm_rank_azimuth]
# compute the local pads and sizes
# spatial
self.nlat_local = min(latdist, self.nlat - self.comm_rank_polar * latdist)
self.nlatpad_local = latdist - self.nlat_local
self.nlon_local = min(londist, self.nlon - self.comm_rank_azimuth * londist)
self.nlonpad_local = londist - self.nlon_local
# frequency
self.lmax_local = min(ldist, self.lmax - self.comm_rank_polar * ldist)
self.lpad_local = ldist - self.lmax_local
self.mmax_local = min(mdist, self.mmax - self.comm_rank_azimuth * mdist)
self.mpad_local = mdist - self.mmax_local
pct = split_tensor_along_dim(pct, dim=0, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth]
# register
self.register_buffer('pct', pct, persistent=False)
......@@ -288,55 +235,41 @@ class DistributedInverseRealSHT(nn.Module):
def forward(self, x: torch.Tensor):
# we need to ensure that we can split the channels evenly
assert(x.shape[1] % self.comm_size_polar == 0)
assert(x.shape[1] % self.comm_size_azimuth == 0)
num_chans = x.shape[1]
# transpose: after that, channels are split, l is local:
if self.comm_size_polar > 1:
xt = distributed_transpose_polar.apply(x, (1, -2))
else:
xt = x
# remove padding in l:
xtt = xt[..., :self.lmax, :]
x = distributed_transpose_polar.apply(x, (1, -2), self.l_shapes)
# Evaluate associated Legendre functions on the output nodes
xttr = torch.view_as_real(xtt)
x = torch.view_as_real(x)
# einsum
xs = torch.einsum('...lmr, mlk->...kmr', xttr, self.pct.to(xttr.dtype)).contiguous()
x = torch.view_as_complex(xs)
xs = torch.einsum('...lmr, mlk->...kmr', x, self.pct.to(x.dtype)).contiguous()
#rl = torch.einsum('...lm, mlk->...km', x[..., 0], self.pct.to(x.dtype) )
#im = torch.einsum('...lm, mlk->...km', x[..., 1], self.pct.to(x.dtype) )
#xs = torch.stack((rl, im), -1).contiguous()
# transpose: after this, l is split and channels are local
xp = F.pad(x, [0, 0, 0, self.nlatpad])
# inverse FFT
x = torch.view_as_complex(xs)
if self.comm_size_polar > 1:
y = distributed_transpose_polar.apply(xp, (-2, 1))
else:
y = xp
chan_shapes = compute_split_shapes(num_chans, self.comm_size_polar)
x = distributed_transpose_polar.apply(x, (-2, 1), chan_shapes)
# transpose: after this, channels are split and m is local
if self.comm_size_azimuth > 1:
yt = distributed_transpose_azimuth.apply(y, (1, -1))
else:
yt = y
# truncate
ytt = yt[..., :self.mmax]
x = distributed_transpose_azimuth.apply(x, (1, -1), self.m_shapes)
# apply the inverse (real) FFT
x = torch.fft.irfft(ytt, n=self.nlon, dim=-1, norm="forward")
# pad before we transpose back
xp = F.pad(x, [0, self.nlonpad])
x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
# transpose: after this, m is split and channels are local
if self.comm_size_azimuth > 1:
out = distributed_transpose_azimuth.apply(xp, (-1, 1))
else:
out = xp
chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
x = distributed_transpose_azimuth.apply(x, (-1, 1), chan_shapes)
return out
return x
class DistributedRealVectorSHT(nn.Module):
......@@ -393,18 +326,13 @@ class DistributedRealVectorSHT(nn.Module):
# determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1
# spatial paddings
latdist = (self.nlat + self.comm_size_polar - 1) // self.comm_size_polar
self.nlatpad = latdist * self.comm_size_polar - self.nlat
londist = (self.nlon + self.comm_size_azimuth - 1) // self.comm_size_azimuth
self.nlonpad = londist * self.comm_size_azimuth - self.nlon
# frequency paddings
ldist = (self.lmax + self.comm_size_polar - 1) // self.comm_size_polar
self.lpad = ldist * self.comm_size_polar - self.lmax
mdist = (self.mmax + self.comm_size_azimuth - 1) // self.comm_size_azimuth
self.mpad = mdist * self.comm_size_azimuth - self.mmax
# compute splits
self.lat_shapes = compute_split_shapes(self.nlat, self.comm_size_polar)
self.lon_shapes = compute_split_shapes(self.nlon, self.comm_size_azimuth)
self.l_shapes = compute_split_shapes(self.lmax, self.comm_size_polar)
self.m_shapes = compute_split_shapes(self.mmax, self.comm_size_azimuth)
# compute weights
weights = torch.from_numpy(w)
dpct = _precompute_dlegpoly(self.mmax, self.lmax, tq, norm=self.norm, csphase=self.csphase)
dpct = torch.from_numpy(dpct)
......@@ -418,25 +346,12 @@ class DistributedRealVectorSHT(nn.Module):
weights[1] = -1 * weights[1]
# we need to split in m, pad before:
weights = F.pad(weights, [0, 0, 0, 0, 0, self.mpad], mode="constant")
weights = torch.split(weights, (self.mmax+self.mpad) // self.comm_size_azimuth, dim=1)[self.comm_rank_azimuth]
weights = split_tensor_along_dim(weights, dim=1, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth]
# remember quadrature weights
self.register_buffer('weights', weights, persistent=False)
# compute the local pad and size
# spatial
self.nlat_local = min(latdist, self.nlat - self.comm_rank_polar * latdist)
self.nlatpad_local = latdist - self.nlat_local
self.nlon_local = min(londist, self.nlon - self.comm_rank_azimuth * londist)
self.nlonpad_local = londist - self.nlon_local
# frequency
self.lmax_local = min(ldist, self.lmax - self.comm_rank_polar * ldist)
self.lpad_local = ldist - self.lmax_local
self.mmax_local = min(mdist, self.mmax - self.comm_rank_azimuth * mdist)
self.mpad_local = mdist - self.mmax_local
def extra_repr(self):
"""
Pretty print module
......@@ -446,72 +361,60 @@ class DistributedRealVectorSHT(nn.Module):
def forward(self, x: torch.Tensor):
assert(len(x.shape) >= 3)
assert(x.shape[1] % self.comm_size_polar == 0)
assert(x.shape[1] % self.comm_size_azimuth == 0)
# we need to ensure that we can split the channels evenly
num_chans = x.shape[1]
# h and w is split. First we make w local by transposing into channel dim
if self.comm_size_azimuth > 1:
xt = distributed_transpose_azimuth.apply(x, (1, -1))
else:
xt = x
x = distributed_transpose_azimuth.apply(x, (1, -1), self.lon_shapes)
# apply real fft in the longitudinal direction: make sure to truncate to nlon
xtf = 2.0 * torch.pi * torch.fft.rfft(xt, n=self.nlon, dim=-1, norm="forward")
x = 2.0 * torch.pi * torch.fft.rfft(x, n=self.nlon, dim=-1, norm="forward")
# truncate
xtft = xtf[..., :self.mmax]
# pad the dim to allow for splitting
xtfp = F.pad(xtft, [0, self.mpad], mode="constant")
x = x[..., :self.mmax]
# transpose: after this, m is split and c is local
if self.comm_size_azimuth > 1:
y = distributed_transpose_azimuth.apply(xtfp, (-1, 1))
else:
y = xtfp
chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
x = distributed_transpose_azimuth.apply(x, (-1, 1), chan_shapes)
# transpose: after this, c is split and h is local
if self.comm_size_polar > 1:
yt = distributed_transpose_polar.apply(y, (1, -2))
else:
yt = y
# the input data might be padded, make sure to truncate to nlat:
ytt = yt[..., :self.nlat, :]
x = distributed_transpose_polar.apply(x, (1, -2), self.lat_shapes)
# do the Legendre-Gauss quadrature
yttr = torch.view_as_real(ytt)
x = torch.view_as_real(x)
# create output array
yor = torch.zeros_like(yttr, dtype=yttr.dtype, device=yttr.device)
xs = torch.zeros_like(x, dtype=x.dtype, device=x.device)
# contraction - spheroidal component
# real component
yor[..., 0, :, :, 0] = torch.einsum('...km,mlk->...lm', yttr[..., 0, :, :, 0], self.weights[0].to(yttr.dtype)) \
- torch.einsum('...km,mlk->...lm', yttr[..., 1, :, :, 1], self.weights[1].to(yttr.dtype))
# iamg component
yor[..., 0, :, :, 1] = torch.einsum('...km,mlk->...lm', yttr[..., 0, :, :, 1], self.weights[0].to(yttr.dtype)) \
+ torch.einsum('...km,mlk->...lm', yttr[..., 1, :, :, 0], self.weights[1].to(yttr.dtype))
xs[..., 0, :, :, 0] = torch.einsum('...km,mlk->...lm', x[..., 0, :, :, 0], self.weights[0].to(xs.dtype)) \
- torch.einsum('...km,mlk->...lm', x[..., 1, :, :, 1], self.weights[1].to(xs.dtype))
# imag component
xs[..., 0, :, :, 1] = torch.einsum('...km,mlk->...lm', x[..., 0, :, :, 1], self.weights[0].to(xs.dtype)) \
+ torch.einsum('...km,mlk->...lm', x[..., 1, :, :, 0], self.weights[1].to(xs.dtype))
# contraction - toroidal component
# real component
yor[..., 1, :, :, 0] = - torch.einsum('...km,mlk->...lm', yttr[..., 0, :, :, 1], self.weights[1].to(yttr.dtype)) \
- torch.einsum('...km,mlk->...lm', yttr[..., 1, :, :, 0], self.weights[0].to(yttr.dtype))
xs[..., 1, :, :, 0] = - torch.einsum('...km,mlk->...lm', x[..., 0, :, :, 1], self.weights[1].to(xs.dtype)) \
- torch.einsum('...km,mlk->...lm', x[..., 1, :, :, 0], self.weights[0].to(xs.dtype))
# imag component
yor[..., 1, :, :, 1] = torch.einsum('...km,mlk->...lm', yttr[..., 0, :, :, 0], self.weights[1].to(yttr.dtype)) \
- torch.einsum('...km,mlk->...lm', yttr[..., 1, :, :, 1], self.weights[0].to(yttr.dtype))
xs[..., 1, :, :, 1] = torch.einsum('...km,mlk->...lm', x[..., 0, :, :, 0], self.weights[1].to(xs.dtype)) \
- torch.einsum('...km,mlk->...lm', x[..., 1, :, :, 1], self.weights[0].to(xs.dtype))
# pad if required
yopr = F.pad(yor, [0, 0, 0, 0, 0, self.lpad], mode="constant")
yop = torch.view_as_complex(yopr)
x = torch.view_as_complex(xs)
# transpose: after this, l is split and c is local
if self.comm_size_polar > 1:
y = distributed_transpose_polar.apply(yop, (-2, 1))
else:
y = yop
chan_shapes = compute_split_shapes(num_chans, self.comm_size_polar)
x = distributed_transpose_polar.apply(x, (-2, 1), chan_shapes)
return y
return x
class DistributedInverseRealVectorSHT(nn.Module):
......@@ -556,42 +459,22 @@ class DistributedInverseRealVectorSHT(nn.Module):
# determine the dimensions
self.mmax = mmax or self.nlon // 2 + 1
# spatial paddings
latdist = (self.nlat + self.comm_size_polar - 1) // self.comm_size_polar
self.nlatpad = latdist * self.comm_size_polar - self.nlat
londist = (self.nlon + self.comm_size_azimuth - 1) // self.comm_size_azimuth
self.nlonpad = londist * self.comm_size_azimuth - self.nlon
# frequency paddings
ldist = (self.lmax + self.comm_size_polar - 1) // self.comm_size_polar
self.lpad = ldist * self.comm_size_polar - self.lmax
mdist = (self.mmax + self.comm_size_azimuth - 1) // self.comm_size_azimuth
self.mpad = mdist * self.comm_size_azimuth - self.mmax
# compute splits
self.lat_shapes = compute_split_shapes(self.nlat, self.comm_size_polar)
self.lon_shapes = compute_split_shapes(self.nlon, self.comm_size_azimuth)
self.l_shapes = compute_split_shapes(self.lmax, self.comm_size_polar)
self.m_shapes = compute_split_shapes(self.mmax, self.comm_size_azimuth)
# compute legende polynomials
dpct = _precompute_dlegpoly(self.mmax, self.lmax, t, norm=self.norm, inverse=True, csphase=self.csphase)
dpct = torch.from_numpy(dpct)
# split in m
dpct = F.pad(dpct, [0, 0, 0, 0, 0, self.mpad], mode="constant")
dpct = torch.split(dpct, (self.mmax+self.mpad) // self.comm_size_azimuth, dim=0)[self.comm_rank_azimuth]
dpct = split_tensor_along_dim(dpct, dim=1, num_chunks=self.comm_size_azimuth)[self.comm_rank_azimuth]
# register buffer
self.register_buffer('dpct', dpct, persistent=False)
# compute the local pad and size
# spatial
self.nlat_local = min(latdist, self.nlat - self.comm_rank_polar * latdist)
self.nlatpad_local = latdist - self.nlat_local
self.nlon_local = min(londist, self.nlon - self.comm_rank_azimuth * londist)
self.nlonpad_local = londist - self.nlon_local
# frequency
self.lmax_local = min(ldist, self.lmax - self.comm_rank_polar * ldist)
self.lpad_local = ldist - self.lmax_local
self.mmax_local = min(mdist, self.mmax - self.comm_rank_azimuth * mdist)
self.mpad_local = mdist - self.mmax_local
def extra_repr(self):
"""
Pretty print module
......@@ -600,36 +483,31 @@ class DistributedInverseRealVectorSHT(nn.Module):
def forward(self, x: torch.Tensor):
assert(x.shape[1] % self.comm_size_polar == 0)
assert(x.shape[1] % self.comm_size_azimuth == 0)
# store num channels
num_chans = x.shape[1]
# transpose: after that, channels are split, l is local:
if self.comm_size_polar > 1:
xt = distributed_transpose_polar.apply(x, (1, -2))
else:
xt = x
# remove padding in l:
xtt = xt[..., :self.lmax, :]
x = distributed_transpose_polar.apply(x, (1, -2), self.l_shapes)
# Evaluate associated Legendre functions on the output nodes
xttr = torch.view_as_real(xtt)
x = torch.view_as_real(x)
# contraction - spheroidal component
# real component
srl = torch.einsum('...lm,mlk->...km', xttr[..., 0, :, :, 0], self.dpct[0].to(xttr.dtype)) \
- torch.einsum('...lm,mlk->...km', xttr[..., 1, :, :, 1], self.dpct[1].to(xttr.dtype))
srl = torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 0], self.dpct[0].to(x.dtype)) \
- torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 1], self.dpct[1].to(x.dtype))
# imag component
sim = torch.einsum('...lm,mlk->...km', xttr[..., 0, :, :, 1], self.dpct[0].to(xttr.dtype)) \
+ torch.einsum('...lm,mlk->...km', xttr[..., 1, :, :, 0], self.dpct[1].to(xttr.dtype))
sim = torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 1], self.dpct[0].to(x.dtype)) \
+ torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 0], self.dpct[1].to(x.dtype))
# contraction - toroidal component
# real component
trl = - torch.einsum('...lm,mlk->...km', xttr[..., 0, :, :, 1], self.dpct[1].to(xttr.dtype)) \
- torch.einsum('...lm,mlk->...km', xttr[..., 1, :, :, 0], self.dpct[0].to(xttr.dtype))
trl = - torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 1], self.dpct[1].to(x.dtype)) \
- torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 0], self.dpct[0].to(x.dtype))
# imag component
tim = torch.einsum('...lm,mlk->...km', xttr[..., 0, :, :, 0], self.dpct[1].to(xttr.dtype)) \
- torch.einsum('...lm,mlk->...km', xttr[..., 1, :, :, 1], self.dpct[0].to(xttr.dtype))
tim = torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 0], self.dpct[1].to(x.dtype)) \
- torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 1], self.dpct[0].to(x.dtype))
# reassemble
s = torch.stack((srl, sim), -1)
......@@ -639,33 +517,20 @@ class DistributedInverseRealVectorSHT(nn.Module):
# convert to complex
x = torch.view_as_complex(xs)
# transpose: after this, l is split and channels are local
xp = F.pad(x, [0, 0, 0, self.nlatpad])
if self.comm_size_polar > 1:
y = distributed_transpose_polar.apply(xp, (-2, 1))
else:
y = xp
chan_shapes = compute_split_shapes(num_chans, self.comm_size_polar)
x = distributed_transpose_polar.apply(x, (-2, 1), chan_shapes)
# transpose: after this, channels are split and m is local
if self.comm_size_azimuth > 1:
yt = distributed_transpose_azimuth.apply(y, (1, -1))
else:
yt = y
# truncate
ytt = yt[..., :self.mmax]
x = distributed_transpose_azimuth.apply(x, (1, -1), self.m_shapes)
# apply the inverse (real) FFT
x = torch.fft.irfft(x, n=self.nlon, dim=-1, norm="forward")
# pad before we transpose back
xp = F.pad(x, [0, self.nlonpad])
# transpose: after this, m is split and channels are local
if self.comm_size_azimuth > 1:
out = distributed_transpose_azimuth.apply(xp, (-1, 1))
else:
out = xp
chan_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
x = distributed_transpose_azimuth.apply(x, (-1, 1), chan_shapes)
return out
\ No newline at end of file
return x
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import os
import unittest
from parameterized import parameterized
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch_harmonics as harmonics
import torch_harmonics.distributed as thd
class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
@classmethod
def setUpClass(cls):
# set up distributed
cls.world_rank = int(os.getenv('WORLD_RANK', 0))
cls.grid_size_h = int(os.getenv('GRID_H', 1))
cls.grid_size_w = int(os.getenv('GRID_W', 1))
port = int(os.getenv('MASTER_PORT', '29501'))
master_address = os.getenv('MASTER_ADDR', 'localhost')
cls.world_size = cls.grid_size_h * cls.grid_size_w
if torch.cuda.is_available():
if cls.world_rank == 0:
print("Running test on GPU")
local_rank = cls.world_rank % torch.cuda.device_count()
cls.device = torch.device(f"cuda:{local_rank}")
torch.cuda.manual_seed(333)
proc_backend = 'nccl'
else:
if cls.world_rank == 0:
print("Running test on CPU")
cls.device = torch.device('cpu')
proc_backend = 'gloo'
torch.manual_seed(333)
dist.init_process_group(backend = proc_backend,
init_method = f"tcp://{master_address}:{port}",
rank = cls.world_rank,
world_size = cls.world_size)
cls.wrank = cls.world_rank % cls.grid_size_w
cls.hrank = cls.world_rank // cls.grid_size_w
# now set up the comm groups:
#set default
cls.w_group = None
cls.h_group = None
# do the init
wgroups = []
for w in range(0, cls.world_size, cls.grid_size_w):
start = w
end = w + cls.grid_size_w
wgroups.append(list(range(start, end)))
if cls.world_rank == 0:
print("w-groups:", wgroups)
for grp in wgroups:
if len(grp) == 1:
continue
tmp_group = dist.new_group(ranks=grp)
if cls.world_rank in grp:
cls.w_group = tmp_group
# transpose:
hgroups = [sorted(list(i)) for i in zip(*wgroups)]
if cls.world_rank == 0:
print("h-groups:", hgroups)
for grp in hgroups:
if len(grp) == 1:
continue
tmp_group = dist.new_group(ranks=grp)
if cls.world_rank in grp:
cls.h_group = tmp_group
# set seed
torch.manual_seed(333)
if cls.world_rank == 0:
print(f"Running distributed tests on grid H x W = {cls.grid_size_h} x {cls.grid_size_w}")
# initializing sht
thd.init(cls.h_group, cls.w_group)
def _split_helper(self, tensor):
with torch.no_grad():
# split in W
tensor_list_local = thd.split_tensor_along_dim(tensor, dim=-1, num_chunks=self.grid_size_w)
tensor_local = tensor_list_local[self.wrank]
# split in H
tensor_list_local = thd.split_tensor_along_dim(tensor_local, dim=-2, num_chunks=self.grid_size_h)
tensor_local = tensor_list_local[self.hrank]
return tensor_local
def _gather_helper_fwd(self, tensor, B, C, transform_dist, vector):
# we need the shapes
l_shapes = transform_dist.l_shapes
m_shapes = transform_dist.m_shapes
# gather in W
if self.grid_size_w > 1:
if vector:
gather_shapes = [(B, C, 2, l_shapes[self.hrank], m) for m in m_shapes]
else:
gather_shapes = [(B, C, l_shapes[self.hrank], m) for m in m_shapes]
olist = [torch.empty(shape, dtype=tensor.dtype, device=tensor.device) for shape in gather_shapes]
olist[self.wrank] = tensor
dist.all_gather(olist, tensor, group=self.w_group)
tensor_gather = torch.cat(olist, dim=-1)
else:
tensor_gather = tensor
# gather in H
if self.grid_size_h > 1:
if vector:
gather_shapes = [(B, C, 2, l, transform_dist.mmax) for l in l_shapes]
else:
gather_shapes = [(B, C, l, transform_dist.mmax) for l in l_shapes]
olist = [torch.empty(shape, dtype=tensor_gather.dtype, device=tensor_gather.device) for shape in gather_shapes]
olist[self.hrank] = tensor_gather
dist.all_gather(olist, tensor_gather, group=self.h_group)
tensor_gather = torch.cat(olist, dim=-2)
return tensor_gather
def _gather_helper_bwd(self, tensor, B, C, transform_dist, vector):
# we need the shapes
lat_shapes = transform_dist.lat_shapes
lon_shapes = transform_dist.lon_shapes
# gather in W
if self.grid_size_w > 1:
if vector:
gather_shapes = [(B, C, 2, lat_shapes[self.hrank], w) for w in lon_shapes]
else:
gather_shapes = [(B, C, lat_shapes[self.hrank], w) for w in lon_shapes]
olist = [torch.empty(shape, dtype=tensor.dtype, device=tensor.device) for shape in gather_shapes]
olist[self.wrank] = tensor
dist.all_gather(olist, tensor, group=self.w_group)
tensor_gather = torch.cat(olist, dim=-1)
else:
tensor_gather = tensor
# gather in H
if self.grid_size_h > 1:
if vector:
gather_shapes = [(B, C, 2, h, transform_dist.nlon) for h in lat_shapes]
else:
gather_shapes = [(B, C, h, transform_dist.nlon) for h in lat_shapes]
olist = [torch.empty(shape, dtype=tensor_gather.dtype, device=tensor_gather.device) for shape in gather_shapes]
olist[self.hrank] = tensor_gather
dist.all_gather(olist, tensor_gather, group=self.h_group)
tensor_gather = torch.cat(olist, dim=-2)
return tensor_gather
@parameterized.expand([
[256, 512, 32, 8, "equiangular", False, 1e-9],
[256, 512, 32, 8, "legendre-gauss", False, 1e-9],
[256, 512, 32, 8, "equiangular", False, 1e-9],
[256, 512, 32, 8, "legendre-gauss", False, 1e-9],
[256, 512, 32, 8, "equiangular", False, 1e-9],
[256, 512, 32, 8, "legendre-gauss", False, 1e-9],
[361, 720, 1, 10, "equiangular", False, 1e-6],
[361, 720, 1, 10, "legendre-gauss", False, 1e-6],
[256, 512, 32, 8, "equiangular", True, 1e-9],
[256, 512, 32, 8, "legendre-gauss", True, 1e-9],
[256, 512, 32, 8, "equiangular", True, 1e-9],
[256, 512, 32, 8, "legendre-gauss", True, 1e-9],
[256, 512, 32, 8, "equiangular", True, 1e-9],
[256, 512, 32, 8, "legendre-gauss", True, 1e-9],
[361, 720, 1, 10, "equiangular", True, 1e-6],
[361, 720, 1, 10, "legendre-gauss", True, 1e-6],
])
def test_distributed_sht(self, nlat, nlon, batch_size, num_chan, grid, vector, tol):
B, C, H, W = batch_size, num_chan, nlat, nlon
# set up handles
if vector:
forward_transform_local = harmonics.RealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
forward_transform_dist = thd.DistributedRealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
else:
forward_transform_local = harmonics.RealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
forward_transform_dist = thd.DistributedRealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
# create tensors
if vector:
inp_full = torch.randn((B, C, 2, H, W), dtype=torch.float32, device=self.device)
else:
inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device)
#############################################################
# local transform
#############################################################
# FWD pass
inp_full.requires_grad = True
out_full = forward_transform_local(inp_full)
# create grad for backward
with torch.no_grad():
# create full grad
ograd_full = torch.randn_like(out_full)
# BWD pass
out_full.backward(ograd_full)
igrad_full = inp_full.grad.clone()
#############################################################
# distributed transform
#############################################################
# FWD pass
inp_local = self._split_helper(inp_full)
inp_local.requires_grad = True
out_local = forward_transform_dist(inp_local)
# BWD pass
ograd_local = self._split_helper(ograd_full)
out_local = forward_transform_dist(inp_local)
out_local.backward(ograd_local)
igrad_local = inp_local.grad.clone()
#############################################################
# evaluate FWD pass
#############################################################
with torch.no_grad():
out_gather_full = self._gather_helper_fwd(out_local, B, C, forward_transform_dist, vector)
err = torch.mean(torch.norm(out_full-out_gather_full, p='fro', dim=(-1,-2)) / torch.norm(out_full, p='fro', dim=(-1,-2)) )
if self.world_rank == 0:
print(f"final relative error of output: {err.item()}")
self.assertTrue(err.item() <= tol)
#############################################################
# evaluate BWD pass
#############################################################
with torch.no_grad():
igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, forward_transform_dist, vector)
err = torch.mean(torch.norm(igrad_full-igrad_gather_full, p='fro', dim=(-1,-2)) / torch.norm(igrad_full, p='fro', dim=(-1,-2)) )
if self.world_rank == 0:
print(f"final relative error of gradients: {err.item()}")
self.assertTrue(err.item() <= tol)
@parameterized.expand([
[256, 512, 32, 8, "equiangular", False, 1e-9],
[256, 512, 32, 8, "legendre-gauss", False, 1e-9],
[256, 512, 32, 8, "equiangular", False, 1e-9],
[256, 512, 32, 8, "legendre-gauss", False, 1e-9],
[256, 512, 32, 8, "equiangular", False, 1e-9],
[256, 512, 32, 8, "legendre-gauss", False, 1e-9],
[361, 720, 1, 10, "equiangular", False, 1e-6],
[361, 720, 1, 10, "legendre-gauss", False, 1e-6],
[256, 512, 32, 8, "equiangular", True, 1e-9],
[256, 512, 32, 8, "legendre-gauss", True, 1e-9],
[256, 512, 32, 8, "equiangular", True, 1e-9],
[256, 512, 32, 8, "legendre-gauss", True, 1e-9],
[256, 512, 32, 8, "equiangular", True, 1e-9],
[256, 512, 32, 8, "legendre-gauss", True, 1e-9],
[361, 720, 1, 10, "equiangular", True, 1e-6],
[361, 720, 1, 10, "legendre-gauss", True, 1e-6],
])
def test_distributed_isht(self, nlat, nlon, batch_size, num_chan, grid, vector, tol):
B, C, H, W = batch_size, num_chan, nlat, nlon
if vector:
forward_transform_local = harmonics.RealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
backward_transform_local = harmonics.InverseRealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
backward_transform_dist = thd.DistributedInverseRealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
else:
forward_transform_local = harmonics.RealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
backward_transform_local = harmonics.InverseRealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
backward_transform_dist = thd.DistributedInverseRealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
# create tensors
if vector:
dummy_full = torch.randn((B, C, 2, H, W), dtype=torch.float32, device=self.device)
else:
dummy_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device)
inp_full = forward_transform_local(dummy_full)
#############################################################
# local transform
#############################################################
# FWD pass
inp_full.requires_grad = True
out_full = backward_transform_local(inp_full)
# create grad for backward
with torch.no_grad():
# create full grad
ograd_full = torch.randn_like(out_full)
# BWD pass
out_full.backward(ograd_full)
# repeat once due to known irfft bug
inp_full.grad = None
out_full = backward_transform_local(inp_full)
out_full.backward(ograd_full)
igrad_full = inp_full.grad.clone()
#############################################################
# distributed transform
#############################################################
# FWD pass
inp_local = self._split_helper(inp_full)
inp_local.requires_grad = True
out_local = backward_transform_dist(inp_local)
# BWD pass
ograd_local = self._split_helper(ograd_full)
out_local = backward_transform_dist(inp_local)
out_local.backward(ograd_local)
igrad_local = inp_local.grad.clone()
#############################################################
# evaluate FWD pass
#############################################################
with torch.no_grad():
out_gather_full = self._gather_helper_bwd(out_local, B, C, backward_transform_dist, vector)
err = torch.mean(torch.norm(out_full-out_gather_full, p='fro', dim=(-1,-2)) / torch.norm(out_full, p='fro', dim=(-1,-2)) )
if self.world_rank == 0:
print(f"final relative error of output: {err.item()}")
self.assertTrue(err.item() <= tol)
#############################################################
# evaluate BWD pass
#############################################################
with torch.no_grad():
igrad_gather_full = self._gather_helper_fwd(igrad_local, B, C, backward_transform_dist, vector)
err = torch.mean(torch.norm(igrad_full-igrad_gather_full, p='fro', dim=(-1,-2)) / torch.norm(igrad_full, p='fro', dim=(-1,-2)) )
if self.world_rank == 0:
print(f"final relative error of gradients: {err.item()}")
self.assertTrue(err.item() <= tol)
if __name__ == '__main__':
unittest.main()
......@@ -28,12 +28,34 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
from typing import List
import torch
import torch.distributed as dist
from .utils import polar_group, azimuth_group, is_initialized
# helper routine to compute uneven splitting in balanced way:
def compute_split_shapes(size: int, num_chunks: int) -> List[int]:
# treat trivial case first
if num_chunks == 1:
return [size]
# first, check if we can split using div-up to balance the load:
chunk_size = (size + num_chunks - 1) // num_chunks
last_chunk_size = max(0, size - chunk_size * (num_chunks - 1))
if last_chunk_size == 0:
# in this case, the last shard would be empty, split with floor instead:
chunk_size = size // num_chunks
last_chunk_size = size - chunk_size * (num_chunks-1)
# generate sections list
sections = [chunk_size for _ in range(num_chunks - 1)] + [last_chunk_size]
return sections
# general helpers
def get_memory_format(tensor):
if tensor.is_contiguous(memory_format=torch.channels_last):
......@@ -41,75 +63,92 @@ def get_memory_format(tensor):
else:
return torch.contiguous_format
def split_tensor_along_dim(tensor, dim, num_chunks):
assert dim < tensor.dim(), f"Error, tensor dimension is {tensor.dim()} which cannot be split along {dim}"
assert (tensor.shape[dim] % num_chunks == 0), f"Error, cannot split dim {dim} evenly. Dim size is \
{tensor.shape[dim]} and requested numnber of splits is {num_chunks}"
chunk_size = tensor.shape[dim] // num_chunks
tensor_list = torch.split(tensor, chunk_size, dim=dim)
assert (tensor.shape[dim] >= num_chunks), f"Error, cannot split dim {dim} of size {tensor.shape[dim]} into \
{num_chunks} chunks. Empty slices are currently not supported."
# get split
sections = compute_split_shapes(tensor.shape[dim], num_chunks)
tensor_list = torch.split(tensor, sections, dim=dim)
return tensor_list
def _transpose(tensor, dim0, dim1, group=None, async_op=False):
def _transpose(tensor, dim0, dim1, dim1_split_sizes, group=None, async_op=False):
# get input format
input_format = get_memory_format(tensor)
# get comm params
comm_size = dist.get_world_size(group=group)
comm_rank = dist.get_rank(group=group)
# split and local transposition
split_size = tensor.shape[dim0] // comm_size
x_send = [y.contiguous(memory_format=input_format) for y in torch.split(tensor, split_size, dim=dim0)]
x_recv = [torch.empty_like(x_send[0]).contiguous(memory_format=input_format) for _ in range(comm_size)]
tsplit = split_tensor_along_dim(tensor, num_chunks=comm_size, dim=dim0)
x_send = [y.contiguous(memory_format=input_format) for y in tsplit]
x_send_shapes = [x.shape for x in x_send]
x_recv = []
x_shape = list(x_send_shapes[comm_rank])
for dim1_len in dim1_split_sizes:
x_shape[dim1] = dim1_len
x_recv.append(torch.empty(x_shape, dtype=tensor.dtype, device=tensor.device, memory_format=input_format))
# global transposition
req = dist.all_to_all(x_recv, x_send, group=group, async_op=async_op)
# get dim0 split sizes
dim0_split_sizes = [x[dim0] for x in x_send_shapes]
return x_recv, req
return x_recv, dim0_split_sizes, req
class distributed_transpose_azimuth(torch.autograd.Function):
@staticmethod
def forward(ctx, x, dim):
def forward(ctx, x, dims, dim1_split_sizes):
input_format = get_memory_format(x)
# WAR for a potential contig check torch bug for channels last contig tensors
x = x.contiguous()
xlist, _ = _transpose(x, dim[0], dim[1], group=azimuth_group())
x = torch.cat(xlist, dim=dim[1]).contiguous(memory_format=input_format)
ctx.dim = dim
xlist, dim0_split_sizes, _ = _transpose(x, dims[0], dims[1], dim1_split_sizes, group=azimuth_group())
x = torch.cat(xlist, dim=dims[1]).contiguous(memory_format=input_format)
ctx.dims = dims
ctx.dim0_split_sizes = dim0_split_sizes
return x
@staticmethod
def backward(ctx, go):
input_format = get_memory_format(go)
dim = ctx.dim
dims = ctx.dims
dim0_split_sizes = ctx.dim0_split_sizes
# WAR for a potential contig check torch bug for channels last contig tensors
go = go.contiguous()
gilist, _ = _transpose(go, dim[1], dim[0], group=azimuth_group())
gi = torch.cat(gilist, dim=dim[0]).contiguous(memory_format=input_format)
return gi, None
gilist, _, _ = _transpose(go, dims[1], dims[0], dim0_split_sizes, group=azimuth_group())
gi = torch.cat(gilist, dim=dims[0]).contiguous(memory_format=input_format)
return gi, None, None
class distributed_transpose_polar(torch.autograd.Function):
@staticmethod
def forward(ctx, x, dim):
def forward(ctx, x, dim, dim1_split_sizes):
input_format = get_memory_format(x)
# WAR for a potential contig check torch bug for channels last contig tensors
x = x.contiguous()
xlist, _ = _transpose(x, dim[0], dim[1], group=polar_group())
xlist, dim0_split_sizes, _ = _transpose(x, dim[0], dim[1], dim1_split_sizes, group=polar_group())
x = torch.cat(xlist, dim=dim[1]).contiguous(memory_format=input_format)
ctx.dim = dim
ctx.dim0_split_sizes = dim0_split_sizes
return x
@staticmethod
def backward(ctx, go):
input_format = get_memory_format(go)
dim = ctx.dim
dim0_split_sizes = ctx.dim0_split_sizes
# WAR for a potential contig check torch bug for channels last contig tensors
go = go.contiguous()
gilist, _ = _transpose(go, dim[1], dim[0], group=polar_group())
gilist, _, _ = _transpose(go, dim[1], dim[0], dim0_split_sizes, group=polar_group())
gi = torch.cat(gilist, dim=dim[0]).contiguous(memory_format=input_format)
return gi, None
return gi, None, None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment