Commit 3f125603 authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

fixing sclae factor logic to assume that the poles are included in the latitiude grid

parent 4369d182
...@@ -373,7 +373,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0): ...@@ -373,7 +373,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
dt = 1 * 3600 dt = 1 * 3600
dt_solver = 150 dt_solver = 150
nsteps = dt // dt_solver nsteps = dt // dt_solver
dataset = PdeDataset(dt=dt, nsteps=nsteps, dims=(256, 512), device=device, normalize=True) dataset = PdeDataset(dt=dt, nsteps=nsteps, dims=(257, 512), device=device, normalize=True)
# There is still an issue with parallel dataloading. Do NOT use it at the moment # There is still an issue with parallel dataloading. Do NOT use it at the moment
# dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, persistent_workers=True) # dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, persistent_workers=True)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0, persistent_workers=False) dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0, persistent_workers=False)
......
...@@ -454,8 +454,8 @@ class LocalSphericalNeuralOperatorNet(nn.Module): ...@@ -454,8 +454,8 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
else: else:
raise ValueError(f"Unknown activation function {activation_function}") raise ValueError(f"Unknown activation function {activation_function}")
# compute downsampled image size # compute downsampled image size. We assume that the latitude-grid includes both poles
self.h = self.img_size[0] // scale_factor self.h = (self.img_size[0] - 1) // scale_factor
self.w = self.img_size[1] // scale_factor self.w = self.img_size[1] // scale_factor
# dropout # dropout
......
...@@ -325,8 +325,8 @@ class SphericalFourierNeuralOperatorNet(nn.Module): ...@@ -325,8 +325,8 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
else: else:
raise ValueError(f"Unknown activation function {activation_function}") raise ValueError(f"Unknown activation function {activation_function}")
# compute downsampled image size # compute downsampled image size. We assume that the latitude-grid includes both poles
self.h = self.img_size[0] // scale_factor self.h = (self.img_size[0] - 1) // scale_factor
self.w = self.img_size[1] // scale_factor self.w = self.img_size[1] // scale_factor
# dropout # dropout
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment