Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
7fb0c483
Unverified
Commit
7fb0c483
authored
Oct 04, 2023
by
Boris Bonev
Committed by
GitHub
Oct 04, 2023
Browse files
Bbonev/sfno update (#10)
* reworked SFNO example * updated changelog
parent
cec07d7a
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
247 additions
and
159 deletions
+247
-159
Changelog.md
Changelog.md
+1
-0
examples/train_sfno.py
examples/train_sfno.py
+1
-1
notebooks/train_sfno.ipynb
notebooks/train_sfno.ipynb
+41
-41
torch_harmonics/examples/sfno/models/contractions.py
torch_harmonics/examples/sfno/models/contractions.py
+10
-30
torch_harmonics/examples/sfno/models/factorizations.py
torch_harmonics/examples/sfno/models/factorizations.py
+3
-3
torch_harmonics/examples/sfno/models/layers.py
torch_harmonics/examples/sfno/models/layers.py
+95
-8
torch_harmonics/examples/sfno/models/sfno.py
torch_harmonics/examples/sfno/models/sfno.py
+96
-76
No files found.
Changelog.md
View file @
7fb0c483
...
...
@@ -5,6 +5,7 @@
### v0.6.3
*
Adding gradient check in unit tests
*
Updated SFNO example
### v0.6.2
...
...
examples/train_sfno.py
View file @
7fb0c483
...
...
@@ -334,7 +334,7 @@ def main(train=True, load_checkpoint=False, enable_amp=False):
# SFNO models
models
[
'sfno_sc3_layer4_edim256_linear'
]
=
partial
(
SFNO
,
spectral_transform
=
'sht'
,
filter_type
=
'linear'
,
img_size
=
(
nlat
,
nlon
),
num_layers
=
4
,
scale_factor
=
3
,
embed_dim
=
256
,
operator_type
=
'
vector
'
)
num_layers
=
4
,
scale_factor
=
3
,
embed_dim
=
256
,
operator_type
=
'
driscoll-healy
'
)
models
[
'sfno_sc3_layer4_edim256_real'
]
=
partial
(
SFNO
,
spectral_transform
=
'sht'
,
filter_type
=
'non-linear'
,
img_size
=
(
nlat
,
nlon
),
num_layers
=
4
,
scale_factor
=
3
,
embed_dim
=
256
,
complex_activation
=
'real'
,
operator_type
=
'diagonal'
)
# FNO models
...
...
notebooks/train_sfno.ipynb
View file @
7fb0c483
This diff is collapsed.
Click to expand it.
torch_harmonics/examples/sfno/models/contractions.py
View file @
7fb0c483
...
...
@@ -36,32 +36,27 @@ Contains complex contractions wrapped into jit for harmonic layers
"""
@
torch
.
jit
.
script
def
compl_contract2d_fwd
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
tmp
=
torch
.
einsum
(
"bixys,kixr->srbkx"
,
a
,
b
)
res
=
torch
.
stack
([
tmp
[
0
,
0
,...]
-
tmp
[
1
,
1
,...],
tmp
[
1
,
0
,...]
+
tmp
[
0
,
1
,...]],
dim
=-
1
)
return
res
@
torch
.
jit
.
script
def
compl_contract2d_fwd_c
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
contract_diagonal
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
ac
=
torch
.
view_as_complex
(
a
)
bc
=
torch
.
view_as_complex
(
b
)
res
=
torch
.
einsum
(
"bixy,kix->bkx"
,
ac
,
bc
)
res
=
torch
.
einsum
(
"bixy,kix
y
->bkx
y
"
,
ac
,
bc
)
return
torch
.
view_as_real
(
res
)
@
torch
.
jit
.
script
def
compl_contract_fwd
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
tmp
=
torch
.
einsum
(
"bins,kinr->srbkn"
,
a
,
b
)
res
=
torch
.
stack
([
tmp
[
0
,
0
,...]
-
tmp
[
1
,
1
,...],
tmp
[
1
,
0
,...]
+
tmp
[
0
,
1
,...]],
dim
=-
1
)
return
res
def
contract_dhconv
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
ac
=
torch
.
view_as_complex
(
a
)
bc
=
torch
.
view_as_complex
(
b
)
res
=
torch
.
einsum
(
"bixy,kix->bkxy"
,
ac
,
bc
)
return
torch
.
view_as_real
(
res
)
@
torch
.
jit
.
script
def
compl_
contract_
fwd_c
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
contract_
blockdiag
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
ac
=
torch
.
view_as_complex
(
a
)
bc
=
torch
.
view_as_complex
(
b
)
res
=
torch
.
einsum
(
"bi
n
,ki
n
->bk
n
"
,
ac
,
bc
)
res
=
torch
.
einsum
(
"bi
xy
,ki
xyz
->bk
xz
"
,
ac
,
bc
)
return
torch
.
view_as_real
(
res
)
# Helper routines for
spherical MLPs
# Helper routines for
the non-linear FNOs (Attention-like)
@
torch
.
jit
.
script
def
compl_mul1d_fwd
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
tmp
=
torch
.
einsum
(
"bixs,ior->srbox"
,
a
,
b
)
...
...
@@ -124,18 +119,3 @@ def real_mul2d_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
def
real_muladd2d_fwd
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
c
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
compl_mul2d_fwd_c
(
a
,
b
)
+
c
# for all the experimental layers
# @torch.jit.script
# def compl_exp_mul2d_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
# ac = torch.view_as_complex(a)
# bc = torch.view_as_complex(b)
# resc = torch.einsum("bixy,xio->boxy", ac, bc)
# res = torch.view_as_real(resc)
# return res
# @torch.jit.script
# def compl_exp_muladd2d_fwd(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
# tmpcc = torch.view_as_complex(compl_exp_mul2d_fwd(a, b))
# cc = torch.view_as_complex(c)
# return torch.view_as_real(tmpcc + cc)
torch_harmonics/examples/sfno/models/factorizations.py
View file @
7fb0c483
...
...
@@ -59,7 +59,7 @@ def _contract_dense(x, weight, separable=False, operator_type='diagonal'):
elif
operator_type
==
'block-diagonal'
:
weight_syms
.
insert
(
-
1
,
einsum_symbols
[
order
+
1
])
out_syms
[
-
1
]
=
weight_syms
[
-
2
]
elif
operator_type
==
'
vector
'
:
elif
operator_type
==
'
driscoll-healy
'
:
weight_syms
.
pop
()
else
:
raise
ValueError
(
f
"Unkonw operator type
{
operator_type
}
"
)
...
...
@@ -92,7 +92,7 @@ def _contract_cp(x, cp_weight, separable=False, operator_type='diagonal'):
elif
operator_type
==
'block-diagonal'
:
out_syms
[
-
1
]
=
einsum_symbols
[
order
+
2
]
factor_syms
+=
[
out_syms
[
-
1
]
+
rank_sym
]
elif
operator_type
==
'
vector
'
:
elif
operator_type
==
'
driscoll-healy
'
:
factor_syms
.
pop
()
else
:
raise
ValueError
(
f
"Unkonw operator type
{
operator_type
}
"
)
...
...
@@ -148,7 +148,7 @@ def _contract_tt(x, tt_weight, separable=False, operator_type='diagonal'):
elif
operator_type
==
'block-diagonal'
:
weight_syms
.
insert
(
-
1
,
einsum_symbols
[
order
+
1
])
out_syms
[
-
1
]
=
weight_syms
[
-
2
]
elif
operator_type
==
'
vector
'
:
elif
operator_type
==
'
driscoll-healy
'
:
weight_syms
.
pop
()
else
:
raise
ValueError
(
f
"Unkonw operator type
{
operator_type
}
"
)
...
...
torch_harmonics/examples/sfno/models/layers.py
View file @
7fb0c483
...
...
@@ -40,8 +40,6 @@ from torch_harmonics import *
from
.contractions
import
*
from
.activations
import
*
from
.factorizations
import
get_contract_fun
# # import FactorizedTensor from tensorly for tensorized operations
# import tensorly as tl
# from tensorly.plugins import use_opt_einsum
...
...
@@ -207,7 +205,7 @@ class InverseRealFFT2(nn.Module):
def
forward
(
self
,
x
):
return
torch
.
fft
.
irfft2
(
x
,
dim
=
(
-
2
,
-
1
),
s
=
(
self
.
nlat
,
self
.
nlon
),
norm
=
"ortho"
)
class
SpectralConvS2
(
nn
.
Module
):
"""
Spectral Convolution according to Driscoll & Healy. Designed for convolutions on the two-sphere S2
...
...
@@ -221,7 +219,95 @@ class SpectralConvS2(nn.Module):
in_channels
,
out_channels
,
scale
=
'auto'
,
operator_type
=
'diagonal'
,
operator_type
=
'driscoll-healy'
,
lr_scale_exponent
=
0
,
bias
=
False
):
super
(
SpectralConvS2
,
self
).
__init__
()
if
scale
==
'auto'
:
scale
=
(
2
/
in_channels
)
**
0.5
self
.
forward_transform
=
forward_transform
self
.
inverse_transform
=
inverse_transform
self
.
modes_lat
=
self
.
inverse_transform
.
lmax
self
.
modes_lon
=
self
.
inverse_transform
.
mmax
self
.
scale_residual
=
(
self
.
forward_transform
.
nlat
!=
self
.
inverse_transform
.
nlat
)
\
or
(
self
.
forward_transform
.
nlon
!=
self
.
inverse_transform
.
nlon
)
# remember factorization details
self
.
operator_type
=
operator_type
assert
self
.
inverse_transform
.
lmax
==
self
.
modes_lat
assert
self
.
inverse_transform
.
mmax
==
self
.
modes_lon
weight_shape
=
[
in_channels
,
out_channels
]
if
self
.
operator_type
==
'diagonal'
:
weight_shape
+=
[
self
.
modes_lat
,
self
.
modes_lon
]
from
.contractions
import
contract_diagonal
as
_contract
elif
self
.
operator_type
==
'block-diagonal'
:
weight_shape
+=
[
self
.
modes_lat
,
self
.
modes_lon
,
self
.
modes_lon
]
from
.contractions
import
contract_blockdiag
as
_contract
elif
self
.
operator_type
==
'driscoll-healy'
:
weight_shape
+=
[
self
.
modes_lat
]
from
.contractions
import
contract_dhconv
as
_contract
else
:
raise
NotImplementedError
(
f
"Unkonw operator type f
{
self
.
operator_type
}
"
)
# form weight tensors
self
.
weight
=
nn
.
Parameter
(
scale
*
torch
.
randn
(
*
weight_shape
,
2
))
# rescale the learning rate for better training of spectral parameters
lr_scale
=
(
torch
.
arange
(
self
.
modes_lat
)
+
1
).
reshape
(
-
1
,
1
)
**
(
lr_scale_exponent
)
self
.
register_buffer
(
"lr_scale"
,
lr_scale
)
# self.weight.register_hook(lambda grad: self.lr_scale*grad)
# get the right contraction function
self
.
_contract
=
_contract
if
bias
:
self
.
bias
=
nn
.
Parameter
(
scale
*
torch
.
randn
(
1
,
out_channels
,
1
,
1
))
def
forward
(
self
,
x
):
dtype
=
x
.
dtype
x
=
x
.
float
()
residual
=
x
with
amp
.
autocast
(
enabled
=
False
):
x
=
self
.
forward_transform
(
x
)
if
self
.
scale_residual
:
residual
=
self
.
inverse_transform
(
x
)
x
=
torch
.
view_as_real
(
x
)
x
=
self
.
_contract
(
x
,
self
.
weight
)
x
=
torch
.
view_as_complex
(
x
)
with
amp
.
autocast
(
enabled
=
False
):
x
=
self
.
inverse_transform
(
x
)
if
hasattr
(
self
,
'bias'
):
x
=
x
+
self
.
bias
x
=
x
.
type
(
dtype
)
return
x
,
residual
class
FactorizedSpectralConvS2
(
nn
.
Module
):
"""
Factorized version of SpectralConvS2. Uses tensorly-torch to keep the weights factorized
"""
def
__init__
(
self
,
forward_transform
,
inverse_transform
,
in_channels
,
out_channels
,
scale
=
'auto'
,
operator_type
=
'driscoll-healy'
,
rank
=
0.2
,
factorization
=
None
,
separable
=
False
,
...
...
@@ -231,7 +317,7 @@ class SpectralConvS2(nn.Module):
super
(
SpectralConvS2
,
self
).
__init__
()
if
scale
==
'auto'
:
scale
=
(
1
/
(
in_channels
*
out_channels
))
scale
=
(
2
/
in_channels
)
**
0.5
self
.
forward_transform
=
forward_transform
self
.
inverse_transform
=
inverse_transform
...
...
@@ -266,7 +352,7 @@ class SpectralConvS2(nn.Module):
weight_shape
+=
[
self
.
modes_lat
,
self
.
modes_lon
]
elif
self
.
operator_type
==
'block-diagonal'
:
weight_shape
+=
[
self
.
modes_lat
,
self
.
modes_lon
,
self
.
modes_lon
]
elif
self
.
operator_type
==
'
vector
'
:
elif
self
.
operator_type
==
'
driscoll-healy
'
:
weight_shape
+=
[
self
.
modes_lat
]
else
:
raise
NotImplementedError
(
f
"Unkonw operator type f
{
self
.
operator_type
}
"
)
...
...
@@ -278,6 +364,8 @@ class SpectralConvS2(nn.Module):
# initialization of weights
self
.
weight
.
normal_
(
0
,
scale
)
# get the right contraction function
from
.factorizations
import
get_contract_fun
self
.
_contract
=
get_contract_fun
(
self
.
weight
,
implementation
=
implementation
,
separable
=
separable
)
if
bias
:
...
...
@@ -289,7 +377,6 @@ class SpectralConvS2(nn.Module):
dtype
=
x
.
dtype
x
=
x
.
float
()
residual
=
x
B
,
C
,
H
,
W
=
x
.
shape
with
amp
.
autocast
(
enabled
=
False
):
x
=
self
.
forward_transform
(
x
)
...
...
@@ -467,7 +554,7 @@ class SpectralAttentionS2(nn.Module):
for
l
in
range
(
0
,
self
.
spectral_layers
):
self
.
activations
.
append
(
ComplexReLU
(
mode
=
complex_activation
,
bias_shape
=
(
hidden_size
,
1
,
1
),
scale
=
self
.
scale
))
elif
operator_type
==
'
vector
'
:
elif
operator_type
==
'
driscoll-healy
'
:
self
.
mul_add_handle
=
compl_exp_muladd2d_fwd
self
.
mul_handle
=
compl_exp_mul2d_fwd
...
...
torch_harmonics/examples/sfno/models/sfno.py
View file @
7fb0c483
...
...
@@ -48,20 +48,21 @@ class SpectralFilterLayer(nn.Module):
forward_transform
,
inverse_transform
,
embed_dim
,
filter_type
=
'
non-linear
'
,
operator_type
=
'
diagonal
'
,
filter_type
=
"
non-linear
"
,
operator_type
=
"
diagonal
"
,
sparsity_threshold
=
0.0
,
use_complex_kernels
=
True
,
hidden_size_factor
=
2
,
lr_scale_exponent
=
0
,
factorization
=
None
,
separable
=
False
,
rank
=
1e-2
,
complex_activation
=
'
real
'
,
complex_activation
=
"
real
"
,
spectral_layers
=
1
,
drop_rate
=
0
):
super
(
SpectralFilterLayer
,
self
).
__init__
()
if
filter_type
==
'
non-linear
'
and
isinstance
(
forward_transform
,
RealSHT
):
if
filter_type
==
"
non-linear
"
and
isinstance
(
forward_transform
,
RealSHT
):
self
.
filter
=
SpectralAttentionS2
(
forward_transform
,
inverse_transform
,
embed_dim
,
...
...
@@ -73,7 +74,7 @@ class SpectralFilterLayer(nn.Module):
drop_rate
=
drop_rate
,
bias
=
False
)
elif
filter_type
==
'
non-linear
'
and
isinstance
(
forward_transform
,
RealFFT2
):
elif
filter_type
==
"
non-linear
"
and
isinstance
(
forward_transform
,
RealFFT2
):
self
.
filter
=
SpectralAttention2d
(
forward_transform
,
inverse_transform
,
embed_dim
,
...
...
@@ -85,16 +86,25 @@ class SpectralFilterLayer(nn.Module):
drop_rate
=
drop_rate
,
bias
=
False
)
elif
filter_type
==
'
linear
'
:
elif
filter_type
==
"
linear
"
and
factorization
is
None
:
self
.
filter
=
SpectralConvS2
(
forward_transform
,
inverse_transform
,
embed_dim
,
embed_dim
,
operator_type
=
operator_type
,
rank
=
rank
,
factorization
=
factorization
,
separable
=
separable
,
lr_scale_exponent
=
lr_scale_exponent
,
bias
=
True
)
elif
filter_type
==
"linear"
and
factorization
is
not
None
:
self
.
filter
=
FactorizedSpectralConvS2
(
forward_transform
,
inverse_transform
,
embed_dim
,
embed_dim
,
operator_type
=
operator_type
,
rank
=
rank
,
factorization
=
factorization
,
separable
=
separable
,
bias
=
True
)
else
:
raise
(
NotImplementedError
)
...
...
@@ -111,29 +121,27 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
forward_transform
,
inverse_transform
,
embed_dim
,
filter_type
=
'
non-linear
'
,
operator_type
=
'diagonal'
,
filter_type
=
"
non-linear
"
,
operator_type
=
"driscoll-healy"
,
mlp_ratio
=
2.
,
drop_rate
=
0.
,
drop_path
=
0.
,
act_layer
=
nn
.
GELU
,
norm_layer
=
(
nn
.
LayerNorm
,
nn
.
LayerNorm
)
,
norm_layer
=
nn
.
Identity
,
sparsity_threshold
=
0.0
,
use_complex_kernels
=
True
,
lr_scale_exponent
=
0
,
factorization
=
None
,
separable
=
False
,
rank
=
128
,
inner_skip
=
'
linear
'
,
outer_skip
=
None
,
# None, nn.linear or nn.Identity
inner_skip
=
"
linear
"
,
outer_skip
=
None
,
concat_skip
=
False
,
use_mlp
=
True
,
complex_activation
=
'
real
'
,
complex_activation
=
"
real
"
,
spectral_layers
=
3
):
super
(
SphericalFourierNeuralOperatorBlock
,
self
).
__init__
()
# norm layer
self
.
norm0
=
norm_layer
[
0
]()
#((h,w))
# convolution layer
self
.
filter
=
SpectralFilterLayer
(
forward_transform
,
inverse_transform
,
...
...
@@ -143,6 +151,7 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
sparsity_threshold
=
sparsity_threshold
,
use_complex_kernels
=
use_complex_kernels
,
hidden_size_factor
=
mlp_ratio
,
lr_scale_exponent
=
lr_scale_exponent
,
factorization
=
factorization
,
separable
=
separable
,
rank
=
rank
,
...
...
@@ -150,24 +159,28 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
spectral_layers
=
spectral_layers
,
drop_rate
=
drop_rate
)
if
inner_skip
==
'
linear
'
:
if
inner_skip
==
"
linear
"
:
self
.
inner_skip
=
nn
.
Conv2d
(
embed_dim
,
embed_dim
,
1
,
1
)
elif
inner_skip
==
'
identity
'
:
elif
inner_skip
==
"
identity
"
:
self
.
inner_skip
=
nn
.
Identity
()
elif
inner_skip
==
"none"
:
pass
else
:
raise
ValueError
(
f
"Unknown skip connection type
{
inner_skip
}
"
)
self
.
concat_skip
=
concat_skip
if
concat_skip
and
inner_skip
is
not
None
:
self
.
inner_skip_conv
=
nn
.
Conv2d
(
2
*
embed_dim
,
embed_dim
,
1
,
bias
=
False
)
if
filter_type
==
'
linear
'
or
filter_type
==
'local'
:
if
filter_type
==
"
linear
"
:
self
.
act_layer
=
act_layer
()
# first normalisation layer
self
.
norm0
=
norm_layer
()
# dropout
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.
else
nn
.
Identity
()
# norm layer
self
.
norm1
=
norm_layer
[
1
]()
#((h,w))
if
use_mlp
==
True
:
mlp_hidden_dim
=
int
(
embed_dim
*
mlp_ratio
)
...
...
@@ -177,44 +190,51 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
drop_rate
=
drop_rate
,
checkpointing
=
False
)
if
outer_skip
==
'
linear
'
:
if
outer_skip
==
"
linear
"
:
self
.
outer_skip
=
nn
.
Conv2d
(
embed_dim
,
embed_dim
,
1
,
1
)
elif
outer_skip
==
'
identity
'
:
elif
outer_skip
==
"
identity
"
:
self
.
outer_skip
=
nn
.
Identity
()
elif
outer_skip
==
"none"
:
pass
else
:
raise
ValueError
(
f
"Unknown skip connection type
{
outer_skip
}
"
)
if
concat_skip
and
outer_skip
is
not
None
:
self
.
outer_skip_conv
=
nn
.
Conv2d
(
2
*
embed_dim
,
embed_dim
,
1
,
bias
=
False
)
# second normalisation layer
self
.
norm1
=
norm_layer
()
def
forward
(
self
,
x
):
x
=
self
.
norm0
(
x
)
x
,
residual
=
self
.
filter
(
x
)
if
hasattr
(
self
,
'
inner_skip
'
):
if
hasattr
(
self
,
"
inner_skip
"
):
if
self
.
concat_skip
:
x
=
torch
.
cat
((
x
,
self
.
inner_skip
(
residual
)),
dim
=
1
)
x
=
self
.
inner_skip_conv
(
x
)
else
:
x
=
x
+
self
.
inner_skip
(
residual
)
if
hasattr
(
self
,
'
act_layer
'
):
if
hasattr
(
self
,
"
act_layer
"
):
x
=
self
.
act_layer
(
x
)
x
=
self
.
norm
1
(
x
)
x
=
self
.
norm
0
(
x
)
if
hasattr
(
self
,
'
mlp
'
):
if
hasattr
(
self
,
"
mlp
"
):
x
=
self
.
mlp
(
x
)
x
=
self
.
drop_path
(
x
)
if
hasattr
(
self
,
'
outer_skip
'
):
if
hasattr
(
self
,
"
outer_skip
"
):
if
self
.
concat_skip
:
x
=
torch
.
cat
((
x
,
self
.
outer_skip
(
residual
)),
dim
=
1
)
x
=
self
.
outer_skip_conv
(
x
)
else
:
x
=
x
+
self
.
outer_skip
(
residual
)
x
=
self
.
norm1
(
x
)
return
x
class
SphericalFourierNeuralOperatorNet
(
nn
.
Module
):
...
...
@@ -229,7 +249,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
spectral_transform : str, optional
Type of spectral transformation to use, by default "sht"
operator_type : str, optional
Type of operator to use ('
vector
', 'diagonal'), by default "
vector
"
Type of operator to use ('
driscoll-healy
', 'diagonal'), by default "
driscoll-healy
"
img_shape : tuple, optional
Shape of the input channels, by default (128, 256)
scale_factor : int, optional
...
...
@@ -247,7 +267,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
encoder_layers : int, optional
Number of layers in the encoder, by default 1
use_mlp : int, optional
Whether to use MLP, by default True
Whether to use MLP
s in the SFNO blocks
, by default True
mlp_ratio : int, optional
Ratio of MLP to use, by default 2.0
drop_rate : float, optional
...
...
@@ -266,6 +286,8 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
Whether to add a single large skip connection, by default True
rank : float, optional
Rank of the approximation, by default 1.0
lr_scale_exponent : float, optional
exponential rescaling of spectral coefficients, by default 0.0 (no rescaling)
factorization : Any, optional
Type of factorization to use, by default None
separable : bool, optional
...
...
@@ -287,10 +309,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
... in_chans=2,
... out_chans=2,
... embed_dim=16,
... num_layers=2,
... encoder_layers=1,
... num_blocks=4,
... spectral_layers=2,
... num_layers=4,
... use_mlp=True,)
>>> model(torch.randn(1, 2, 128, 256)).shape
torch.Size([1, 2, 128, 256])
...
...
@@ -298,30 +317,31 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
def
__init__
(
self
,
filter_type
=
'
linear
'
,
spectral_transform
=
'
sht
'
,
operator_type
=
'vector'
,
filter_type
=
"
linear
"
,
spectral_transform
=
"
sht
"
,
operator_type
=
"driscoll-healy"
,
img_size
=
(
128
,
256
),
scale_factor
=
3
,
in_chans
=
3
,
out_chans
=
3
,
embed_dim
=
256
,
num_layers
=
4
,
activation_function
=
'
gelu
'
,
activation_function
=
"
gelu
"
,
encoder_layers
=
1
,
use_mlp
=
True
,
mlp_ratio
=
2.
,
drop_rate
=
0.
,
drop_path_rate
=
0.
,
sparsity_threshold
=
0.0
,
normalization_layer
=
'instance_norm'
,
normalization_layer
=
"none"
,
hard_thresholding_fraction
=
1.0
,
use_complex_kernels
=
True
,
big_skip
=
True
,
lr_scale_exponent
=
0
,
factorization
=
None
,
separable
=
False
,
rank
=
128
,
complex_activation
=
'
real
'
,
complex_activation
=
"
real
"
,
spectral_layers
=
2
,
pos_embed
=
True
):
...
...
@@ -342,6 +362,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
self
.
use_mlp
=
use_mlp
self
.
encoder_layers
=
encoder_layers
self
.
big_skip
=
big_skip
self
.
lr_scale_exponent
=
lr_scale_exponent
self
.
factorization
=
factorization
self
.
separable
=
separable
,
self
.
rank
=
rank
...
...
@@ -349,9 +370,9 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
self
.
spectral_layers
=
spectral_layers
# activation function
if
activation_function
==
'
relu
'
:
if
activation_function
==
"
relu
"
:
self
.
activation_function
=
nn
.
ReLU
elif
activation_function
==
'
gelu
'
:
elif
activation_function
==
"
gelu
"
:
self
.
activation_function
=
nn
.
GELU
else
:
raise
ValueError
(
f
"Unknown activation function
{
activation_function
}
"
)
...
...
@@ -383,28 +404,28 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
self
.
pos_embed
=
None
# encoder
encoder_hidden_dim
=
self
.
embed_dim
current_dim
=
self
.
in_chans
encoder_modules
=
[]
for
i
in
range
(
self
.
encoder_layers
):
encoder_modules
.
append
(
nn
.
Conv2d
(
current_dim
,
encoder_hidden_dim
,
1
,
bias
=
True
))
encoder_modules
.
append
(
self
.
activation_function
())
current_dim
=
encoder_hidden_dim
encoder
_modules
.
append
(
nn
.
Conv2d
(
current_dim
,
self
.
embed_dim
,
1
,
bias
=
False
))
self
.
encoder
=
nn
.
Sequential
(
*
encoder
_modules
)
encoder_hidden_dim
=
int
(
self
.
embed_dim
*
mlp_ratio
)
encoder
=
MLP
(
in_features
=
self
.
in_chans
,
out_features
=
self
.
embed_dim
,
hidden_features
=
encoder_hidden_dim
,
act_layer
=
self
.
activation_function
,
drop_rate
=
drop_rate
,
checkpointing
=
False
)
self
.
encoder
=
encoder
#
self.encoder = nn.Sequential(encoder
, norm_layer0()
)
# prepare the spectral transform
if
self
.
spectral_transform
==
'
sht
'
:
if
self
.
spectral_transform
==
"
sht
"
:
modes_lat
=
int
(
self
.
h
*
self
.
hard_thresholding_fraction
)
modes_lon
=
int
((
self
.
w
//
2
+
1
)
*
self
.
hard_thresholding_fraction
)
self
.
trans_down
=
RealSHT
(
*
self
.
img_size
,
lmax
=
modes_lat
,
mmax
=
modes_lon
,
grid
=
'
equiangular
'
).
float
()
self
.
itrans_up
=
InverseRealSHT
(
*
self
.
img_size
,
lmax
=
modes_lat
,
mmax
=
modes_lon
,
grid
=
'
equiangular
'
).
float
()
self
.
trans
=
RealSHT
(
self
.
h
,
self
.
w
,
lmax
=
modes_lat
,
mmax
=
modes_lon
,
grid
=
'
legendre-gauss
'
).
float
()
self
.
itrans
=
InverseRealSHT
(
self
.
h
,
self
.
w
,
lmax
=
modes_lat
,
mmax
=
modes_lon
,
grid
=
'
legendre-gauss
'
).
float
()
self
.
trans_down
=
RealSHT
(
*
self
.
img_size
,
lmax
=
modes_lat
,
mmax
=
modes_lon
,
grid
=
"
equiangular
"
).
float
()
self
.
itrans_up
=
InverseRealSHT
(
*
self
.
img_size
,
lmax
=
modes_lat
,
mmax
=
modes_lon
,
grid
=
"
equiangular
"
).
float
()
self
.
trans
=
RealSHT
(
self
.
h
,
self
.
w
,
lmax
=
modes_lat
,
mmax
=
modes_lon
,
grid
=
"
legendre-gauss
"
).
float
()
self
.
itrans
=
InverseRealSHT
(
self
.
h
,
self
.
w
,
lmax
=
modes_lat
,
mmax
=
modes_lon
,
grid
=
"
legendre-gauss
"
).
float
()
elif
self
.
spectral_transform
==
'
fft
'
:
elif
self
.
spectral_transform
==
"
fft
"
:
modes_lat
=
int
(
self
.
h
*
self
.
hard_thresholding_fraction
)
modes_lon
=
int
((
self
.
w
//
2
+
1
)
*
self
.
hard_thresholding_fraction
)
...
...
@@ -415,7 +436,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
self
.
itrans
=
InverseRealFFT2
(
self
.
h
,
self
.
w
,
lmax
=
modes_lat
,
mmax
=
modes_lon
).
float
()
else
:
raise
(
ValueError
(
'
Unknown spectral transform
'
))
raise
(
ValueError
(
"
Unknown spectral transform
"
))
self
.
blocks
=
nn
.
ModuleList
([])
for
i
in
range
(
self
.
num_layers
):
...
...
@@ -430,11 +451,11 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
outer_skip
=
'identity'
if
first_layer
:
norm_layer
=
(
norm_layer0
,
norm_layer1
)
norm_layer
=
norm_layer1
elif
last_layer
:
norm_layer
=
(
norm_layer1
,
norm_layer0
)
norm_layer
=
norm_layer0
else
:
norm_layer
=
(
norm_layer1
,
norm_layer1
)
norm_layer
=
norm_layer1
block
=
SphericalFourierNeuralOperatorBlock
(
forward_transform
,
inverse_transform
,
...
...
@@ -451,6 +472,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
inner_skip
=
inner_skip
,
outer_skip
=
outer_skip
,
use_mlp
=
use_mlp
,
lr_scale_exponent
=
self
.
lr_scale_exponent
,
factorization
=
self
.
factorization
,
separable
=
self
.
separable
,
rank
=
self
.
rank
,
...
...
@@ -460,15 +482,13 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
self
.
blocks
.
append
(
block
)
# decoder
decoder_hidden_dim
=
self
.
embed_dim
current_dim
=
self
.
embed_dim
+
self
.
big_skip
*
self
.
in_chans
decoder_modules
=
[]
for
i
in
range
(
self
.
encoder_layers
):
decoder_modules
.
append
(
nn
.
Conv2d
(
current_dim
,
decoder_hidden_dim
,
1
,
bias
=
True
))
decoder_modules
.
append
(
self
.
activation_function
())
current_dim
=
decoder_hidden_dim
decoder_modules
.
append
(
nn
.
Conv2d
(
current_dim
,
self
.
out_chans
,
1
,
bias
=
False
))
self
.
decoder
=
nn
.
Sequential
(
*
decoder_modules
)
encoder_hidden_dim
=
int
(
self
.
embed_dim
*
mlp_ratio
)
self
.
decoder
=
MLP
(
in_features
=
self
.
embed_dim
+
self
.
big_skip
*
self
.
in_chans
,
out_features
=
self
.
out_chans
,
hidden_features
=
encoder_hidden_dim
,
act_layer
=
self
.
activation_function
,
drop_rate
=
drop_rate
,
checkpointing
=
False
)
# trunc_normal_(self.pos_embed, std=.02)
self
.
apply
(
self
.
_init_weights
)
...
...
@@ -482,7 +502,7 @@ class SphericalFourierNeuralOperatorNet(nn.Module):
@
torch
.
jit
.
ignore
def
no_weight_decay
(
self
):
return
{
'
pos_embed
'
,
'
cls_token
'
}
return
{
"
pos_embed
"
,
"
cls_token
"
}
def
forward_features
(
self
,
x
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment