"src/diffusers/pipelines/ddpm/pipeline_ddpm.py" did not exist on "dd4cd081db39d6769060bb48d0137b832789f015"
Commit 3350099a authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

adding option in LSNO to select between upsampling methods

parent b6b2bce3
......@@ -445,7 +445,8 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
normalization_layer="none",
kernel_shape=[4, 4],
encoder_kernel_shape=[4, 4],
filter_basis_type="morlet"
filter_basis_type="morlet",
upsample_sht = True,
)
models[f"lsno_sc2_layers4_e32_zernike"] = partial(
......@@ -463,7 +464,8 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
normalization_layer="none",
kernel_shape=[4],
encoder_kernel_shape=[4],
filter_basis_type="zernike"
filter_basis_type="zernike",
upsample_sht = True,
)
# iterate over models and train each model
......
......@@ -70,7 +70,7 @@ class DiscreteContinuousEncoder(nn.Module):
grid_out=grid_out,
groups=groups,
bias=bias,
theta_cutoff=1.0 * torch.pi / float(out_shape[0] - 1),
theta_cutoff=4.0 * torch.pi / float(out_shape[0] - 1),
)
def forward(self, x):
......@@ -97,13 +97,17 @@ class DiscreteContinuousDecoder(nn.Module):
basis_type="piecewise linear",
groups=1,
bias=False,
upsample_sht=False
):
super().__init__()
# # set up
# set up upsampling
if upsample_sht:
self.sht = RealSHT(*in_shape, grid=grid_in).float()
self.isht = InverseRealSHT(*out_shape, lmax=self.sht.lmax, mmax=self.sht.mmax, grid=grid_out).float()
self.upscale = ResampleS2(*in_shape, *out_shape, grid_in=grid_in, grid_out=grid_out)
self.upsample = nn.Sequential(self.sht, self.isht)
else:
self.upsample = ResampleS2(*in_shape, *out_shape, grid_in=grid_in, grid_out=grid_out)
# set up DISCO convolution
self.conv = DiscreteContinuousConvS2(
......@@ -117,19 +121,15 @@ class DiscreteContinuousDecoder(nn.Module):
grid_out=grid_out,
groups=groups,
bias=False,
theta_cutoff=1.0 * torch.pi / float(in_shape[0] - 1),
theta_cutoff=4.0 * torch.pi / float(in_shape[0] - 1),
)
def upscale_sht(self, x: torch.Tensor):
return self.isht(self.sht(x))
def forward(self, x):
dtype = x.dtype
x = self.upscale(x)
with amp.autocast(device_type="cuda", enabled=False):
x = x.float()
# x = self.upscale_sht(x)
x = self.upsample(x)
x = self.conv(x)
x = x.to(dtype=dtype)
......@@ -182,7 +182,7 @@ class SphericalNeuralOperatorBlock(nn.Module):
grid_in=forward_transform.grid,
grid_out=inverse_transform.grid,
bias=False,
theta_cutoff=1.0 * (disco_kernel_shape[0] + 1) * torch.pi / float(inverse_transform.nlat - 1),
theta_cutoff=4.0 * (disco_kernel_shape[0] + 1) * torch.pi / float(inverse_transform.nlat - 1),
)
elif conv_type == "global":
self.global_conv = SpectralConvS2(forward_transform, inverse_transform, input_dim, output_dim, gain=gain_factor, operator_type=operator_type, bias=False)
......@@ -309,6 +309,8 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
Whether to add a single large skip connection, by default True
pos_embed : bool, optional
Whether to use positional embedding, by default True
upsample_sht : bool, optional
Use SHT upsampling if true, else linear interpolation
Example
-----------
......@@ -359,6 +361,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
use_complex_kernels=True,
big_skip=False,
pos_embed=False,
upsample_sht=False,
):
super().__init__()
......@@ -491,6 +494,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
basis_type=filter_basis_type,
groups=1,
bias=False,
upsample_sht=upsample_sht
)
# # residual prediction
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment