"vscode:/vscode.git/clone" did not exist on "89bc30798f90dff8d5c8ac7014ea97a832b47674"
Commit 7286a0d6 authored by Boris Bonev's avatar Boris Bonev Committed by Boris Bonev
Browse files

some minor bugfixes

parent a2b21fb6
...@@ -393,8 +393,10 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0): ...@@ -393,8 +393,10 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
dt = 1 * 3600 dt = 1 * 3600
dt_solver = 150 dt_solver = 150
nsteps = dt // dt_solver nsteps = dt // dt_solver
dataset = PdeDataset(dt=dt, nsteps=nsteps, dims=(257, 512), device=device, grid="legendre-gauss", normalize=True) grid = "legendre-gauss"
dataset.sht = RealSHT(nlat=257, nlon=512, grid= "equiangular").to(device=device) nlat, nlon =(181, 360)
dataset = PdeDataset(dt=dt, nsteps=nsteps, dims=(nlat, nlon), device=device, grid=grid, normalize=True)
dataset.sht = RealSHT(nlat=nlat, nlon=nlon, grid= grid).to(device=device)
# There is still an issue with parallel dataloading. Do NOT use it at the moment # There is still an issue with parallel dataloading. Do NOT use it at the moment
# dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, persistent_workers=True) # dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, persistent_workers=True)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0, persistent_workers=False) dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0, persistent_workers=False)
...@@ -412,27 +414,27 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0): ...@@ -412,27 +414,27 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
from torch_harmonics.examples.models import SphericalFourierNeuralOperatorNet as SFNO from torch_harmonics.examples.models import SphericalFourierNeuralOperatorNet as SFNO
from torch_harmonics.examples.models import LocalSphericalNeuralOperatorNet as LSNO from torch_harmonics.examples.models import LocalSphericalNeuralOperatorNet as LSNO
models[f"sfno_sc2_layers4_e32_nomlp_leggauss"] = partial( models[f"sfno_sc2_layers4_e32"] = partial(
SFNO, SFNO,
img_size=(nlat, nlon), img_size=(nlat, nlon),
grid="legendre-gauss", grid=grid,
# hard_thresholding_fraction=0.8, hard_thresholding_fraction=0.8,
num_layers=4, num_layers=4,
scale_factor=2, scale_factor=2,
embed_dim=32, embed_dim=32,
operator_type="driscoll-healy", operator_type="driscoll-healy",
activation_function="gelu", activation_function="gelu",
big_skip=False, big_skip=True,
pos_embed=False, pos_embed=False,
use_mlp=False, use_mlp=True,
normalization_layer="none", normalization_layer="none",
) )
models[f"lsno_sc1_layers4_e32_nomlp"] = partial( models[f"lsno_sc2_layers4_e32"] = partial(
LSNO, LSNO,
spectral_transform="sht", spectral_transform="sht",
img_size=(nlat, nlon), img_size=(nlat, nlon),
grid="equiangular", grid=grid,
num_layers=4, num_layers=4,
scale_factor=2, scale_factor=2,
embed_dim=32, embed_dim=32,
...@@ -440,7 +442,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0): ...@@ -440,7 +442,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
activation_function="gelu", activation_function="gelu",
big_skip=True, big_skip=True,
pos_embed=False, pos_embed=False,
use_mlp=False, use_mlp=True,
normalization_layer="none", normalization_layer="none",
) )
......
...@@ -34,7 +34,7 @@ import torch.nn as nn ...@@ -34,7 +34,7 @@ import torch.nn as nn
from torch_harmonics import * from torch_harmonics import *
from .layers import * from ._layers import *
from functools import partial from functools import partial
......
...@@ -33,7 +33,7 @@ import torch ...@@ -33,7 +33,7 @@ import torch
from math import ceil from math import ceil
from ...shallow_water_equations import ShallowWaterSolver from .shallow_water_equations import ShallowWaterSolver
class PdeDataset(torch.utils.data.Dataset): class PdeDataset(torch.utils.data.Dataset):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment