Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
87d9bfdc
Commit
87d9bfdc
authored
Jan 13, 2025
by
Boris Bonev
Committed by
Boris Bonev
Jan 14, 2025
Browse files
bugfix in distributed convolution
parent
e5a9c4af
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
23 additions
and
17 deletions
+23
-17
tests/run_tests.sh
tests/run_tests.sh
+1
-0
tests/test_convolution.py
tests/test_convolution.py
+2
-0
tests/test_distributed_convolution.py
tests/test_distributed_convolution.py
+12
-12
torch_harmonics/distributed/distributed_convolution.py
torch_harmonics/distributed/distributed_convolution.py
+8
-5
No files found.
tests/run_tests.sh
View file @
87d9bfdc
...
...
@@ -70,6 +70,7 @@ if [ "$run_distributed" = "true" ]; then
export GRID_W=
${
grid_size_lon
}
;
python3 -m pytest tests/test_distributed_sht.py
python3 -m pytest tests/test_distributed_convolution.py
python3 -m pytest tests/test_distributed_resample.py
"
else
echo
"Skipping distributed tests."
...
...
tests/test_convolution.py
View file @
87d9bfdc
...
...
@@ -177,6 +177,7 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
[
8
,
4
,
2
,
(
24
,
48
),
(
12
,
24
),
[
3
,
3
],
"piecewise linear"
,
"mean"
,
"equiangular"
,
"equiangular"
,
False
,
1e-4
],
[
8
,
4
,
2
,
(
24
,
48
),
(
12
,
24
),
[
4
,
3
],
"piecewise linear"
,
"mean"
,
"equiangular"
,
"equiangular"
,
False
,
1e-4
],
[
8
,
4
,
2
,
(
24
,
48
),
(
12
,
24
),
[
2
,
2
],
"morlet"
,
"mean"
,
"equiangular"
,
"equiangular"
,
False
,
1e-4
],
[
8
,
4
,
2
,
(
24
,
48
),
(
12
,
24
),
[
3
],
"zernike"
,
"mean"
,
"equiangular"
,
"equiangular"
,
False
,
1e-4
],
[
8
,
4
,
2
,
(
16
,
24
),
(
8
,
8
),
[
3
],
"piecewise linear"
,
"mean"
,
"equiangular"
,
"equiangular"
,
False
,
1e-4
],
[
8
,
4
,
2
,
(
18
,
36
),
(
6
,
12
),
[
7
],
"piecewise linear"
,
"mean"
,
"equiangular"
,
"equiangular"
,
False
,
1e-4
],
[
8
,
4
,
2
,
(
16
,
32
),
(
8
,
16
),
[
5
],
"piecewise linear"
,
"mean"
,
"equiangular"
,
"legendre-gauss"
,
False
,
1e-4
],
...
...
@@ -188,6 +189,7 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
[
8
,
4
,
2
,
(
12
,
24
),
(
24
,
48
),
[
3
,
3
],
"piecewise linear"
,
"mean"
,
"equiangular"
,
"equiangular"
,
True
,
1e-4
],
[
8
,
4
,
2
,
(
12
,
24
),
(
24
,
48
),
[
4
,
3
],
"piecewise linear"
,
"mean"
,
"equiangular"
,
"equiangular"
,
True
,
1e-4
],
[
8
,
4
,
2
,
(
12
,
24
),
(
24
,
48
),
[
2
,
2
],
"morlet"
,
"mean"
,
"equiangular"
,
"equiangular"
,
True
,
1e-4
],
[
8
,
4
,
2
,
(
12
,
24
),
(
24
,
48
),
[
3
],
"zernike"
,
"mean"
,
"equiangular"
,
"equiangular"
,
True
,
1e-4
],
[
8
,
4
,
2
,
(
8
,
8
),
(
16
,
24
),
[
3
],
"piecewise linear"
,
"mean"
,
"equiangular"
,
"equiangular"
,
True
,
1e-4
],
[
8
,
4
,
2
,
(
6
,
12
),
(
18
,
36
),
[
7
],
"piecewise linear"
,
"mean"
,
"equiangular"
,
"equiangular"
,
True
,
1e-4
],
[
8
,
4
,
2
,
(
8
,
16
),
(
16
,
32
),
[
5
],
"piecewise linear"
,
"mean"
,
"equiangular"
,
"legendre-gauss"
,
True
,
1e-4
],
...
...
tests/test_distributed_convolution.py
View file @
87d9bfdc
...
...
@@ -183,18 +183,18 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
@
parameterized
.
expand
(
[
[
128
,
256
,
128
,
256
,
32
,
8
,
[
3
],
"piecewise linear"
,
"
individual
"
,
1
,
"equiangular"
,
"equiangular"
,
False
,
1e-5
],
[
129
,
256
,
128
,
256
,
32
,
8
,
[
3
],
"piecewise linear"
,
"
individual
"
,
1
,
"equiangular"
,
"equiangular"
,
False
,
1e-5
],
[
128
,
256
,
128
,
256
,
32
,
8
,
[
3
,
2
],
"piecewise linear"
,
"
individual
"
,
1
,
"equiangular"
,
"equiangular"
,
False
,
1e-5
],
[
128
,
256
,
64
,
128
,
32
,
8
,
[
3
],
"piecewise linear"
,
"
individual
"
,
1
,
"equiangular"
,
"equiangular"
,
False
,
1e-5
],
[
128
,
256
,
128
,
256
,
32
,
8
,
[
3
],
"piecewise linear"
,
"
individual
"
,
2
,
"equiangular"
,
"equiangular"
,
False
,
1e-5
],
[
128
,
256
,
128
,
256
,
32
,
6
,
[
3
],
"piecewise linear"
,
"
individual
"
,
1
,
"equiangular"
,
"equiangular"
,
False
,
1e-5
],
[
128
,
256
,
128
,
256
,
32
,
8
,
[
3
],
"piecewise linear"
,
"
individual
"
,
1
,
"equiangular"
,
"equiangular"
,
True
,
1e-5
],
[
129
,
256
,
129
,
256
,
32
,
8
,
[
3
],
"piecewise linear"
,
"
individual
"
,
1
,
"equiangular"
,
"equiangular"
,
True
,
1e-5
],
[
128
,
256
,
128
,
256
,
32
,
8
,
[
3
,
2
],
"piecewise linear"
,
"
individual
"
,
1
,
"equiangular"
,
"equiangular"
,
True
,
1e-5
],
[
64
,
128
,
128
,
256
,
32
,
8
,
[
3
],
"piecewise linear"
,
"
individual
"
,
1
,
"equiangular"
,
"equiangular"
,
True
,
1e-5
],
[
128
,
256
,
128
,
256
,
32
,
8
,
[
3
],
"piecewise linear"
,
"
individual
"
,
2
,
"equiangular"
,
"equiangular"
,
True
,
1e-5
],
[
128
,
256
,
128
,
256
,
32
,
6
,
[
3
],
"piecewise linear"
,
"
individual
"
,
1
,
"equiangular"
,
"equiangular"
,
True
,
1e-5
],
[
128
,
256
,
128
,
256
,
32
,
8
,
[
3
],
"piecewise linear"
,
"
mean
"
,
1
,
"equiangular"
,
"equiangular"
,
False
,
1e-5
],
[
129
,
256
,
128
,
256
,
32
,
8
,
[
3
],
"piecewise linear"
,
"
mean
"
,
1
,
"equiangular"
,
"equiangular"
,
False
,
1e-5
],
[
128
,
256
,
128
,
256
,
32
,
8
,
[
3
,
2
],
"piecewise linear"
,
"
mean
"
,
1
,
"equiangular"
,
"equiangular"
,
False
,
1e-5
],
[
128
,
256
,
64
,
128
,
32
,
8
,
[
3
],
"piecewise linear"
,
"
mean
"
,
1
,
"equiangular"
,
"equiangular"
,
False
,
1e-5
],
[
128
,
256
,
128
,
256
,
32
,
8
,
[
3
],
"piecewise linear"
,
"
mean
"
,
2
,
"equiangular"
,
"equiangular"
,
False
,
1e-5
],
[
128
,
256
,
128
,
256
,
32
,
6
,
[
3
],
"piecewise linear"
,
"
mean
"
,
1
,
"equiangular"
,
"equiangular"
,
False
,
1e-5
],
[
128
,
256
,
128
,
256
,
32
,
8
,
[
3
],
"piecewise linear"
,
"
mean
"
,
1
,
"equiangular"
,
"equiangular"
,
True
,
1e-5
],
[
129
,
256
,
129
,
256
,
32
,
8
,
[
3
],
"piecewise linear"
,
"
mean
"
,
1
,
"equiangular"
,
"equiangular"
,
True
,
1e-5
],
[
128
,
256
,
128
,
256
,
32
,
8
,
[
3
,
2
],
"piecewise linear"
,
"
mean
"
,
1
,
"equiangular"
,
"equiangular"
,
True
,
1e-5
],
[
64
,
128
,
128
,
256
,
32
,
8
,
[
3
],
"piecewise linear"
,
"
mean
"
,
1
,
"equiangular"
,
"equiangular"
,
True
,
1e-5
],
[
128
,
256
,
128
,
256
,
32
,
8
,
[
3
],
"piecewise linear"
,
"
mean
"
,
2
,
"equiangular"
,
"equiangular"
,
True
,
1e-5
],
[
128
,
256
,
128
,
256
,
32
,
6
,
[
3
],
"piecewise linear"
,
"
mean
"
,
1
,
"equiangular"
,
"equiangular"
,
True
,
1e-5
],
]
)
def
test_distributed_disco_conv
(
...
...
torch_harmonics/distributed/distributed_convolution.py
View file @
87d9bfdc
...
...
@@ -112,11 +112,12 @@ def _precompute_distributed_convolution_tensor_s2(
# It's imporatant to not include the 2 pi point in the longitudes, as it is equivalent to lon=0
lons_in
=
torch
.
linspace
(
0
,
2
*
math
.
pi
,
nlon_in
+
1
)[:
-
1
]
# compute quadrature weights that will be merged into the Psi tensor
# compute quadrature weights and merge them into the convolution tensor.
# These quadrature integrate to 1 over the sphere.
if
transpose_normalization
:
quad_weights
=
2.0
*
torch
.
pi
*
torch
.
from_numpy
(
wout
).
float
().
reshape
(
-
1
,
1
)
/
nlon_in
quad_weights
=
torch
.
from_numpy
(
wout
).
float
().
reshape
(
-
1
,
1
)
/
nlon_in
/
2.0
else
:
quad_weights
=
2.0
*
torch
.
pi
*
torch
.
from_numpy
(
win
).
float
().
reshape
(
-
1
,
1
)
/
nlon_in
quad_weights
=
torch
.
from_numpy
(
win
).
float
().
reshape
(
-
1
,
1
)
/
nlon_in
/
2.0
out_idx
=
[]
out_vals
=
[]
...
...
@@ -129,11 +130,11 @@ def _precompute_distributed_convolution_tensor_s2(
# compute cartesian coordinates of the rotated position
# This uses the YZY convention of Euler angles, where the last angle (alpha) is a passive rotation,
# and therefore applied with a negative sign
z
=
-
torch
.
cos
(
beta
)
*
torch
.
sin
(
alpha
)
*
torch
.
sin
(
gamma
)
+
torch
.
cos
(
alpha
)
*
torch
.
cos
(
gamma
)
x
=
torch
.
cos
(
alpha
)
*
torch
.
cos
(
beta
)
*
torch
.
sin
(
gamma
)
+
torch
.
cos
(
gamma
)
*
torch
.
sin
(
alpha
)
y
=
torch
.
sin
(
beta
)
*
torch
.
sin
(
gamma
)
z
=
-
torch
.
cos
(
beta
)
*
torch
.
sin
(
alpha
)
*
torch
.
sin
(
gamma
)
+
torch
.
cos
(
alpha
)
*
torch
.
cos
(
gamma
)
# normalization is
e
mportant to avoid NaNs when arccos and atan are applied
# normalization is
i
mportant to avoid NaNs when arccos and atan are applied
# this can otherwise lead to spurious artifacts in the solution
norm
=
torch
.
sqrt
(
x
*
x
+
y
*
y
+
z
*
z
)
x
=
x
/
norm
...
...
@@ -270,6 +271,7 @@ class DistributedDiscreteContinuousConvS2(DiscreteContinuousConv):
roff_idx
=
preprocess_psi
(
self
.
kernel_size
,
self
.
nlat_out_local
,
ker_idx
,
row_idx
,
col_idx
,
vals
).
contiguous
()
self
.
register_buffer
(
"psi_roff_idx"
,
roff_idx
,
persistent
=
False
)
# save all datastructures
self
.
register_buffer
(
"psi_ker_idx"
,
ker_idx
,
persistent
=
False
)
self
.
register_buffer
(
"psi_row_idx"
,
row_idx
,
persistent
=
False
)
self
.
register_buffer
(
"psi_col_idx"
,
col_idx
,
persistent
=
False
)
...
...
@@ -412,6 +414,7 @@ class DistributedDiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
roff_idx
=
preprocess_psi
(
self
.
kernel_size
,
self
.
nlat_in_local
,
ker_idx
,
row_idx
,
col_idx
,
vals
).
contiguous
()
self
.
register_buffer
(
"psi_roff_idx"
,
roff_idx
,
persistent
=
False
)
# save all datastructures
self
.
register_buffer
(
"psi_ker_idx"
,
ker_idx
,
persistent
=
False
)
self
.
register_buffer
(
"psi_row_idx"
,
row_idx
,
persistent
=
False
)
self
.
register_buffer
(
"psi_col_idx"
,
col_idx
,
persistent
=
False
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment