Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
7286a0d6
"vscode:/vscode.git/clone" did not exist on "89bc30798f90dff8d5c8ac7014ea97a832b47674"
Commit
7286a0d6
authored
Dec 15, 2024
by
Boris Bonev
Committed by
Boris Bonev
Jan 14, 2025
Browse files
some minor bugfixes
parent
a2b21fb6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
14 additions
and
12 deletions
+14
-12
examples/train_sfno.py
examples/train_sfno.py
+12
-10
torch_harmonics/examples/models/sfno.py
torch_harmonics/examples/models/sfno.py
+1
-1
torch_harmonics/examples/pde_dataset.py
torch_harmonics/examples/pde_dataset.py
+1
-1
No files found.
examples/train_sfno.py
View file @
7286a0d6
...
@@ -393,8 +393,10 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
...
@@ -393,8 +393,10 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
dt
=
1
*
3600
dt
=
1
*
3600
dt_solver
=
150
dt_solver
=
150
nsteps
=
dt
//
dt_solver
nsteps
=
dt
//
dt_solver
dataset
=
PdeDataset
(
dt
=
dt
,
nsteps
=
nsteps
,
dims
=
(
257
,
512
),
device
=
device
,
grid
=
"legendre-gauss"
,
normalize
=
True
)
grid
=
"legendre-gauss"
dataset
.
sht
=
RealSHT
(
nlat
=
257
,
nlon
=
512
,
grid
=
"equiangular"
).
to
(
device
=
device
)
nlat
,
nlon
=
(
181
,
360
)
dataset
=
PdeDataset
(
dt
=
dt
,
nsteps
=
nsteps
,
dims
=
(
nlat
,
nlon
),
device
=
device
,
grid
=
grid
,
normalize
=
True
)
dataset
.
sht
=
RealSHT
(
nlat
=
nlat
,
nlon
=
nlon
,
grid
=
grid
).
to
(
device
=
device
)
# There is still an issue with parallel dataloading. Do NOT use it at the moment
# There is still an issue with parallel dataloading. Do NOT use it at the moment
# dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, persistent_workers=True)
# dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, persistent_workers=True)
dataloader
=
DataLoader
(
dataset
,
batch_size
=
4
,
shuffle
=
True
,
num_workers
=
0
,
persistent_workers
=
False
)
dataloader
=
DataLoader
(
dataset
,
batch_size
=
4
,
shuffle
=
True
,
num_workers
=
0
,
persistent_workers
=
False
)
...
@@ -412,27 +414,27 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
...
@@ -412,27 +414,27 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
from
torch_harmonics.examples.models
import
SphericalFourierNeuralOperatorNet
as
SFNO
from
torch_harmonics.examples.models
import
SphericalFourierNeuralOperatorNet
as
SFNO
from
torch_harmonics.examples.models
import
LocalSphericalNeuralOperatorNet
as
LSNO
from
torch_harmonics.examples.models
import
LocalSphericalNeuralOperatorNet
as
LSNO
models
[
f
"sfno_sc2_layers4_e32
_nomlp_leggauss
"
]
=
partial
(
models
[
f
"sfno_sc2_layers4_e32"
]
=
partial
(
SFNO
,
SFNO
,
img_size
=
(
nlat
,
nlon
),
img_size
=
(
nlat
,
nlon
),
grid
=
"legendre-gauss"
,
grid
=
grid
,
#
hard_thresholding_fraction=0.8,
hard_thresholding_fraction
=
0.8
,
num_layers
=
4
,
num_layers
=
4
,
scale_factor
=
2
,
scale_factor
=
2
,
embed_dim
=
32
,
embed_dim
=
32
,
operator_type
=
"driscoll-healy"
,
operator_type
=
"driscoll-healy"
,
activation_function
=
"gelu"
,
activation_function
=
"gelu"
,
big_skip
=
Fals
e
,
big_skip
=
Tru
e
,
pos_embed
=
False
,
pos_embed
=
False
,
use_mlp
=
Fals
e
,
use_mlp
=
Tru
e
,
normalization_layer
=
"none"
,
normalization_layer
=
"none"
,
)
)
models
[
f
"lsno_sc
1
_layers4_e32
_nomlp
"
]
=
partial
(
models
[
f
"lsno_sc
2
_layers4_e32"
]
=
partial
(
LSNO
,
LSNO
,
spectral_transform
=
"sht"
,
spectral_transform
=
"sht"
,
img_size
=
(
nlat
,
nlon
),
img_size
=
(
nlat
,
nlon
),
grid
=
"equiangular"
,
grid
=
grid
,
num_layers
=
4
,
num_layers
=
4
,
scale_factor
=
2
,
scale_factor
=
2
,
embed_dim
=
32
,
embed_dim
=
32
,
...
@@ -440,7 +442,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
...
@@ -440,7 +442,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
activation_function
=
"gelu"
,
activation_function
=
"gelu"
,
big_skip
=
True
,
big_skip
=
True
,
pos_embed
=
False
,
pos_embed
=
False
,
use_mlp
=
Fals
e
,
use_mlp
=
Tru
e
,
normalization_layer
=
"none"
,
normalization_layer
=
"none"
,
)
)
...
...
torch_harmonics/examples/models/sfno.py
View file @
7286a0d6
...
@@ -34,7 +34,7 @@ import torch.nn as nn
...
@@ -34,7 +34,7 @@ import torch.nn as nn
from
torch_harmonics
import
*
from
torch_harmonics
import
*
from
.layers
import
*
from
.
_
layers
import
*
from
functools
import
partial
from
functools
import
partial
...
...
torch_harmonics/examples/pde_dataset.py
View file @
7286a0d6
...
@@ -33,7 +33,7 @@ import torch
...
@@ -33,7 +33,7 @@ import torch
from
math
import
ceil
from
math
import
ceil
from
..
.shallow_water_equations
import
ShallowWaterSolver
from
.shallow_water_equations
import
ShallowWaterSolver
class
PdeDataset
(
torch
.
utils
.
data
.
Dataset
):
class
PdeDataset
(
torch
.
utils
.
data
.
Dataset
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment