Commit 3f125603 authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

fixing sclae factor logic to assume that the poles are included in the latitiude grid

parent 4369d182
......@@ -373,7 +373,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
dt = 1 * 3600
dt_solver = 150
nsteps = dt // dt_solver
dataset = PdeDataset(dt=dt, nsteps=nsteps, dims=(256, 512), device=device, normalize=True)
dataset = PdeDataset(dt=dt, nsteps=nsteps, dims=(257, 512), device=device, normalize=True)
# There is still an issue with parallel dataloading. Do NOT use it at the moment
# dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, persistent_workers=True)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0, persistent_workers=False)
......
......@@ -454,8 +454,8 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
else:
raise ValueError(f"Unknown activation function {activation_function}")
# compute downsampled image size
self.h = self.img_size[0] // scale_factor
# compute downsampled image size. We assume that the latitude-grid includes both poles
self.h = (self.img_size[0] - 1) // scale_factor
self.w = self.img_size[1] // scale_factor
# dropout
......
......@@ -325,8 +325,8 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
else:
raise ValueError(f"Unknown activation function {activation_function}")
# compute downsampled image size
self.h = self.img_size[0] // scale_factor
# compute downsampled image size. We assume that the latitude-grid includes both poles
self.h = (self.img_size[0] - 1) // scale_factor
self.w = self.img_size[1] // scale_factor
# dropout
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment