Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
ea8d1a2e
Commit
ea8d1a2e
authored
Jul 21, 2025
by
Thorsten Kurth
Browse files
adding device args to more functions
parent
c877cda6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
64 additions
and
4 deletions
+64
-4
tests/test_convolution.py
tests/test_convolution.py
+61
-1
torch_harmonics/filter_basis.py
torch_harmonics/filter_basis.py
+3
-3
No files found.
tests/test_convolution.py
View file @
ea8d1a2e
...
...
@@ -127,7 +127,7 @@ def _precompute_convolution_tensor_dense(
quad_weights
=
win
.
reshape
(
-
1
,
1
)
/
nlon_in
/
2.0
# array for accumulating non-zero indices
out
=
torch
.
zeros
(
kernel_size
,
nlat_out
,
nlon_out
,
nlat_in
,
nlon_in
,
dtype
=
torch
.
float64
)
out
=
torch
.
zeros
(
kernel_size
,
nlat_out
,
nlon_out
,
nlat_in
,
nlon_in
,
dtype
=
torch
.
float64
,
device
=
lons_in
.
device
)
for
t
in
range
(
nlat_out
):
for
p
in
range
(
nlon_out
):
...
...
@@ -315,6 +315,66 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
self
.
assertTrue
(
torch
.
allclose
(
x_grad
,
x_ref_grad
,
rtol
=
tol
,
atol
=
tol
))
self
.
assertTrue
(
torch
.
allclose
(
conv
.
weight
.
grad
,
w_ref
.
grad
,
rtol
=
tol
,
atol
=
tol
))
@
parameterized
.
expand
(
[
[
8
,
4
,
2
,
(
16
,
32
),
(
16
,
32
),
(
3
),
"piecewise linear"
,
"mean"
,
"equiangular"
,
"equiangular"
,
False
,
1e-4
,
False
],
[
8
,
4
,
2
,
(
16
,
32
),
(
8
,
16
),
(
5
),
"piecewise linear"
,
"mean"
,
"legendre-gauss"
,
"legendre-gauss"
,
False
,
1e-4
,
False
],
[
8
,
4
,
2
,
(
16
,
32
),
(
16
,
32
),
(
3
),
"piecewise linear"
,
"mean"
,
"equiangular"
,
"equiangular"
,
True
,
1e-4
,
False
],
[
8
,
4
,
2
,
(
8
,
16
),
(
16
,
32
),
(
5
),
"piecewise linear"
,
"mean"
,
"legendre-gauss"
,
"legendre-gauss"
,
True
,
1e-4
,
False
],
]
)
@
unittest
.
skipIf
(
not
torch
.
cuda
.
is_available
(),
"CUDA is not available"
)
def
test_device_instantiation
(
self
,
batch_size
,
in_channels
,
out_channels
,
in_shape
,
out_shape
,
kernel_shape
,
basis_type
,
basis_norm_mode
,
grid_in
,
grid_out
,
transpose
,
tol
,
verbose
):
nlat_in
,
nlon_in
=
in_shape
nlat_out
,
nlon_out
=
out_shape
if
isinstance
(
kernel_shape
,
int
):
theta_cutoff
=
(
kernel_shape
+
1
)
*
torch
.
pi
/
float
(
nlat_in
-
1
)
else
:
theta_cutoff
=
(
kernel_shape
[
0
]
+
1
)
*
torch
.
pi
/
float
(
nlat_in
-
1
)
Conv
=
DiscreteContinuousConvTransposeS2
if
transpose
else
DiscreteContinuousConvS2
conv_host
=
Conv
(
in_channels
,
out_channels
,
in_shape
,
out_shape
,
kernel_shape
,
basis_type
=
basis_type
,
basis_norm_mode
=
basis_norm_mode
,
groups
=
1
,
grid_in
=
grid_in
,
grid_out
=
grid_out
,
bias
=
False
,
theta_cutoff
=
theta_cutoff
,
)
torch
.
set_default_device
(
self
.
device
)
#with(self.device):
conv_device
=
Conv
(
in_channels
,
out_channels
,
in_shape
,
out_shape
,
kernel_shape
,
basis_type
=
basis_type
,
basis_norm_mode
=
basis_norm_mode
,
groups
=
1
,
grid_in
=
grid_in
,
grid_out
=
grid_out
,
bias
=
False
,
theta_cutoff
=
theta_cutoff
,
)
print
(
conv_host
.
psi_col_idx
.
device
,
conv_device
.
psi_col_idx
.
device
)
self
.
assertTrue
(
torch
.
allclose
(
conv_host
.
psi_col_idx
.
cpu
(),
conv_device
.
psi_col_idx
.
cpu
()))
self
.
assertTrue
(
torch
.
allclose
(
conv_host
.
psi_row_idx
.
cpu
(),
conv_device
.
psi_row_idx
.
cpu
()))
self
.
assertTrue
(
torch
.
allclose
(
conv_host
.
psi_roff_idx
.
cpu
(),
conv_device
.
psi_roff_idx
.
cpu
()))
self
.
assertTrue
(
torch
.
allclose
(
conv_host
.
psi_vals
.
cpu
(),
conv_device
.
psi_vals
.
cpu
()))
self
.
assertTrue
(
torch
.
allclose
(
conv_host
.
psi_idx
.
cpu
(),
conv_device
.
psi_idx
.
cpu
()))
if
__name__
==
"__main__"
:
unittest
.
main
()
torch_harmonics/filter_basis.py
View file @
ea8d1a2e
...
...
@@ -254,7 +254,7 @@ class MorletFilterBasis(FilterBasis):
mkernel
=
ikernel
//
self
.
kernel_shape
[
1
]
# get relevant indices
iidx
=
torch
.
argwhere
((
r
<=
r_cutoff
)
&
torch
.
full_like
(
ikernel
,
True
,
dtype
=
torch
.
bool
))
iidx
=
torch
.
argwhere
((
r
<=
r_cutoff
)
&
torch
.
full_like
(
ikernel
,
True
,
dtype
=
torch
.
bool
,
device
=
r
.
device
))
# get corresponding r, phi, x and y coordinates
r
=
r
[
iidx
[:,
1
],
iidx
[:,
2
]]
/
r_cutoff
...
...
@@ -316,10 +316,10 @@ class ZernikeFilterBasis(FilterBasis):
"""
# enumerator for basis function
ikernel
=
torch
.
arange
(
self
.
kernel_size
).
reshape
(
-
1
,
1
,
1
)
ikernel
=
torch
.
arange
(
self
.
kernel_size
,
device
=
r
.
device
).
reshape
(
-
1
,
1
,
1
)
# get relevant indices
iidx
=
torch
.
argwhere
((
r
<=
r_cutoff
)
&
torch
.
full_like
(
ikernel
,
True
,
dtype
=
torch
.
bool
))
iidx
=
torch
.
argwhere
((
r
<=
r_cutoff
)
&
torch
.
full_like
(
ikernel
,
True
,
dtype
=
torch
.
bool
,
device
=
r
.
device
))
# indexing logic for zernike polynomials
# the total index is given by (n * (n + 2) + l ) // 2 which needs to be reversed
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment