Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
b5c410c0
Unverified
Commit
b5c410c0
authored
Jul 21, 2025
by
Thorsten Kurth
Committed by
GitHub
Jul 21, 2025
Browse files
Merge pull request #93 from NVIDIA/tkurth/device-fixes
Tkurth/device fixes
parents
4aaff021
3d604f85
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
144 additions
and
28 deletions
+144
-28
tests/test_convolution.py
tests/test_convolution.py
+69
-4
tests/test_sht.py
tests/test_sht.py
+41
-5
torch_harmonics/convolution.py
torch_harmonics/convolution.py
+3
-3
torch_harmonics/csrc/disco/disco_helpers.cpp
torch_harmonics/csrc/disco/disco_helpers.cpp
+18
-3
torch_harmonics/filter_basis.py
torch_harmonics/filter_basis.py
+3
-3
torch_harmonics/legendre.py
torch_harmonics/legendre.py
+4
-4
torch_harmonics/quadrature.py
torch_harmonics/quadrature.py
+1
-1
torch_harmonics/random_fields.py
torch_harmonics/random_fields.py
+3
-3
torch_harmonics/resample.py
torch_harmonics/resample.py
+2
-2
No files found.
tests/test_convolution.py
View file @
b5c410c0
...
@@ -39,7 +39,7 @@ from torch.autograd import gradcheck
...
@@ -39,7 +39,7 @@ from torch.autograd import gradcheck
from
torch_harmonics
import
quadrature
,
DiscreteContinuousConvS2
,
DiscreteContinuousConvTransposeS2
from
torch_harmonics
import
quadrature
,
DiscreteContinuousConvS2
,
DiscreteContinuousConvTransposeS2
from
torch_harmonics.quadrature
import
_precompute_grid
,
_precompute_latitudes
,
_precompute_longitudes
from
torch_harmonics.quadrature
import
_precompute_grid
,
_precompute_latitudes
,
_precompute_longitudes
from
torch_harmonics.convolution
import
_precompute_convolution_tensor_s2
_devices
=
[(
torch
.
device
(
"cpu"
),)]
_devices
=
[(
torch
.
device
(
"cpu"
),)]
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
...
@@ -127,7 +127,7 @@ def _precompute_convolution_tensor_dense(
...
@@ -127,7 +127,7 @@ def _precompute_convolution_tensor_dense(
quad_weights
=
win
.
reshape
(
-
1
,
1
)
/
nlon_in
/
2.0
quad_weights
=
win
.
reshape
(
-
1
,
1
)
/
nlon_in
/
2.0
# array for accumulating non-zero indices
# array for accumulating non-zero indices
out
=
torch
.
zeros
(
kernel_size
,
nlat_out
,
nlon_out
,
nlat_in
,
nlon_in
,
dtype
=
torch
.
float64
)
out
=
torch
.
zeros
(
kernel_size
,
nlat_out
,
nlon_out
,
nlat_in
,
nlon_in
,
dtype
=
torch
.
float64
,
device
=
lons_in
.
device
)
for
t
in
range
(
nlat_out
):
for
t
in
range
(
nlat_out
):
for
p
in
range
(
nlon_out
):
for
p
in
range
(
nlon_out
):
...
@@ -199,9 +199,10 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
...
@@ -199,9 +199,10 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
[
8
,
4
,
2
,
(
8
,
16
),
(
16
,
32
),
(
5
),
"piecewise linear"
,
"mean"
,
"equiangular"
,
"legendre-gauss"
,
True
,
1e-4
,
False
],
[
8
,
4
,
2
,
(
8
,
16
),
(
16
,
32
),
(
5
),
"piecewise linear"
,
"mean"
,
"equiangular"
,
"legendre-gauss"
,
True
,
1e-4
,
False
],
[
8
,
4
,
2
,
(
8
,
16
),
(
16
,
32
),
(
5
),
"piecewise linear"
,
"mean"
,
"legendre-gauss"
,
"equiangular"
,
True
,
1e-4
,
False
],
[
8
,
4
,
2
,
(
8
,
16
),
(
16
,
32
),
(
5
),
"piecewise linear"
,
"mean"
,
"legendre-gauss"
,
"equiangular"
,
True
,
1e-4
,
False
],
[
8
,
4
,
2
,
(
8
,
16
),
(
16
,
32
),
(
5
),
"piecewise linear"
,
"mean"
,
"legendre-gauss"
,
"legendre-gauss"
,
True
,
1e-4
,
False
],
[
8
,
4
,
2
,
(
8
,
16
),
(
16
,
32
),
(
5
),
"piecewise linear"
,
"mean"
,
"legendre-gauss"
,
"legendre-gauss"
,
True
,
1e-4
,
False
],
]
],
skip_on_empty
=
True
,
)
)
def
test_
disco_convolution
(
def
test_
forward_backward
(
self
,
self
,
batch_size
,
batch_size
,
in_channels
,
in_channels
,
...
@@ -315,6 +316,70 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
...
@@ -315,6 +316,70 @@ class TestDiscreteContinuousConvolution(unittest.TestCase):
self
.
assertTrue
(
torch
.
allclose
(
x_grad
,
x_ref_grad
,
rtol
=
tol
,
atol
=
tol
))
self
.
assertTrue
(
torch
.
allclose
(
x_grad
,
x_ref_grad
,
rtol
=
tol
,
atol
=
tol
))
self
.
assertTrue
(
torch
.
allclose
(
conv
.
weight
.
grad
,
w_ref
.
grad
,
rtol
=
tol
,
atol
=
tol
))
self
.
assertTrue
(
torch
.
allclose
(
conv
.
weight
.
grad
,
w_ref
.
grad
,
rtol
=
tol
,
atol
=
tol
))
@
parameterized
.
expand
(
[
[
8
,
4
,
2
,
(
16
,
32
),
(
16
,
32
),
(
3
),
"piecewise linear"
,
"mean"
,
"equiangular"
,
"equiangular"
,
False
,
1e-4
,
False
],
[
8
,
4
,
2
,
(
16
,
32
),
(
8
,
16
),
(
5
),
"piecewise linear"
,
"mean"
,
"legendre-gauss"
,
"legendre-gauss"
,
False
,
1e-4
,
False
],
[
8
,
4
,
2
,
(
16
,
32
),
(
16
,
32
),
(
3
),
"piecewise linear"
,
"mean"
,
"equiangular"
,
"equiangular"
,
True
,
1e-4
,
False
],
[
8
,
4
,
2
,
(
8
,
16
),
(
16
,
32
),
(
5
),
"piecewise linear"
,
"mean"
,
"legendre-gauss"
,
"legendre-gauss"
,
True
,
1e-4
,
False
],
],
skip_on_empty
=
True
,
)
@
unittest
.
skipIf
(
not
torch
.
cuda
.
is_available
(),
"CUDA is not available"
)
def
test_device_instantiation
(
self
,
batch_size
,
in_channels
,
out_channels
,
in_shape
,
out_shape
,
kernel_shape
,
basis_type
,
basis_norm_mode
,
grid_in
,
grid_out
,
transpose
,
tol
,
verbose
):
nlat_in
,
nlon_in
=
in_shape
nlat_out
,
nlon_out
=
out_shape
if
isinstance
(
kernel_shape
,
int
):
theta_cutoff
=
(
kernel_shape
+
1
)
*
torch
.
pi
/
float
(
nlat_in
-
1
)
else
:
theta_cutoff
=
(
kernel_shape
[
0
]
+
1
)
*
torch
.
pi
/
float
(
nlat_in
-
1
)
# get handle
Conv
=
DiscreteContinuousConvTransposeS2
if
transpose
else
DiscreteContinuousConvS2
# init on cpu
conv_host
=
Conv
(
in_channels
,
out_channels
,
in_shape
,
out_shape
,
kernel_shape
,
basis_type
=
basis_type
,
basis_norm_mode
=
basis_norm_mode
,
groups
=
1
,
grid_in
=
grid_in
,
grid_out
=
grid_out
,
bias
=
False
,
theta_cutoff
=
theta_cutoff
,
)
#torch.set_default_device(self.device)
with
torch
.
device
(
self
.
device
):
conv_device
=
Conv
(
in_channels
,
out_channels
,
in_shape
,
out_shape
,
kernel_shape
,
basis_type
=
basis_type
,
basis_norm_mode
=
basis_norm_mode
,
groups
=
1
,
grid_in
=
grid_in
,
grid_out
=
grid_out
,
bias
=
False
,
theta_cutoff
=
theta_cutoff
,
)
# since we specified the device specifier everywhere, it should always
# use the cpu and it should be the same everywhere
self
.
assertTrue
(
torch
.
allclose
(
conv_host
.
psi_col_idx
.
cpu
(),
conv_device
.
psi_col_idx
.
cpu
()))
self
.
assertTrue
(
torch
.
allclose
(
conv_host
.
psi_row_idx
.
cpu
(),
conv_device
.
psi_row_idx
.
cpu
()))
self
.
assertTrue
(
torch
.
allclose
(
conv_host
.
psi_roff_idx
.
cpu
(),
conv_device
.
psi_roff_idx
.
cpu
()))
self
.
assertTrue
(
torch
.
allclose
(
conv_host
.
psi_vals
.
cpu
(),
conv_device
.
psi_vals
.
cpu
()))
self
.
assertTrue
(
torch
.
allclose
(
conv_host
.
psi_idx
.
cpu
(),
conv_device
.
psi_idx
.
cpu
()))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
tests/test_sht.py
View file @
b5c410c0
...
@@ -101,15 +101,16 @@ class TestSphericalHarmonicTransform(unittest.TestCase):
...
@@ -101,15 +101,16 @@ class TestSphericalHarmonicTransform(unittest.TestCase):
[
33
,
64
,
32
,
"ortho"
,
"equiangular"
,
1e-9
,
False
],
[
33
,
64
,
32
,
"ortho"
,
"equiangular"
,
1e-9
,
False
],
[
33
,
64
,
32
,
"ortho"
,
"legendre-gauss"
,
1e-9
,
False
],
[
33
,
64
,
32
,
"ortho"
,
"legendre-gauss"
,
1e-9
,
False
],
[
33
,
64
,
32
,
"ortho"
,
"lobatto"
,
1e-9
,
False
],
[
33
,
64
,
32
,
"ortho"
,
"lobatto"
,
1e-9
,
False
],
[
33
,
64
,
32
,
"four-pi"
,
"equiangular"
,
1e-9
,
False
],
[
33
,
64
,
32
,
"four-pi"
,
"equiangular"
,
1e-9
,
False
],
[
33
,
64
,
32
,
"four-pi"
,
"legendre-gauss"
,
1e-9
,
False
],
[
33
,
64
,
32
,
"four-pi"
,
"legendre-gauss"
,
1e-9
,
False
],
[
33
,
64
,
32
,
"four-pi"
,
"lobatto"
,
1e-9
,
False
],
[
33
,
64
,
32
,
"four-pi"
,
"lobatto"
,
1e-9
,
False
],
[
33
,
64
,
32
,
"schmidt"
,
"equiangular"
,
1e-9
,
False
],
[
33
,
64
,
32
,
"schmidt"
,
"equiangular"
,
1e-9
,
False
],
[
33
,
64
,
32
,
"schmidt"
,
"legendre-gauss"
,
1e-9
,
False
],
[
33
,
64
,
32
,
"schmidt"
,
"legendre-gauss"
,
1e-9
,
False
],
[
33
,
64
,
32
,
"schmidt"
,
"lobatto"
,
1e-9
,
False
],
[
33
,
64
,
32
,
"schmidt"
,
"lobatto"
,
1e-9
,
False
],
]
],
skip_on_empty
=
True
,
)
)
def
test_
sht
(
self
,
nlat
,
nlon
,
batch_size
,
norm
,
grid
,
tol
,
verbose
):
def
test_
forward_inverse
(
self
,
nlat
,
nlon
,
batch_size
,
norm
,
grid
,
tol
,
verbose
):
if
verbose
:
if
verbose
:
print
(
f
"Testing real-valued SHT on
{
nlat
}
x
{
nlon
}
{
grid
}
grid with
{
norm
}
normalization on
{
self
.
device
.
type
}
device"
)
print
(
f
"Testing real-valued SHT on
{
nlat
}
x
{
nlon
}
{
grid
}
grid with
{
norm
}
normalization on
{
self
.
device
.
type
}
device"
)
...
@@ -168,9 +169,10 @@ class TestSphericalHarmonicTransform(unittest.TestCase):
...
@@ -168,9 +169,10 @@ class TestSphericalHarmonicTransform(unittest.TestCase):
[
15
,
30
,
2
,
"schmidt"
,
"equiangular"
,
1e-5
,
False
],
[
15
,
30
,
2
,
"schmidt"
,
"equiangular"
,
1e-5
,
False
],
[
15
,
30
,
2
,
"schmidt"
,
"legendre-gauss"
,
1e-5
,
False
],
[
15
,
30
,
2
,
"schmidt"
,
"legendre-gauss"
,
1e-5
,
False
],
[
15
,
30
,
2
,
"schmidt"
,
"lobatto"
,
1e-5
,
False
],
[
15
,
30
,
2
,
"schmidt"
,
"lobatto"
,
1e-5
,
False
],
]
],
skip_on_empty
=
True
,
)
)
def
test_
sht_
grads
(
self
,
nlat
,
nlon
,
batch_size
,
norm
,
grid
,
tol
,
verbose
):
def
test_grads
(
self
,
nlat
,
nlon
,
batch_size
,
norm
,
grid
,
tol
,
verbose
):
if
verbose
:
if
verbose
:
print
(
f
"Testing gradients of real-valued SHT on
{
nlat
}
x
{
nlon
}
{
grid
}
grid with
{
norm
}
normalization"
)
print
(
f
"Testing gradients of real-valued SHT on
{
nlat
}
x
{
nlon
}
{
grid
}
grid with
{
norm
}
normalization"
)
...
@@ -202,6 +204,40 @@ class TestSphericalHarmonicTransform(unittest.TestCase):
...
@@ -202,6 +204,40 @@ class TestSphericalHarmonicTransform(unittest.TestCase):
test_result
=
gradcheck
(
err_handle
,
grad_input
,
eps
=
1e-6
,
atol
=
tol
)
test_result
=
gradcheck
(
err_handle
,
grad_input
,
eps
=
1e-6
,
atol
=
tol
)
self
.
assertTrue
(
test_result
)
self
.
assertTrue
(
test_result
)
@
parameterized
.
expand
(
[
# even-even
[
12
,
24
,
2
,
"ortho"
,
"equiangular"
,
1e-5
,
False
],
[
12
,
24
,
2
,
"ortho"
,
"legendre-gauss"
,
1e-5
,
False
],
[
12
,
24
,
2
,
"ortho"
,
"lobatto"
,
1e-5
,
False
],
],
skip_on_empty
=
True
,
)
@
unittest
.
skipIf
(
not
torch
.
cuda
.
is_available
(),
"CUDA is not available"
)
def
test_device_instantiation
(
self
,
nlat
,
nlon
,
batch_size
,
norm
,
grid
,
tol
,
verbose
):
if
verbose
:
print
(
f
"Testing device instantiation of real-valued SHT on
{
nlat
}
x
{
nlon
}
{
grid
}
grid with
{
norm
}
normalization"
)
if
grid
==
"equiangular"
:
mmax
=
nlat
//
2
elif
grid
==
"lobatto"
:
mmax
=
nlat
-
1
else
:
mmax
=
nlat
lmax
=
mmax
# init on cpu
sht_host
=
th
.
RealSHT
(
nlat
,
nlon
,
mmax
=
mmax
,
lmax
=
lmax
,
grid
=
grid
,
norm
=
norm
)
isht_host
=
th
.
InverseRealSHT
(
nlat
,
nlon
,
mmax
=
mmax
,
lmax
=
lmax
,
grid
=
grid
,
norm
=
norm
)
# init on device
with
torch
.
device
(
self
.
device
):
sht_device
=
th
.
RealSHT
(
nlat
,
nlon
,
mmax
=
mmax
,
lmax
=
lmax
,
grid
=
grid
,
norm
=
norm
)
isht_device
=
th
.
InverseRealSHT
(
nlat
,
nlon
,
mmax
=
mmax
,
lmax
=
lmax
,
grid
=
grid
,
norm
=
norm
)
self
.
assertTrue
(
torch
.
allclose
(
sht_host
.
weights
.
cpu
(),
sht_device
.
weights
.
cpu
()))
self
.
assertTrue
(
torch
.
allclose
(
isht_host
.
pct
.
cpu
(),
isht_device
.
pct
.
cpu
()))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
torch_harmonics/convolution.py
View file @
b5c410c0
...
@@ -92,8 +92,8 @@ def _normalize_convolution_tensor_s2(
...
@@ -92,8 +92,8 @@ def _normalize_convolution_tensor_s2(
q
=
quad_weights
[
ilat_in
].
reshape
(
-
1
)
q
=
quad_weights
[
ilat_in
].
reshape
(
-
1
)
# buffer to store intermediate values
# buffer to store intermediate values
vnorm
=
torch
.
zeros
(
kernel_size
,
nlat_out
)
vnorm
=
torch
.
zeros
(
kernel_size
,
nlat_out
,
device
=
psi_vals
.
device
)
support
=
torch
.
zeros
(
kernel_size
,
nlat_out
)
support
=
torch
.
zeros
(
kernel_size
,
nlat_out
,
device
=
psi_vals
.
device
)
# loop through dimensions to compute the norms
# loop through dimensions to compute the norms
for
ik
in
range
(
kernel_size
):
for
ik
in
range
(
kernel_size
):
...
@@ -207,7 +207,7 @@ def _precompute_convolution_tensor_s2(
...
@@ -207,7 +207,7 @@ def _precompute_convolution_tensor_s2(
sgamma
=
torch
.
sin
(
gamma
)
sgamma
=
torch
.
sin
(
gamma
)
# compute row offsets
# compute row offsets
out_roff
=
torch
.
zeros
(
nlat_out
+
1
,
dtype
=
torch
.
int64
)
out_roff
=
torch
.
zeros
(
nlat_out
+
1
,
dtype
=
torch
.
int64
,
device
=
lons_in
.
device
)
out_roff
[
0
]
=
0
out_roff
[
0
]
=
0
for
t
in
range
(
nlat_out
):
for
t
in
range
(
nlat_out
):
# the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis
# the last angle has a negative sign as it is a passive rotation, which rotates the filter around the y-axis
...
...
torch_harmonics/csrc/disco/disco_helpers.cpp
View file @
b5c410c0
...
@@ -104,13 +104,22 @@ torch::Tensor preprocess_psi(const int64_t K, const int64_t Ho, torch::Tensor ke
...
@@ -104,13 +104,22 @@ torch::Tensor preprocess_psi(const int64_t K, const int64_t Ho, torch::Tensor ke
CHECK_INPUT_TENSOR
(
col_idx
);
CHECK_INPUT_TENSOR
(
col_idx
);
CHECK_INPUT_TENSOR
(
val
);
CHECK_INPUT_TENSOR
(
val
);
// get the input device and make sure all tensors are on the same device
auto
device
=
ker_idx
.
device
();
TORCH_INTERNAL_ASSERT
(
device
.
type
()
==
row_idx
.
device
().
type
()
&&
(
device
.
type
()
==
col_idx
.
device
().
type
())
&&
(
device
.
type
()
==
val
.
device
().
type
()));
// move to cpu
ker_idx
=
ker_idx
.
to
(
torch
::
kCPU
);
row_idx
=
row_idx
.
to
(
torch
::
kCPU
);
col_idx
=
col_idx
.
to
(
torch
::
kCPU
);
val
=
val
.
to
(
torch
::
kCPU
);
int64_t
nnz
=
val
.
size
(
0
);
int64_t
nnz
=
val
.
size
(
0
);
int64_t
*
ker_h
=
ker_idx
.
data_ptr
<
int64_t
>
();
int64_t
*
ker_h
=
ker_idx
.
data_ptr
<
int64_t
>
();
int64_t
*
row_h
=
row_idx
.
data_ptr
<
int64_t
>
();
int64_t
*
row_h
=
row_idx
.
data_ptr
<
int64_t
>
();
int64_t
*
col_h
=
col_idx
.
data_ptr
<
int64_t
>
();
int64_t
*
col_h
=
col_idx
.
data_ptr
<
int64_t
>
();
int64_t
*
roff_h
=
new
int64_t
[
Ho
*
K
+
1
];
int64_t
*
roff_h
=
new
int64_t
[
Ho
*
K
+
1
];
int64_t
nrows
;
int64_t
nrows
;
// float *val_h = val.data_ptr<float>();
AT_DISPATCH_FLOATING_TYPES
(
val
.
scalar_type
(),
"preprocess_psi"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
val
.
scalar_type
(),
"preprocess_psi"
,
([
&
]
{
preprocess_psi_kernel
<
scalar_t
>
(
nnz
,
K
,
Ho
,
ker_h
,
row_h
,
col_h
,
roff_h
,
preprocess_psi_kernel
<
scalar_t
>
(
nnz
,
K
,
Ho
,
ker_h
,
row_h
,
col_h
,
roff_h
,
...
@@ -118,13 +127,19 @@ torch::Tensor preprocess_psi(const int64_t K, const int64_t Ho, torch::Tensor ke
...
@@ -118,13 +127,19 @@ torch::Tensor preprocess_psi(const int64_t K, const int64_t Ho, torch::Tensor ke
}));
}));
// create output tensor
// create output tensor
auto
options
=
torch
::
TensorOptions
().
dtype
(
row_idx
.
dtype
());
auto
roff_idx
=
torch
::
empty
({
nrows
+
1
},
row_idx
.
options
());
auto
roff_idx
=
torch
::
empty
({
nrows
+
1
},
options
);
int64_t
*
roff_out_h
=
roff_idx
.
data_ptr
<
int64_t
>
();
int64_t
*
roff_out_h
=
roff_idx
.
data_ptr
<
int64_t
>
();
for
(
int64_t
i
=
0
;
i
<
(
nrows
+
1
);
i
++
)
{
roff_out_h
[
i
]
=
roff_h
[
i
];
}
for
(
int64_t
i
=
0
;
i
<
(
nrows
+
1
);
i
++
)
{
roff_out_h
[
i
]
=
roff_h
[
i
];
}
delete
[]
roff_h
;
delete
[]
roff_h
;
// move to original device
ker_idx
=
ker_idx
.
to
(
device
);
row_idx
=
row_idx
.
to
(
device
);
col_idx
=
col_idx
.
to
(
device
);
val
=
val
.
to
(
device
);
roff_idx
=
roff_idx
.
to
(
device
);
return
roff_idx
;
return
roff_idx
;
}
}
...
...
torch_harmonics/filter_basis.py
View file @
b5c410c0
...
@@ -254,7 +254,7 @@ class MorletFilterBasis(FilterBasis):
...
@@ -254,7 +254,7 @@ class MorletFilterBasis(FilterBasis):
mkernel
=
ikernel
//
self
.
kernel_shape
[
1
]
mkernel
=
ikernel
//
self
.
kernel_shape
[
1
]
# get relevant indices
# get relevant indices
iidx
=
torch
.
argwhere
((
r
<=
r_cutoff
)
&
torch
.
full_like
(
ikernel
,
True
,
dtype
=
torch
.
bool
))
iidx
=
torch
.
argwhere
((
r
<=
r_cutoff
)
&
torch
.
full_like
(
ikernel
,
True
,
dtype
=
torch
.
bool
,
device
=
r
.
device
))
# get corresponding r, phi, x and y coordinates
# get corresponding r, phi, x and y coordinates
r
=
r
[
iidx
[:,
1
],
iidx
[:,
2
]]
/
r_cutoff
r
=
r
[
iidx
[:,
1
],
iidx
[:,
2
]]
/
r_cutoff
...
@@ -316,10 +316,10 @@ class ZernikeFilterBasis(FilterBasis):
...
@@ -316,10 +316,10 @@ class ZernikeFilterBasis(FilterBasis):
"""
"""
# enumerator for basis function
# enumerator for basis function
ikernel
=
torch
.
arange
(
self
.
kernel_size
).
reshape
(
-
1
,
1
,
1
)
ikernel
=
torch
.
arange
(
self
.
kernel_size
,
device
=
r
.
device
).
reshape
(
-
1
,
1
,
1
)
# get relevant indices
# get relevant indices
iidx
=
torch
.
argwhere
((
r
<=
r_cutoff
)
&
torch
.
full_like
(
ikernel
,
True
,
dtype
=
torch
.
bool
))
iidx
=
torch
.
argwhere
((
r
<=
r_cutoff
)
&
torch
.
full_like
(
ikernel
,
True
,
dtype
=
torch
.
bool
,
device
=
r
.
device
))
# indexing logic for zernike polynomials
# indexing logic for zernike polynomials
# the total index is given by (n * (n + 2) + l ) // 2 which needs to be reversed
# the total index is given by (n * (n + 2) + l ) // 2 which needs to be reversed
...
...
torch_harmonics/legendre.py
View file @
b5c410c0
...
@@ -57,10 +57,10 @@ def legpoly(mmax: int, lmax: int, x: torch.Tensor, norm: Optional[str]="ortho",
...
@@ -57,10 +57,10 @@ def legpoly(mmax: int, lmax: int, x: torch.Tensor, norm: Optional[str]="ortho",
# compute the tensor P^m_n:
# compute the tensor P^m_n:
nmax
=
max
(
mmax
,
lmax
)
nmax
=
max
(
mmax
,
lmax
)
vdm
=
torch
.
zeros
((
nmax
,
nmax
,
len
(
x
)),
dtype
=
torch
.
float64
,
requires_grad
=
False
)
vdm
=
torch
.
zeros
((
nmax
,
nmax
,
len
(
x
)),
dtype
=
torch
.
float64
,
device
=
x
.
device
,
requires_grad
=
False
)
norm_factor
=
1.
if
norm
==
"ortho"
else
math
.
sqrt
(
4
*
math
.
pi
)
norm_factor
=
1.
0
if
norm
==
"ortho"
else
math
.
sqrt
(
4
*
math
.
pi
)
norm_factor
=
1.
/
norm_factor
if
inverse
else
norm_factor
norm_factor
=
1.
0
/
norm_factor
if
inverse
else
norm_factor
# initial values to start the recursion
# initial values to start the recursion
vdm
[
0
,
0
,:]
=
norm_factor
/
math
.
sqrt
(
4
*
math
.
pi
)
vdm
[
0
,
0
,:]
=
norm_factor
/
math
.
sqrt
(
4
*
math
.
pi
)
...
@@ -123,7 +123,7 @@ def _precompute_dlegpoly(mmax: int, lmax: int, t: torch.Tensor,
...
@@ -123,7 +123,7 @@ def _precompute_dlegpoly(mmax: int, lmax: int, t: torch.Tensor,
pct
=
_precompute_legpoly
(
mmax
+
1
,
lmax
+
1
,
t
,
norm
=
norm
,
inverse
=
inverse
,
csphase
=
False
)
pct
=
_precompute_legpoly
(
mmax
+
1
,
lmax
+
1
,
t
,
norm
=
norm
,
inverse
=
inverse
,
csphase
=
False
)
dpct
=
torch
.
zeros
((
2
,
mmax
,
lmax
,
len
(
t
)),
dtype
=
torch
.
float64
,
requires_grad
=
False
)
dpct
=
torch
.
zeros
((
2
,
mmax
,
lmax
,
len
(
t
)),
dtype
=
torch
.
float64
,
device
=
t
.
device
,
requires_grad
=
False
)
# fill the derivative terms wrt theta
# fill the derivative terms wrt theta
for
l
in
range
(
0
,
lmax
):
for
l
in
range
(
0
,
lmax
):
...
...
torch_harmonics/quadrature.py
View file @
b5c410c0
...
@@ -169,7 +169,7 @@ def clenshaw_curtiss_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]
...
@@ -169,7 +169,7 @@ def clenshaw_curtiss_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]
tcc
=
torch
.
cos
(
torch
.
linspace
(
math
.
pi
,
0
,
n
,
dtype
=
torch
.
float64
,
requires_grad
=
False
))
tcc
=
torch
.
cos
(
torch
.
linspace
(
math
.
pi
,
0
,
n
,
dtype
=
torch
.
float64
,
requires_grad
=
False
))
if
n
==
2
:
if
n
==
2
:
wcc
=
torch
.
tensor
([
1.0
,
1.0
],
dtype
=
torch
.
float64
)
wcc
=
torch
.
as_
tensor
([
1.0
,
1.0
],
dtype
=
torch
.
float64
)
else
:
else
:
n1
=
n
-
1
n1
=
n
-
1
...
...
torch_harmonics/random_fields.py
View file @
b5c410c0
...
@@ -77,7 +77,7 @@ class GaussianRandomFieldS2(torch.nn.Module):
...
@@ -77,7 +77,7 @@ class GaussianRandomFieldS2(torch.nn.Module):
self
.
isht
=
InverseRealSHT
(
self
.
nlat
,
2
*
self
.
nlat
,
grid
=
grid
,
norm
=
'backward'
).
to
(
dtype
=
dtype
)
self
.
isht
=
InverseRealSHT
(
self
.
nlat
,
2
*
self
.
nlat
,
grid
=
grid
,
norm
=
'backward'
).
to
(
dtype
=
dtype
)
#Square root of the eigenvalues of C.
#Square root of the eigenvalues of C.
sqrt_eig
=
torch
.
tensor
([
j
*
(
j
+
1
)
for
j
in
range
(
self
.
nlat
)]).
view
(
self
.
nlat
,
1
).
repeat
(
1
,
self
.
nlat
+
1
)
sqrt_eig
=
torch
.
as_
tensor
([
j
*
(
j
+
1
)
for
j
in
range
(
self
.
nlat
)]).
view
(
self
.
nlat
,
1
).
repeat
(
1
,
self
.
nlat
+
1
)
sqrt_eig
=
torch
.
tril
(
sigma
*
(((
sqrt_eig
/
radius
**
2
)
+
tau
**
2
)
**
(
-
alpha
/
2.0
)))
sqrt_eig
=
torch
.
tril
(
sigma
*
(((
sqrt_eig
/
radius
**
2
)
+
tau
**
2
)
**
(
-
alpha
/
2.0
)))
sqrt_eig
[
0
,
0
]
=
0.0
sqrt_eig
[
0
,
0
]
=
0.0
sqrt_eig
=
sqrt_eig
.
unsqueeze
(
0
)
sqrt_eig
=
sqrt_eig
.
unsqueeze
(
0
)
...
@@ -85,8 +85,8 @@ class GaussianRandomFieldS2(torch.nn.Module):
...
@@ -85,8 +85,8 @@ class GaussianRandomFieldS2(torch.nn.Module):
#Save mean and var of the standard Gaussian.
#Save mean and var of the standard Gaussian.
#Need these to re-initialize distribution on a new device.
#Need these to re-initialize distribution on a new device.
mean
=
torch
.
tensor
([
0.0
]).
to
(
dtype
=
dtype
)
mean
=
torch
.
as_
tensor
([
0.0
]).
to
(
dtype
=
dtype
)
var
=
torch
.
tensor
([
1.0
]).
to
(
dtype
=
dtype
)
var
=
torch
.
as_
tensor
([
1.0
]).
to
(
dtype
=
dtype
)
self
.
register_buffer
(
'mean'
,
mean
)
self
.
register_buffer
(
'mean'
,
mean
)
self
.
register_buffer
(
'var'
,
var
)
self
.
register_buffer
(
'var'
,
var
)
...
...
torch_harmonics/resample.py
View file @
b5c410c0
...
@@ -75,9 +75,9 @@ class ResampleS2(nn.Module):
...
@@ -75,9 +75,9 @@ class ResampleS2(nn.Module):
# we need to expand the solution to the poles before interpolating
# we need to expand the solution to the poles before interpolating
self
.
expand_poles
=
(
self
.
lats_out
>
self
.
lats_in
[
-
1
]).
any
()
or
(
self
.
lats_out
<
self
.
lats_in
[
0
]).
any
()
self
.
expand_poles
=
(
self
.
lats_out
>
self
.
lats_in
[
-
1
]).
any
()
or
(
self
.
lats_out
<
self
.
lats_in
[
0
]).
any
()
if
self
.
expand_poles
:
if
self
.
expand_poles
:
self
.
lats_in
=
torch
.
cat
([
torch
.
tensor
([
0.
],
dtype
=
torch
.
float64
),
self
.
lats_in
=
torch
.
cat
([
torch
.
as_
tensor
([
0.
],
dtype
=
torch
.
float64
,
device
=
self
.
lats_in
.
device
),
self
.
lats_in
,
self
.
lats_in
,
torch
.
tensor
([
math
.
pi
],
dtype
=
torch
.
float64
)]).
contiguous
()
torch
.
as_
tensor
([
math
.
pi
],
dtype
=
torch
.
float64
,
device
=
self
.
lats_in
.
device
)]).
contiguous
()
# prepare the interpolation by computing indices to the left and right of each output latitude
# prepare the interpolation by computing indices to the left and right of each output latitude
lat_idx
=
torch
.
searchsorted
(
self
.
lats_in
,
self
.
lats_out
,
side
=
"right"
)
-
1
lat_idx
=
torch
.
searchsorted
(
self
.
lats_in
,
self
.
lats_out
,
side
=
"right"
)
-
1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment