Commit 96a2b546 authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

updated resampling module to handle corner cases where the input signal needs...

updated resampling module to handle corner cases where the input signal needs to be extended towards the pole
parent 856a0f18
......@@ -11,7 +11,9 @@
* New filter basis normalization in DISCO convolutions
* Reworked DISCO filter basis datastructure
* Support for new filter basis types
* Adding Morlet-like basis functions on a spherical disk
* Adding Morlet wavelet basis functions on a spherical disk
* Cleaning up the SFNO example and adding new Local Spherical Neural Operator model
* Updated resampling module to extend input signal to the poles if needed
### v0.7.3
......
This diff is collapsed.
......@@ -440,19 +440,6 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
theta_cutoff=math.sqrt(2) * torch.pi / float(self.h - 1),
)
# # encoder
# self.encoder = DiscreteContinuousEncoder(
# inp_shape=self.img_size,
# out_shape=(self.h, self.w),
# grid_in=grid,
# grid_out=grid_internal,
# inp_chans=self.in_chans,
# out_chans=self.embed_dim,
# kernel_shape=self.encoder_kernel_shape,
# groups=1,
# bias=False,
# )
# prepare the spectral transform
if self.spectral_transform == "sht":
modes_lat = int(self.h * self.hard_thresholding_fraction)
......@@ -501,20 +488,6 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
self.blocks.append(block)
# # decoder
# self.decoder = DiscreteContinuousConvTransposeS2(
# self.embed_dim,
# self.out_chans,
# (self.h, self.w),
# self.img_size,
# self.encoder_kernel_shape,
# groups=1,
# grid_in="legendre-gauss",
# grid_out=grid,
# bias=False,
# theta_cutoff=math.sqrt(2) * torch.pi / float(self.h - 1),
# )
# decoder
self.decoder = DiscreteContinuousDecoder(
inp_shape=(self.h, self.w),
......
......@@ -71,11 +71,21 @@ class ResampleS2(nn.Module):
self.lats_out, _ = _precompute_latitudes(nlat_out, grid=grid_out)
self.lons_out = np.linspace(0, 2 * math.pi, nlon_out, endpoint=False)
# in the case where some points lie outside of the range spanned by lats_in,
# we need to expand the solution to the poles before interpolating
self.expand_poles = (self.lats_out > self.lats_in[-1]).any() or (self.lats_out < self.lats_in[0]).any()
if self.expand_poles:
self.lats_in = np.insert(self.lats_in, 0, 0.0)
self.lats_in = np.append(self.lats_in, np.pi)
# prepare the interpolation by computing indices to the left and right of each output latitude
lat_idx = np.searchsorted(self.lats_in, self.lats_out, side="right") - 1
# to guarantee everything stays in bounds
# make sure that we properly treat the last point if they coincide with the pole
lat_idx = np.where(self.lats_out == self.lats_in[-1], lat_idx - 1, lat_idx)
# lat_idx = np.where(self.lats_out > self.lats_in[-1], lat_idx - 1, lat_idx)
# lat_idx = np.where(self.lats_out < self.lats_in[0], 0, lat_idx)
# compute the interpolation weights along the latitude
lat_weights = torch.from_numpy((self.lats_out - self.lats_in[lat_idx]) / np.diff(self.lats_in)[lat_idx]).float()
lat_weights = lat_weights.unsqueeze(-1)
......@@ -123,12 +133,22 @@ class ResampleS2(nn.Module):
# x = torch.roll(x, - self.lon_shift, dims=-1)
# return x
def _expand_poles(self, x: torch.Tensor):
repeats = [1 for _ in x.shape]
repeats[-1] = x.shape[-1]
x_north = x[..., 0:1, :].mean(dim=-1, keepdim=True).repeat(*repeats)
x_south = x[..., -1:, :].mean(dim=-1, keepdim=True).repeat(*repeats)
x = torch.concatenate((x_north, x, x_south), dim=-2)
return x
def _upscale_latitudes(self, x: torch.Tensor):
# do the interpolation
x = torch.lerp(x[..., self.lat_idx, :], x[..., self.lat_idx + 1, :], self.lat_weights)
return x
def forward(self, x: torch.Tensor):
if self.expand_poles:
x = self._expand_poles(x)
x = self._upscale_latitudes(x)
x = self._upscale_longitudes(x)
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment