Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
0c067c86
Commit
0c067c86
authored
Dec 17, 2024
by
Boris Bonev
Committed by
Boris Bonev
Jan 14, 2025
Browse files
intermediate release with reworked normalization of S2 convolutions
parent
34927a33
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
119 additions
and
35 deletions
+119
-35
Changelog.md
Changelog.md
+2
-0
torch_harmonics/convolution.py
torch_harmonics/convolution.py
+73
-20
torch_harmonics/distributed/distributed_convolution.py
torch_harmonics/distributed/distributed_convolution.py
+36
-8
torch_harmonics/examples/models/lsno.py
torch_harmonics/examples/models/lsno.py
+2
-2
torch_harmonics/filter_basis.py
torch_harmonics/filter_basis.py
+6
-5
No files found.
Changelog.md
View file @
0c067c86
...
...
@@ -8,8 +8,10 @@
*
Hotfix to the numpy version requirements
*
Changing default grid in all SHT routines to
`equiangular`
, which makes it consistent with DISCO convolutions
*
Cleaning up the SFNO example and adding new Local Spherical Neural Operator model
*
New filter basis normalization in DISCO convolutions
*
Reworked DISCO filter basis datastructure
*
Support for new filter basis types
*
Adding Morlet-like basis functions on a spherical disk
### v0.7.2
...
...
torch_harmonics/convolution.py
View file @
0c067c86
...
...
@@ -56,9 +56,11 @@ except ImportError as err:
_cuda_extension_available
=
False
def
_normalize_convolution_tensor_s2
(
psi_idx
,
psi_vals
,
in_shape
,
out_shape
,
kernel_size
,
quad_weights
,
transpose_normalization
=
False
,
merge_quadrature
=
False
,
eps
=
1e-9
):
def
_normalize_convolution_tensor_s2
(
psi_idx
,
psi_vals
,
in_shape
,
out_shape
,
kernel_size
,
quad_weights
,
transpose_normalization
=
False
,
basis_norm_mode
=
"sum"
,
merge_quadrature
=
False
,
eps
=
1e-9
):
"""
Discretely normalizes the convolution tensor.
Discretely normalizes the convolution tensor.
Supports different normalization modes
"""
nlat_in
,
nlon_in
=
in_shape
...
...
@@ -74,10 +76,21 @@ def _normalize_convolution_tensor_s2(psi_idx, psi_vals, in_shape, out_shape, ker
# loop through dimensions which require normalization
for
ik
in
range
(
kernel_size
):
for
ilat
in
range
(
nlat_in
):
# get relevant entries
# get relevant entries depending on the normalization mode
if
basis_norm_mode
==
"individual"
:
iidx
=
torch
.
argwhere
((
idx
[
0
]
==
ik
)
&
(
idx
[
2
]
==
ilat
))
# normalize, while summing also over the input longitude dimension here as this is not available for the output
vnorm
=
torch
.
sum
(
psi_vals
[
iidx
]
*
q
[
iidx
])
vnorm
=
torch
.
sum
(
psi_vals
[
iidx
].
abs
()
*
q
[
iidx
])
elif
basis_norm_mode
==
"sum"
:
# this will perform repeated normalization in the kernel dimension but this shouldn't lead to issues
iidx
=
torch
.
argwhere
(
idx
[
2
]
==
ilat
)
# normalize, while summing also over the input longitude dimension here as this is not available for the output
vnorm
=
torch
.
sum
(
psi_vals
[
iidx
].
abs
()
*
q
[
iidx
])
else
:
raise
ValueError
(
f
"Unknown basis normalization mode
{
basis_norm_mode
}
."
)
if
merge_quadrature
:
# the correction factor accounts for the difference in longitudinal grid points when the input vector is upscaled
psi_vals
[
iidx
]
=
psi_vals
[
iidx
]
*
q
[
iidx
]
*
nlon_in
/
nlon_out
/
(
vnorm
+
eps
)
...
...
@@ -90,10 +103,20 @@ def _normalize_convolution_tensor_s2(psi_idx, psi_vals, in_shape, out_shape, ker
# loop through dimensions which require normalization
for
ik
in
range
(
kernel_size
):
for
ilat
in
range
(
nlat_out
):
# get relevant entries
# get relevant entries depending on the normalization mode
if
basis_norm_mode
==
"individual"
:
iidx
=
torch
.
argwhere
((
idx
[
0
]
==
ik
)
&
(
idx
[
1
]
==
ilat
))
# normalize
vnorm
=
torch
.
sum
(
psi_vals
[
iidx
]
*
q
[
iidx
])
vnorm
=
torch
.
sum
(
psi_vals
[
iidx
].
abs
()
*
q
[
iidx
])
elif
basis_norm_mode
==
"sum"
:
# this will perform repeated normalization in the kernel dimension but this shouldn't lead to issues
iidx
=
torch
.
argwhere
(
idx
[
1
]
==
ilat
)
# normalize
vnorm
=
torch
.
sum
(
psi_vals
[
iidx
].
abs
()
*
q
[
iidx
])
else
:
raise
ValueError
(
f
"Unknown basis normalization mode
{
basis_norm_mode
}
."
)
if
merge_quadrature
:
psi_vals
[
iidx
]
=
psi_vals
[
iidx
]
*
q
[
iidx
]
/
(
vnorm
+
eps
)
else
:
...
...
@@ -110,6 +133,7 @@ def _precompute_convolution_tensor_s2(
grid_out
=
"equiangular"
,
theta_cutoff
=
0.01
*
math
.
pi
,
transpose_normalization
=
False
,
basis_norm_mode
=
"sum"
,
merge_quadrature
=
False
,
):
"""
...
...
@@ -136,6 +160,7 @@ def _precompute_convolution_tensor_s2(
nlat_in
,
nlon_in
=
in_shape
nlat_out
,
nlon_out
=
out_shape
# precompute input and output grids
lats_in
,
win
=
_precompute_latitudes
(
nlat_in
,
grid
=
grid_in
)
lats_in
=
torch
.
from_numpy
(
lats_in
).
float
()
lats_out
,
wout
=
_precompute_latitudes
(
nlat_out
,
grid
=
grid_out
)
...
...
@@ -145,6 +170,12 @@ def _precompute_convolution_tensor_s2(
# It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
lons_in
=
torch
.
linspace
(
0
,
2
*
math
.
pi
,
nlon_in
+
1
)[:
-
1
]
# compute quadrature weights that will be merged into the Psi tensor
if
transpose_normalization
:
quad_weights
=
2.0
*
torch
.
pi
*
torch
.
from_numpy
(
wout
).
float
().
reshape
(
-
1
,
1
)
/
nlon_in
else
:
quad_weights
=
2.0
*
torch
.
pi
*
torch
.
from_numpy
(
win
).
float
().
reshape
(
-
1
,
1
)
/
nlon_in
out_idx
=
[]
out_vals
=
[]
for
t
in
range
(
nlat_out
):
...
...
@@ -185,12 +216,16 @@ def _precompute_convolution_tensor_s2(
out_idx
=
torch
.
cat
(
out_idx
,
dim
=-
1
).
to
(
torch
.
long
).
contiguous
()
out_vals
=
torch
.
cat
(
out_vals
,
dim
=-
1
).
to
(
torch
.
float32
).
contiguous
()
if
transpose_normalization
:
quad_weights
=
2.0
*
torch
.
pi
*
torch
.
from_numpy
(
wout
).
float
().
reshape
(
-
1
,
1
)
/
nlon_in
else
:
quad_weights
=
2.0
*
torch
.
pi
*
torch
.
from_numpy
(
win
).
float
().
reshape
(
-
1
,
1
)
/
nlon_in
out_vals
=
_normalize_convolution_tensor_s2
(
out_idx
,
out_vals
,
in_shape
,
out_shape
,
kernel_size
,
quad_weights
,
transpose_normalization
=
transpose_normalization
,
merge_quadrature
=
merge_quadrature
out_idx
,
out_vals
,
in_shape
,
out_shape
,
kernel_size
,
quad_weights
,
transpose_normalization
=
transpose_normalization
,
basis_norm_mode
=
basis_norm_mode
,
merge_quadrature
=
merge_quadrature
,
)
return
out_idx
,
out_vals
...
...
@@ -198,7 +233,7 @@ def _precompute_convolution_tensor_s2(
class
DiscreteContinuousConv
(
nn
.
Module
,
metaclass
=
abc
.
ABCMeta
):
"""
Abstract base class for
DISCO
convolutions
Abstract base class for
discrete-continuous
convolutions
"""
def
__init__
(
...
...
@@ -245,7 +280,7 @@ class DiscreteContinuousConv(nn.Module, metaclass=abc.ABCMeta):
class
DiscreteContinuousConvS2
(
DiscreteContinuousConv
):
"""
Discrete-continuous convolutions
(DISCO)
on the 2-Sphere as described in [1].
Discrete-continuous
(DISCO)
convolutions on the 2-Sphere as described in [1].
[1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
"""
...
...
@@ -258,6 +293,7 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
out_shape
:
Tuple
[
int
],
kernel_shape
:
Union
[
int
,
List
[
int
]],
basis_type
:
Optional
[
str
]
=
"piecewise linear"
,
basis_norm_mode
:
Optional
[
str
]
=
"sum"
,
groups
:
Optional
[
int
]
=
1
,
grid_in
:
Optional
[
str
]
=
"equiangular"
,
grid_out
:
Optional
[
str
]
=
"equiangular"
,
...
...
@@ -277,7 +313,15 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
raise
ValueError
(
"Error, theta_cutoff has to be positive."
)
idx
,
vals
=
_precompute_convolution_tensor_s2
(
in_shape
,
out_shape
,
self
.
filter_basis
,
grid_in
=
grid_in
,
grid_out
=
grid_out
,
theta_cutoff
=
theta_cutoff
,
transpose_normalization
=
False
,
merge_quadrature
=
True
in_shape
,
out_shape
,
self
.
filter_basis
,
grid_in
=
grid_in
,
grid_out
=
grid_out
,
theta_cutoff
=
theta_cutoff
,
transpose_normalization
=
False
,
basis_norm_mode
=
basis_norm_mode
,
merge_quadrature
=
True
,
)
# sort the values
...
...
@@ -339,7 +383,7 @@ class DiscreteContinuousConvS2(DiscreteContinuousConv):
class
DiscreteContinuousConvTransposeS2
(
DiscreteContinuousConv
):
"""
Discrete-continuous transpose convolutions
(DISCO)
on the 2-Sphere as described in [1].
Discrete-continuous
(DISCO)
transpose convolutions on the 2-Sphere as described in [1].
[1] Ocampo, Price, McEwen, Scalable and equivariant spherical CNNs by discrete-continuous (DISCO) convolutions, ICLR (2023), arXiv:2209.13603
"""
...
...
@@ -352,6 +396,7 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
out_shape
:
Tuple
[
int
],
kernel_shape
:
Union
[
int
,
List
[
int
]],
basis_type
:
Optional
[
str
]
=
"piecewise linear"
,
basis_norm_mode
:
Optional
[
str
]
=
"sum"
,
groups
:
Optional
[
int
]
=
1
,
grid_in
:
Optional
[
str
]
=
"equiangular"
,
grid_out
:
Optional
[
str
]
=
"equiangular"
,
...
...
@@ -372,7 +417,15 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
# switch in_shape and out_shape since we want transpose conv
idx
,
vals
=
_precompute_convolution_tensor_s2
(
out_shape
,
in_shape
,
self
.
filter_basis
,
grid_in
=
grid_out
,
grid_out
=
grid_in
,
theta_cutoff
=
theta_cutoff
,
transpose_normalization
=
True
,
merge_quadrature
=
True
out_shape
,
in_shape
,
self
.
filter_basis
,
grid_in
=
grid_out
,
grid_out
=
grid_in
,
theta_cutoff
=
theta_cutoff
,
transpose_normalization
=
True
,
basis_norm_mode
=
basis_norm_mode
,
merge_quadrature
=
True
,
)
# sort the values
...
...
torch_harmonics/distributed/distributed_convolution.py
View file @
0c067c86
...
...
@@ -76,6 +76,7 @@ def _precompute_distributed_convolution_tensor_s2(
grid_out
=
"equiangular"
,
theta_cutoff
=
0.01
*
math
.
pi
,
transpose_normalization
=
False
,
basis_norm_mode
=
"sum"
,
merge_quadrature
=
False
,
):
"""
...
...
@@ -111,6 +112,12 @@ def _precompute_distributed_convolution_tensor_s2(
# It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
lons_in
=
torch
.
linspace
(
0
,
2
*
math
.
pi
,
nlon_in
+
1
)[:
-
1
]
# compute quadrature weights that will be merged into the Psi tensor
if
transpose_normalization
:
quad_weights
=
2.0
*
torch
.
pi
*
torch
.
from_numpy
(
wout
).
float
().
reshape
(
-
1
,
1
)
/
nlon_in
else
:
quad_weights
=
2.0
*
torch
.
pi
*
torch
.
from_numpy
(
win
).
float
().
reshape
(
-
1
,
1
)
/
nlon_in
out_idx
=
[]
out_vals
=
[]
for
t
in
range
(
nlat_out
):
...
...
@@ -151,13 +158,16 @@ def _precompute_distributed_convolution_tensor_s2(
out_idx
=
torch
.
cat
(
out_idx
,
dim
=-
1
).
to
(
torch
.
long
).
contiguous
()
out_vals
=
torch
.
cat
(
out_vals
,
dim
=-
1
).
to
(
torch
.
float32
).
contiguous
()
# perform the normalization over the entire psi matrix
if
transpose_normalization
:
quad_weights
=
2.0
*
torch
.
pi
*
torch
.
from_numpy
(
wout
).
float
().
reshape
(
-
1
,
1
)
/
nlon_in
else
:
quad_weights
=
2.0
*
torch
.
pi
*
torch
.
from_numpy
(
win
).
float
().
reshape
(
-
1
,
1
)
/
nlon_in
out_vals
=
_normalize_convolution_tensor_s2
(
out_idx
,
out_vals
,
in_shape
,
out_shape
,
kernel_size
,
quad_weights
,
transpose_normalization
=
transpose_normalization
,
merge_quadrature
=
merge_quadrature
out_idx
,
out_vals
,
in_shape
,
out_shape
,
kernel_size
,
quad_weights
,
transpose_normalization
=
transpose_normalization
,
basis_norm_mode
=
basis_norm_mode
,
merge_quadrature
=
merge_quadrature
,
)
# TODO: this part can be split off into it's own function
...
...
@@ -197,6 +207,7 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
out_shape
:
Tuple
[
int
],
kernel_shape
:
Union
[
int
,
List
[
int
]],
basis_type
:
Optional
[
str
]
=
"piecewise linear"
,
basis_norm_mode
:
Optional
[
str
]
=
"sum"
,
groups
:
Optional
[
int
]
=
1
,
grid_in
:
Optional
[
str
]
=
"equiangular"
,
grid_out
:
Optional
[
str
]
=
"equiangular"
,
...
...
@@ -236,7 +247,15 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
self
.
nlat_out_local
=
self
.
nlat_out
idx
,
vals
=
_precompute_distributed_convolution_tensor_s2
(
in_shape
,
out_shape
,
self
.
filter_basis
,
grid_in
=
grid_in
,
grid_out
=
grid_out
,
theta_cutoff
=
theta_cutoff
,
transpose_normalization
=
False
,
merge_quadrature
=
True
in_shape
,
out_shape
,
self
.
filter_basis
,
grid_in
=
grid_in
,
grid_out
=
grid_out
,
theta_cutoff
=
theta_cutoff
,
transpose_normalization
=
False
,
basis_norm_mode
=
basis_norm_mode
,
merge_quadrature
=
True
,
)
# sort the values
...
...
@@ -328,6 +347,7 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
out_shape
:
Tuple
[
int
],
kernel_shape
:
Union
[
int
,
List
[
int
]],
basis_type
:
Optional
[
str
]
=
"piecewise linear"
,
basis_norm_mode
:
Optional
[
str
]
=
"sum"
,
groups
:
Optional
[
int
]
=
1
,
grid_in
:
Optional
[
str
]
=
"equiangular"
,
grid_out
:
Optional
[
str
]
=
"equiangular"
,
...
...
@@ -369,7 +389,15 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
# switch in_shape and out_shape since we want transpose conv
# distributed mode here is swapped because of the transpose
idx
,
vals
=
_precompute_distributed_convolution_tensor_s2
(
out_shape
,
in_shape
,
self
.
filter_basis
,
grid_in
=
grid_out
,
grid_out
=
grid_in
,
theta_cutoff
=
theta_cutoff
,
transpose_normalization
=
True
,
merge_quadrature
=
True
out_shape
,
in_shape
,
self
.
filter_basis
,
grid_in
=
grid_out
,
grid_out
=
grid_in
,
theta_cutoff
=
theta_cutoff
,
transpose_normalization
=
True
,
basis_norm_mode
=
basis_norm_mode
,
merge_quadrature
=
True
,
)
# sort the values
...
...
torch_harmonics/examples/models/lsno.py
View file @
0c067c86
...
...
@@ -69,7 +69,7 @@ class DiscreteContinuousEncoder(nn.Module):
grid_out
=
grid_out
,
groups
=
groups
,
bias
=
bias
,
theta_cutoff
=
4
*
math
.
sqrt
(
2
)
*
torch
.
pi
/
float
(
out_shape
[
0
]
-
1
),
theta_cutoff
=
math
.
sqrt
(
2
)
*
torch
.
pi
/
float
(
out_shape
[
0
]
-
1
),
)
def
forward
(
self
,
x
):
...
...
@@ -115,7 +115,7 @@ class DiscreteContinuousDecoder(nn.Module):
grid_out
=
grid_out
,
groups
=
groups
,
bias
=
False
,
theta_cutoff
=
4
*
math
.
sqrt
(
2
)
*
torch
.
pi
/
float
(
inp_shape
[
0
]
-
1
),
theta_cutoff
=
math
.
sqrt
(
2
)
*
torch
.
pi
/
float
(
inp_shape
[
0
]
-
1
),
)
# self.convt = nn.Conv2d(inp_chans, out_chans, 1, bias=False)
...
...
torch_harmonics/filter_basis.py
View file @
0c067c86
...
...
@@ -47,7 +47,7 @@ def get_filter_basis(kernel_shape: Union[int, List[int], Tuple[int, int]], basis
raise
ValueError
(
f
"Unknown basis_type
{
basis_type
}
"
)
class
Abstract
FilterBasis
(
metaclass
=
abc
.
ABCMeta
):
class
FilterBasis
(
metaclass
=
abc
.
ABCMeta
):
"""
Abstract base class for a filter basis
"""
...
...
@@ -72,7 +72,7 @@ class AbstractFilterBasis(metaclass=abc.ABCMeta):
raise
NotImplementedError
class
PiecewiseLinearFilterBasis
(
Abstract
FilterBasis
):
class
PiecewiseLinearFilterBasis
(
FilterBasis
):
"""
Tensor-product basis on a disk constructed from piecewise linear basis functions.
"""
...
...
@@ -190,7 +190,7 @@ class PiecewiseLinearFilterBasis(AbstractFilterBasis):
else
:
return
self
.
_compute_support_vals_isotropic
(
r
,
phi
,
r_cutoff
=
r_cutoff
)
class
DiskMorletFilterBasis
(
Abstract
FilterBasis
):
class
DiskMorletFilterBasis
(
FilterBasis
):
"""
Morlet-like Filter basis. A Gaussian is multiplied with a Fourier basis in x and y.
"""
...
...
@@ -228,7 +228,8 @@ class DiskMorletFilterBasis(AbstractFilterBasis):
iidx
=
torch
.
argwhere
((
r
<=
r_cutoff
)
&
torch
.
full_like
(
ikernel
,
True
,
dtype
=
torch
.
bool
))
# # computes the Gaussian envelope. To ensure that the curve is roughly 0 at the boundary, we rescale the Gaussian by 0.25
width
=
0.01
# width = 0.01
width
=
0.25
# width = 1.0
# envelope = self._gaussian_envelope(r, width=0.25 * r_cutoff)
...
...
@@ -245,7 +246,7 @@ class DiskMorletFilterBasis(AbstractFilterBasis):
# computes the Gaussian envelope. To ensure that the curve is roughly 0 at the boundary, we rescale the Gaussian by 0.25
vals
=
self
.
_gaussian_envelope
(
r
[
iidx
[:,
1
],
iidx
[:,
2
]]
/
r_cutoff
,
width
=
width
)
*
harmonic
[
iidx
[:,
0
],
iidx
[:,
1
],
iidx
[:,
2
]]
/
disk_area
vals
=
torch
.
ones_like
(
vals
)
#
vals = torch.ones_like(vals)
return
iidx
,
vals
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment