Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
54502a17
Unverified
Commit
54502a17
authored
Jan 30, 2024
by
Boris Bonev
Committed by
GitHub
Jan 30, 2024
Browse files
Bbonev/disco refactor (#29)
* Cleaned up DISCO convolutions
parent
c971d458
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
196 additions
and
117 deletions
+196
-117
Changelog.md
Changelog.md
+5
-4
torch_harmonics/__init__.py
torch_harmonics/__init__.py
+1
-1
torch_harmonics/convolution.py
torch_harmonics/convolution.py
+155
-104
torch_harmonics/quadrature.py
torch_harmonics/quadrature.py
+35
-8
No files found.
Changelog.md
View file @
54502a17
...
...
@@ -4,10 +4,11 @@
### v0.6.5
*
Discrrete-continuous (DISCO) convolutions on the sphere
*
Isotropic and anisotropic DISCO convolutions
*
Accelerated DISCO convolutions on GPU via Triton implementation
*
Unittests for DISCO convolutions
*
Discrete-continuous (DISCO) convolutions on the sphere and in two dimensions
*
DISCO supports isotropic and anisotropic kernel functions parameterized as hat functions
*
Supports regular and transpose convolutions
*
Accelerated spherical DISCO convolutions on GPU via Triton implementation
*
Unittests for DISCO convolutions in
`tests/test_convolution.py`
### v0.6.4
...
...
torch_harmonics/__init__.py
View file @
54502a17
...
...
@@ -29,7 +29,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
__version__
=
'0.6.
4
'
__version__
=
'0.6.
5
'
from
.sht
import
RealSHT
,
InverseRealSHT
,
RealVectorSHT
,
InverseRealVectorSHT
from
.convolution
import
DiscreteContinuousConvS2
,
DiscreteContinuousConvTransposeS2
...
...
torch_harmonics/convolution.py
View file @
54502a17
...
...
@@ -29,6 +29,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import
abc
from
typing
import
List
,
Tuple
,
Union
,
Optional
import
math
...
...
@@ -38,7 +39,7 @@ import torch.nn as nn
from
functools
import
partial
from
torch_harmonics.quadrature
import
_precompute_latitudes
from
torch_harmonics.quadrature
import
_precompute_grid
,
_precompute_latitudes
from
torch_harmonics._disco_convolution
import
(
_disco_s2_contraction_torch
,
_disco_s2_transpose_contraction_torch
,
...
...
@@ -47,50 +48,67 @@ from torch_harmonics._disco_convolution import (
)
def
_compute_support_vals_isotropic
(
theta
:
torch
.
Tensor
,
phi
:
torch
.
Tensor
,
n
theta
:
int
,
theta
_cutoff
:
float
):
def
_compute_support_vals_isotropic
(
r
:
torch
.
Tensor
,
phi
:
torch
.
Tensor
,
n
r
:
int
,
r
_cutoff
:
float
,
norm
:
str
=
"s2"
):
"""
Computes the index set that falls into the isotropic kernel's support and returns both indices and values.
"""
# compute the support
dtheta
=
(
theta_cutoff
-
0.0
)
/
ntheta
ikernel
=
torch
.
arange
(
ntheta
).
reshape
(
-
1
,
1
,
1
)
itheta
=
ikernel
*
dtheta
norm_factor
=
2
*
math
.
pi
*
(
1
-
math
.
cos
(
theta_cutoff
-
dtheta
)
+
math
.
cos
(
theta_cutoff
-
dtheta
)
+
(
math
.
sin
(
theta_cutoff
-
dtheta
)
-
math
.
sin
(
theta_cutoff
))
/
dtheta
)
dr
=
(
r_cutoff
-
0.0
)
/
nr
ikernel
=
torch
.
arange
(
nr
).
reshape
(
-
1
,
1
,
1
)
ir
=
ikernel
*
dr
if
norm
==
"none"
:
norm_factor
=
1.0
elif
norm
==
"2d"
:
norm_factor
=
math
.
pi
*
(
r_cutoff
*
nr
/
(
nr
+
1
))
**
2
+
math
.
pi
*
r_cutoff
**
2
*
(
2
*
nr
/
(
nr
+
1
)
+
1
)
/
(
nr
+
1
)
/
3
elif
norm
==
"s2"
:
norm_factor
=
2
*
math
.
pi
*
(
1
-
math
.
cos
(
r_cutoff
-
dr
)
+
math
.
cos
(
r_cutoff
-
dr
)
+
(
math
.
sin
(
r_cutoff
-
dr
)
-
math
.
sin
(
r_cutoff
))
/
dr
)
else
:
raise
ValueError
(
f
"Unknown normalization mode
{
norm
}
."
)
# find the indices where the rotated position falls into the support of the kernel
iidx
=
torch
.
argwhere
(((
theta
-
itheta
).
abs
()
<=
dtheta
)
&
(
theta
<=
theta
_cutoff
))
vals
=
(
1
-
(
theta
[
iidx
[:,
1
],
iidx
[:,
2
]]
-
i
theta
[
iidx
[:,
0
],
0
,
0
]).
abs
()
/
d
theta
)
/
norm_factor
iidx
=
torch
.
argwhere
(((
r
-
ir
).
abs
()
<=
dr
)
&
(
r
<=
r
_cutoff
))
vals
=
(
1
-
(
r
[
iidx
[:,
1
],
iidx
[:,
2
]]
-
i
r
[
iidx
[:,
0
],
0
,
0
]).
abs
()
/
d
r
)
/
norm_factor
return
iidx
,
vals
def
_compute_support_vals_anisotropic
(
theta
:
torch
.
Tensor
,
phi
:
torch
.
Tensor
,
ntheta
:
int
,
nphi
:
int
,
theta_cutoff
:
float
):
def
_compute_support_vals_anisotropic
(
r
:
torch
.
Tensor
,
phi
:
torch
.
Tensor
,
nr
:
int
,
nphi
:
int
,
r_cutoff
:
float
,
norm
:
str
=
"s2"
):
"""
Computes the index set that falls into the anisotropic kernel's support and returns both indices and values.
"""
# compute the support
d
theta
=
(
theta
_cutoff
-
0.0
)
/
n
theta
d
r
=
(
r
_cutoff
-
0.0
)
/
n
r
dphi
=
2.0
*
math
.
pi
/
nphi
kernel_size
=
(
n
theta
-
1
)
*
nphi
+
1
kernel_size
=
(
n
r
-
1
)
*
nphi
+
1
ikernel
=
torch
.
arange
(
kernel_size
).
reshape
(
-
1
,
1
,
1
)
i
theta
=
((
ikernel
-
1
)
//
nphi
+
1
)
*
d
theta
i
r
=
((
ikernel
-
1
)
//
nphi
+
1
)
*
d
r
iphi
=
((
ikernel
-
1
)
%
nphi
)
*
dphi
norm_factor
=
2
*
math
.
pi
*
(
1
-
math
.
cos
(
theta_cutoff
-
dtheta
)
+
math
.
cos
(
theta_cutoff
-
dtheta
)
+
(
math
.
sin
(
theta_cutoff
-
dtheta
)
-
math
.
sin
(
theta_cutoff
))
/
dtheta
)
if
norm
==
"none"
:
norm_factor
=
1.0
elif
norm
==
"2d"
:
norm_factor
=
math
.
pi
*
(
r_cutoff
*
nr
/
(
nr
+
1
))
**
2
+
math
.
pi
*
r_cutoff
**
2
*
(
2
*
nr
/
(
nr
+
1
)
+
1
)
/
(
nr
+
1
)
/
3
elif
norm
==
"s2"
:
norm_factor
=
2
*
math
.
pi
*
(
1
-
math
.
cos
(
r_cutoff
-
dr
)
+
math
.
cos
(
r_cutoff
-
dr
)
+
(
math
.
sin
(
r_cutoff
-
dr
)
-
math
.
sin
(
r_cutoff
))
/
dr
)
else
:
raise
ValueError
(
f
"Unknown normalization mode
{
norm
}
."
)
# find the indices where the rotated position falls into the support of the kernel
cond_theta
=
((
theta
-
itheta
).
abs
()
<=
dtheta
)
&
(
theta
<=
theta_cutoff
)
cond_phi
=
(
ikernel
==
0
)
|
((
phi
-
iphi
).
abs
()
<=
dphi
)
|
((
2
*
math
.
pi
-
(
phi
-
iphi
).
abs
())
<=
dphi
)
iidx
=
torch
.
argwhere
(
cond_theta
&
cond_phi
)
vals
=
(
1
-
(
theta
[
iidx
[:,
1
],
iidx
[:,
2
]]
-
itheta
[
iidx
[:,
0
],
0
,
0
]).
abs
()
/
dtheta
)
/
norm_factor
vals
*=
torch
.
where
(
iidx
[:,
0
]
>
0
,
(
1
-
torch
.
minimum
((
phi
[
iidx
[:,
1
],
iidx
[:,
2
]]
-
iphi
[
iidx
[:,
0
],
0
,
0
]).
abs
(),
(
2
*
math
.
pi
-
(
phi
[
iidx
[:,
1
],
iidx
[:,
2
]]
-
iphi
[
iidx
[:,
0
],
0
,
0
]).
abs
())
)
/
dphi
),
1.0
)
cond_r
=
((
r
-
ir
).
abs
()
<=
dr
)
&
(
r
<=
r_cutoff
)
cond_phi
=
(
ikernel
==
0
)
|
((
phi
-
iphi
).
abs
()
<=
dphi
)
|
((
2
*
math
.
pi
-
(
phi
-
iphi
).
abs
())
<=
dphi
)
iidx
=
torch
.
argwhere
(
cond_r
&
cond_phi
)
vals
=
(
1
-
(
r
[
iidx
[:,
1
],
iidx
[:,
2
]]
-
ir
[
iidx
[:,
0
],
0
,
0
]).
abs
()
/
dr
)
/
norm_factor
vals
*=
torch
.
where
(
iidx
[:,
0
]
>
0
,
(
1
-
torch
.
minimum
((
phi
[
iidx
[:,
1
],
iidx
[:,
2
]]
-
iphi
[
iidx
[:,
0
],
0
,
0
]).
abs
(),
(
2
*
math
.
pi
-
(
phi
[
iidx
[:,
1
],
iidx
[:,
2
]]
-
iphi
[
iidx
[:,
0
],
0
,
0
]).
abs
()))
/
dphi
),
1.0
,
)
return
iidx
,
vals
def
_precompute_convolution_tensor
(
in_shape
,
out_shape
,
kernel_shape
,
grid_in
=
"equiangular"
,
grid_out
=
"equiangular"
,
theta_cutoff
=
0.01
*
math
.
pi
):
def
_precompute_convolution_tensor_s2
(
in_shape
,
out_shape
,
kernel_shape
,
grid_in
=
"equiangular"
,
grid_out
=
"equiangular"
,
theta_cutoff
=
0.01
*
math
.
pi
):
"""
Precomputes the rotated filters at positions $R^{-1}_j \omega_i = R^{-1}_j R_i
\n
u = Y(-
\t
heta_j)Z(\phi_i - \phi_j)Y(
\t
heta_j)
\n
u$.
Assumes a tensorized grid on the sphere with an equidistant sampling in longitude as described in Ocampo et al.
...
...
@@ -111,9 +129,9 @@ def _precompute_convolution_tensor(
assert
len
(
out_shape
)
==
2
if
len
(
kernel_shape
)
==
1
:
kernel_handle
=
partial
(
_compute_support_vals_isotropic
,
n
theta
=
kernel_shape
[
0
],
theta
_cutoff
=
theta_cutoff
)
kernel_handle
=
partial
(
_compute_support_vals_isotropic
,
n
r
=
kernel_shape
[
0
],
r
_cutoff
=
theta_cutoff
,
norm
=
"s2"
)
elif
len
(
kernel_shape
)
==
2
:
kernel_handle
=
partial
(
_compute_support_vals_anisotropic
,
n
theta
=
kernel_shape
[
0
],
nphi
=
kernel_shape
[
1
],
theta
_cutoff
=
theta_cutoff
)
kernel_handle
=
partial
(
_compute_support_vals_anisotropic
,
n
r
=
kernel_shape
[
0
],
nphi
=
kernel_shape
[
1
],
r
_cutoff
=
theta_cutoff
,
norm
=
"s2"
)
else
:
raise
ValueError
(
"kernel_shape should be either one- or two-dimensional."
)
...
...
@@ -131,24 +149,24 @@ def _precompute_convolution_tensor(
# compute the phi differences
# It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
lons_in
=
torch
.
linspace
(
0
,
2
*
math
.
pi
,
nlon_in
+
1
)[:
-
1
]
lons_in
=
torch
.
linspace
(
0
,
2
*
math
.
pi
,
nlon_in
+
1
)[:
-
1
]
for
t
in
range
(
nlat_out
):
# the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis
alpha
=
-
lats_out
[
t
]
alpha
=
-
lats_out
[
t
]
beta
=
lons_in
gamma
=
lats_in
.
reshape
(
-
1
,
1
)
# compute cartesian coordinates of the rotated position
# This uses the YZY convention of Euler angles, where the last angle (alpha) is a passive rotation,
# and therefore applied with a negative sign
z
=
-
torch
.
cos
(
beta
)
*
torch
.
sin
(
alpha
)
*
torch
.
sin
(
gamma
)
+
torch
.
cos
(
alpha
)
*
torch
.
cos
(
gamma
)
z
=
-
torch
.
cos
(
beta
)
*
torch
.
sin
(
alpha
)
*
torch
.
sin
(
gamma
)
+
torch
.
cos
(
alpha
)
*
torch
.
cos
(
gamma
)
x
=
torch
.
cos
(
alpha
)
*
torch
.
cos
(
beta
)
*
torch
.
sin
(
gamma
)
+
torch
.
cos
(
gamma
)
*
torch
.
sin
(
alpha
)
y
=
torch
.
sin
(
beta
)
*
torch
.
sin
(
gamma
)
# normalization is emportant to avoid NaNs when arccos and atan are applied
# this can otherwise lead to spurious artifacts in the solution
norm
=
torch
.
sqrt
(
x
*
x
+
y
*
y
+
z
*
z
)
norm
=
torch
.
sqrt
(
x
*
x
+
y
*
y
+
z
*
z
)
x
=
x
/
norm
y
=
y
/
norm
z
=
z
/
norm
...
...
@@ -170,9 +188,96 @@ def _precompute_convolution_tensor(
return
out_idx
,
out_vals
# TODO:
# - derive conv and conv transpose from single module
class
DiscreteContinuousConvS2
(
nn
.
Module
):
def
_precompute_convolution_tensor_2d
(
grid_in
,
grid_out
,
kernel_shape
,
radius_cutoff
=
0.01
,
periodic
=
False
):
"""
Precomputes the translated filters at positions $T^{-1}_j \omega_i = T^{-1}_j T_i
\n
u$. Similar to the S2 routine,
only that it assumes a non-periodic subset of the euclidean plane
"""
# check that input arrays are valid point clouds in 2D
assert
len
(
grid_in
)
==
2
assert
len
(
grid_out
)
==
2
assert
grid_in
.
shape
[
0
]
==
2
assert
grid_out
.
shape
[
0
]
==
2
n_in
=
grid_in
.
shape
[
-
1
]
n_out
=
grid_out
.
shape
[
-
1
]
if
len
(
kernel_shape
)
==
1
:
kernel_handle
=
partial
(
_compute_support_vals_isotropic
,
nr
=
kernel_shape
[
0
],
r_cutoff
=
radius_cutoff
,
norm
=
"2d"
)
elif
len
(
kernel_shape
)
==
2
:
kernel_handle
=
partial
(
_compute_support_vals_anisotropic
,
nr
=
kernel_shape
[
0
],
nphi
=
kernel_shape
[
1
],
r_cutoff
=
radius_cutoff
,
norm
=
"2d"
)
else
:
raise
ValueError
(
"kernel_shape should be either one- or two-dimensional."
)
grid_in
=
grid_in
.
reshape
(
2
,
1
,
n_in
)
grid_out
=
grid_out
.
reshape
(
2
,
n_out
,
1
)
diffs
=
grid_in
-
grid_out
if
periodic
:
periodic_diffs
=
torch
.
where
(
diffs
>
0.0
,
diffs
-
1
,
diffs
+
1
)
diffs
=
torch
.
where
(
diffs
.
abs
()
<
periodic_diffs
.
abs
(),
diffs
,
periodic_diffs
)
r
=
torch
.
sqrt
(
diffs
[
0
]
**
2
+
diffs
[
1
]
**
2
)
phi
=
torch
.
arctan2
(
diffs
[
1
],
diffs
[
0
])
+
torch
.
pi
idx
,
vals
=
kernel_handle
(
r
,
phi
)
idx
=
idx
.
permute
(
1
,
0
)
return
idx
,
vals
class
DiscreteContinuousConv
(
nn
.
Module
,
metaclass
=
abc
.
ABCMeta
):
"""
Abstract base class for DISCO convolutions
"""
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
kernel_shape
:
Union
[
int
,
List
[
int
]],
groups
:
Optional
[
int
]
=
1
,
bias
:
Optional
[
bool
]
=
True
,
):
super
().
__init__
()
if
isinstance
(
kernel_shape
,
int
):
self
.
kernel_shape
=
[
kernel_shape
]
else
:
self
.
kernel_shape
=
kernel_shape
if
len
(
self
.
kernel_shape
)
==
1
:
self
.
kernel_size
=
self
.
kernel_shape
[
0
]
elif
len
(
self
.
kernel_shape
)
==
2
:
self
.
kernel_size
=
(
self
.
kernel_shape
[
0
]
-
1
)
*
self
.
kernel_shape
[
1
]
+
1
else
:
raise
ValueError
(
"kernel_shape should be either one- or two-dimensional."
)
# groups
self
.
groups
=
groups
# weight tensor
if
in_channels
%
self
.
groups
!=
0
:
raise
ValueError
(
"Error, the number of input channels has to be an integer multiple of the group size"
)
if
out_channels
%
self
.
groups
!=
0
:
raise
ValueError
(
"Error, the number of output channels has to be an integer multiple of the group size"
)
self
.
groupsize
=
in_channels
//
self
.
groups
scale
=
math
.
sqrt
(
1.0
/
self
.
groupsize
)
self
.
weight
=
nn
.
Parameter
(
scale
*
torch
.
randn
(
out_channels
,
self
.
groupsize
,
self
.
kernel_size
))
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
out_channels
))
else
:
self
.
bias
=
None
@
abc
.
abstractmethod
def
forward
(
self
,
x
:
torch
.
Tensor
):
raise
NotImplementedError
class
DiscreteContinuousConvS2
(
DiscreteContinuousConv
):
"""
Discrete-continuous convolutions (DISCO) on the 2-Sphere as described in [1].
...
...
@@ -192,24 +297,14 @@ class DiscreteContinuousConvS2(nn.Module):
bias
:
Optional
[
bool
]
=
True
,
theta_cutoff
:
Optional
[
float
]
=
None
,
):
super
().
__init__
()
super
().
__init__
(
in_channels
,
out_channels
,
kernel_shape
,
groups
,
bias
)
self
.
nlat_in
,
self
.
nlon_in
=
in_shape
self
.
nlat_out
,
self
.
nlon_out
=
out_shape
if
isinstance
(
kernel_shape
,
int
):
kernel_shape
=
[
kernel_shape
]
if
len
(
kernel_shape
)
==
1
:
self
.
kernel_size
=
kernel_shape
[
0
]
elif
len
(
kernel_shape
)
==
2
:
self
.
kernel_size
=
(
kernel_shape
[
0
]
-
1
)
*
kernel_shape
[
1
]
+
1
else
:
raise
ValueError
(
"kernel_shape should be either one- or two-dimensional."
)
# compute theta cutoff based on the bandlimit of the input field
if
theta_cutoff
is
None
:
theta_cutoff
=
(
kernel_shape
[
0
]
+
1
)
*
torch
.
pi
/
float
(
self
.
nlat_in
-
1
)
theta_cutoff
=
(
self
.
kernel_shape
[
0
]
+
1
)
*
torch
.
pi
/
float
(
self
.
nlat_in
-
1
)
if
theta_cutoff
<=
0.0
:
raise
ValueError
(
"Error, theta_cutoff has to be positive."
)
...
...
@@ -219,38 +314,20 @@ class DiscreteContinuousConvS2(nn.Module):
quad_weights
=
2.0
*
torch
.
pi
*
torch
.
from_numpy
(
wgl
).
float
().
reshape
(
-
1
,
1
)
/
self
.
nlon_in
self
.
register_buffer
(
"quad_weights"
,
quad_weights
,
persistent
=
False
)
idx
,
vals
=
_precompute_convolution_tensor
(
in_shape
,
out_shape
,
kernel_shape
,
grid_in
=
grid_in
,
grid_out
=
grid_out
,
theta_cutoff
=
theta_cutoff
)
# psi = torch.sparse_coo_tensor(
# idx, vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in)
# ).coalesce()
idx
,
vals
=
_precompute_convolution_tensor_s2
(
in_shape
,
out_shape
,
self
.
kernel_shape
,
grid_in
=
grid_in
,
grid_out
=
grid_out
,
theta_cutoff
=
theta_cutoff
)
self
.
register_buffer
(
"psi_idx"
,
idx
,
persistent
=
False
)
self
.
register_buffer
(
"psi_vals"
,
vals
,
persistent
=
False
)
# self.register_buffer("psi", psi, persistent=False)
# groups
self
.
groups
=
groups
# weight tensor
if
in_channels
%
self
.
groups
!=
0
:
raise
ValueError
(
"Error, the number of input channels has to be an integer multiple of the group size"
)
if
out_channels
%
self
.
groups
!=
0
:
raise
ValueError
(
"Error, the number of output channels has to be an integer multiple of the group size"
)
self
.
groupsize
=
in_channels
//
self
.
groups
scale
=
math
.
sqrt
(
1.0
/
self
.
groupsize
)
self
.
weight
=
nn
.
Parameter
(
scale
*
torch
.
randn
(
out_channels
,
self
.
groupsize
,
self
.
kernel_size
))
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
out_channels
))
else
:
self
.
bias
=
None
def
get_psi
(
self
):
psi
=
torch
.
sparse_coo_tensor
(
self
.
psi_idx
,
self
.
psi_vals
,
size
=
(
self
.
kernel_size
,
self
.
nlat_out
,
self
.
nlat_in
*
self
.
nlon_in
)).
coalesce
()
return
psi
def
forward
(
self
,
x
:
torch
.
Tensor
,
use_triton_kernel
:
bool
=
True
)
->
torch
.
Tensor
:
# pre-multiply x with the quadrature weights
x
=
self
.
quad_weights
*
x
psi
=
torch
.
sparse_coo_tensor
(
self
.
psi_idx
,
self
.
psi_vals
,
size
=
(
self
.
kernel_size
,
self
.
nlat_out
,
self
.
nlat_in
*
self
.
nlon_in
)).
coalesce
()
psi
=
self
.
get_psi
()
if
x
.
is_cuda
and
use_triton_kernel
:
x
=
_disco_s2_contraction_triton
(
x
,
psi
,
self
.
nlon_out
)
...
...
@@ -271,7 +348,7 @@ class DiscreteContinuousConvS2(nn.Module):
return
out
class
DiscreteContinuousConvTransposeS2
(
nn
.
Module
):
class
DiscreteContinuousConvTransposeS2
(
DiscreteContinuousConv
):
"""
Discrete-continuous transpose convolutions (DISCO) on the 2-Sphere as described in [1].
...
...
@@ -291,23 +368,14 @@ class DiscreteContinuousConvTransposeS2(nn.Module):
bias
:
Optional
[
bool
]
=
True
,
theta_cutoff
:
Optional
[
float
]
=
None
,
):
super
().
__init__
()
super
().
__init__
(
in_channels
,
out_channels
,
kernel_shape
,
groups
,
bias
)
self
.
nlat_in
,
self
.
nlon_in
=
in_shape
self
.
nlat_out
,
self
.
nlon_out
=
out_shape
if
isinstance
(
kernel_shape
,
int
):
kernel_shape
=
[
kernel_shape
]
if
len
(
kernel_shape
)
==
1
:
self
.
kernel_size
=
kernel_shape
[
0
]
elif
len
(
kernel_shape
)
==
2
:
self
.
kernel_size
=
(
kernel_shape
[
0
]
-
1
)
*
kernel_shape
[
1
]
+
1
else
:
raise
ValueError
(
"kernel_shape should be either one- or two-dimensional."
)
# bandlimit
if
theta_cutoff
is
None
:
theta_cutoff
=
(
kernel_shape
[
0
]
+
1
)
*
torch
.
pi
/
float
(
self
.
nlat_in
-
1
)
theta_cutoff
=
(
self
.
kernel_shape
[
0
]
+
1
)
*
torch
.
pi
/
float
(
self
.
nlat_in
-
1
)
if
theta_cutoff
<=
0.0
:
raise
ValueError
(
"Error, theta_cutoff has to be positive."
)
...
...
@@ -318,32 +386,14 @@ class DiscreteContinuousConvTransposeS2(nn.Module):
self
.
register_buffer
(
"quad_weights"
,
quad_weights
,
persistent
=
False
)
# switch in_shape and out_shape since we want transpose conv
idx
,
vals
=
_precompute_convolution_tensor
(
out_shape
,
in_shape
,
kernel_shape
,
grid_in
=
grid_out
,
grid_out
=
grid_in
,
theta_cutoff
=
theta_cutoff
)
# psi = torch.sparse_coo_tensor(
# idx, vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out)
# ).coalesce()
idx
,
vals
=
_precompute_convolution_tensor_s2
(
out_shape
,
in_shape
,
self
.
kernel_shape
,
grid_in
=
grid_out
,
grid_out
=
grid_in
,
theta_cutoff
=
theta_cutoff
)
self
.
register_buffer
(
"psi_idx"
,
idx
,
persistent
=
False
)
self
.
register_buffer
(
"psi_vals"
,
vals
,
persistent
=
False
)
# self.register_buffer("psi", psi, persistent=False)
# groups
self
.
groups
=
groups
# weight tensor
if
in_channels
%
self
.
groups
!=
0
:
raise
ValueError
(
"Error, the number of input channels has to be an integer multiple of the group size"
)
if
out_channels
%
self
.
groups
!=
0
:
raise
ValueError
(
"Error, the number of output channels has to be an integer multiple of the group size"
)
self
.
groupsize
=
in_channels
//
self
.
groups
scale
=
math
.
sqrt
(
1.0
/
self
.
groupsize
)
self
.
weight
=
nn
.
Parameter
(
scale
*
torch
.
randn
(
out_channels
,
self
.
groupsize
,
self
.
kernel_size
))
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
out_channels
))
else
:
self
.
bias
=
None
def
get_psi
(
self
):
psi
=
torch
.
sparse_coo_tensor
(
self
.
psi_idx
,
self
.
psi_vals
,
size
=
(
self
.
kernel_size
,
self
.
nlat_in
,
self
.
nlat_out
*
self
.
nlon_out
)).
coalesce
()
return
psi
def
forward
(
self
,
x
:
torch
.
Tensor
,
use_triton_kernel
:
bool
=
True
)
->
torch
.
Tensor
:
# extract shape
...
...
@@ -357,7 +407,7 @@ class DiscreteContinuousConvTransposeS2(nn.Module):
# pre-multiply x with the quadrature weights
x
=
self
.
quad_weights
*
x
psi
=
torch
.
sparse_coo_tensor
(
self
.
psi_idx
,
self
.
psi_vals
,
size
=
(
self
.
kernel_size
,
self
.
nlat_in
,
self
.
nlat_out
*
self
.
nlon_out
)).
coalesce
()
psi
=
self
.
get_psi
()
if
x
.
is_cuda
and
use_triton_kernel
:
out
=
_disco_s2_transpose_contraction_triton
(
x
,
psi
,
self
.
nlon_out
)
...
...
@@ -368,3 +418,4 @@ class DiscreteContinuousConvTransposeS2(nn.Module):
out
=
out
+
self
.
bias
.
reshape
(
1
,
-
1
,
1
,
1
)
return
out
torch_harmonics/quadrature.py
View file @
54502a17
...
...
@@ -31,26 +31,53 @@
import
numpy
as
np
def
_precompute_grid
(
n
,
grid
=
"equidistant"
,
a
=
0.0
,
b
=
1.0
,
periodic
=
False
):
if
(
grid
!=
"equidistant"
)
and
periodic
:
raise
ValueError
(
f
"Periodic grid is only supported on equidistant grids."
)
# compute coordinates
if
grid
==
"equidistant"
:
xlg
,
wlg
=
trapezoidal_weights
(
n
,
a
=
a
,
b
=
b
,
periodic
=
periodic
)
elif
grid
==
"legendre-gauss"
:
xlg
,
wlg
=
legendre_gauss_weights
(
n
,
a
=
a
,
b
=
b
)
elif
grid
==
"lobatto"
:
xlg
,
wlg
=
lobatto_weights
(
n
,
a
=
a
,
b
=
b
)
elif
grid
==
"equiangular"
:
xlg
,
wlg
=
clenshaw_curtiss_weights
(
n
,
a
=
a
,
b
=
b
)
else
:
raise
ValueError
(
f
"Unknown grid type
{
grid
}
"
)
return
xlg
,
wlg
def
_precompute_latitudes
(
nlat
,
grid
=
"equiangular"
):
r
"""
Convenience routine to precompute latitudes
"""
# compute coordinates
if
grid
==
"legendre-gauss"
:
xlg
,
wlg
=
legendre_gauss_weights
(
nlat
)
elif
grid
==
"lobatto"
:
xlg
,
wlg
=
lobatto_weights
(
nlat
)
elif
grid
==
"equiangular"
:
xlg
,
wlg
=
clenshaw_curtiss_weights
(
nlat
)
else
:
raise
ValueError
(
"Unknown grid"
)
xlg
,
wlg
=
_precompute_grid
(
nlat
,
grid
=
grid
,
a
=-
1.0
,
b
=
1.0
,
periodic
=
False
)
lats
=
np
.
flip
(
np
.
arccos
(
xlg
)).
copy
()
wlg
=
np
.
flip
(
wlg
).
copy
()
return
lats
,
wlg
def
trapezoidal_weights
(
n
,
a
=-
1.0
,
b
=
1.0
,
periodic
=
False
):
r
"""
Helper routine which returns equidistant nodes with trapezoidal weights
on the interval [a, b]
"""
xlg
=
np
.
linspace
(
a
,
b
,
n
)
wlg
=
(
b
-
a
)
/
(
n
-
1
)
*
np
.
ones
(
n
)
if
not
periodic
:
wlg
[
0
]
*=
0.5
wlg
[
-
1
]
*=
0.5
return
xlg
,
wlg
def
legendre_gauss_weights
(
n
,
a
=-
1.0
,
b
=
1.0
):
r
"""
Helper routine which returns the Legendre-Gauss nodes and weights
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment