Commit 55bbcb25 authored by Thorsten Kurth's avatar Thorsten Kurth Committed by Boris Bonev
Browse files

implemented slerp

parent 87d9bfdc
This diff is collapsed.
......@@ -183,12 +183,14 @@ class TestDistributedResampling(unittest.TestCase):
@parameterized.expand(
[
[64, 128, 128, 256, 32, 8, "equiangular", "equiangular", 1e-7],
[128, 256, 64, 128, 32, 8, "equiangular", "equiangular", 1e-7],
[64, 128, 128, 256, 32, 8, "equiangular", "equiangular", "bilinear", 1e-7],
[128, 256, 64, 128, 32, 8, "equiangular", "equiangular", "bilinear", 1e-7],
[64, 128, 128, 256, 32, 8, "equiangular", "equiangular", "bilinear-spherical", 1e-7],
[128, 256, 64, 128, 32, 8, "equiangular", "equiangular", "bilinear-spherical", 1e-7],
]
)
def test_distributed_resampling(
self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, grid_in, grid_out, tol
self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, grid_in, grid_out, mode, tol
):
B, C, H, W = batch_size, num_chan, nlat_in, nlon_in
......@@ -200,6 +202,7 @@ class TestDistributedResampling(unittest.TestCase):
nlon_out=nlon_out,
grid_in=grid_in,
grid_out=grid_out,
mode=mode,
)
# set up handlesD
......
......@@ -33,7 +33,7 @@ __version__ = "0.7.4a"
from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT
from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
from .resampling import ResampleS2
from .resample import ResampleS2
from . import quadrature
from . import random_fields
from . import examples
......@@ -57,7 +57,7 @@ class DistributedResampleS2(nn.Module):
super().__init__()
# currently only bilinear is supported
if mode == "bilinear":
if mode in ["bilinear", "bilinear-spherical"]:
self.mode = mode
else:
raise NotImplementedError(f"unknown interpolation mode {mode}")
......@@ -138,7 +138,15 @@ class DistributedResampleS2(nn.Module):
def _upscale_longitudes(self, x: torch.Tensor):
# do the interpolation
x = torch.lerp(x[..., self.lon_idx_left], x[..., self.lon_idx_right], self.lon_weights)
if self.mode == "bilinear":
x = torch.lerp(x[..., self.lon_idx_left], x[..., self.lon_idx_right], self.lon_weights)
else:
omega = x[..., self.lon_idx_right] - x[..., self.lon_idx_left]
somega = torch.sin(omega)
start_prefac = torch.where(somega>1.e-4, torch.sin((1.-self.lon_weights) * omega)/somega, (1.-self.lon_weights))
end_prefac = torch.where(somega>1.e-4, torch.sin(self.lon_weights * omega)/somega, self.lon_weights)
x = start_prefac * x[..., self.lon_idx_left] + end_prefac * x[..., self.lon_idx_right]
return x
# old deprecated method with repeat_interleave
......@@ -158,7 +166,15 @@ class DistributedResampleS2(nn.Module):
def _upscale_latitudes(self, x: torch.Tensor):
# do the interpolation
x = torch.lerp(x[..., self.lat_idx, :], x[..., self.lat_idx + 1, :], self.lat_weights)
if self.mode == "bilinear":
x = torch.lerp(x[..., self.lat_idx, :], x[..., self.lat_idx + 1, :], self.lat_weights)
else:
omega = x[..., self.lat_idx + 1, :] - x[..., self.lat_idx, :]
somega = torch.sin(omega)
start_prefac = torch.where(somega>1.e-4, torch.sin((1.-self.lat_weights) * omega)/somega, (1.-self.lat_weights))
end_prefac = torch.where(somega>1.e-4, torch.sin(self.lat_weights * omega)/somega, self.lat_weights)
x = start_prefac * x[..., self.lat_idx, :] + end_prefac * x[..., self.lat_idx + 1, :]
return x
def forward(self, x: torch.Tensor):
......
......@@ -54,7 +54,7 @@ class ResampleS2(nn.Module):
super().__init__()
# currently only bilinear is supported
if mode == "bilinear":
if mode in ["bilinear", "bilinear-spherical"]:
self.mode = mode
else:
raise NotImplementedError(f"unknown interpolation mode {mode}")
......@@ -123,7 +123,15 @@ class ResampleS2(nn.Module):
def _upscale_longitudes(self, x: torch.Tensor):
# do the interpolation
x = torch.lerp(x[..., self.lon_idx_left], x[..., self.lon_idx_right], self.lon_weights)
if self.mode == "bilinear":
x = torch.lerp(x[..., self.lon_idx_left], x[..., self.lon_idx_right], self.lon_weights)
else:
omega = x[..., self.lon_idx_right] - x[..., self.lon_idx_left]
somega = torch.sin(omega)
start_prefac = torch.where(somega>1.e-4, torch.sin((1.-self.lon_weights) * omega)/somega, (1.-self.lon_weights))
end_prefac = torch.where(somega>1.e-4, torch.sin(self.lon_weights * omega)/somega, self.lon_weights)
x = start_prefac * x[..., self.lon_idx_left] + end_prefac * x[..., self.lon_idx_right]
return x
# old deprecated method with repeat_interleave
......@@ -143,7 +151,15 @@ class ResampleS2(nn.Module):
def _upscale_latitudes(self, x: torch.Tensor):
# do the interpolation
x = torch.lerp(x[..., self.lat_idx, :], x[..., self.lat_idx + 1, :], self.lat_weights)
if self.mode == "bilinear":
x = torch.lerp(x[..., self.lat_idx, :], x[..., self.lat_idx + 1, :], self.lat_weights)
else:
omega = x[..., self.lat_idx + 1, :] - x[..., self.lat_idx, :]
somega = torch.sin(omega)
start_prefac = torch.where(somega>1.e-4, torch.sin((1.-self.lat_weights) * omega)/somega, (1.-self.lat_weights))
end_prefac = torch.where(somega>1.e-4, torch.sin(self.lat_weights * omega)/somega, self.lat_weights)
x = start_prefac * x[..., self.lat_idx, :] + end_prefac * x[..., self.lat_idx + 1, :]
return x
def forward(self, x: torch.Tensor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment