Commit 55bbcb25 authored by Thorsten Kurth's avatar Thorsten Kurth Committed by Boris Bonev
Browse files

implemented slerp

parent 87d9bfdc
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -183,12 +183,14 @@ class TestDistributedResampling(unittest.TestCase): ...@@ -183,12 +183,14 @@ class TestDistributedResampling(unittest.TestCase):
@parameterized.expand( @parameterized.expand(
[ [
[64, 128, 128, 256, 32, 8, "equiangular", "equiangular", 1e-7], [64, 128, 128, 256, 32, 8, "equiangular", "equiangular", "bilinear", 1e-7],
[128, 256, 64, 128, 32, 8, "equiangular", "equiangular", 1e-7], [128, 256, 64, 128, 32, 8, "equiangular", "equiangular", "bilinear", 1e-7],
[64, 128, 128, 256, 32, 8, "equiangular", "equiangular", "bilinear-spherical", 1e-7],
[128, 256, 64, 128, 32, 8, "equiangular", "equiangular", "bilinear-spherical", 1e-7],
] ]
) )
def test_distributed_resampling( def test_distributed_resampling(
self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, grid_in, grid_out, tol self, nlat_in, nlon_in, nlat_out, nlon_out, batch_size, num_chan, grid_in, grid_out, mode, tol
): ):
B, C, H, W = batch_size, num_chan, nlat_in, nlon_in B, C, H, W = batch_size, num_chan, nlat_in, nlon_in
...@@ -200,6 +202,7 @@ class TestDistributedResampling(unittest.TestCase): ...@@ -200,6 +202,7 @@ class TestDistributedResampling(unittest.TestCase):
nlon_out=nlon_out, nlon_out=nlon_out,
grid_in=grid_in, grid_in=grid_in,
grid_out=grid_out, grid_out=grid_out,
mode=mode,
) )
# set up handlesD # set up handlesD
......
...@@ -33,7 +33,7 @@ __version__ = "0.7.4a" ...@@ -33,7 +33,7 @@ __version__ = "0.7.4a"
from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT
from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2 from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2
from .resampling import ResampleS2 from .resample import ResampleS2
from . import quadrature from . import quadrature
from . import random_fields from . import random_fields
from . import examples from . import examples
...@@ -57,7 +57,7 @@ class DistributedResampleS2(nn.Module): ...@@ -57,7 +57,7 @@ class DistributedResampleS2(nn.Module):
super().__init__() super().__init__()
# currently only bilinear is supported # currently only bilinear is supported
if mode == "bilinear": if mode in ["bilinear", "bilinear-spherical"]:
self.mode = mode self.mode = mode
else: else:
raise NotImplementedError(f"unknown interpolation mode {mode}") raise NotImplementedError(f"unknown interpolation mode {mode}")
...@@ -138,7 +138,15 @@ class DistributedResampleS2(nn.Module): ...@@ -138,7 +138,15 @@ class DistributedResampleS2(nn.Module):
def _upscale_longitudes(self, x: torch.Tensor): def _upscale_longitudes(self, x: torch.Tensor):
# do the interpolation # do the interpolation
x = torch.lerp(x[..., self.lon_idx_left], x[..., self.lon_idx_right], self.lon_weights) if self.mode == "bilinear":
x = torch.lerp(x[..., self.lon_idx_left], x[..., self.lon_idx_right], self.lon_weights)
else:
omega = x[..., self.lon_idx_right] - x[..., self.lon_idx_left]
somega = torch.sin(omega)
start_prefac = torch.where(somega>1.e-4, torch.sin((1.-self.lon_weights) * omega)/somega, (1.-self.lon_weights))
end_prefac = torch.where(somega>1.e-4, torch.sin(self.lon_weights * omega)/somega, self.lon_weights)
x = start_prefac * x[..., self.lon_idx_left] + end_prefac * x[..., self.lon_idx_right]
return x return x
# old deprecated method with repeat_interleave # old deprecated method with repeat_interleave
...@@ -158,7 +166,15 @@ class DistributedResampleS2(nn.Module): ...@@ -158,7 +166,15 @@ class DistributedResampleS2(nn.Module):
def _upscale_latitudes(self, x: torch.Tensor): def _upscale_latitudes(self, x: torch.Tensor):
# do the interpolation # do the interpolation
x = torch.lerp(x[..., self.lat_idx, :], x[..., self.lat_idx + 1, :], self.lat_weights) if self.mode == "bilinear":
x = torch.lerp(x[..., self.lat_idx, :], x[..., self.lat_idx + 1, :], self.lat_weights)
else:
omega = x[..., self.lat_idx + 1, :] - x[..., self.lat_idx, :]
somega = torch.sin(omega)
start_prefac = torch.where(somega>1.e-4, torch.sin((1.-self.lat_weights) * omega)/somega, (1.-self.lat_weights))
end_prefac = torch.where(somega>1.e-4, torch.sin(self.lat_weights * omega)/somega, self.lat_weights)
x = start_prefac * x[..., self.lat_idx, :] + end_prefac * x[..., self.lat_idx + 1, :]
return x return x
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
......
...@@ -54,7 +54,7 @@ class ResampleS2(nn.Module): ...@@ -54,7 +54,7 @@ class ResampleS2(nn.Module):
super().__init__() super().__init__()
# currently only bilinear is supported # currently only bilinear is supported
if mode == "bilinear": if mode in ["bilinear", "bilinear-spherical"]:
self.mode = mode self.mode = mode
else: else:
raise NotImplementedError(f"unknown interpolation mode {mode}") raise NotImplementedError(f"unknown interpolation mode {mode}")
...@@ -123,7 +123,15 @@ class ResampleS2(nn.Module): ...@@ -123,7 +123,15 @@ class ResampleS2(nn.Module):
def _upscale_longitudes(self, x: torch.Tensor): def _upscale_longitudes(self, x: torch.Tensor):
# do the interpolation # do the interpolation
x = torch.lerp(x[..., self.lon_idx_left], x[..., self.lon_idx_right], self.lon_weights) if self.mode == "bilinear":
x = torch.lerp(x[..., self.lon_idx_left], x[..., self.lon_idx_right], self.lon_weights)
else:
omega = x[..., self.lon_idx_right] - x[..., self.lon_idx_left]
somega = torch.sin(omega)
start_prefac = torch.where(somega>1.e-4, torch.sin((1.-self.lon_weights) * omega)/somega, (1.-self.lon_weights))
end_prefac = torch.where(somega>1.e-4, torch.sin(self.lon_weights * omega)/somega, self.lon_weights)
x = start_prefac * x[..., self.lon_idx_left] + end_prefac * x[..., self.lon_idx_right]
return x return x
# old deprecated method with repeat_interleave # old deprecated method with repeat_interleave
...@@ -143,7 +151,15 @@ class ResampleS2(nn.Module): ...@@ -143,7 +151,15 @@ class ResampleS2(nn.Module):
def _upscale_latitudes(self, x: torch.Tensor): def _upscale_latitudes(self, x: torch.Tensor):
# do the interpolation # do the interpolation
x = torch.lerp(x[..., self.lat_idx, :], x[..., self.lat_idx + 1, :], self.lat_weights) if self.mode == "bilinear":
x = torch.lerp(x[..., self.lat_idx, :], x[..., self.lat_idx + 1, :], self.lat_weights)
else:
omega = x[..., self.lat_idx + 1, :] - x[..., self.lat_idx, :]
somega = torch.sin(omega)
start_prefac = torch.where(somega>1.e-4, torch.sin((1.-self.lat_weights) * omega)/somega, (1.-self.lat_weights))
end_prefac = torch.where(somega>1.e-4, torch.sin(self.lat_weights * omega)/somega, self.lat_weights)
x = start_prefac * x[..., self.lat_idx, :] + end_prefac * x[..., self.lat_idx + 1, :]
return x return x
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment