Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
4770621b
Commit
4770621b
authored
Jul 17, 2025
by
Andrea Paris
Committed by
Boris Bonev
Jul 21, 2025
Browse files
cleanup of tests
parent
30d8b2da
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
2 additions
and
152 deletions
+2
-152
tests/test_distributed_convolution.py
tests/test_distributed_convolution.py
+0
-52
tests/test_distributed_resample.py
tests/test_distributed_resample.py
+1
-37
tests/test_distributed_sht.py
tests/test_distributed_sht.py
+1
-55
tests/test_sht.py
tests/test_sht.py
+0
-8
No files found.
tests/test_distributed_convolution.py
View file @
4770621b
...
...
@@ -45,15 +45,6 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
@
classmethod
def
setUpClass
(
cls
):
"""
Set up the distributed convolution test.
Parameters
----------
cls : TestDistributedDiscreteContinuousConvolution
The test class instance
"""
# set up distributed
cls
.
world_rank
=
int
(
os
.
getenv
(
"WORLD_RANK"
,
0
))
cls
.
grid_size_h
=
int
(
os
.
getenv
(
"GRID_H"
,
1
))
...
...
@@ -216,41 +207,6 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
def
test_distributed_disco_conv
(
self
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
batch_size
,
num_chan
,
kernel_shape
,
basis_type
,
basis_norm_mode
,
groups
,
grid_in
,
grid_out
,
transpose
,
tol
):
"""
Test the distributed discrete-continuous convolution module.
Parameters
----------
nlat_in : int
Number of latitude points in input
nlon_in : int
Number of longitude points in input
nlat_out : int
Number of latitude points in output
nlon_out : int
Number of longitude points in output
batch_size : int
Batch size
num_chan : int
Number of channels
kernel_shape : tuple
Kernel shape
basis_type : str
Basis type
basis_norm_mode : str
Basis normalization mode
groups : int
Number of groups
grid_in : str
Grid type for input
grid_out : str
Grid type for output
transpose : bool
Whether to transpose the convolution
tol : float
Tolerance for numerical equivalence
"""
B
,
C
,
H
,
W
=
batch_size
,
num_chan
,
nlat_in
,
nlon_in
...
...
@@ -285,9 +241,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
# create tensors
inp_full
=
torch
.
randn
((
B
,
C
,
H
,
W
),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
#############################################################
# local conv
#############################################################
# FWD pass
inp_full
.
requires_grad
=
True
out_full
=
conv_local
(
inp_full
)
...
...
@@ -301,9 +255,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
out_full
.
backward
(
ograd_full
)
igrad_full
=
inp_full
.
grad
.
clone
()
#############################################################
# distributed conv
#############################################################
# FWD pass
inp_local
=
self
.
_split_helper
(
inp_full
)
inp_local
.
requires_grad
=
True
...
...
@@ -315,9 +267,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
out_local
.
backward
(
ograd_local
)
igrad_local
=
inp_local
.
grad
.
clone
()
#############################################################
# evaluate FWD pass
#############################################################
with
torch
.
no_grad
():
out_gather_full
=
self
.
_gather_helper_fwd
(
out_local
,
B
,
C
,
conv_dist
)
err
=
torch
.
mean
(
torch
.
norm
(
out_full
-
out_gather_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
))
/
torch
.
norm
(
out_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
)))
...
...
@@ -325,9 +275,7 @@ class TestDistributedDiscreteContinuousConvolution(unittest.TestCase):
print
(
f
"final relative error of output:
{
err
.
item
()
}
"
)
self
.
assertTrue
(
err
.
item
()
<=
tol
)
#############################################################
# evaluate BWD pass
#############################################################
with
torch
.
no_grad
():
igrad_gather_full
=
self
.
_gather_helper_bwd
(
igrad_local
,
B
,
C
,
conv_dist
)
...
...
tests/test_distributed_resample.py
View file @
4770621b
...
...
@@ -200,35 +200,7 @@ class TestDistributedResampling(unittest.TestCase):
def
test_distributed_resampling
(
self
,
nlat_in
,
nlon_in
,
nlat_out
,
nlon_out
,
batch_size
,
num_chan
,
grid_in
,
grid_out
,
mode
,
tol
,
verbose
):
"""
Test the distributed resampling module.
Parameters
----------
nlat_in : int
Number of latitude points in input
nlon_in : int
Number of longitude points in input
nlat_out : int
Number of latitude points in output
nlon_out : int
Number of longitude points in output
batch_size : int
Batch size
num_chan : int
Number of channels
grid_in : str
Grid type for input
grid_out : str
Grid type for output
mode : str
Resampling mode
tol : float
Tolerance for numerical equivalence
verbose : bool
Whether to print verbose output
"""
B
,
C
,
H
,
W
=
batch_size
,
num_chan
,
nlat_in
,
nlon_in
res_args
=
dict
(
...
...
@@ -248,9 +220,7 @@ class TestDistributedResampling(unittest.TestCase):
# create tensors
inp_full
=
torch
.
randn
((
B
,
C
,
H
,
W
),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
#############################################################
# local conv
#############################################################
# FWD pass
inp_full
.
requires_grad
=
True
out_full
=
res_local
(
inp_full
)
...
...
@@ -264,9 +234,7 @@ class TestDistributedResampling(unittest.TestCase):
out_full
.
backward
(
ograd_full
)
igrad_full
=
inp_full
.
grad
.
clone
()
#############################################################
# distributed conv
#############################################################
# FWD pass
inp_local
=
self
.
_split_helper
(
inp_full
)
inp_local
.
requires_grad
=
True
...
...
@@ -278,9 +246,7 @@ class TestDistributedResampling(unittest.TestCase):
out_local
.
backward
(
ograd_local
)
igrad_local
=
inp_local
.
grad
.
clone
()
#############################################################
# evaluate FWD pass
#############################################################
with
torch
.
no_grad
():
out_gather_full
=
self
.
_gather_helper_fwd
(
out_local
,
B
,
C
,
res_dist
)
err
=
torch
.
mean
(
torch
.
norm
(
out_full
-
out_gather_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
))
/
torch
.
norm
(
out_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
)))
...
...
@@ -288,9 +254,7 @@ class TestDistributedResampling(unittest.TestCase):
print
(
f
"final relative error of output:
{
err
.
item
()
}
"
)
self
.
assertTrue
(
err
.
item
()
<=
tol
)
#############################################################
# evaluate BWD pass
#############################################################
with
torch
.
no_grad
():
igrad_gather_full
=
self
.
_gather_helper_bwd
(
igrad_local
,
B
,
C
,
res_dist
)
...
...
tests/test_distributed_sht.py
View file @
4770621b
...
...
@@ -215,27 +215,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
]
)
def
test_distributed_sht
(
self
,
nlat
,
nlon
,
batch_size
,
num_chan
,
grid
,
vector
,
tol
):
"""
Test the distributed spherical harmonic transform.
Parameters
----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
batch_size : int
Batch size
num_chan : int
Number of channels
grid : str
Grid type
vector : bool
Whether to use vector spherical harmonic transform
tol : float
Tolerance for numerical equivalence
"""
B
,
C
,
H
,
W
=
batch_size
,
num_chan
,
nlat
,
nlon
# set up handles
...
...
@@ -252,9 +232,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
else
:
inp_full
=
torch
.
randn
((
B
,
C
,
H
,
W
),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
#############################################################
# local transform
#############################################################
# FWD pass
inp_full
.
requires_grad
=
True
out_full
=
forward_transform_local
(
inp_full
)
...
...
@@ -268,9 +246,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
out_full
.
backward
(
ograd_full
)
igrad_full
=
inp_full
.
grad
.
clone
()
#############################################################
# distributed transform
#############################################################
# FWD pass
inp_local
=
self
.
_split_helper
(
inp_full
)
inp_local
.
requires_grad
=
True
...
...
@@ -282,9 +258,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
out_local
.
backward
(
ograd_local
)
igrad_local
=
inp_local
.
grad
.
clone
()
#############################################################
# evaluate FWD pass
#############################################################
with
torch
.
no_grad
():
out_gather_full
=
self
.
_gather_helper_fwd
(
out_local
,
B
,
C
,
forward_transform_dist
,
vector
)
err
=
torch
.
mean
(
torch
.
norm
(
out_full
-
out_gather_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
))
/
torch
.
norm
(
out_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
)))
...
...
@@ -292,9 +266,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
print
(
f
"final relative error of output:
{
err
.
item
()
}
"
)
self
.
assertTrue
(
err
.
item
()
<=
tol
)
#############################################################
# evaluate BWD pass
#############################################################
with
torch
.
no_grad
():
igrad_gather_full
=
self
.
_gather_helper_bwd
(
igrad_local
,
B
,
C
,
forward_transform_dist
,
vector
)
err
=
torch
.
mean
(
torch
.
norm
(
igrad_full
-
igrad_gather_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
))
/
torch
.
norm
(
igrad_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
)))
...
...
@@ -323,26 +295,6 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
]
)
def
test_distributed_isht
(
self
,
nlat
,
nlon
,
batch_size
,
num_chan
,
grid
,
vector
,
tol
):
"""
Test the distributed inverse spherical harmonic transform.
Parameters
----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
batch_size : int
Batch size
num_chan : int
Number of channels
grid : str
Grid type
vector : bool
Whether to use vector spherical harmonic transform
tol : float
Tolerance for numerical equivalence
"""
B
,
C
,
H
,
W
=
batch_size
,
num_chan
,
nlat
,
nlon
...
...
@@ -383,9 +335,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
out_full
.
backward
(
ograd_full
)
igrad_full
=
inp_full
.
grad
.
clone
()
#############################################################
# distributed transform
#############################################################
# FWD pass
inp_local
=
self
.
_split_helper
(
inp_full
)
inp_local
.
requires_grad
=
True
...
...
@@ -397,9 +347,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
out_local
.
backward
(
ograd_local
)
igrad_local
=
inp_local
.
grad
.
clone
()
#############################################################
# evaluate FWD pass
#############################################################
with
torch
.
no_grad
():
out_gather_full
=
self
.
_gather_helper_bwd
(
out_local
,
B
,
C
,
backward_transform_dist
,
vector
)
err
=
torch
.
mean
(
torch
.
norm
(
out_full
-
out_gather_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
))
/
torch
.
norm
(
out_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
)))
...
...
@@ -407,9 +355,7 @@ class TestDistributedSphericalHarmonicTransform(unittest.TestCase):
print
(
f
"final relative error of output:
{
err
.
item
()
}
"
)
self
.
assertTrue
(
err
.
item
()
<=
tol
)
#############################################################
# evaluate BWD pass
#############################################################
with
torch
.
no_grad
():
igrad_gather_full
=
self
.
_gather_helper_fwd
(
igrad_local
,
B
,
C
,
backward_transform_dist
,
vector
)
err
=
torch
.
mean
(
torch
.
norm
(
igrad_full
-
igrad_gather_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
))
/
torch
.
norm
(
igrad_full
,
p
=
"fro"
,
dim
=
(
-
1
,
-
2
)))
...
...
tests/test_sht.py
View file @
4770621b
...
...
@@ -65,14 +65,6 @@ class TestLegendrePolynomials(unittest.TestCase):
self
.
tol
=
1e-9
def
test_legendre
(
self
,
verbose
=
False
):
"""
Test the computation of associated Legendre polynomials.
Parameters
----------
verbose : bool, optional
Whether to print verbose output, by default False
"""
if
verbose
:
print
(
"Testing computation of associated Legendre polynomials"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment