Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
3350099a
Commit
3350099a
authored
Jan 13, 2025
by
Boris Bonev
Committed by
Boris Bonev
Jan 14, 2025
Browse files
adding option in LSNO to select between upsampling methods
parent
b6b2bce3
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
14 deletions
+20
-14
examples/train_sfno.py
examples/train_sfno.py
+4
-2
torch_harmonics/examples/models/lsno.py
torch_harmonics/examples/models/lsno.py
+16
-12
No files found.
examples/train_sfno.py
View file @
3350099a
...
...
@@ -445,7 +445,8 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
normalization_layer
=
"none"
,
kernel_shape
=
[
4
,
4
],
encoder_kernel_shape
=
[
4
,
4
],
filter_basis_type
=
"morlet"
filter_basis_type
=
"morlet"
,
upsample_sht
=
True
,
)
models
[
f
"lsno_sc2_layers4_e32_zernike"
]
=
partial
(
...
...
@@ -463,7 +464,8 @@ def main(train=True, load_checkpoint=False, enable_amp=False, log_grads=0):
normalization_layer
=
"none"
,
kernel_shape
=
[
4
],
encoder_kernel_shape
=
[
4
],
filter_basis_type
=
"zernike"
filter_basis_type
=
"zernike"
,
upsample_sht
=
True
,
)
# iterate over models and train each model
...
...
torch_harmonics/examples/models/lsno.py
View file @
3350099a
...
...
@@ -70,7 +70,7 @@ class DiscreteContinuousEncoder(nn.Module):
grid_out
=
grid_out
,
groups
=
groups
,
bias
=
bias
,
theta_cutoff
=
1
.0
*
torch
.
pi
/
float
(
out_shape
[
0
]
-
1
),
theta_cutoff
=
4
.0
*
torch
.
pi
/
float
(
out_shape
[
0
]
-
1
),
)
def
forward
(
self
,
x
):
...
...
@@ -97,13 +97,17 @@ class DiscreteContinuousDecoder(nn.Module):
basis_type
=
"piecewise linear"
,
groups
=
1
,
bias
=
False
,
upsample_sht
=
False
):
super
().
__init__
()
# # set up
# set up upsampling
if
upsample_sht
:
self
.
sht
=
RealSHT
(
*
in_shape
,
grid
=
grid_in
).
float
()
self
.
isht
=
InverseRealSHT
(
*
out_shape
,
lmax
=
self
.
sht
.
lmax
,
mmax
=
self
.
sht
.
mmax
,
grid
=
grid_out
).
float
()
self
.
upscale
=
ResampleS2
(
*
in_shape
,
*
out_shape
,
grid_in
=
grid_in
,
grid_out
=
grid_out
)
self
.
upsample
=
nn
.
Sequential
(
self
.
sht
,
self
.
isht
)
else
:
self
.
upsample
=
ResampleS2
(
*
in_shape
,
*
out_shape
,
grid_in
=
grid_in
,
grid_out
=
grid_out
)
# set up DISCO convolution
self
.
conv
=
DiscreteContinuousConvS2
(
...
...
@@ -117,19 +121,15 @@ class DiscreteContinuousDecoder(nn.Module):
grid_out
=
grid_out
,
groups
=
groups
,
bias
=
False
,
theta_cutoff
=
1
.0
*
torch
.
pi
/
float
(
in_shape
[
0
]
-
1
),
theta_cutoff
=
4
.0
*
torch
.
pi
/
float
(
in_shape
[
0
]
-
1
),
)
def
upscale_sht
(
self
,
x
:
torch
.
Tensor
):
return
self
.
isht
(
self
.
sht
(
x
))
def
forward
(
self
,
x
):
dtype
=
x
.
dtype
x
=
self
.
upscale
(
x
)
with
amp
.
autocast
(
device_type
=
"cuda"
,
enabled
=
False
):
x
=
x
.
float
()
#
x = self.ups
cale_sht
(x)
x
=
self
.
ups
ample
(
x
)
x
=
self
.
conv
(
x
)
x
=
x
.
to
(
dtype
=
dtype
)
...
...
@@ -182,7 +182,7 @@ class SphericalNeuralOperatorBlock(nn.Module):
grid_in
=
forward_transform
.
grid
,
grid_out
=
inverse_transform
.
grid
,
bias
=
False
,
theta_cutoff
=
1
.0
*
(
disco_kernel_shape
[
0
]
+
1
)
*
torch
.
pi
/
float
(
inverse_transform
.
nlat
-
1
),
theta_cutoff
=
4
.0
*
(
disco_kernel_shape
[
0
]
+
1
)
*
torch
.
pi
/
float
(
inverse_transform
.
nlat
-
1
),
)
elif
conv_type
==
"global"
:
self
.
global_conv
=
SpectralConvS2
(
forward_transform
,
inverse_transform
,
input_dim
,
output_dim
,
gain
=
gain_factor
,
operator_type
=
operator_type
,
bias
=
False
)
...
...
@@ -309,6 +309,8 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
Whether to add a single large skip connection, by default True
pos_embed : bool, optional
Whether to use positional embedding, by default True
upsample_sht : bool, optional
Use SHT upsampling if true, else linear interpolation
Example
-----------
...
...
@@ -359,6 +361,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
use_complex_kernels
=
True
,
big_skip
=
False
,
pos_embed
=
False
,
upsample_sht
=
False
,
):
super
().
__init__
()
...
...
@@ -491,6 +494,7 @@ class LocalSphericalNeuralOperatorNet(nn.Module):
basis_type
=
filter_basis_type
,
groups
=
1
,
bias
=
False
,
upsample_sht
=
upsample_sht
)
# # residual prediction
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment