Commit 7286a0d6 authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

some minor bugfixes

parent a2b21fb6
......@@ -393,8 +393,10 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
dt = 1 * 3600
dt_solver = 150
nsteps = dt // dt_solver
dataset = PdeDataset(dt=dt, nsteps=nsteps, dims=(257, 512), device=device, grid="legendre-gauss", normalize=True)
dataset.sht = RealSHT(nlat=257, nlon=512, grid= "equiangular").to(device=device)
grid = "legendre-gauss"
nlat, nlon =(181, 360)
dataset = PdeDataset(dt=dt, nsteps=nsteps, dims=(nlat, nlon), device=device, grid=grid, normalize=True)
dataset.sht = RealSHT(nlat=nlat, nlon=nlon, grid= grid).to(device=device)
# There is still an issue with parallel dataloading. Do NOT use it at the moment
# dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, persistent_workers=True)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0, persistent_workers=False)
......@@ -412,27 +414,27 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
from torch_harmonics.examples.models import SphericalFourierNeuralOperatorNet as SFNO
from torch_harmonics.examples.models import LocalSphericalNeuralOperatorNet as LSNO
models[f"sfno_sc2_layers4_e32_nomlp_leggauss"] = partial(
models[f"sfno_sc2_layers4_e32"] = partial(
SFNO,
img_size=(nlat, nlon),
grid="legendre-gauss",
# hard_thresholding_fraction=0.8,
grid=grid,
hard_thresholding_fraction=0.8,
num_layers=4,
scale_factor=2,
embed_dim=32,
operator_type="driscoll-healy",
activation_function="gelu",
big_skip=False,
big_skip=True,
pos_embed=False,
use_mlp=False,
use_mlp=True,
normalization_layer="none",
)
models[f"lsno_sc1_layers4_e32_nomlp"] = partial(
models[f"lsno_sc2_layers4_e32"] = partial(
LSNO,
spectral_transform="sht",
img_size=(nlat, nlon),
grid="equiangular",
grid=grid,
num_layers=4,
scale_factor=2,
embed_dim=32,
......@@ -440,7 +442,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
activation_function="gelu",
big_skip=True,
pos_embed=False,
use_mlp=False,
use_mlp=True,
normalization_layer="none",
)
......
......@@ -34,7 +34,7 @@ import torch.nn as nn
from torch_harmonics import *
from .layers import *
from ._layers import *
from functools import partial
......
......@@ -33,7 +33,7 @@ import torch
from math import ceil
from ...shallow_water_equations import ShallowWaterSolver
from .shallow_water_equations import ShallowWaterSolver
class PdeDataset(torch.utils.data.Dataset):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment