Commit 55bbcb25 authored by Thorsten Kurth's avatar Thorsten Kurth Committed by Boris Bonev
Browse files

implemented slerp

parent 87d9bfdc
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -183,12 +183,14 @@ class TestDistributedResampling(unittest.TestCase):
@parameterized.expand(
[
[64, 128, 128, 256, 32, 8, "equiangular", "equiangular", 1e-7],
[128, 256, 64, 128, 32, 8, "equiangular", "equiangular", 1e-7],
[64, 128, 128, 256, 32, 8, "equiangular", "equiangular", "bilinear", 1e-7],
[128, 256, 64, 128, 32, 8, "equiangular", "equiangular", "bilinear", 1e-7],
[64, 128, 128, 256, 32, 8, "equiangular", "equiangular", "bilinear-spherical", 1e-7],
[128, 256, 64, 128, 32, 8, "equiangular", "equiangular", "bilinear-spherical", 1e-7],
]
)
def test_distributed_resampling(
self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, grid_in, grid_out, tol
self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, grid_in, grid_out, mode, tol
):
B, C, H, W = batch_size, num_chan, nlat_in, nlon_in
......@@ -200,6 +202,7 @@ class TestDistributedResampling(unittest.TestCase):
nlon_out=nlon_out,
grid_in=grid_in,
grid_out=grid_out,
mode=mode,
)
# set up handlesD
......
......@@ -33,7 +33,7 @@ __version__ = "0.7.4a"
from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT
from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
from .resampling import ResampleS2
from .resample import ResampleS2
from . import quadrature
from . import random_fields
from . import examples
......@@ -57,7 +57,7 @@ class DistributedResampleS2(nn.Module):
super().__init__()
# currently only bilinear is supported
if mode == "bilinear":
if mode in ["bilinear", "bilinear-spherical"]:
self.mode = mode
else:
raise NotImplementedError(f"unknown interpolation mode {mode}")
......@@ -138,7 +138,15 @@ class DistributedResampleS2(nn.Module):
def _upscale_longitudes(self, x: torch.Tensor):
# do the interpolation
x = torch.lerp(x[..., self.lon_idx_left], x[..., self.lon_idx_right], self.lon_weights)
if self.mode == "bilinear":
x = torch.lerp(x[..., self.lon_idx_left], x[..., self.lon_idx_right], self.lon_weights)
else:
omega = x[..., self.lon_idx_right] - x[..., self.lon_idx_left]
somega = torch.sin(omega)
start_prefac = torch.where(somega>1.e-4, torch.sin((1.-self.lon_weights) * omega)/somega, (1.-self.lon_weights))
end_prefac = torch.where(somega>1.e-4, torch.sin(self.lon_weights * omega)/somega, self.lon_weights)
x = start_prefac * x[..., self.lon_idx_left] + end_prefac * x[..., self.lon_idx_right]
return x
# old deprecated method with repeat_interleave
......@@ -158,7 +166,15 @@ class DistributedResampleS2(nn.Module):
def _upscale_latitudes(self, x: torch.Tensor):
# do the interpolation
x = torch.lerp(x[..., self.lat_idx, :], x[..., self.lat_idx + 1, :], self.lat_weights)
if self.mode == "bilinear":
x = torch.lerp(x[..., self.lat_idx, :], x[..., self.lat_idx + 1, :], self.lat_weights)
else:
omega = x[..., self.lat_idx + 1, :] - x[..., self.lat_idx, :]
somega = torch.sin(omega)
start_prefac = torch.where(somega>1.e-4, torch.sin((1.-self.lat_weights) * omega)/somega, (1.-self.lat_weights))
end_prefac = torch.where(somega>1.e-4, torch.sin(self.lat_weights * omega)/somega, self.lat_weights)
x = start_prefac * x[..., self.lat_idx, :] + end_prefac * x[..., self.lat_idx + 1, :]
return x
def forward(self, x: torch.Tensor):
......
......@@ -54,7 +54,7 @@ class ResampleS2(nn.Module):
super().__init__()
# currently only bilinear is supported
if mode == "bilinear":
if mode in ["bilinear", "bilinear-spherical"]:
self.mode = mode
else:
raise NotImplementedError(f"unknown interpolation mode {mode}")
......@@ -123,7 +123,15 @@ class ResampleS2(nn.Module):
def _upscale_longitudes(self, x: torch.Tensor):
# do the interpolation
x = torch.lerp(x[..., self.lon_idx_left], x[..., self.lon_idx_right], self.lon_weights)
if self.mode == "bilinear":
x = torch.lerp(x[..., self.lon_idx_left], x[..., self.lon_idx_right], self.lon_weights)
else:
omega = x[..., self.lon_idx_right] - x[..., self.lon_idx_left]
somega = torch.sin(omega)
start_prefac = torch.where(somega>1.e-4, torch.sin((1.-self.lon_weights) * omega)/somega, (1.-self.lon_weights))
end_prefac = torch.where(somega>1.e-4, torch.sin(self.lon_weights * omega)/somega, self.lon_weights)
x = start_prefac * x[..., self.lon_idx_left] + end_prefac * x[..., self.lon_idx_right]
return x
# old deprecated method with repeat_interleave
......@@ -143,7 +151,15 @@ class ResampleS2(nn.Module):
def _upscale_latitudes(self, x: torch.Tensor):
# do the interpolation
x = torch.lerp(x[..., self.lat_idx, :], x[..., self.lat_idx + 1, :], self.lat_weights)
if self.mode == "bilinear":
x = torch.lerp(x[..., self.lat_idx, :], x[..., self.lat_idx + 1, :], self.lat_weights)
else:
omega = x[..., self.lat_idx + 1, :] - x[..., self.lat_idx, :]
somega = torch.sin(omega)
start_prefac = torch.where(somega>1.e-4, torch.sin((1.-self.lat_weights) * omega)/somega, (1.-self.lat_weights))
end_prefac = torch.where(somega>1.e-4, torch.sin(self.lat_weights * omega)/somega, self.lat_weights)
x = start_prefac * x[..., self.lat_idx, :] + end_prefac * x[..., self.lat_idx + 1, :]
return x
def forward(self, x: torch.Tensor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment