Commit e5a9c4af authored by Thorsten Kurth's avatar Thorsten Kurth Committed by Boris Bonev
Browse files

adding distributed resampling and test routines

parent 3350099a
...@@ -205,7 +205,7 @@ ...@@ -205,7 +205,7 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "dace",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },
...@@ -219,7 +219,7 @@ ...@@ -219,7 +219,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.12" "version": "3.8.18"
} }
}, },
"nbformat": 4, "nbformat": 4,
......
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import os
import unittest
from parameterized import parameterized
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch_harmonics as harmonics
import torch_harmonics.distributed as thd
class TestDistributedResampling(unittest.TestCase):
@classmethod
def setUpClass(cls):
# set up distributed
cls.world_rank = int(os.getenv("WORLD_RANK", 0))
cls.grid_size_h = int(os.getenv("GRID_H", 1))
cls.grid_size_w = int(os.getenv("GRID_W", 1))
port = int(os.getenv("MASTER_PORT", "29501"))
master_address = os.getenv("MASTER_ADDR", "localhost")
cls.world_size = cls.grid_size_h * cls.grid_size_w
if torch.cuda.is_available():
if cls.world_rank == 0:
print("Running test on GPU")
local_rank = cls.world_rank % torch.cuda.device_count()
cls.device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(local_rank)
torch.cuda.manual_seed(333)
proc_backend = "nccl"
else:
if cls.world_rank == 0:
print("Running test on CPU")
cls.device = torch.device("cpu")
proc_backend = "gloo"
torch.manual_seed(333)
dist.init_process_group(backend=proc_backend, init_method=f"tcp://{master_address}:{port}", rank=cls.world_rank, world_size=cls.world_size)
cls.wrank = cls.world_rank % cls.grid_size_w
cls.hrank = cls.world_rank // cls.grid_size_w
# now set up the comm groups:
# set default
cls.w_group = None
cls.h_group = None
# do the init
wgroups = []
for w in range(0, cls.world_size, cls.grid_size_w):
start = w
end = w + cls.grid_size_w
wgroups.append(list(range(start, end)))
if cls.world_rank == 0:
print("w-groups:", wgroups)
for grp in wgroups:
if len(grp) == 1:
continue
tmp_group = dist.new_group(ranks=grp)
if cls.world_rank in grp:
cls.w_group = tmp_group
# transpose:
hgroups = [sorted(list(i)) for i in zip(*wgroups)]
if cls.world_rank == 0:
print("h-groups:", hgroups)
for grp in hgroups:
if len(grp) == 1:
continue
tmp_group = dist.new_group(ranks=grp)
if cls.world_rank in grp:
cls.h_group = tmp_group
if cls.world_rank == 0:
print(f"Running distributed tests on grid H x W = {cls.grid_size_h} x {cls.grid_size_w}")
# initializing sht
thd.init(cls.h_group, cls.w_group)
@classmethod
def tearDownClass(cls):
thd.finalize()
dist.destroy_process_group(None)
def _split_helper(self, tensor):
with torch.no_grad():
# split in W
tensor_list_local = thd.split_tensor_along_dim(tensor, dim=-1, num_chunks=self.grid_size_w)
tensor_local = tensor_list_local[self.wrank]
# split in H
tensor_list_local = thd.split_tensor_along_dim(tensor_local, dim=-2, num_chunks=self.grid_size_h)
tensor_local = tensor_list_local[self.hrank]
return tensor_local
def _gather_helper_fwd(self, tensor, B, C, convolution_dist):
# we need the shapes
lat_shapes = convolution_dist.lat_out_shapes
lon_shapes = convolution_dist.lon_out_shapes
# gather in W
tensor = tensor.contiguous()
if self.grid_size_w > 1:
gather_shapes = [(B, C, lat_shapes[self.hrank], w) for w in lon_shapes]
olist = [torch.empty(shape, dtype=tensor.dtype, device=tensor.device) for shape in gather_shapes]
olist[self.wrank] = tensor
dist.all_gather(olist, tensor, group=self.w_group)
tensor_gather = torch.cat(olist, dim=-1)
else:
tensor_gather = tensor
# gather in H
tensor_gather = tensor_gather.contiguous()
if self.grid_size_h > 1:
gather_shapes = [(B, C, h, convolution_dist.nlon_out) for h in lat_shapes]
olist = [torch.empty(shape, dtype=tensor_gather.dtype, device=tensor_gather.device) for shape in gather_shapes]
olist[self.hrank] = tensor_gather
dist.all_gather(olist, tensor_gather, group=self.h_group)
tensor_gather = torch.cat(olist, dim=-2)
return tensor_gather
def _gather_helper_bwd(self, tensor, B, C, resampling_dist):
# we need the shapes
lat_shapes = resampling_dist.lat_in_shapes
lon_shapes = resampling_dist.lon_in_shapes
# gather in W
if self.grid_size_w > 1:
gather_shapes = [(B, C, lat_shapes[self.hrank], w) for w in lon_shapes]
olist = [torch.empty(shape, dtype=tensor.dtype, device=tensor.device) for shape in gather_shapes]
olist[self.wrank] = tensor
dist.all_gather(olist, tensor, group=self.w_group)
tensor_gather = torch.cat(olist, dim=-1)
else:
tensor_gather = tensor
# gather in H
if self.grid_size_h > 1:
gather_shapes = [(B, C, h, resampling_dist.nlon_in) for h in lat_shapes]
olist = [torch.empty(shape, dtype=tensor_gather.dtype, device=tensor_gather.device) for shape in gather_shapes]
olist[self.hrank] = tensor_gather
dist.all_gather(olist, tensor_gather, group=self.h_group)
tensor_gather = torch.cat(olist, dim=-2)
return tensor_gather
@parameterized.expand(
[
[64, 128, 128, 256, 32, 8, "equiangular", "equiangular", 1e-7],
[128, 256, 64, 128, 32, 8, "equiangular", "equiangular", 1e-7],
]
)
def test_distributed_resampling(
self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, grid_in, grid_out, tol
):
B, C, H, W = batch_size, num_chan, nlat_in, nlon_in
res_args = dict(
nlat_in=nlat_in,
nlon_in=nlon_in,
nlat_out=nlat_out,
nlon_out=nlon_out,
grid_in=grid_in,
grid_out=grid_out,
)
# set up handlesD
res_local = harmonics.ResampleS2(**res_args).to(self.device)
res_dist = thd.DistributedResampleS2(**res_args).to(self.device)
# create tensors
inp_full = torch.randn((B, C, H, W), dtype=torch.float32, device=self.device)
#############################################################
# local conv
#############################################################
# FWD pass
inp_full.requires_grad = True
out_full = res_local(inp_full)
# create grad for backward
with torch.no_grad():
# create full grad
ograd_full = torch.randn_like(out_full)
# BWD pass
out_full.backward(ograd_full)
igrad_full = inp_full.grad.clone()
#############################################################
# distributed conv
#############################################################
# FWD pass
inp_local = self._split_helper(inp_full)
inp_local.requires_grad = True
out_local = res_dist(inp_local)
# BWD pass
ograd_local = self._split_helper(ograd_full)
out_local = res_dist(inp_local)
out_local.backward(ograd_local)
igrad_local = inp_local.grad.clone()
#############################################################
# evaluate FWD pass
#############################################################
with torch.no_grad():
out_gather_full = self._gather_helper_fwd(out_local, B, C, res_dist)
err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2)))
if self.world_rank == 0:
print(f"final relative error of output: {err.item()}")
self.assertTrue(err.item() <= tol)
#############################################################
# evaluate BWD pass
#############################################################
with torch.no_grad():
igrad_gather_full = self._gather_helper_bwd(igrad_local, B, C, res_dist)
err = torch.mean(torch.norm(igrad_full - igrad_gather_full, p="fro", dim=(-1, -2)) / torch.norm(igrad_full, p="fro", dim=(-1, -2)))
if self.world_rank == 0:
print(f"final relative error of gradients: {err.item()}")
self.assertTrue(err.item() <= tol)
if __name__ == "__main__":
unittest.main()
...@@ -51,3 +51,6 @@ from .distributed_sht import DistributedRealVectorSHT, DistributedInverseRealVec ...@@ -51,3 +51,6 @@ from .distributed_sht import DistributedRealVectorSHT, DistributedInverseRealVec
# import DISCO # import DISCO
from .distributed_convolution import DistributedDiscreteContinuousConvS2 from .distributed_convolution import DistributedDiscreteContinuousConvS2
from .distributed_convolution import DistributedDiscreteContinuousConvTransposeS2 from .distributed_convolution import DistributedDiscreteContinuousConvTransposeS2
# import resampling
from .distributed_resample import DistributedResampleS2
# coding=utf-8
# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
from typing import List, Tuple, Union, Optional
import math
import numpy as np
import torch
import torch.nn as nn
from torch_harmonics.quadrature import _precompute_latitudes
from torch_harmonics.distributed import polar_group_size, azimuth_group_size, distributed_transpose_azimuth, distributed_transpose_polar
from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank
from torch_harmonics.distributed import compute_split_shapes
class DistributedResampleS2(nn.Module):
def __init__(
self,
nlat_in: int,
nlon_in: int,
nlat_out: int,
nlon_out: int,
grid_in: Optional[str] = "equiangular",
grid_out: Optional[str] = "equiangular",
mode: Optional[str] = "bilinear",
):
super().__init__()
# currently only bilinear is supported
if mode == "bilinear":
self.mode = mode
else:
raise NotImplementedError(f"unknown interpolation mode {mode}")
self.nlat_in, self.nlon_in = nlat_in, nlon_in
self.nlat_out, self.nlon_out = nlat_out, nlon_out
self.grid_in = grid_in
self.grid_out = grid_out
# get the comms grid:
self.comm_size_polar = polar_group_size()
self.comm_rank_polar = polar_group_rank()
self.comm_size_azimuth = azimuth_group_size()
self.comm_rank_azimuth = azimuth_group_rank()
# compute splits: is this correct even when expanding the poles?
self.lat_in_shapes = compute_split_shapes(self.nlat_in, self.comm_size_polar)
self.lon_in_shapes = compute_split_shapes(self.nlon_in, self.comm_size_azimuth)
self.lat_out_shapes = compute_split_shapes(self.nlat_out, self.comm_size_polar)
self.lon_out_shapes = compute_split_shapes(self.nlon_out, self.comm_size_azimuth)
# for upscaling the latitudes we will use interpolation
self.lats_in, _ = _precompute_latitudes(nlat_in, grid=grid_in)
self.lons_in = np.linspace(0, 2 * math.pi, nlon_in, endpoint=False)
self.lats_out, _ = _precompute_latitudes(nlat_out, grid=grid_out)
self.lons_out = np.linspace(0, 2 * math.pi, nlon_out, endpoint=False)
# in the case where some points lie outside of the range spanned by lats_in,
# we need to expand the solution to the poles before interpolating
self.expand_poles = (self.lats_out > self.lats_in[-1]).any() or (self.lats_out < self.lats_in[0]).any()
if self.expand_poles:
self.lats_in = np.insert(self.lats_in, 0, 0.0)
self.lats_in = np.append(self.lats_in, np.pi)
# prepare the interpolation by computing indices to the left and right of each output latitude
lat_idx = np.searchsorted(self.lats_in, self.lats_out, side="right") - 1
# make sure that we properly treat the last point if they coincide with the pole
lat_idx = np.where(self.lats_out == self.lats_in[-1], lat_idx - 1, lat_idx)
# lat_idx = np.where(self.lats_out > self.lats_in[-1], lat_idx - 1, lat_idx)
# lat_idx = np.where(self.lats_out < self.lats_in[0], 0, lat_idx)
# compute the interpolation weights along the latitude
lat_weights = torch.from_numpy((self.lats_out - self.lats_in[lat_idx]) / np.diff(self.lats_in)[lat_idx]).float()
lat_weights = lat_weights.unsqueeze(-1)
# convert to tensor
lat_idx = torch.LongTensor(lat_idx)
# register buffers
self.register_buffer("lat_idx", lat_idx, persistent=False)
self.register_buffer("lat_weights", lat_weights, persistent=False)
# get left and right indices but this time make sure periodicity in the longitude is handled
lon_idx_left = np.searchsorted(self.lons_in, self.lons_out, side="right") - 1
lon_idx_right = np.where(self.lons_out >= self.lons_in[-1], np.zeros_like(lon_idx_left), lon_idx_left + 1)
# get the difference
diff = self.lons_in[lon_idx_right] - self.lons_in[lon_idx_left]
diff = np.where(diff < 0.0, diff + 2 * math.pi, diff)
lon_weights = torch.from_numpy((self.lons_out - self.lons_in[lon_idx_left]) / diff).float()
# convert to tensor
lon_idx_left = torch.LongTensor(lon_idx_left)
lon_idx_right = torch.LongTensor(lon_idx_right)
# register buffers
self.register_buffer("lon_idx_left", lon_idx_left, persistent=False)
self.register_buffer("lon_idx_right", lon_idx_right, persistent=False)
self.register_buffer("lon_weights", lon_weights, persistent=False)
def extra_repr(self):
r"""
Pretty print module
"""
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}"
def _upscale_longitudes(self, x: torch.Tensor):
# do the interpolation
x = torch.lerp(x[..., self.lon_idx_left], x[..., self.lon_idx_right], self.lon_weights)
return x
# old deprecated method with repeat_interleave
# def _upscale_longitudes(self, x: torch.Tensor):
# # for artifact-free upsampling in the longitudinal direction
# x = torch.repeat_interleave(x, self.lon_scale_factor, dim=-1)
# x = torch.roll(x, - self.lon_shift, dims=-1)
# return x
def _expand_poles(self, x: torch.Tensor):
repeats = [1 for _ in x.shape]
repeats[-1] = x.shape[-1]
x_north = x[..., 0:1, :].mean(dim=-1, keepdim=True).repeat(*repeats)
x_south = x[..., -1:, :].mean(dim=-1, keepdim=True).repeat(*repeats)
x = torch.concatenate((x_north, x, x_south), dim=-2)
return x
def _upscale_latitudes(self, x: torch.Tensor):
# do the interpolation
x = torch.lerp(x[..., self.lat_idx, :], x[..., self.lat_idx + 1, :], self.lat_weights)
return x
def forward(self, x: torch.Tensor):
# transpose data so that h is local, and channels are split
num_chans = x.shape[-3]
# h and w is split. First we make w local by transposing into channel dim
if self.comm_size_polar > 1:
channels_shapes = compute_split_shapes(num_chans, self.comm_size_polar)
x = distributed_transpose_polar.apply(x, (-3, -2), self.lat_in_shapes)
# expand poles if requested
if self.expand_poles:
x = self._expand_poles(x)
# upscaling
x = self._upscale_latitudes(x)
# now, transpose back
if self.comm_size_polar > 1:
x = distributed_transpose_polar.apply(x, (-2, -3), channels_shapes)
# now, transpose in w:
if self.comm_size_azimuth > 1:
channels_shapes = compute_split_shapes(num_chans, self.comm_size_azimuth)
x = distributed_transpose_azimuth.apply(x, (-3, -1), self.lon_in_shapes)
# upscale
x = self._upscale_longitudes(x)
# transpose back
if self.comm_size_azimuth > 1:
x = distributed_transpose_azimuth.apply(x, (-1, -3), channels_shapes)
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment