Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
d81fbd34
Commit
d81fbd34
authored
Jan 10, 2025
by
Boris Bonev
Committed by
Boris Bonev
Jan 14, 2025
Browse files
changing default normalization mode in DISCO
parent
96a2b546
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
60 additions
and
78 deletions
+60
-78
examples/train_sfno.py
examples/train_sfno.py
+15
-16
notebooks/resample_sphere.ipynb
notebooks/resample_sphere.ipynb
+12
-19
notebooks/train_sfno.ipynb
notebooks/train_sfno.ipynb
+4
-9
torch_harmonics/convolution.py
torch_harmonics/convolution.py
+4
-4
torch_harmonics/distributed/distributed_convolution.py
torch_harmonics/distributed/distributed_convolution.py
+3
-3
torch_harmonics/examples/models/lsno.py
torch_harmonics/examples/models/lsno.py
+22
-27
No files found.
examples/train_sfno.py
View file @
d81fbd34
...
...
@@ -430,21 +430,20 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
normalization_layer
=
"none"
,
)
# models[f"lsno_sc2_layers4_e32"] = partial(
# LSNO,
# spectral_transform="sht",
# img_size=(nlat, nlon),
# grid=grid,
# num_layers=4,
# scale_factor=2,
# embed_dim=32,
# operator_type="driscoll-healy",
# activation_function="gelu",
# big_skip=True,
# pos_embed=False,
# use_mlp=True,
# normalization_layer="none",
# )
models
[
f
"lsno_sc2_layers4_e32"
]
=
partial
(
LSNO
,
img_size
=
(
nlat
,
nlon
),
grid
=
grid
,
num_layers
=
4
,
scale_factor
=
2
,
embed_dim
=
32
,
operator_type
=
"driscoll-healy"
,
activation_function
=
"gelu"
,
big_skip
=
False
,
pos_embed
=
False
,
use_mlp
=
True
,
normalization_layer
=
"none"
,
)
# iterate over models and train each model
root_path
=
os
.
path
.
dirname
(
__file__
)
...
...
@@ -487,7 +486,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
scheduler
=
torch
.
optim
.
lr_scheduler
.
ReduceLROnPlateau
(
optimizer
,
'min'
)
gscaler
=
amp
.
GradScaler
(
enabled
=
enable_amp
)
dataloader
.
dataset
.
nsteps
=
2
*
dt
//
dt_solver
train_model
(
model
,
dataloader
,
optimizer
,
gscaler
,
scheduler
,
nepochs
=
2
0
,
loss_fn
=
"l2"
,
nfuture
=
nfuture
,
enable_amp
=
enable_amp
,
log_grads
=
log_grads
)
train_model
(
model
,
dataloader
,
optimizer
,
gscaler
,
scheduler
,
nepochs
=
1
0
,
loss_fn
=
"l2"
,
nfuture
=
nfuture
,
enable_amp
=
enable_amp
,
log_grads
=
log_grads
)
dataloader
.
dataset
.
nsteps
=
1
*
dt
//
dt_solver
training_time
=
time
.
time
()
-
start_time
...
...
notebooks/resample_sphere.ipynb
View file @
d81fbd34
This source diff could not be displayed because it is too large. You can
view the blob
instead.
notebooks/train_sfno.ipynb
View file @
d81fbd34
...
...
@@ -176,7 +176,7 @@
"# activation_function = nn.ReLU,\n",
"# bias = False):\n",
"# super().__init__()\n",
"
\n",
"\n",
"# current_dim = input_dim\n",
"# layers = []\n",
"# for l in range(num_layers-1):\n",
...
...
@@ -221,7 +221,7 @@
" loss = solver.integrate_grid((prd - tar)**2, dimensionless=True).sum(dim=-1)\n",
" if relative:\n",
" loss = loss / solver.integrate_grid(tar**2, dimensionless=True).sum(dim=-1)\n",
"
\n",
"\n",
" if not squared:\n",
" loss = torch.sqrt(loss)\n",
" loss = loss.mean()\n",
...
...
@@ -515,7 +515,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3
(ipykernel)
",
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
...
...
@@ -531,12 +531,7 @@
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
}
}
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
...
...
torch_harmonics/convolution.py
View file @
d81fbd34
...
...
@@ -57,7 +57,7 @@ except ImportError as err:
def
_normalize_convolution_tensor_s2
(
psi_idx
,
psi_vals
,
in_shape
,
out_shape
,
kernel_size
,
quad_weights
,
transpose_normalization
=
False
,
basis_norm_mode
=
"
none
"
,
merge_quadrature
=
False
,
eps
=
1e-9
psi_idx
,
psi_vals
,
in_shape
,
out_shape
,
kernel_size
,
quad_weights
,
transpose_normalization
=
False
,
basis_norm_mode
=
"
mean
"
,
merge_quadrature
=
False
,
eps
=
1e-9
):
"""
Discretely normalizes the convolution tensor and pre-applies quadrature weights. Supports the following three normalization modes:
...
...
@@ -135,7 +135,7 @@ def _precompute_convolution_tensor_s2(
grid_out
=
"equiangular"
,
theta_cutoff
=
0.01
*
math
.
pi
,
transpose_normalization
=
False
,
basis_norm_mode
=
"
none
"
,
basis_norm_mode
=
"
mean
"
,
merge_quadrature
=
False
,
):
"""
...
...
@@ -297,7 +297,7 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
out_shape
:
Tuple
[
int
],
kernel_shape
:
Union
[
int
,
List
[
int
]],
basis_type
:
Optional
[
str
]
=
"piecewise linear"
,
basis_norm_mode
:
Optional
[
str
]
=
"
none
"
,
basis_norm_mode
:
Optional
[
str
]
=
"
mean
"
,
groups
:
Optional
[
int
]
=
1
,
grid_in
:
Optional
[
str
]
=
"equiangular"
,
grid_out
:
Optional
[
str
]
=
"equiangular"
,
...
...
@@ -403,7 +403,7 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
out_shape
:
Tuple
[
int
],
kernel_shape
:
Union
[
int
,
List
[
int
]],
basis_type
:
Optional
[
str
]
=
"piecewise linear"
,
basis_norm_mode
:
Optional
[
str
]
=
"
none
"
,
basis_norm_mode
:
Optional
[
str
]
=
"
mean
"
,
groups
:
Optional
[
int
]
=
1
,
grid_in
:
Optional
[
str
]
=
"equiangular"
,
grid_out
:
Optional
[
str
]
=
"equiangular"
,
...
...
torch_harmonics/distributed/distributed_convolution.py
View file @
d81fbd34
...
...
@@ -76,7 +76,7 @@ def _precompute_distributed_convolution_tensor_s2(
grid_out
=
"equiangular"
,
theta_cutoff
=
0.01
*
math
.
pi
,
transpose_normalization
=
False
,
basis_norm_mode
=
"
none
"
,
basis_norm_mode
=
"
mean
"
,
merge_quadrature
=
False
,
):
"""
...
...
@@ -208,7 +208,7 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
out_shape
:
Tuple
[
int
],
kernel_shape
:
Union
[
int
,
List
[
int
]],
basis_type
:
Optional
[
str
]
=
"piecewise linear"
,
basis_norm_mode
:
Optional
[
str
]
=
"
none
"
,
basis_norm_mode
:
Optional
[
str
]
=
"
mean
"
,
groups
:
Optional
[
int
]
=
1
,
grid_in
:
Optional
[
str
]
=
"equiangular"
,
grid_out
:
Optional
[
str
]
=
"equiangular"
,
...
...
@@ -348,7 +348,7 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
out_shape
:
Tuple
[
int
],
kernel_shape
:
Union
[
int
,
List
[
int
]],
basis_type
:
Optional
[
str
]
=
"piecewise linear"
,
basis_norm_mode
:
Optional
[
str
]
=
"
none
"
,
basis_norm_mode
:
Optional
[
str
]
=
"
mean
"
,
groups
:
Optional
[
int
]
=
1
,
grid_in
:
Optional
[
str
]
=
"equiangular"
,
grid_out
:
Optional
[
str
]
=
"equiangular"
,
...
...
torch_harmonics/examples/models/lsno.py
View file @
d81fbd34
...
...
@@ -35,6 +35,7 @@ import torch.amp as amp
from
torch_harmonics
import
RealSHT
,
InverseRealSHT
from
torch_harmonics
import
DiscreteContinuousConvS2
,
DiscreteContinuousConvTransposeS2
from
torch_harmonics
import
ResampleS2
from
._layers
import
*
...
...
@@ -44,7 +45,7 @@ from functools import partial
class
DiscreteContinuousEncoder
(
nn
.
Module
):
def
__init__
(
self
,
in
p
_shape
=
(
721
,
1440
),
in_shape
=
(
721
,
1440
),
out_shape
=
(
480
,
960
),
grid_in
=
"equiangular"
,
grid_out
=
"equiangular"
,
...
...
@@ -61,7 +62,7 @@ class DiscreteContinuousEncoder(nn.Module):
self
.
conv
=
DiscreteContinuousConvS2
(
inp_chans
,
out_chans
,
in_shape
=
in
p
_shape
,
in_shape
=
in_shape
,
out_shape
=
out_shape
,
kernel_shape
=
kernel_shape
,
basis_type
=
basis_type
,
...
...
@@ -69,7 +70,7 @@ class DiscreteContinuousEncoder(nn.Module):
grid_out
=
grid_out
,
groups
=
groups
,
bias
=
bias
,
theta_cutoff
=
math
.
sqrt
(
2
)
*
torch
.
pi
/
float
(
out_shape
[
0
]
-
1
),
theta_cutoff
=
math
.
sqrt
(
2
.0
)
*
torch
.
pi
/
float
(
out_shape
[
0
]
-
1
),
)
def
forward
(
self
,
x
):
...
...
@@ -86,7 +87,7 @@ class DiscreteContinuousEncoder(nn.Module):
class
DiscreteContinuousDecoder
(
nn
.
Module
):
def
__init__
(
self
,
in
p
_shape
=
(
480
,
960
),
in_shape
=
(
480
,
960
),
out_shape
=
(
721
,
1440
),
grid_in
=
"equiangular"
,
grid_out
=
"equiangular"
,
...
...
@@ -99,12 +100,13 @@ class DiscreteContinuousDecoder(nn.Module):
):
super
().
__init__
()
# set up
self
.
sht
=
RealSHT
(
*
in
p
_shape
,
grid
=
grid_in
).
float
()
#
#
set up
self
.
sht
=
RealSHT
(
*
in_shape
,
grid
=
grid_in
).
float
()
self
.
isht
=
InverseRealSHT
(
*
out_shape
,
lmax
=
self
.
sht
.
lmax
,
mmax
=
self
.
sht
.
mmax
,
grid
=
grid_out
).
float
()
# self.upscale = ResampleS2(*in_shape, *out_shape, grid_in=grid_in, grid_out=grid_out)
# set up DISCO convolution
self
.
conv
t
=
DiscreteContinuousConv
Transpose
S2
(
self
.
conv
=
DiscreteContinuousConvS2
(
inp_chans
,
out_chans
,
in_shape
=
out_shape
,
...
...
@@ -115,21 +117,22 @@ class DiscreteContinuousDecoder(nn.Module):
grid_out
=
grid_out
,
groups
=
groups
,
bias
=
False
,
theta_cutoff
=
math
.
sqrt
(
2
)
*
torch
.
pi
/
float
(
in
p
_shape
[
0
]
-
1
),
theta_cutoff
=
math
.
sqrt
(
2
.0
)
*
torch
.
pi
/
float
(
in_shape
[
0
]
-
1
),
)
# self.convt = nn.Conv2d(inp_chans, out_chans, 1, bias=False)
def
_
upscale_sht
(
self
,
x
:
torch
.
Tensor
):
def
upscale_sht
(
self
,
x
:
torch
.
Tensor
):
return
self
.
isht
(
self
.
sht
(
x
))
def
forward
(
self
,
x
):
dtype
=
x
.
dtype
# x = self.upscale(x)
with
amp
.
autocast
(
device_type
=
"cuda"
,
enabled
=
False
):
x
=
x
.
float
()
x
=
self
.
_
upscale_sht
(
x
)
x
=
self
.
conv
t
(
x
)
x
=
self
.
upscale_sht
(
x
)
x
=
self
.
conv
(
x
)
x
=
x
.
to
(
dtype
=
dtype
)
return
x
...
...
@@ -182,7 +185,7 @@ class SphericalNeuralOperatorBlock(nn.Module):
grid_in
=
forward_transform
.
grid
,
grid_out
=
inverse_transform
.
grid
,
bias
=
False
,
theta_cutoff
=
4
*
math
.
sqrt
(
2
)
*
torch
.
pi
/
float
(
inverse_transform
.
nlat
-
1
),
theta_cutoff
=
4
*
math
.
sqrt
(
2
.0
)
*
torch
.
pi
/
float
(
inverse_transform
.
nlat
-
1
),
)
elif
conv_type
==
"global"
:
self
.
global_conv
=
SpectralConvS2
(
forward_transform
,
inverse_transform
,
input_dim
,
output_dim
,
gain
=
gain_factor
,
operator_type
=
operator_type
,
bias
=
False
)
...
...
@@ -272,8 +275,6 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
Parameters
-----------
spectral_transform : str, optional
Type of spectral transformation to use, by default "sht"
operator_type : str, optional
Type of operator to use ('driscoll-healy', 'diagonal'), by default "driscoll-healy"
img_shape : tuple, optional
...
...
@@ -339,7 +340,6 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
def
__init__
(
self
,
spectral_transform
=
"sht"
,
operator_type
=
"driscoll-healy"
,
img_size
=
(
128
,
256
),
grid
=
"equiangular"
,
...
...
@@ -365,7 +365,6 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
):
super
().
__init__
()
self
.
spectral_transform
=
spectral_transform
self
.
operator_type
=
operator_type
self
.
img_size
=
img_size
self
.
grid
=
grid
...
...
@@ -440,8 +439,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
theta_cutoff
=
math
.
sqrt
(
2
)
*
torch
.
pi
/
float
(
self
.
h
-
1
),
)
# prepare the spectral transform
if
self
.
spectral_transform
==
"sht"
:
# prepare the SHT
modes_lat
=
int
(
self
.
h
*
self
.
hard_thresholding_fraction
)
modes_lon
=
int
(
self
.
w
//
2
*
self
.
hard_thresholding_fraction
)
modes_lat
=
modes_lon
=
min
(
modes_lat
,
modes_lon
)
...
...
@@ -449,9 +447,6 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
self
.
trans
=
RealSHT
(
self
.
h
,
self
.
w
,
lmax
=
modes_lat
,
mmax
=
modes_lon
,
grid
=
grid_internal
).
float
()
self
.
itrans
=
InverseRealSHT
(
self
.
h
,
self
.
w
,
lmax
=
modes_lat
,
mmax
=
modes_lon
,
grid
=
grid_internal
).
float
()
else
:
raise
(
ValueError
(
"Unknown spectral transform"
))
self
.
blocks
=
nn
.
ModuleList
([])
for
i
in
range
(
self
.
num_layers
):
first_layer
=
i
==
0
...
...
@@ -490,7 +485,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
# decoder
self
.
decoder
=
DiscreteContinuousDecoder
(
in
p
_shape
=
(
self
.
h
,
self
.
w
),
in_shape
=
(
self
.
h
,
self
.
w
),
out_shape
=
self
.
img_size
,
grid_in
=
grid_internal
,
grid_out
=
grid
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment