Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
54502a17
Unverified
Commit
54502a17
authored
Jan 30, 2024
by
Boris Bonev
Committed by
GitHub
Jan 30, 2024
Browse files
Bbonev/disco refactor (#29)
* Cleaned up DISCO convolutions
parent
c971d458
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
196 additions
and
117 deletions
+196
-117
Changelog.md
Changelog.md
+5
-4
torch_harmonics/__init__.py
torch_harmonics/__init__.py
+1
-1
torch_harmonics/convolution.py
torch_harmonics/convolution.py
+155
-104
torch_harmonics/quadrature.py
torch_harmonics/quadrature.py
+35
-8
No files found.
Changelog.md
View file @
54502a17
...
@@ -4,10 +4,11 @@
...
@@ -4,10 +4,11 @@
### v0.6.5
### v0.6.5
*
Discrrete-continuous (DISCO) convolutions on the sphere
*
Discrete-continuous (DISCO) convolutions on the sphere and in two dimensions
*
Isotropic and anisotropic DISCO convolutions
*
DISCO supports isotropic and anisotropic kernel functions parameterized as hat functions
*
Accelerated DISCO convolutions on GPU via Triton implementation
*
Supports regular and transpose convolutions
*
Unittests for DISCO convolutions
*
Accelerated spherical DISCO convolutions on GPU via Triton implementation
*
Unittests for DISCO convolutions in
`tests/test_convolution.py`
### v0.6.4
### v0.6.4
...
...
torch_harmonics/__init__.py
View file @
54502a17
...
@@ -29,7 +29,7 @@
...
@@ -29,7 +29,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#
__version__
=
'0.6.
4
'
__version__
=
'0.6.
5
'
from
.sht
import
RealSHT
,
InverseRealSHT
,
RealVectorSHT
,
InverseRealVectorSHT
from
.sht
import
RealSHT
,
InverseRealSHT
,
RealVectorSHT
,
InverseRealVectorSHT
from
.convolution
import
DiscreteContinuousConvS2
,
DiscreteContinuousConvTransposeS2
from
.convolution
import
DiscreteContinuousConvS2
,
DiscreteContinuousConvTransposeS2
...
...
torch_harmonics/convolution.py
View file @
54502a17
...
@@ -29,6 +29,7 @@
...
@@ -29,6 +29,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#
import
abc
from
typing
import
List
,
Tuple
,
Union
,
Optional
from
typing
import
List
,
Tuple
,
Union
,
Optional
import
math
import
math
...
@@ -38,7 +39,7 @@ import torch.nn as nn
...
@@ -38,7 +39,7 @@ import torch.nn as nn
from
functools
import
partial
from
functools
import
partial
from
torch_harmonics.quadrature
import
_precompute_latitudes
from
torch_harmonics.quadrature
import
_precompute_grid
,
_precompute_latitudes
from
torch_harmonics._disco_convolution
import
(
from
torch_harmonics._disco_convolution
import
(
_disco_s2_contraction_torch
,
_disco_s2_contraction_torch
,
_disco_s2_transpose_contraction_torch
,
_disco_s2_transpose_contraction_torch
,
...
@@ -47,50 +48,67 @@ from torch_harmonics._disco_convolution import (
...
@@ -47,50 +48,67 @@ from torch_harmonics._disco_convolution import (
)
)
def
_compute_support_vals_isotropic
(
theta
:
torch
.
Tensor
,
phi
:
torch
.
Tensor
,
n
theta
:
int
,
theta
_cutoff
:
float
):
def
_compute_support_vals_isotropic
(
r
:
torch
.
Tensor
,
phi
:
torch
.
Tensor
,
n
r
:
int
,
r
_cutoff
:
float
,
norm
:
str
=
"s2"
):
"""
"""
Computes the index set that falls into the isotropic kernel's support and returns both indices and values.
Computes the index set that falls into the isotropic kernel's support and returns both indices and values.
"""
"""
# compute the support
# compute the support
dtheta
=
(
theta_cutoff
-
0.0
)
/
ntheta
dr
=
(
r_cutoff
-
0.0
)
/
nr
ikernel
=
torch
.
arange
(
ntheta
).
reshape
(
-
1
,
1
,
1
)
ikernel
=
torch
.
arange
(
nr
).
reshape
(
-
1
,
1
,
1
)
itheta
=
ikernel
*
dtheta
ir
=
ikernel
*
dr
norm_factor
=
2
*
math
.
pi
*
(
1
-
math
.
cos
(
theta_cutoff
-
dtheta
)
+
math
.
cos
(
theta_cutoff
-
dtheta
)
+
(
math
.
sin
(
theta_cutoff
-
dtheta
)
-
math
.
sin
(
theta_cutoff
))
/
dtheta
)
if
norm
==
"none"
:
norm_factor
=
1.0
elif
norm
==
"2d"
:
norm_factor
=
math
.
pi
*
(
r_cutoff
*
nr
/
(
nr
+
1
))
**
2
+
math
.
pi
*
r_cutoff
**
2
*
(
2
*
nr
/
(
nr
+
1
)
+
1
)
/
(
nr
+
1
)
/
3
elif
norm
==
"s2"
:
norm_factor
=
2
*
math
.
pi
*
(
1
-
math
.
cos
(
r_cutoff
-
dr
)
+
math
.
cos
(
r_cutoff
-
dr
)
+
(
math
.
sin
(
r_cutoff
-
dr
)
-
math
.
sin
(
r_cutoff
))
/
dr
)
else
:
raise
ValueError
(
f
"Unknown normalization mode
{
norm
}
."
)
# find the indices where the rotated position falls into the support of the kernel
# find the indices where the rotated position falls into the support of the kernel
iidx
=
torch
.
argwhere
(((
theta
-
itheta
).
abs
()
<=
dtheta
)
&
(
theta
<=
theta
_cutoff
))
iidx
=
torch
.
argwhere
(((
r
-
ir
).
abs
()
<=
dr
)
&
(
r
<=
r
_cutoff
))
vals
=
(
1
-
(
theta
[
iidx
[:,
1
],
iidx
[:,
2
]]
-
i
theta
[
iidx
[:,
0
],
0
,
0
]).
abs
()
/
d
theta
)
/
norm_factor
vals
=
(
1
-
(
r
[
iidx
[:,
1
],
iidx
[:,
2
]]
-
i
r
[
iidx
[:,
0
],
0
,
0
]).
abs
()
/
d
r
)
/
norm_factor
return
iidx
,
vals
return
iidx
,
vals
def
_compute_support_vals_anisotropic
(
theta
:
torch
.
Tensor
,
phi
:
torch
.
Tensor
,
ntheta
:
int
,
nphi
:
int
,
theta_cutoff
:
float
):
def
_compute_support_vals_anisotropic
(
r
:
torch
.
Tensor
,
phi
:
torch
.
Tensor
,
nr
:
int
,
nphi
:
int
,
r_cutoff
:
float
,
norm
:
str
=
"s2"
):
"""
"""
Computes the index set that falls into the anisotropic kernel's support and returns both indices and values.
Computes the index set that falls into the anisotropic kernel's support and returns both indices and values.
"""
"""
# compute the support
# compute the support
d
theta
=
(
theta
_cutoff
-
0.0
)
/
n
theta
d
r
=
(
r
_cutoff
-
0.0
)
/
n
r
dphi
=
2.0
*
math
.
pi
/
nphi
dphi
=
2.0
*
math
.
pi
/
nphi
kernel_size
=
(
n
theta
-
1
)
*
nphi
+
1
kernel_size
=
(
n
r
-
1
)
*
nphi
+
1
ikernel
=
torch
.
arange
(
kernel_size
).
reshape
(
-
1
,
1
,
1
)
ikernel
=
torch
.
arange
(
kernel_size
).
reshape
(
-
1
,
1
,
1
)
i
theta
=
((
ikernel
-
1
)
//
nphi
+
1
)
*
d
theta
i
r
=
((
ikernel
-
1
)
//
nphi
+
1
)
*
d
r
iphi
=
((
ikernel
-
1
)
%
nphi
)
*
dphi
iphi
=
((
ikernel
-
1
)
%
nphi
)
*
dphi
norm_factor
=
2
*
math
.
pi
*
(
1
-
math
.
cos
(
theta_cutoff
-
dtheta
)
+
math
.
cos
(
theta_cutoff
-
dtheta
)
+
(
math
.
sin
(
theta_cutoff
-
dtheta
)
-
math
.
sin
(
theta_cutoff
))
/
dtheta
)
if
norm
==
"none"
:
norm_factor
=
1.0
elif
norm
==
"2d"
:
norm_factor
=
math
.
pi
*
(
r_cutoff
*
nr
/
(
nr
+
1
))
**
2
+
math
.
pi
*
r_cutoff
**
2
*
(
2
*
nr
/
(
nr
+
1
)
+
1
)
/
(
nr
+
1
)
/
3
elif
norm
==
"s2"
:
norm_factor
=
2
*
math
.
pi
*
(
1
-
math
.
cos
(
r_cutoff
-
dr
)
+
math
.
cos
(
r_cutoff
-
dr
)
+
(
math
.
sin
(
r_cutoff
-
dr
)
-
math
.
sin
(
r_cutoff
))
/
dr
)
else
:
raise
ValueError
(
f
"Unknown normalization mode
{
norm
}
."
)
# find the indices where the rotated position falls into the support of the kernel
# find the indices where the rotated position falls into the support of the kernel
cond_theta
=
((
theta
-
itheta
).
abs
()
<=
dtheta
)
&
(
theta
<=
theta_cutoff
)
cond_r
=
((
r
-
ir
).
abs
()
<=
dr
)
&
(
r
<=
r_cutoff
)
cond_phi
=
(
ikernel
==
0
)
|
((
phi
-
iphi
).
abs
()
<=
dphi
)
|
((
2
*
math
.
pi
-
(
phi
-
iphi
).
abs
())
<=
dphi
)
cond_phi
=
(
ikernel
==
0
)
|
((
phi
-
iphi
).
abs
()
<=
dphi
)
|
((
2
*
math
.
pi
-
(
phi
-
iphi
).
abs
())
<=
dphi
)
iidx
=
torch
.
argwhere
(
cond_theta
&
cond_phi
)
iidx
=
torch
.
argwhere
(
cond_r
&
cond_phi
)
vals
=
(
1
-
(
theta
[
iidx
[:,
1
],
iidx
[:,
2
]]
-
itheta
[
iidx
[:,
0
],
0
,
0
]).
abs
()
/
dtheta
)
/
norm_factor
vals
=
(
1
-
(
r
[
iidx
[:,
1
],
iidx
[:,
2
]]
-
ir
[
iidx
[:,
0
],
0
,
0
]).
abs
()
/
dr
)
/
norm_factor
vals
*=
torch
.
where
(
iidx
[:,
0
]
>
0
,
(
1
-
torch
.
minimum
((
phi
[
iidx
[:,
1
],
iidx
[:,
2
]]
-
iphi
[
iidx
[:,
0
],
0
,
0
]).
abs
(),
(
2
*
math
.
pi
-
(
phi
[
iidx
[:,
1
],
iidx
[:,
2
]]
-
iphi
[
iidx
[:,
0
],
0
,
0
]).
abs
())
)
/
dphi
),
1.0
)
vals
*=
torch
.
where
(
iidx
[:,
0
]
>
0
,
(
1
-
torch
.
minimum
((
phi
[
iidx
[:,
1
],
iidx
[:,
2
]]
-
iphi
[
iidx
[:,
0
],
0
,
0
]).
abs
(),
(
2
*
math
.
pi
-
(
phi
[
iidx
[:,
1
],
iidx
[:,
2
]]
-
iphi
[
iidx
[:,
0
],
0
,
0
]).
abs
()))
/
dphi
),
1.0
,
)
return
iidx
,
vals
return
iidx
,
vals
def
_precompute_convolution_tensor
(
def
_precompute_convolution_tensor_s2
(
in_shape
,
out_shape
,
kernel_shape
,
grid_in
=
"equiangular"
,
grid_out
=
"equiangular"
,
theta_cutoff
=
0.01
*
math
.
pi
):
in_shape
,
out_shape
,
kernel_shape
,
grid_in
=
"equiangular"
,
grid_out
=
"equiangular"
,
theta_cutoff
=
0.01
*
math
.
pi
):
"""
"""
Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i
\n
u = Y(-
\t
heta_j)Z(\phi_i - \phi_j)Y(
\t
heta_j)
\n
u$.
Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i
\n
u = Y(-
\t
heta_j)Z(\phi_i - \phi_j)Y(
\t
heta_j)
\n
u$.
Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al.
Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al.
...
@@ -111,9 +129,9 @@ def _precompute_convolution_tensor(
...
@@ -111,9 +129,9 @@ def _precompute_convolution_tensor(
assert
len
(
out_shape
)
==
2
assert
len
(
out_shape
)
==
2
if
len
(
kernel_shape
)
==
1
:
if
len
(
kernel_shape
)
==
1
:
kernel_handle
=
partial
(
_compute_support_vals_isotropic
,
n
theta
=
kernel_shape
[
0
],
theta
_cutoff
=
theta_cutoff
)
kernel_handle
=
partial
(
_compute_support_vals_isotropic
,
n
r
=
kernel_shape
[
0
],
r
_cutoff
=
theta_cutoff
,
norm
=
"s2"
)
elif
len
(
kernel_shape
)
==
2
:
elif
len
(
kernel_shape
)
==
2
:
kernel_handle
=
partial
(
_compute_support_vals_anisotropic
,
n
theta
=
kernel_shape
[
0
],
nphi
=
kernel_shape
[
1
],
theta
_cutoff
=
theta_cutoff
)
kernel_handle
=
partial
(
_compute_support_vals_anisotropic
,
n
r
=
kernel_shape
[
0
],
nphi
=
kernel_shape
[
1
],
r
_cutoff
=
theta_cutoff
,
norm
=
"s2"
)
else
:
else
:
raise
ValueError
(
"kernel_shape should be either one- or two-dimensional."
)
raise
ValueError
(
"kernel_shape should be either one- or two-dimensional."
)
...
@@ -131,24 +149,24 @@ def _precompute_convolution_tensor(
...
@@ -131,24 +149,24 @@ def _precompute_convolution_tensor(
# compute the phi differences
# compute the phi differences
# It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
# It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
lons_in
=
torch
.
linspace
(
0
,
2
*
math
.
pi
,
nlon_in
+
1
)[:
-
1
]
lons_in
=
torch
.
linspace
(
0
,
2
*
math
.
pi
,
nlon_in
+
1
)[:
-
1
]
for
t
in
range
(
nlat_out
):
for
t
in
range
(
nlat_out
):
# the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis
# the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis
alpha
=
-
lats_out
[
t
]
alpha
=
-
lats_out
[
t
]
beta
=
lons_in
beta
=
lons_in
gamma
=
lats_in
.
reshape
(
-
1
,
1
)
gamma
=
lats_in
.
reshape
(
-
1
,
1
)
# compute cartesian coordinates of the rotated position
# compute cartesian coordinates of the rotated position
# This uses the YZY convention of Euler angles, where the last angle (alpha) is a passive rotation,
# This uses the YZY convention of Euler angles, where the last angle (alpha) is a passive rotation,
# and therefore applied with a negative sign
# and therefore applied with a negative sign
z
=
-
torch
.
cos
(
beta
)
*
torch
.
sin
(
alpha
)
*
torch
.
sin
(
gamma
)
+
torch
.
cos
(
alpha
)
*
torch
.
cos
(
gamma
)
z
=
-
torch
.
cos
(
beta
)
*
torch
.
sin
(
alpha
)
*
torch
.
sin
(
gamma
)
+
torch
.
cos
(
alpha
)
*
torch
.
cos
(
gamma
)
x
=
torch
.
cos
(
alpha
)
*
torch
.
cos
(
beta
)
*
torch
.
sin
(
gamma
)
+
torch
.
cos
(
gamma
)
*
torch
.
sin
(
alpha
)
x
=
torch
.
cos
(
alpha
)
*
torch
.
cos
(
beta
)
*
torch
.
sin
(
gamma
)
+
torch
.
cos
(
gamma
)
*
torch
.
sin
(
alpha
)
y
=
torch
.
sin
(
beta
)
*
torch
.
sin
(
gamma
)
y
=
torch
.
sin
(
beta
)
*
torch
.
sin
(
gamma
)
# normalization is emportant to avoid NaNs when arccos and atan are applied
# normalization is emportant to avoid NaNs when arccos and atan are applied
# this can otherwise lead to spurious artifacts in the solution
# this can otherwise lead to spurious artifacts in the solution
norm
=
torch
.
sqrt
(
x
*
x
+
y
*
y
+
z
*
z
)
norm
=
torch
.
sqrt
(
x
*
x
+
y
*
y
+
z
*
z
)
x
=
x
/
norm
x
=
x
/
norm
y
=
y
/
norm
y
=
y
/
norm
z
=
z
/
norm
z
=
z
/
norm
...
@@ -170,9 +188,96 @@ def _precompute_convolution_tensor(
...
@@ -170,9 +188,96 @@ def _precompute_convolution_tensor(
return
out_idx
,
out_vals
return
out_idx
,
out_vals
# TODO:
def
_precompute_convolution_tensor_2d
(
grid_in
,
grid_out
,
kernel_shape
,
radius_cutoff
=
0.01
,
periodic
=
False
):
# - derive conv and conv transpose from single module
"""
class
DiscreteContinuousConvS2
(
nn
.
Module
):
Precomputes the translated filters at positions $T^{-1}_j \omega_i = T^{-1}_j T_i
\n
u$. Similar to the S2 routine,
only that it assumes a non-periodic subset of the euclidean plane
"""
# check that input arrays are valid point clouds in 2D
assert
len
(
grid_in
)
==
2
assert
len
(
grid_out
)
==
2
assert
grid_in
.
shape
[
0
]
==
2
assert
grid_out
.
shape
[
0
]
==
2
n_in
=
grid_in
.
shape
[
-
1
]
n_out
=
grid_out
.
shape
[
-
1
]
if
len
(
kernel_shape
)
==
1
:
kernel_handle
=
partial
(
_compute_support_vals_isotropic
,
nr
=
kernel_shape
[
0
],
r_cutoff
=
radius_cutoff
,
norm
=
"2d"
)
elif
len
(
kernel_shape
)
==
2
:
kernel_handle
=
partial
(
_compute_support_vals_anisotropic
,
nr
=
kernel_shape
[
0
],
nphi
=
kernel_shape
[
1
],
r_cutoff
=
radius_cutoff
,
norm
=
"2d"
)
else
:
raise
ValueError
(
"kernel_shape should be either one- or two-dimensional."
)
grid_in
=
grid_in
.
reshape
(
2
,
1
,
n_in
)
grid_out
=
grid_out
.
reshape
(
2
,
n_out
,
1
)
diffs
=
grid_in
-
grid_out
if
periodic
:
periodic_diffs
=
torch
.
where
(
diffs
>
0.0
,
diffs
-
1
,
diffs
+
1
)
diffs
=
torch
.
where
(
diffs
.
abs
()
<
periodic_diffs
.
abs
(),
diffs
,
periodic_diffs
)
r
=
torch
.
sqrt
(
diffs
[
0
]
**
2
+
diffs
[
1
]
**
2
)
phi
=
torch
.
arctan2
(
diffs
[
1
],
diffs
[
0
])
+
torch
.
pi
idx
,
vals
=
kernel_handle
(
r
,
phi
)
idx
=
idx
.
permute
(
1
,
0
)
return
idx
,
vals
class
DiscreteContinuousConv
(
nn
.
Module
,
metaclass
=
abc
.
ABCMeta
):
"""
Abstract base class for DISCO convolutions
"""
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
kernel_shape
:
Union
[
int
,
List
[
int
]],
groups
:
Optional
[
int
]
=
1
,
bias
:
Optional
[
bool
]
=
True
,
):
super
().
__init__
()
if
isinstance
(
kernel_shape
,
int
):
self
.
kernel_shape
=
[
kernel_shape
]
else
:
self
.
kernel_shape
=
kernel_shape
if
len
(
self
.
kernel_shape
)
==
1
:
self
.
kernel_size
=
self
.
kernel_shape
[
0
]
elif
len
(
self
.
kernel_shape
)
==
2
:
self
.
kernel_size
=
(
self
.
kernel_shape
[
0
]
-
1
)
*
self
.
kernel_shape
[
1
]
+
1
else
:
raise
ValueError
(
"kernel_shape should be either one- or two-dimensional."
)
# groups
self
.
groups
=
groups
# weight tensor
if
in_channels
%
self
.
groups
!=
0
:
raise
ValueError
(
"Error, the number of input channels has to be an integer multiple of the group size"
)
if
out_channels
%
self
.
groups
!=
0
:
raise
ValueError
(
"Error, the number of output channels has to be an integer multiple of the group size"
)
self
.
groupsize
=
in_channels
//
self
.
groups
scale
=
math
.
sqrt
(
1.0
/
self
.
groupsize
)
self
.
weight
=
nn
.
Parameter
(
scale
*
torch
.
randn
(
out_channels
,
self
.
groupsize
,
self
.
kernel_size
))
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
out_channels
))
else
:
self
.
bias
=
None
@
abc
.
abstractmethod
def
forward
(
self
,
x
:
torch
.
Tensor
):
raise
NotImplementedError
class
DiscreteContinuousConvS2
(
DiscreteContinuousConv
):
"""
"""
Discrete-continuous convolutions (DISCO) on the 2-Sphere as described in [1].
Discrete-continuous convolutions (DISCO) on the 2-Sphere as described in [1].
...
@@ -192,24 +297,14 @@ class DiscreteContinuousConvS2(nn.Module):
...
@@ -192,24 +297,14 @@ class DiscreteContinuousConvS2(nn.Module):
bias
:
Optional
[
bool
]
=
True
,
bias
:
Optional
[
bool
]
=
True
,
theta_cutoff
:
Optional
[
float
]
=
None
,
theta_cutoff
:
Optional
[
float
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
(
in_channels
,
out_channels
,
kernel_shape
,
groups
,
bias
)
self
.
nlat_in
,
self
.
nlon_in
=
in_shape
self
.
nlat_in
,
self
.
nlon_in
=
in_shape
self
.
nlat_out
,
self
.
nlon_out
=
out_shape
self
.
nlat_out
,
self
.
nlon_out
=
out_shape
if
isinstance
(
kernel_shape
,
int
):
kernel_shape
=
[
kernel_shape
]
if
len
(
kernel_shape
)
==
1
:
self
.
kernel_size
=
kernel_shape
[
0
]
elif
len
(
kernel_shape
)
==
2
:
self
.
kernel_size
=
(
kernel_shape
[
0
]
-
1
)
*
kernel_shape
[
1
]
+
1
else
:
raise
ValueError
(
"kernel_shape should be either one- or two-dimensional."
)
# compute theta cutoff based on the bandlimit of the input field
# compute theta cutoff based on the bandlimit of the input field
if
theta_cutoff
is
None
:
if
theta_cutoff
is
None
:
theta_cutoff
=
(
kernel_shape
[
0
]
+
1
)
*
torch
.
pi
/
float
(
self
.
nlat_in
-
1
)
theta_cutoff
=
(
self
.
kernel_shape
[
0
]
+
1
)
*
torch
.
pi
/
float
(
self
.
nlat_in
-
1
)
if
theta_cutoff
<=
0.0
:
if
theta_cutoff
<=
0.0
:
raise
ValueError
(
"Error, theta_cutoff has to be positive."
)
raise
ValueError
(
"Error, theta_cutoff has to be positive."
)
...
@@ -219,38 +314,20 @@ class DiscreteContinuousConvS2(nn.Module):
...
@@ -219,38 +314,20 @@ class DiscreteContinuousConvS2(nn.Module):
quad_weights
=
2.0
*
torch
.
pi
*
torch
.
from_numpy
(
wgl
).
float
().
reshape
(
-
1
,
1
)
/
self
.
nlon_in
quad_weights
=
2.0
*
torch
.
pi
*
torch
.
from_numpy
(
wgl
).
float
().
reshape
(
-
1
,
1
)
/
self
.
nlon_in
self
.
register_buffer
(
"quad_weights"
,
quad_weights
,
persistent
=
False
)
self
.
register_buffer
(
"quad_weights"
,
quad_weights
,
persistent
=
False
)
idx
,
vals
=
_precompute_convolution_tensor
(
idx
,
vals
=
_precompute_convolution_tensor_s2
(
in_shape
,
out_shape
,
self
.
kernel_shape
,
grid_in
=
grid_in
,
grid_out
=
grid_out
,
theta_cutoff
=
theta_cutoff
)
in_shape
,
out_shape
,
kernel_shape
,
grid_in
=
grid_in
,
grid_out
=
grid_out
,
theta_cutoff
=
theta_cutoff
)
# psi = torch.sparse_coo_tensor(
# idx, vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in)
# ).coalesce()
self
.
register_buffer
(
"psi_idx"
,
idx
,
persistent
=
False
)
self
.
register_buffer
(
"psi_idx"
,
idx
,
persistent
=
False
)
self
.
register_buffer
(
"psi_vals"
,
vals
,
persistent
=
False
)
self
.
register_buffer
(
"psi_vals"
,
vals
,
persistent
=
False
)
# self.register_buffer("psi", psi, persistent=False)
# groups
self
.
groups
=
groups
# weight tensor
if
in_channels
%
self
.
groups
!=
0
:
raise
ValueError
(
"Error, the number of input channels has to be an integer multiple of the group size"
)
if
out_channels
%
self
.
groups
!=
0
:
raise
ValueError
(
"Error, the number of output channels has to be an integer multiple of the group size"
)
self
.
groupsize
=
in_channels
//
self
.
groups
scale
=
math
.
sqrt
(
1.0
/
self
.
groupsize
)
self
.
weight
=
nn
.
Parameter
(
scale
*
torch
.
randn
(
out_channels
,
self
.
groupsize
,
self
.
kernel_size
))
if
bias
:
def
get_psi
(
self
):
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
out_channels
))
psi
=
torch
.
sparse_coo_tensor
(
self
.
psi_idx
,
self
.
psi_vals
,
size
=
(
self
.
kernel_size
,
self
.
nlat_out
,
self
.
nlat_in
*
self
.
nlon_in
)).
coalesce
()
else
:
return
psi
self
.
bias
=
None
def
forward
(
self
,
x
:
torch
.
Tensor
,
use_triton_kernel
:
bool
=
True
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
,
use_triton_kernel
:
bool
=
True
)
->
torch
.
Tensor
:
# pre-multiply x with the quadrature weights
# pre-multiply x with the quadrature weights
x
=
self
.
quad_weights
*
x
x
=
self
.
quad_weights
*
x
psi
=
torch
.
sparse_coo_tensor
(
self
.
psi_idx
,
self
.
psi_vals
,
size
=
(
self
.
kernel_size
,
self
.
nlat_out
,
self
.
nlat_in
*
self
.
nlon_in
)).
coalesce
()
psi
=
self
.
get_psi
()
if
x
.
is_cuda
and
use_triton_kernel
:
if
x
.
is_cuda
and
use_triton_kernel
:
x
=
_disco_s2_contraction_triton
(
x
,
psi
,
self
.
nlon_out
)
x
=
_disco_s2_contraction_triton
(
x
,
psi
,
self
.
nlon_out
)
...
@@ -271,7 +348,7 @@ class DiscreteContinuousConvS2(nn.Module):
...
@@ -271,7 +348,7 @@ class DiscreteContinuousConvS2(nn.Module):
return
out
return
out
class
DiscreteContinuousConvTransposeS2
(
nn
.
Module
):
class
DiscreteContinuousConvTransposeS2
(
DiscreteContinuousConv
):
"""
"""
Discrete-continuous transpose convolutions (DISCO) on the 2-Sphere as described in [1].
Discrete-continuous transpose convolutions (DISCO) on the 2-Sphere as described in [1].
...
@@ -291,23 +368,14 @@ class DiscreteContinuousConvTransposeS2(nn.Module):
...
@@ -291,23 +368,14 @@ class DiscreteContinuousConvTransposeS2(nn.Module):
bias
:
Optional
[
bool
]
=
True
,
bias
:
Optional
[
bool
]
=
True
,
theta_cutoff
:
Optional
[
float
]
=
None
,
theta_cutoff
:
Optional
[
float
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
(
in_channels
,
out_channels
,
kernel_shape
,
groups
,
bias
)
self
.
nlat_in
,
self
.
nlon_in
=
in_shape
self
.
nlat_in
,
self
.
nlon_in
=
in_shape
self
.
nlat_out
,
self
.
nlon_out
=
out_shape
self
.
nlat_out
,
self
.
nlon_out
=
out_shape
if
isinstance
(
kernel_shape
,
int
):
kernel_shape
=
[
kernel_shape
]
if
len
(
kernel_shape
)
==
1
:
self
.
kernel_size
=
kernel_shape
[
0
]
elif
len
(
kernel_shape
)
==
2
:
self
.
kernel_size
=
(
kernel_shape
[
0
]
-
1
)
*
kernel_shape
[
1
]
+
1
else
:
raise
ValueError
(
"kernel_shape should be either one- or two-dimensional."
)
# bandlimit
# bandlimit
if
theta_cutoff
is
None
:
if
theta_cutoff
is
None
:
theta_cutoff
=
(
kernel_shape
[
0
]
+
1
)
*
torch
.
pi
/
float
(
self
.
nlat_in
-
1
)
theta_cutoff
=
(
self
.
kernel_shape
[
0
]
+
1
)
*
torch
.
pi
/
float
(
self
.
nlat_in
-
1
)
if
theta_cutoff
<=
0.0
:
if
theta_cutoff
<=
0.0
:
raise
ValueError
(
"Error, theta_cutoff has to be positive."
)
raise
ValueError
(
"Error, theta_cutoff has to be positive."
)
...
@@ -318,32 +386,14 @@ class DiscreteContinuousConvTransposeS2(nn.Module):
...
@@ -318,32 +386,14 @@ class DiscreteContinuousConvTransposeS2(nn.Module):
self
.
register_buffer
(
"quad_weights"
,
quad_weights
,
persistent
=
False
)
self
.
register_buffer
(
"quad_weights"
,
quad_weights
,
persistent
=
False
)
# switch in_shape and out_shape since we want transpose conv
# switch in_shape and out_shape since we want transpose conv
idx
,
vals
=
_precompute_convolution_tensor
(
idx
,
vals
=
_precompute_convolution_tensor_s2
(
out_shape
,
in_shape
,
self
.
kernel_shape
,
grid_in
=
grid_out
,
grid_out
=
grid_in
,
theta_cutoff
=
theta_cutoff
)
out_shape
,
in_shape
,
kernel_shape
,
grid_in
=
grid_out
,
grid_out
=
grid_in
,
theta_cutoff
=
theta_cutoff
)
# psi = torch.sparse_coo_tensor(
# idx, vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out)
# ).coalesce()
self
.
register_buffer
(
"psi_idx"
,
idx
,
persistent
=
False
)
self
.
register_buffer
(
"psi_idx"
,
idx
,
persistent
=
False
)
self
.
register_buffer
(
"psi_vals"
,
vals
,
persistent
=
False
)
self
.
register_buffer
(
"psi_vals"
,
vals
,
persistent
=
False
)
# self.register_buffer("psi", psi, persistent=False)
# groups
self
.
groups
=
groups
# weight tensor
def
get_psi
(
self
):
if
in_channels
%
self
.
groups
!=
0
:
psi
=
torch
.
sparse_coo_tensor
(
self
.
psi_idx
,
self
.
psi_vals
,
size
=
(
self
.
kernel_size
,
self
.
nlat_in
,
self
.
nlat_out
*
self
.
nlon_out
)).
coalesce
()
raise
ValueError
(
"Error, the number of input channels has to be an integer multiple of the group size"
)
return
psi
if
out_channels
%
self
.
groups
!=
0
:
raise
ValueError
(
"Error, the number of output channels has to be an integer multiple of the group size"
)
self
.
groupsize
=
in_channels
//
self
.
groups
scale
=
math
.
sqrt
(
1.0
/
self
.
groupsize
)
self
.
weight
=
nn
.
Parameter
(
scale
*
torch
.
randn
(
out_channels
,
self
.
groupsize
,
self
.
kernel_size
))
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
out_channels
))
else
:
self
.
bias
=
None
def
forward
(
self
,
x
:
torch
.
Tensor
,
use_triton_kernel
:
bool
=
True
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
,
use_triton_kernel
:
bool
=
True
)
->
torch
.
Tensor
:
# extract shape
# extract shape
...
@@ -357,7 +407,7 @@ class DiscreteContinuousConvTransposeS2(nn.Module):
...
@@ -357,7 +407,7 @@ class DiscreteContinuousConvTransposeS2(nn.Module):
# pre-multiply x with the quadrature weights
# pre-multiply x with the quadrature weights
x
=
self
.
quad_weights
*
x
x
=
self
.
quad_weights
*
x
psi
=
torch
.
sparse_coo_tensor
(
self
.
psi_idx
,
self
.
psi_vals
,
size
=
(
self
.
kernel_size
,
self
.
nlat_in
,
self
.
nlat_out
*
self
.
nlon_out
)).
coalesce
()
psi
=
self
.
get_psi
()
if
x
.
is_cuda
and
use_triton_kernel
:
if
x
.
is_cuda
and
use_triton_kernel
:
out
=
_disco_s2_transpose_contraction_triton
(
x
,
psi
,
self
.
nlon_out
)
out
=
_disco_s2_transpose_contraction_triton
(
x
,
psi
,
self
.
nlon_out
)
...
@@ -368,3 +418,4 @@ class DiscreteContinuousConvTransposeS2(nn.Module):
...
@@ -368,3 +418,4 @@ class DiscreteContinuousConvTransposeS2(nn.Module):
out
=
out
+
self
.
bias
.
reshape
(
1
,
-
1
,
1
,
1
)
out
=
out
+
self
.
bias
.
reshape
(
1
,
-
1
,
1
,
1
)
return
out
return
out
torch_harmonics/quadrature.py
View file @
54502a17
...
@@ -31,26 +31,53 @@
...
@@ -31,26 +31,53 @@
import
numpy
as
np
import
numpy
as
np
def
_precompute_grid
(
n
,
grid
=
"equidistant"
,
a
=
0.0
,
b
=
1.0
,
periodic
=
False
):
if
(
grid
!=
"equidistant"
)
and
periodic
:
raise
ValueError
(
f
"Periodic grid is only supported on equidistant grids."
)
# compute coordinates
if
grid
==
"equidistant"
:
xlg
,
wlg
=
trapezoidal_weights
(
n
,
a
=
a
,
b
=
b
,
periodic
=
periodic
)
elif
grid
==
"legendre-gauss"
:
xlg
,
wlg
=
legendre_gauss_weights
(
n
,
a
=
a
,
b
=
b
)
elif
grid
==
"lobatto"
:
xlg
,
wlg
=
lobatto_weights
(
n
,
a
=
a
,
b
=
b
)
elif
grid
==
"equiangular"
:
xlg
,
wlg
=
clenshaw_curtiss_weights
(
n
,
a
=
a
,
b
=
b
)
else
:
raise
ValueError
(
f
"Unknown grid type
{
grid
}
"
)
return
xlg
,
wlg
def
_precompute_latitudes
(
nlat
,
grid
=
"equiangular"
):
def
_precompute_latitudes
(
nlat
,
grid
=
"equiangular"
):
r
"""
r
"""
Convenience routine to precompute latitudes
Convenience routine to precompute latitudes
"""
"""
# compute coordinates
# compute coordinates
if
grid
==
"legendre-gauss"
:
xlg
,
wlg
=
_precompute_grid
(
nlat
,
grid
=
grid
,
a
=-
1.0
,
b
=
1.0
,
periodic
=
False
)
xlg
,
wlg
=
legendre_gauss_weights
(
nlat
)
elif
grid
==
"lobatto"
:
xlg
,
wlg
=
lobatto_weights
(
nlat
)
elif
grid
==
"equiangular"
:
xlg
,
wlg
=
clenshaw_curtiss_weights
(
nlat
)
else
:
raise
ValueError
(
"Unknown grid"
)
lats
=
np
.
flip
(
np
.
arccos
(
xlg
)).
copy
()
lats
=
np
.
flip
(
np
.
arccos
(
xlg
)).
copy
()
wlg
=
np
.
flip
(
wlg
).
copy
()
wlg
=
np
.
flip
(
wlg
).
copy
()
return
lats
,
wlg
return
lats
,
wlg
def
trapezoidal_weights
(
n
,
a
=-
1.0
,
b
=
1.0
,
periodic
=
False
):
r
"""
Helper routine which returns equidistant nodes with trapezoidal weights
on the interval [a, b]
"""
xlg
=
np
.
linspace
(
a
,
b
,
n
)
wlg
=
(
b
-
a
)
/
(
n
-
1
)
*
np
.
ones
(
n
)
if
not
periodic
:
wlg
[
0
]
*=
0.5
wlg
[
-
1
]
*=
0.5
return
xlg
,
wlg
def
legendre_gauss_weights
(
n
,
a
=-
1.0
,
b
=
1.0
):
def
legendre_gauss_weights
(
n
,
a
=-
1.0
,
b
=
1.0
):
r
"""
r
"""
Helper routine which returns the Legendre-Gauss nodes and weights
Helper routine which returns the Legendre-Gauss nodes and weights
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment