Unverified Commit eab72f04 authored by Thorsten Kurth's avatar Thorsten Kurth Committed by GitHub
Browse files

Tkurth/flexible sharding (#22)

* working distributed fwd SHT with flexible sharding and without padding

* fixing distributed sht with new logic

* fixed distributed SHT and ISHT with flexible padding

* working unittest for distributed fwd SHT

* distributed tests converted to unit tests

* bumping version number up to 0.6.4

* fixing small splitting bug in th distributed

* updated changeloc, removed tests folder
parent 31a33579
...@@ -2,6 +2,11 @@ ...@@ -2,6 +2,11 @@
## Versioning ## Versioning
### v0.6.4
* reworking distributed to allow for uneven split tensors, effectively removing the necessity of padding the transformed tensors
* distributed SHT tests are now using unittest. Test extended to vector SHT versions. Tests are defined in `torch_harmonics/distributed/distributed_tests.py`
* base pytorch container version bumped up to 23.11 in Dockerfile
### v0.6.3 ### v0.6.3
* Adding gradient check in unit tests * Adding gradient check in unit tests
......
...@@ -30,9 +30,10 @@ ...@@ -30,9 +30,10 @@
# build after cloning in directoy torch_harmonics via # build after cloning in directoy torch_harmonics via
# docker build . -t torch_harmonics # docker build . -t torch_harmonics
FROM nvcr.io/nvidia/pytorch:22.08-py3 FROM nvcr.io/nvidia/pytorch:23.11-py3
COPY . /workspace/torch_harmonics COPY . /workspace/torch_harmonics
RUN pip install --use-feature=in-tree-build /workspace/torch_harmonics RUN pip install parameterized
RUN pip install /workspace/torch_harmonics
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# ignore this (just for development without installation)
import sys
import os
sys.path.append("..")
sys.path.append(".")
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch_harmonics as harmonics
import torch_harmonics.distributed as thd
try:
from tqdm import tqdm
except:
tqdm = lambda x : x
# set up distributed
world_rank = int(os.getenv('WORLD_RANK', 0))
grid_size_h = int(os.getenv('GRID_H', 1))
grid_size_w = int(os.getenv('GRID_W', 1))
port = int(os.getenv('MASTER_PORT', 0))
master_address = os.getenv('MASTER_ADDR', 'localhost')
world_size = grid_size_h * grid_size_w
dist.init_process_group(backend = 'nccl',
init_method = f"tcp://{master_address}:{port}",
rank = world_rank,
world_size = world_size)
local_rank = world_rank % torch.cuda.device_count()
device = torch.device(f"cuda:{local_rank}")
# compute local ranks in h and w:
# rank = wrank + grid_size_w * hrank
wrank = world_rank % grid_size_w
hrank = world_rank // grid_size_w
w_group = None
h_group = None
# now set up the comm grid:
wgroups = []
for h in range(grid_size_h):
start = h
end = h + grid_size_w
wgroups.append(list(range(start, end)))
print(wgroups)
for grp in wgroups:
if len(grp) == 1:
continue
tmp_group = dist.new_group(ranks=grp)
if wrank in grp:
w_group = tmp_group
# transpose:
hgroups = [sorted(list(i)) for i in zip(*wgroups)]
print(hgroups)
for grp in hgroups:
if len(grp) == 1:
continue
tmp_group = dist.new_group(ranks=grp)
if hrank in grp:
h_group = tmp_group
# set device
torch.cuda.set_device(device.index)
# set seed
torch.manual_seed(333)
torch.cuda.manual_seed(333)
if world_rank == 0:
print(f"Running distributed test on grid H x W = {grid_size_h} x {grid_size_w}")
# initializing sht
thd.init(h_group, w_group)
# common parameters
B, C, H, W = 1, 8, 721, 1440
Hloc = (H + grid_size_h - 1) // grid_size_h
Wloc = (W + grid_size_w - 1) // grid_size_w
Hpad = grid_size_h * Hloc - H
Wpad = grid_size_w * Wloc - W
# do serial tests first:
forward_transform_local = harmonics.RealSHT(nlat=H, nlon=W).to(device)
backward_transform_local = harmonics.InverseRealSHT(nlat=H, nlon=W).to(device)
backward_transform_dist = thd.DistributedInverseRealSHT(nlat=H, nlon=W).to(device)
Lpad = backward_transform_dist.lpad
Mpad = backward_transform_dist.mpad
Lloc = (Lpad + backward_transform_dist.lmax) // grid_size_h
Mloc = (Mpad + backward_transform_dist.mmax) // grid_size_w
# create tensors
dummy_full = torch.randn((B, C, H, W), dtype=torch.float32, device=device)
inp_full = forward_transform_local(dummy_full)
# pad
with torch.no_grad():
inp_pad = F.pad(inp_full, (0, Mpad, 0, Lpad))
# split in W
inp_local = torch.split(inp_pad, split_size_or_sections=Mloc, dim=-1)[wrank]
# split in H
inp_local = torch.split(inp_local, split_size_or_sections=Lloc, dim=-2)[hrank]
# do FWD transform
out_full = backward_transform_local(inp_full)
out_local = backward_transform_dist(inp_local)
# gather the local data
# gather in W
if grid_size_w > 1:
olist = [torch.empty_like(out_local) for _ in range(grid_size_w)]
olist[wrank] = out_local
dist.all_gather(olist, out_local, group=w_group)
out_full_gather = torch.cat(olist, dim=-1)
out_full_gather = out_full_gather[..., :W]
else:
out_full_gather = out_local
# gather in h
if grid_size_h > 1:
olist = [torch.empty_like(out_full_gather) for _ in range(grid_size_h)]
olist[hrank] = out_full_gather
dist.all_gather(olist, out_full_gather, group=h_group)
out_full_gather = torch.cat(olist, dim=-2)
out_full_gather = out_full_gather[..., :H, :]
if world_rank == 0:
print(f"Local Out: sum={out_full.abs().sum().item()}, max={out_full.abs().max().item()}, min={out_full.abs().min().item()}")
print(f"Dist Out: sum={out_full_gather.abs().sum().item()}, max={out_full_gather.abs().max().item()}, min={out_full_gather.abs().min().item()}")
diff = (out_full-out_full_gather).abs()
print(f"Out Difference: abs={diff.sum().item()}, rel={diff.sum().item() / (0.5*(out_full.abs().sum() + out_full_gather.abs().sum()))}, max={diff.abs().max().item()}")
print("")
# create split input grad
with torch.no_grad():
# create full grad
ograd_full = torch.randn_like(out_full)
# pad
ograd_pad = F.pad(ograd_full, [0, Wpad, 0, Hpad])
# split in W
ograd_local = torch.split(ograd_pad, split_size_or_sections=Wloc, dim=-1)[wrank]
# split in H
ograd_local = torch.split(ograd_local, split_size_or_sections=Hloc, dim=-2)[hrank]
# backward pass:
# local
inp_full.requires_grad = True
out_full = backward_transform_local(inp_full)
out_full.backward(ograd_full)
igrad_full = inp_full.grad.clone()
# distributed
inp_local.requires_grad = True
out_local = backward_transform_dist(inp_local)
out_local.backward(ograd_local)
igrad_local = inp_local.grad.clone()
# gather
# gather in W
if grid_size_w > 1:
olist = [torch.empty_like(igrad_local) for _ in range(grid_size_w)]
olist[wrank] = igrad_local
dist.all_gather(olist, igrad_local, group=w_group)
igrad_full_gather = torch.cat(olist, dim=-1)
igrad_full_gather = igrad_full_gather[..., :backward_transform_dist.mmax]
else:
igrad_full_gather = igrad_local
# gather in h
if grid_size_h > 1:
olist = [torch.empty_like(igrad_full_gather) for _ in range(grid_size_h)]
olist[hrank] = igrad_full_gather
dist.all_gather(olist, igrad_full_gather, group=h_group)
igrad_full_gather = torch.cat(olist, dim=-2)
igrad_full_gather = igrad_full_gather[..., :backward_transform_dist.lmax, :]
if world_rank == 0:
print(f"Local Grad: sum={igrad_full.abs().sum().item()}, max={igrad_full.abs().max().item()}, min={igrad_full.abs().min().item()}")
print(f"Dist Grad: sum={igrad_full_gather.abs().sum().item()}, max={igrad_full_gather.abs().max().item()}, min={igrad_full_gather.abs().min().item()}")
diff = (igrad_full-igrad_full_gather).abs()
print(f"Grad Difference: abs={diff.sum().item()}, rel={diff.sum().item() / (0.5*(igrad_full.abs().sum() + igrad_full_gather.abs().sum()))}, max={diff.abs().max().item()}")
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# ignore this (just for development without installation)
import sys
import os
sys.path.append("..")
sys.path.append(".")
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch_harmonics as harmonics
import torch_harmonics.distributed as thd
try:
from tqdm import tqdm
except:
tqdm = lambda x : x
# set up distributed
world_rank = int(os.getenv('WORLD_RANK', 0))
grid_size_h = int(os.getenv('GRID_H', 1))
grid_size_w = int(os.getenv('GRID_W', 1))
port = int(os.getenv('MASTER_PORT', 0))
master_address = os.getenv('MASTER_ADDR', 'localhost')
world_size = grid_size_h * grid_size_w
dist.init_process_group(backend = 'nccl',
init_method = f"tcp://{master_address}:{port}",
rank = world_rank,
world_size = world_size)
local_rank = world_rank % torch.cuda.device_count()
device = torch.device(f"cuda:{local_rank}")
# compute local ranks in h and w:
# rank = wrank + grid_size_w * hrank
wrank = world_rank % grid_size_w
hrank = world_rank // grid_size_w
w_group = None
h_group = None
# now set up the comm grid:
wgroups = []
for h in range(grid_size_h):
start = h
end = h + grid_size_w
wgroups.append(list(range(start, end)))
print(wgroups)
for grp in wgroups:
if len(grp) == 1:
continue
tmp_group = dist.new_group(ranks=grp)
if wrank in grp:
w_group = tmp_group
# transpose:
hgroups = [sorted(list(i)) for i in zip(*wgroups)]
print(hgroups)
for grp in hgroups:
if len(grp) == 1:
continue
tmp_group = dist.new_group(ranks=grp)
if hrank in grp:
h_group = tmp_group
# set device
torch.cuda.set_device(device.index)
# set seed
torch.manual_seed(333)
torch.cuda.manual_seed(333)
if world_rank == 0:
print(f"Running distributed test on grid H x W = {grid_size_h} x {grid_size_w}")
# initializing sht
thd.init(h_group, w_group)
# common parameters
B, C, H, W = 1, 8, 721, 1440
Hloc = (H + grid_size_h - 1) // grid_size_h
Wloc = (W + grid_size_w - 1) // grid_size_w
Hpad = grid_size_h * Hloc - H
Wpad = grid_size_w * Wloc - W
# do serial tests first:
forward_transform_local = harmonics.RealSHT(nlat=H, nlon=W).to(device)
forward_transform_dist = thd.DistributedRealSHT(nlat=H, nlon=W).to(device)
Lloc = (forward_transform_dist.lpad + forward_transform_dist.lmax) // grid_size_h
Mloc = (forward_transform_dist.mpad + forward_transform_dist.mmax) // grid_size_w
# create tensors
inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=device)
# pad
with torch.no_grad():
inp_pad = F.pad(inp_full, (0, Wpad, 0, Hpad))
# split in W
inp_local = torch.split(inp_pad, split_size_or_sections=Wloc, dim=-1)[wrank]
# split in H
inp_local = torch.split(inp_local, split_size_or_sections=Hloc, dim=-2)[hrank]
# do FWD transform
out_full = forward_transform_local(inp_full)
out_local = forward_transform_dist(inp_local)
# gather the local data
# gather in W
if grid_size_w > 1:
olist = [torch.empty_like(out_local) for _ in range(grid_size_w)]
olist[wrank] = out_local
dist.all_gather(olist, out_local, group=w_group)
out_full_gather = torch.cat(olist, dim=-1)
out_full_gather = out_full_gather[..., :forward_transform_dist.mmax]
else:
out_full_gather = out_local
# gather in h
if grid_size_h > 1:
olist = [torch.empty_like(out_full_gather) for _ in range(grid_size_h)]
olist[hrank] = out_full_gather
dist.all_gather(olist, out_full_gather, group=h_group)
out_full_gather = torch.cat(olist, dim=-2)
out_full_gather = out_full_gather[..., :forward_transform_dist.lmax, :]
if world_rank == 0:
print(f"Local Out: sum={out_full.abs().sum().item()}, max={out_full.abs().max().item()}, min={out_full.abs().min().item()}")
print(f"Dist Out: sum={out_full_gather.abs().sum().item()}, max={out_full_gather.abs().max().item()}, min={out_full_gather.abs().min().item()}")
diff = (out_full-out_full_gather).abs()
print(f"Out Difference: abs={diff.sum().item()}, rel={diff.sum().item() / (0.5*(out_full.abs().sum() + out_full_gather.abs().sum()))}, max={diff.abs().max().item()}")
print("")
# create split input grad
with torch.no_grad():
# create full grad
ograd_full = torch.randn_like(out_full)
# pad
ograd_pad = F.pad(ograd_full, [0, forward_transform_dist.mpad, 0, forward_transform_dist.lpad])
# split in M
ograd_local = torch.split(ograd_pad, split_size_or_sections=Mloc, dim=-1)[wrank]
# split in H
ograd_local = torch.split(ograd_local, split_size_or_sections=Lloc, dim=-2)[hrank]
# backward pass:
# local
inp_full.requires_grad = True
out_full = forward_transform_local(inp_full)
out_full.backward(ograd_full)
igrad_full = inp_full.grad.clone()
# distributed
inp_local.requires_grad = True
out_local = forward_transform_dist(inp_local)
out_local.backward(ograd_local)
igrad_local = inp_local.grad.clone()
# gather
# gather in W
if grid_size_w > 1:
olist = [torch.empty_like(igrad_local) for _ in range(grid_size_w)]
olist[wrank] = igrad_local
dist.all_gather(olist, igrad_local, group=w_group)
igrad_full_gather = torch.cat(olist, dim=-1)
igrad_full_gather = igrad_full_gather[..., :W]
else:
igrad_full_gather = igrad_local
# gather in h
if grid_size_h > 1:
olist = [torch.empty_like(igrad_full_gather) for _ in range(grid_size_h)]
olist[hrank] = igrad_full_gather
dist.all_gather(olist, igrad_full_gather, group=h_group)
igrad_full_gather = torch.cat(olist, dim=-2)
igrad_full_gather = igrad_full_gather[..., :H, :]
if world_rank == 0:
print(f"Local Grad: sum={igrad_full.abs().sum().item()}, max={igrad_full.abs().max().item()}, min={igrad_full.abs().min().item()}")
print(f"Dist Grad: sum={igrad_full_gather.abs().sum().item()}, max={igrad_full_gather.abs().max().item()}, min={igrad_full_gather.abs().min().item()}")
diff = (igrad_full-igrad_full_gather).abs()
print(f"Grad Difference: abs={diff.sum().item()}, rel={diff.sum().item() / (0.5*(igrad_full.abs().sum() + igrad_full_gather.abs().sum()))}, max={diff.abs().max().item()}")
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# #
__version__ = '0.6.3' __version__ = '0.6.4'
from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT
from . import quadrature from . import quadrature
......
...@@ -32,7 +32,7 @@ ...@@ -32,7 +32,7 @@
# we need this in order to enable distributed # we need this in order to enable distributed
from .utils import init, is_initialized, polar_group, azimuth_group from .utils import init, is_initialized, polar_group, azimuth_group
from .utils import polar_group_size, azimuth_group_size, polar_group_rank, azimuth_group_rank from .utils import polar_group_size, azimuth_group_size, polar_group_rank, azimuth_group_rank
from .primitives import distributed_transpose_azimuth, distributed_transpose_polar from .primitives import compute_split_shapes, split_tensor_along_dim, distributed_transpose_azimuth, distributed_transpose_polar
# import the sht stuff # import the sht stuff
from .distributed_sht import DistributedRealSHT, DistributedInverseRealSHT from .distributed_sht import DistributedRealSHT, DistributedInverseRealSHT
......
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import os
import unittest
from parameterized import parameterized
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch_harmonics as harmonics
import torch_harmonics.distributed as thd
class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
@classmethod
def setUpClass(cls):
# set up distributed
cls.world_rank = int(os.getenv('WORLD_RANK', 0))
cls.grid_size_h = int(os.getenv('GRID_H', 1))
cls.grid_size_w = int(os.getenv('GRID_W', 1))
port = int(os.getenv('MASTER_PORT', '29501'))
master_address = os.getenv('MASTER_ADDR', 'localhost')
cls.world_size = cls.grid_size_h * cls.grid_size_w
if torch.cuda.is_available():
if cls.world_rank == 0:
print("Running test on GPU")
local_rank = cls.world_rank % torch.cuda.device_count()
cls.device = torch.device(f"cuda:{local_rank}")
torch.cuda.manual_seed(333)
proc_backend = 'nccl'
else:
if cls.world_rank == 0:
print("Running test on CPU")
cls.device = torch.device('cpu')
proc_backend = 'gloo'
torch.manual_seed(333)
dist.init_process_group(backend = proc_backend,
init_method = f"tcp://{master_address}:{port}",
rank = cls.world_rank,
world_size = cls.world_size)
cls.wrank = cls.world_rank % cls.grid_size_w
cls.hrank = cls.world_rank // cls.grid_size_w
# now set up the comm groups:
#set default
cls.w_group = None
cls.h_group = None
# do the init
wgroups = []
for w in range(0, cls.world_size, cls.grid_size_w):
start = w
end = w + cls.grid_size_w
wgroups.append(list(range(start, end)))
if cls.world_rank == 0:
print("w-groups:", wgroups)
for grp in wgroups:
if len(grp) == 1:
continue
tmp_group = dist.new_group(ranks=grp)
if cls.world_rank in grp:
cls.w_group = tmp_group
# transpose:
hgroups = [sorted(list(i)) for i in zip(*wgroups)]
if cls.world_rank == 0:
print("h-groups:", hgroups)
for grp in hgroups:
if len(grp) == 1:
continue
tmp_group = dist.new_group(ranks=grp)
if cls.world_rank in grp:
cls.h_group = tmp_group
# set seed
torch.manual_seed(333)
if cls.world_rank == 0:
print(f"Running distributed tests on grid H x W = {cls.grid_size_h} x {cls.grid_size_w}")
# initializing sht
thd.init(cls.h_group, cls.w_group)
def _split_helper(self, tensor):
with torch.no_grad():
# split in W
tensor_list_local = thd.split_tensor_along_dim(tensor, dim=-1, num_chunks=self.grid_size_w)
tensor_local = tensor_list_local[self.wrank]
# split in H
tensor_list_local = thd.split_tensor_along_dim(tensor_local, dim=-2, num_chunks=self.grid_size_h)
tensor_local = tensor_list_local[self.hrank]
return tensor_local
def _gather_helper_fwd(self, tensor, B, C, transform_dist, vector):
# we need the shapes
l_shapes = transform_dist.l_shapes
m_shapes = transform_dist.m_shapes
# gather in W
if self.grid_size_w > 1:
if vector:
gather_shapes = [(B, C, 2, l_shapes[self.hrank], m) for m in m_shapes]
else:
gather_shapes = [(B, C, l_shapes[self.hrank], m) for m in m_shapes]
olist = [torch.empty(shape, dtype=tensor.dtype, device=tensor.device) for shape in gather_shapes]
olist[self.wrank] = tensor
dist.all_gather(olist, tensor, group=self.w_group)
tensor_gather = torch.cat(olist, dim=-1)
else:
tensor_gather = tensor
# gather in H
if self.grid_size_h > 1:
if vector:
gather_shapes = [(B, C, 2, l, transform_dist.mmax) for l in l_shapes]
else:
gather_shapes = [(B, C, l, transform_dist.mmax) for l in l_shapes]
olist = [torch.empty(shape, dtype=tensor_gather.dtype, device=tensor_gather.device) for shape in gather_shapes]
olist[self.hrank] = tensor_gather
dist.all_gather(olist, tensor_gather, group=self.h_group)
tensor_gather = torch.cat(olist, dim=-2)
return tensor_gather
def _gather_helper_bwd(self, tensor, B, C, transform_dist, vector):
# we need the shapes
lat_shapes = transform_dist.lat_shapes
lon_shapes = transform_dist.lon_shapes
# gather in W
if self.grid_size_w > 1:
if vector:
gather_shapes = [(B, C, 2, lat_shapes[self.hrank], w) for w in lon_shapes]
else:
gather_shapes = [(B, C, lat_shapes[self.hrank], w) for w in lon_shapes]
olist = [torch.empty(shape, dtype=tensor.dtype, device=tensor.device) for shape in gather_shapes]
olist[self.wrank] = tensor
dist.all_gather(olist, tensor, group=self.w_group)
tensor_gather = torch.cat(olist, dim=-1)
else:
tensor_gather = tensor
# gather in H
if self.grid_size_h > 1:
if vector:
gather_shapes = [(B, C, 2, h, transform_dist.nlon) for h in lat_shapes]
else:
gather_shapes = [(B, C, h, transform_dist.nlon) for h in lat_shapes]
olist = [torch.empty(shape, dtype=tensor_gather.dtype, device=tensor_gather.device) for shape in gather_shapes]
olist[self.hrank] = tensor_gather
dist.all_gather(olist, tensor_gather, group=self.h_group)
tensor_gather = torch.cat(olist, dim=-2)
return tensor_gather
@parameterized.expand([
[256, 512, 32, 8, "equiangular", False, 1e-9],
[256, 512, 32, 8, "legendre-gauss", False, 1e-9],
[256, 512, 32, 8, "equiangular", False, 1e-9],
[256, 512, 32, 8, "legendre-gauss", False, 1e-9],
[256, 512, 32, 8, "equiangular", False, 1e-9],
[256, 512, 32, 8, "legendre-gauss", False, 1e-9],
[361, 720, 1, 10, "equiangular", False, 1e-6],
[361, 720, 1, 10, "legendre-gauss", False, 1e-6],
[256, 512, 32, 8, "equiangular", True, 1e-9],
[256, 512, 32, 8, "legendre-gauss", True, 1e-9],
[256, 512, 32, 8, "equiangular", True, 1e-9],
[256, 512, 32, 8, "legendre-gauss", True, 1e-9],
[256, 512, 32, 8, "equiangular", True, 1e-9],
[256, 512, 32, 8, "legendre-gauss", True, 1e-9],
[361, 720, 1, 10, "equiangular", True, 1e-6],
[361, 720, 1, 10, "legendre-gauss", True, 1e-6],
])
def test_distributed_sht(self, nlat, nlon, batch_size, num_chan, grid, vector, tol):
B, C, H, W = batch_size, num_chan, nlat, nlon
# set up handles
if vector:
forward_transform_local = harmonics.RealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
forward_transform_dist = thd.DistributedRealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
else:
forward_transform_local = harmonics.RealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
forward_transform_dist = thd.DistributedRealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
# create tensors
if vector:
inp_full = torch.randn((B, C, 2, H, W), dtype=torch.float32, device=self.device)
else:
inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device)
#############################################################
# local transform
#############################################################
# FWD pass
inp_full.requires_grad = True
out_full = forward_transform_local(inp_full)
# create grad for backward
with torch.no_grad():
# create full grad
ograd_full = torch.randn_like(out_full)
# BWD pass
out_full.backward(ograd_full)
igrad_full = inp_full.grad.clone()
#############################################################
# distributed transform
#############################################################
# FWD pass
inp_local = self._split_helper(inp_full)
inp_local.requires_grad = True
out_local = forward_transform_dist(inp_local)
# BWD pass
ograd_local = self._split_helper(ograd_full)
out_local = forward_transform_dist(inp_local)
out_local.backward(ograd_local)
igrad_local = inp_local.grad.clone()
#############################################################
# evaluate FWD pass
#############################################################
with torch.no_grad():
out_gather_full = self._gather_helper_fwd(out_local, B, C, forward_transform_dist, vector)
err = torch.mean(torch.norm(out_full-out_gather_full, p='fro', dim=(-1,-2)) / torch.norm(out_full, p='fro', dim=(-1,-2)) )
if self.world_rank == 0:
print(f"final relative error of output: {err.item()}")
self.assertTrue(err.item() <= tol)
#############################################################
# evaluate BWD pass
#############################################################
with torch.no_grad():
igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, forward_transform_dist, vector)
err = torch.mean(torch.norm(igrad_full-igrad_gather_full, p='fro', dim=(-1,-2)) / torch.norm(igrad_full, p='fro', dim=(-1,-2)) )
if self.world_rank == 0:
print(f"final relative error of gradients: {err.item()}")
self.assertTrue(err.item() <= tol)
@parameterized.expand([
[256, 512, 32, 8, "equiangular", False, 1e-9],
[256, 512, 32, 8, "legendre-gauss", False, 1e-9],
[256, 512, 32, 8, "equiangular", False, 1e-9],
[256, 512, 32, 8, "legendre-gauss", False, 1e-9],
[256, 512, 32, 8, "equiangular", False, 1e-9],
[256, 512, 32, 8, "legendre-gauss", False, 1e-9],
[361, 720, 1, 10, "equiangular", False, 1e-6],
[361, 720, 1, 10, "legendre-gauss", False, 1e-6],
[256, 512, 32, 8, "equiangular", True, 1e-9],
[256, 512, 32, 8, "legendre-gauss", True, 1e-9],
[256, 512, 32, 8, "equiangular", True, 1e-9],
[256, 512, 32, 8, "legendre-gauss", True, 1e-9],
[256, 512, 32, 8, "equiangular", True, 1e-9],
[256, 512, 32, 8, "legendre-gauss", True, 1e-9],
[361, 720, 1, 10, "equiangular", True, 1e-6],
[361, 720, 1, 10, "legendre-gauss", True, 1e-6],
])
def test_distributed_isht(self, nlat, nlon, batch_size, num_chan, grid, vector, tol):
B, C, H, W = batch_size, num_chan, nlat, nlon
if vector:
forward_transform_local = harmonics.RealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
backward_transform_local = harmonics.InverseRealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
backward_transform_dist = thd.DistributedInverseRealVectorSHT(nlat=H, nlon=W, grid=grid).to(self.device)
else:
forward_transform_local = harmonics.RealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
backward_transform_local = harmonics.InverseRealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
backward_transform_dist = thd.DistributedInverseRealSHT(nlat=H, nlon=W, grid=grid).to(self.device)
# create tensors
if vector:
dummy_full = torch.randn((B, C, 2, H, W), dtype=torch.float32, device=self.device)
else:
dummy_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device)
inp_full = forward_transform_local(dummy_full)
#############################################################
# local transform
#############################################################
# FWD pass
inp_full.requires_grad = True
out_full = backward_transform_local(inp_full)
# create grad for backward
with torch.no_grad():
# create full grad
ograd_full = torch.randn_like(out_full)
# BWD pass
out_full.backward(ograd_full)
# repeat once due to known irfft bug
inp_full.grad = None
out_full = backward_transform_local(inp_full)
out_full.backward(ograd_full)
igrad_full = inp_full.grad.clone()
#############################################################
# distributed transform
#############################################################
# FWD pass
inp_local = self._split_helper(inp_full)
inp_local.requires_grad = True
out_local = backward_transform_dist(inp_local)
# BWD pass
ograd_local = self._split_helper(ograd_full)
out_local = backward_transform_dist(inp_local)
out_local.backward(ograd_local)
igrad_local = inp_local.grad.clone()
#############################################################
# evaluate FWD pass
#############################################################
with torch.no_grad():
out_gather_full = self._gather_helper_bwd(out_local, B, C, backward_transform_dist, vector)
err = torch.mean(torch.norm(out_full-out_gather_full, p='fro', dim=(-1,-2)) / torch.norm(out_full, p='fro', dim=(-1,-2)) )
if self.world_rank == 0:
print(f"final relative error of output: {err.item()}")
self.assertTrue(err.item() <= tol)
#############################################################
# evaluate BWD pass
#############################################################
with torch.no_grad():
igrad_gather_full = self._gather_helper_fwd(igrad_local, B, C, backward_transform_dist, vector)
err = torch.mean(torch.norm(igrad_full-igrad_gather_full, p='fro', dim=(-1,-2)) / torch.norm(igrad_full, p='fro', dim=(-1,-2)) )
if self.world_rank == 0:
print(f"final relative error of gradients: {err.item()}")
self.assertTrue(err.item() <= tol)
if __name__ == '__main__':
unittest.main()
...@@ -28,12 +28,34 @@ ...@@ -28,12 +28,34 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# #
from typing import List
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from .utils import polar_group, azimuth_group, is_initialized from .utils import polar_group, azimuth_group, is_initialized
# helper routine to compute uneven splitting in balanced way:
def compute_split_shapes(size: int, num_chunks: int) -> List[int]:
# treat trivial case first
if num_chunks == 1:
return [size]
# first, check if we can split using div-up to balance the load:
chunk_size = (size + num_chunks - 1) // num_chunks
last_chunk_size = max(0, size - chunk_size * (num_chunks - 1))
if last_chunk_size == 0:
# in this case, the last shard would be empty, split with floor instead:
chunk_size = size // num_chunks
last_chunk_size = size - chunk_size * (num_chunks-1)
# generate sections list
sections = [chunk_size for _ in range(num_chunks - 1)] + [last_chunk_size]
return sections
# general helpers # general helpers
def get_memory_format(tensor): def get_memory_format(tensor):
if tensor.is_contiguous(memory_format=torch.channels_last): if tensor.is_contiguous(memory_format=torch.channels_last):
...@@ -41,75 +63,92 @@ def get_memory_format(tensor): ...@@ -41,75 +63,92 @@ def get_memory_format(tensor):
else: else:
return torch.contiguous_format return torch.contiguous_format
def split_tensor_along_dim(tensor, dim, num_chunks): def split_tensor_along_dim(tensor, dim, num_chunks):
assert dim < tensor.dim(), f"Error, tensor dimension is {tensor.dim()} which cannot be split along {dim}" assert dim < tensor.dim(), f"Error, tensor dimension is {tensor.dim()} which cannot be split along {dim}"
assert (tensor.shape[dim] % num_chunks == 0), f"Error, cannot split dim {dim} evenly. Dim size is \ assert (tensor.shape[dim] >= num_chunks), f"Error, cannot split dim {dim} of size {tensor.shape[dim]} into \
{tensor.shape[dim]} and requested numnber of splits is {num_chunks}" {num_chunks} chunks. Empty slices are currently not supported."
chunk_size = tensor.shape[dim] // num_chunks
tensor_list = torch.split(tensor, chunk_size, dim=dim) # get split
sections = compute_split_shapes(tensor.shape[dim], num_chunks)
tensor_list = torch.split(tensor, sections, dim=dim)
return tensor_list return tensor_list
def _transpose(tensor, dim0, dim1, group=None, async_op=False):
def _transpose(tensor, dim0, dim1, dim1_split_sizes, group=None, async_op=False):
# get input format # get input format
input_format = get_memory_format(tensor) input_format = get_memory_format(tensor)
# get comm params # get comm params
comm_size = dist.get_world_size(group=group) comm_size = dist.get_world_size(group=group)
comm_rank = dist.get_rank(group=group)
# split and local transposition # split and local transposition
split_size = tensor.shape[dim0] // comm_size tsplit = split_tensor_along_dim(tensor, num_chunks=comm_size, dim=dim0)
x_send = [y.contiguous(memory_format=input_format) for y in torch.split(tensor, split_size, dim=dim0)] x_send = [y.contiguous(memory_format=input_format) for y in tsplit]
x_recv = [torch.empty_like(x_send[0]).contiguous(memory_format=input_format) for _ in range(comm_size)] x_send_shapes = [x.shape for x in x_send]
x_recv = []
x_shape = list(x_send_shapes[comm_rank])
for dim1_len in dim1_split_sizes:
x_shape[dim1] = dim1_len
x_recv.append(torch.empty(x_shape, dtype=tensor.dtype, device=tensor.device, memory_format=input_format))
# global transposition # global transposition
req = dist.all_to_all(x_recv, x_send, group=group, async_op=async_op) req = dist.all_to_all(x_recv, x_send, group=group, async_op=async_op)
return x_recv, req # get dim0 split sizes
dim0_split_sizes = [x[dim0] for x in x_send_shapes]
return x_recv, dim0_split_sizes, req
class distributed_transpose_azimuth(torch.autograd.Function): class distributed_transpose_azimuth(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x, dim): def forward(ctx, x, dims, dim1_split_sizes):
input_format = get_memory_format(x) input_format = get_memory_format(x)
# WAR for a potential contig check torch bug for channels last contig tensors # WAR for a potential contig check torch bug for channels last contig tensors
x = x.contiguous() x = x.contiguous()
xlist, _ = _transpose(x, dim[0], dim[1], group=azimuth_group()) xlist, dim0_split_sizes, _ = _transpose(x, dims[0], dims[1], dim1_split_sizes, group=azimuth_group())
x = torch.cat(xlist, dim=dim[1]).contiguous(memory_format=input_format) x = torch.cat(xlist, dim=dims[1]).contiguous(memory_format=input_format)
ctx.dim = dim ctx.dims = dims
ctx.dim0_split_sizes = dim0_split_sizes
return x return x
@staticmethod @staticmethod
def backward(ctx, go): def backward(ctx, go):
input_format = get_memory_format(go) input_format = get_memory_format(go)
dim = ctx.dim dims = ctx.dims
dim0_split_sizes = ctx.dim0_split_sizes
# WAR for a potential contig check torch bug for channels last contig tensors # WAR for a potential contig check torch bug for channels last contig tensors
go = go.contiguous() go = go.contiguous()
gilist, _ = _transpose(go, dim[1], dim[0], group=azimuth_group()) gilist, _, _ = _transpose(go, dims[1], dims[0], dim0_split_sizes, group=azimuth_group())
gi = torch.cat(gilist, dim=dim[0]).contiguous(memory_format=input_format) gi = torch.cat(gilist, dim=dims[0]).contiguous(memory_format=input_format)
return gi, None return gi, None, None
class distributed_transpose_polar(torch.autograd.Function): class distributed_transpose_polar(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x, dim): def forward(ctx, x, dim, dim1_split_sizes):
input_format = get_memory_format(x) input_format = get_memory_format(x)
# WAR for a potential contig check torch bug for channels last contig tensors # WAR for a potential contig check torch bug for channels last contig tensors
x = x.contiguous() x = x.contiguous()
xlist, _ = _transpose(x, dim[0], dim[1], group=polar_group()) xlist, dim0_split_sizes, _ = _transpose(x, dim[0], dim[1], dim1_split_sizes, group=polar_group())
x = torch.cat(xlist, dim=dim[1]).contiguous(memory_format=input_format) x = torch.cat(xlist, dim=dim[1]).contiguous(memory_format=input_format)
ctx.dim = dim ctx.dim = dim
ctx.dim0_split_sizes = dim0_split_sizes
return x return x
@staticmethod @staticmethod
def backward(ctx, go): def backward(ctx, go):
input_format = get_memory_format(go) input_format = get_memory_format(go)
dim = ctx.dim dim = ctx.dim
dim0_split_sizes = ctx.dim0_split_sizes
# WAR for a potential contig check torch bug for channels last contig tensors # WAR for a potential contig check torch bug for channels last contig tensors
go = go.contiguous() go = go.contiguous()
gilist, _ = _transpose(go, dim[1], dim[0], group=polar_group()) gilist, _, _ = _transpose(go, dim[1], dim[0], dim0_split_sizes, group=polar_group())
gi = torch.cat(gilist, dim=dim[0]).contiguous(memory_format=input_format) gi = torch.cat(gilist, dim=dim[0]).contiguous(memory_format=input_format)
return gi, None return gi, None, None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment