Unverified Commit 318fc76e authored by Thorsten Kurth's avatar Thorsten Kurth Committed by GitHub
Browse files

fixing distributed resampling routine (#74)

parent 18f2c1cc
......@@ -195,6 +195,10 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
[64, 128, 128, 256, 32, 8, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 8, (3), "piecewise linear", "mean", 2, "equiangular", "equiangular", True, 1e-5],
[128, 256, 128, 256, 32, 6, (3), "piecewise linear", "mean", 1, "equiangular", "equiangular", True, 1e-5],
[129, 256, 129, 256, 32, 8, (3, 4), "morlet", "mean", 1, "equiangular", "equiangular", False, 1e-5],
[129, 256, 129, 256, 32, 8, (3, 4), "morlet", "mean", 1, "equiangular", "equiangular", True, 1e-5],
[65, 128, 129, 256, 32, 8, (3, 4), "morlet", "mean", 1, "equiangular", "equiangular", True, 1e-5],
[129, 256, 65, 128, 32, 8, (3, 4), "morlet", "mean", 1, "equiangular", "equiangular", False, 1e-5],
]
)
def test_distributed_disco_conv(
......
......@@ -187,6 +187,10 @@ class TestDistributedResampling(unittest.TestCase):
[128, 256, 64, 128, 32, 8, "equiangular", "equiangular", "bilinear", 1e-7, False],
[64, 128, 128, 256, 32, 8, "equiangular", "equiangular", "bilinear-spherical", 1e-7, False],
[128, 256, 64, 128, 32, 8, "equiangular", "equiangular", "bilinear-spherical", 1e-7, False],
[129, 256, 65, 128, 32, 8, "equiangular", "equiangular", "bilinear", 1e-7, False],
[65, 128, 129, 256, 32, 8, "equiangular", "equiangular", "bilinear", 1e-7, False],
[129, 256, 65, 128, 32, 8, "equiangular", "legendre-gauss", "bilinear", 1e-7, False],
[65, 128, 129, 256, 32, 8, "legendre-gauss", "equiangular", "bilinear", 1e-7, False],
]
)
def test_distributed_resampling(
......@@ -248,7 +252,7 @@ class TestDistributedResampling(unittest.TestCase):
with torch.no_grad():
out_gather_full = self._gather_helper_fwd(out_local, B, C, res_dist)
err = torch.mean(torch.norm(out_full - out_gather_full, p="fro", dim=(-1, -2)) / torch.norm(out_full, p="fro", dim=(-1, -2)))
if verbose and (self.world_rank == )0:
if verbose and (self.world_rank == 0):
print(f"final relative error of output: {err.item()}")
self.assertTrue(err.item() <= tol)
......
......@@ -37,9 +37,11 @@ from .primitives import (
distributed_transpose_azimuth,
distributed_transpose_polar,
reduce_from_polar_region,
reduce_from_azimuth_region,
scatter_to_polar_region,
gather_from_polar_region,
copy_to_polar_region,
copy_to_azimuth_region,
reduce_from_scatter_to_polar_region,
gather_from_copy_to_polar_region
)
......
......@@ -37,6 +37,7 @@ import torch.nn as nn
from torch_harmonics.quadrature import _precompute_latitudes, _precompute_longitudes
from torch_harmonics.distributed import polar_group_size, azimuth_group_size, distributed_transpose_azimuth, distributed_transpose_polar
from torch_harmonics.distributed import reduce_from_azimuth_region, copy_to_azimuth_region
from torch_harmonics.distributed import polar_group_rank, azimuth_group_rank
from torch_harmonics.distributed import compute_split_shapes
......@@ -92,8 +93,6 @@ class DistributedResampleS2(nn.Module):
self.lats_in = torch.cat([torch.tensor([0.], dtype=torch.float64),
self.lats_in,
torch.tensor([math.pi], dtype=torch.float64)]).contiguous()
#self.lats_in = np.insert(self.lats_in, 0, 0.0)
#self.lats_in = np.append(self.lats_in, np.pi)
# prepare the interpolation by computing indices to the left and right of each output latitude
lat_idx = torch.searchsorted(self.lats_in, self.lats_out, side="right") - 1
......@@ -135,34 +134,50 @@ class DistributedResampleS2(nn.Module):
def _upscale_longitudes(self, x: torch.Tensor):
# do the interpolation
lwgt = self.lon_weights.to(x.dtype)
if self.mode == "bilinear":
x = torch.lerp(x[..., self.lon_idx_left], x[..., self.lon_idx_right], self.lon_weights)
x = torch.lerp(x[..., self.lon_idx_left], x[..., self.lon_idx_right], lwgt)
else:
omega = x[..., self.lon_idx_right] - x[..., self.lon_idx_left]
somega = torch.sin(omega)
start_prefac = torch.where(somega > 1e-4, torch.sin((1.0 - self.lon_weights) * omega) / somega, (1.0 - self.lon_weights))
end_prefac = torch.where(somega > 1e-4, torch.sin(self.lon_weights * omega) / somega, self.lon_weights)
start_prefac = torch.where(somega > 1e-4, torch.sin((1.0 - lwgt) * omega) / somega, (1.0 - lwgt))
end_prefac = torch.where(somega > 1e-4, torch.sin(lwgt * omega) / somega, lwgt)
x = start_prefac * x[..., self.lon_idx_left] + end_prefac * x[..., self.lon_idx_right]
return x
def _expand_poles(self, x: torch.Tensor):
repeats = [1 for _ in x.shape]
repeats[-1] = x.shape[-1]
x_north = x[..., 0:1, :].mean(dim=-1, keepdim=True).repeat(*repeats)
x_south = x[..., -1:, :].mean(dim=-1, keepdim=True).repeat(*repeats)
x = torch.concatenate((x_north, x, x_south), dim=-2)
x_north = x[..., 0, :].sum(dim=-1, keepdims=True)
x_south = x[..., -1, :].sum(dim=-1, keepdims=True)
x_count = torch.tensor([x.shape[-1]], dtype=torch.long, device=x.device, requires_grad=False)
if self.comm_size_azimuth > 1:
x_north = reduce_from_azimuth_region(x_north.contiguous())
x_south = reduce_from_azimuth_region(x_south.contiguous())
x_count = reduce_from_azimuth_region(x_count)
x_north = x_north / x_count
x_south = x_south / x_count
if self.comm_size_azimuth > 1:
x_north = copy_to_azimuth_region(x_north)
x_south = copy_to_azimuth_region(x_south)
x = nn.functional.pad(x, pad=[0, 0, 1, 1], mode='constant')
x[..., 0, :] = x_north[...]
x[..., -1, :] = x_south[...]
return x
def _upscale_latitudes(self, x: torch.Tensor):
# do the interpolation
lwgt = self.lat_weights.to(x.dtype)
if self.mode == "bilinear":
x = torch.lerp(x[..., self.lat_idx, :], x[..., self.lat_idx + 1, :], self.lat_weights)
x = torch.lerp(x[..., self.lat_idx, :], x[..., self.lat_idx + 1, :], lwgt)
else:
omega = x[..., self.lat_idx + 1, :] - x[..., self.lat_idx, :]
somega = torch.sin(omega)
start_prefac = torch.where(somega > 1e-4, torch.sin((1.0 - self.lat_weights) * omega) / somega, (1.0 - self.lat_weights))
end_prefac = torch.where(somega > 1e-4, torch.sin(self.lat_weights * omega) / somega, self.lat_weights)
start_prefac = torch.where(somega > 1e-4, torch.sin((1.0 - lwgt) * omega) / somega, (1.0 - lwgt))
end_prefac = torch.where(somega > 1e-4, torch.sin(lwgt * omega) / somega, lwgt)
x = start_prefac * x[..., self.lat_idx, :] + end_prefac * x[..., self.lat_idx + 1, :]
return x
......@@ -174,7 +189,7 @@ class DistributedResampleS2(nn.Module):
# transpose data so that h is local, and channels are split
num_chans = x.shape[-3]
# h and w is split. First we make w local by transposing into channel dim
if self.comm_size_polar > 1:
channels_shapes = compute_split_shapes(num_chans, self.comm_size_polar)
......
......@@ -35,7 +35,7 @@ import torch.distributed as dist
from torch.amp import custom_fwd, custom_bwd
from .utils import polar_group, azimuth_group, polar_group_size
from .utils import is_initialized, is_distributed_polar
from .utils import is_initialized, is_distributed_polar, is_distributed_azimuth
# helper routine to compute uneven splitting in balanced way:
def compute_split_shapes(size: int, num_chunks: int) -> List[int]:
......@@ -262,8 +262,29 @@ class _CopyToPolarRegion(torch.autograd.Function):
return _reduce(grad_output, group=polar_group())
else:
return grad_output, None
class _CopyToAzimuthRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chunk to the rank."""
@staticmethod
def symbolic(graph, input_):
return input_
@staticmethod
@custom_fwd(device_type="cuda")
def forward(ctx, input_):
return input_
@staticmethod
@custom_bwd(device_type="cuda")
def backward(ctx, grad_output):
if is_distributed_azimuth():
return _reduce(grad_output, group=azimuth_group())
else:
return grad_output, None
class _ScatterToPolarRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chunk to the rank."""
......@@ -340,6 +361,30 @@ class _ReduceFromPolarRegion(torch.autograd.Function):
def backward(ctx, grad_output):
return grad_output
class _ReduceFromAzimuthRegion(torch.autograd.Function):
"""All-reduce the input from the azimuth region."""
@staticmethod
def symbolic(graph, input_):
if is_distributed_azimuth():
return _reduce(input_, group=azimuth_group())
else:
return input_
@staticmethod
@custom_fwd(device_type="cuda")
def forward(ctx, input_):
if is_distributed_azimuth():
return _reduce(input_, group=azimuth_group())
else:
return input_
@staticmethod
@custom_bwd(device_type="cuda")
def backward(ctx, grad_output):
return grad_output
class _ReduceFromScatterToPolarRegion(torch.autograd.Function):
"""All-reduce the input from the polar region and scatter back to polar region."""
......@@ -403,23 +448,24 @@ class _GatherFromCopyToPolarRegion(torch.autograd.Function):
def copy_to_polar_region(input_):
return _CopyToPolarRegion.apply(input_)
def copy_to_azimuth_region(input_):
return _CopyToAzimuthRegion.apply(input_)
def reduce_from_polar_region(input_):
return _ReduceFromPolarRegion.apply(input_)
def reduce_from_azimuth_region(input_):
return _ReduceFromAzimuthRegion.apply(input_)
def scatter_to_polar_region(input_, dim_):
return _ScatterToPolarRegion.apply(input_, dim_)
def gather_from_polar_region(input_, dim_, shapes_):
return _GatherFromPolarRegion.apply(input_, dim_, shapes_)
def reduce_from_scatter_to_polar_region(input_, dim_):
return _ReduceFromScatterToPolarRegion.apply(input_, dim_)
def gather_from_copy_to_polar_region(input_, dim_, shapes_):
return _GatherFromCopyToPolarRegion.apply(input_, dim_, shapes_)
......@@ -78,8 +78,6 @@ class ResampleS2(nn.Module):
self.lats_in = torch.cat([torch.tensor([0.], dtype=torch.float64),
self.lats_in,
torch.tensor([math.pi], dtype=torch.float64)]).contiguous()
#self.lats_in = np.insert(self.lats_in, 0, 0.0)
#self.lats_in = np.append(self.lats_in, np.pi)
# prepare the interpolation by computing indices to the left and right of each output latitude
lat_idx = torch.searchsorted(self.lats_in, self.lats_out, side="right") - 1
......@@ -135,11 +133,12 @@ class ResampleS2(nn.Module):
return x
def _expand_poles(self, x: torch.Tensor):
repeats = [1 for _ in x.shape]
repeats[-1] = x.shape[-1]
x_north = x[..., 0:1, :].mean(dim=-1, keepdim=True).repeat(*repeats)
x_south = x[..., -1:, :].mean(dim=-1, keepdim=True).repeat(*repeats)
x = torch.concatenate((x_north, x, x_south), dim=-2).contiguous()
x_north = x[..., 0, :].mean(dim=-1, keepdims=True)
x_south = x[..., -1, :].mean(dim=-1, keepdims=True)
x = nn.functional.pad(x, pad=[0, 0, 1, 1], mode='constant')
x[..., 0, :] = x_north[...]
x[..., -1, :] = x_south[...]
return x
def _upscale_latitudes(self, x: torch.Tensor):
......@@ -162,6 +161,9 @@ class ResampleS2(nn.Module):
if self.expand_poles:
x = self._expand_poles(x)
x = self._upscale_latitudes(x)
x = self._upscale_longitudes(x)
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment