Commit 3350099a authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

adding option in LSNO to select between upsampling methods

parent b6b2bce3
...@@ -445,7 +445,8 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0): ...@@ -445,7 +445,8 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
normalization_layer="none", normalization_layer="none",
kernel_shape=[4, 4], kernel_shape=[4, 4],
encoder_kernel_shape=[4, 4], encoder_kernel_shape=[4, 4],
filter_basis_type="morlet" filter_basis_type="morlet",
upsample_sht = True,
) )
models[f"lsno_sc2_layers4_e32_zernike"] = partial( models[f"lsno_sc2_layers4_e32_zernike"] = partial(
...@@ -463,7 +464,8 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0): ...@@ -463,7 +464,8 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
normalization_layer="none", normalization_layer="none",
kernel_shape=[4], kernel_shape=[4],
encoder_kernel_shape=[4], encoder_kernel_shape=[4],
filter_basis_type="zernike" filter_basis_type="zernike",
upsample_sht = True,
) )
# iterate over models and train each model # iterate over models and train each model
......
...@@ -70,7 +70,7 @@ class DiscreteContinuousEncoder(nn.Module): ...@@ -70,7 +70,7 @@ class DiscreteContinuousEncoder(nn.Module):
grid_out=grid_out, grid_out=grid_out,
groups=groups, groups=groups,
bias=bias, bias=bias,
theta_cutoff=1.0 * torch.pi / float(out_shape[0] - 1), theta_cutoff=4.0 * torch.pi / float(out_shape[0] - 1),
) )
def forward(self, x): def forward(self, x):
...@@ -97,13 +97,17 @@ class DiscreteContinuousDecoder(nn.Module): ...@@ -97,13 +97,17 @@ class DiscreteContinuousDecoder(nn.Module):
basis_type="piecewise linear", basis_type="piecewise linear",
groups=1, groups=1,
bias=False, bias=False,
upsample_sht=False
): ):
super().__init__() super().__init__()
# # set up # set up upsampling
if upsample_sht:
self.sht = RealSHT(*in_shape, grid=grid_in).float() self.sht = RealSHT(*in_shape, grid=grid_in).float()
self.isht = InverseRealSHT(*out_shape, lmax=self.sht.lmax, mmax=self.sht.mmax, grid=grid_out).float() self.isht = InverseRealSHT(*out_shape, lmax=self.sht.lmax, mmax=self.sht.mmax, grid=grid_out).float()
self.upscale = ResampleS2(*in_shape, *out_shape, grid_in=grid_in, grid_out=grid_out) self.upsample = nn.Sequential(self.sht, self.isht)
else:
self.upsample = ResampleS2(*in_shape, *out_shape, grid_in=grid_in, grid_out=grid_out)
# set up DISCO convolution # set up DISCO convolution
self.conv = DiscreteContinuousConvS2( self.conv = DiscreteContinuousConvS2(
...@@ -117,19 +121,15 @@ class DiscreteContinuousDecoder(nn.Module): ...@@ -117,19 +121,15 @@ class DiscreteContinuousDecoder(nn.Module):
grid_out=grid_out, grid_out=grid_out,
groups=groups, groups=groups,
bias=False, bias=False,
theta_cutoff=1.0 * torch.pi / float(in_shape[0] - 1), theta_cutoff=4.0 * torch.pi / float(in_shape[0] - 1),
) )
def upscale_sht(self, x: torch.Tensor):
return self.isht(self.sht(x))
def forward(self, x): def forward(self, x):
dtype = x.dtype dtype = x.dtype
x = self.upscale(x)
with amp.autocast(device_type="cuda", enabled=False): with amp.autocast(device_type="cuda", enabled=False):
x = x.float() x = x.float()
# x = self.upscale_sht(x) x = self.upsample(x)
x = self.conv(x) x = self.conv(x)
x = x.to(dtype=dtype) x = x.to(dtype=dtype)
...@@ -182,7 +182,7 @@ class SphericalNeuralOperatorBlock(nn.Module): ...@@ -182,7 +182,7 @@ class SphericalNeuralOperatorBlock(nn.Module):
grid_in=forward_transform.grid, grid_in=forward_transform.grid,
grid_out=inverse_transform.grid, grid_out=inverse_transform.grid,
bias=False, bias=False,
theta_cutoff=1.0 * (disco_kernel_shape[0] + 1) * torch.pi / float(inverse_transform.nlat - 1), theta_cutoff=4.0 * (disco_kernel_shape[0] + 1) * torch.pi / float(inverse_transform.nlat - 1),
) )
elif conv_type == "global": elif conv_type == "global":
self.global_conv = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain_factor, operator_type=operator_type, bias=False) self.global_conv = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain_factor, operator_type=operator_type, bias=False)
...@@ -309,6 +309,8 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -309,6 +309,8 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
Whether to add a single large skip connection, by default True Whether to add a single large skip connection, by default True
pos_embed : bool, optional pos_embed : bool, optional
Whether to use positional embedding, by default True Whether to use positional embedding, by default True
upsample_sht : bool, optional
Use SHT upsampling if true, else linear interpolation
Example Example
----------- -----------
...@@ -359,6 +361,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -359,6 +361,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
use_complex_kernels=True, use_complex_kernels=True,
big_skip=False, big_skip=False,
pos_embed=False, pos_embed=False,
upsample_sht=False,
): ):
super().__init__() super().__init__()
...@@ -491,6 +494,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -491,6 +494,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
basis_type=filter_basis_type, basis_type=filter_basis_type,
groups=1, groups=1,
bias=False, bias=False,
upsample_sht=upsample_sht
) )
# # residual prediction # # residual prediction
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment