Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
30d8b2da
Commit
30d8b2da
authored
Jul 17, 2025
by
Andrea Paris
Committed by
Boris Bonev
Jul 21, 2025
Browse files
further cleanup
parent
ec53e666
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
23 additions
and
91 deletions
+23
-91
torch_harmonics/convolution.py
torch_harmonics/convolution.py
+0
-14
torch_harmonics/examples/models/_layers.py
torch_harmonics/examples/models/_layers.py
+1
-1
torch_harmonics/examples/models/lsno.py
torch_harmonics/examples/models/lsno.py
+3
-3
torch_harmonics/filter_basis.py
torch_harmonics/filter_basis.py
+5
-20
torch_harmonics/legendre.py
torch_harmonics/legendre.py
+4
-18
torch_harmonics/plotting.py
torch_harmonics/plotting.py
+0
-8
torch_harmonics/quadrature.py
torch_harmonics/quadrature.py
+6
-6
torch_harmonics/resample.py
torch_harmonics/resample.py
+0
-17
torch_harmonics/sht.py
torch_harmonics/sht.py
+4
-4
No files found.
torch_harmonics/convolution.py
View file @
30d8b2da
...
@@ -664,20 +664,6 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
...
@@ -664,20 +664,6 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
return
torch
.
stack
([
self
.
psi_ker_idx
,
self
.
psi_row_idx
,
self
.
psi_col_idx
],
dim
=
0
).
contiguous
()
return
torch
.
stack
([
self
.
psi_ker_idx
,
self
.
psi_row_idx
,
self
.
psi_col_idx
],
dim
=
0
).
contiguous
()
def
get_psi
(
self
,
semi_transposed
:
bool
=
False
):
def
get_psi
(
self
,
semi_transposed
:
bool
=
False
):
"""
Get the convolution tensor
Parameters
-----------
semi_transposed: bool
Whether to semi-transpose the convolution tensor
Returns
-------
psi: torch.Tensor
Convolution tensor
"""
if
semi_transposed
:
if
semi_transposed
:
# we do a semi-transposition to faciliate the computation
# we do a semi-transposition to faciliate the computation
tout
=
self
.
psi_idx
[
2
]
//
self
.
nlon_out
tout
=
self
.
psi_idx
[
2
]
//
self
.
nlon_out
...
...
torch_harmonics/examples/models/_layers.py
View file @
30d8b2da
...
@@ -77,7 +77,7 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
...
@@ -77,7 +77,7 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
def
trunc_normal_
(
tensor
,
mean
=
0.0
,
std
=
1.0
,
a
=-
2.0
,
b
=
2.0
):
def
trunc_normal_
(
tensor
,
mean
=
0.0
,
std
=
1.0
,
a
=-
2.0
,
b
=
2.0
):
r
"""Fills the input Tensor with values drawn from a truncated
"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(
\t
ext{mean},
\t
ext{std}^2)`
normal distribution :math:`\mathcal{N}(
\t
ext{mean},
\t
ext{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
with values outside :math:`[a, b]` redrawn until they are within
...
...
torch_harmonics/examples/models/lsno.py
View file @
30d8b2da
...
@@ -50,7 +50,7 @@ def _compute_cutoff_radius(nlat, kernel_shape, basis_type):
...
@@ -50,7 +50,7 @@ def _compute_cutoff_radius(nlat, kernel_shape, basis_type):
return
(
kernel_shape
[
0
]
+
1
)
*
theta_cutoff_factor
[
basis_type
]
*
math
.
pi
/
float
(
nlat
-
1
)
return
(
kernel_shape
[
0
]
+
1
)
*
theta_cutoff_factor
[
basis_type
]
*
math
.
pi
/
float
(
nlat
-
1
)
class
DiscreteContinuousEncoder
(
nn
.
Module
):
class
DiscreteContinuousEncoder
(
nn
.
Module
):
r
"""
"""
Discrete-continuous encoder for spherical neural operators.
Discrete-continuous encoder for spherical neural operators.
This module performs downsampling using discrete-continuous convolutions on the sphere,
This module performs downsampling using discrete-continuous convolutions on the sphere,
...
@@ -122,7 +122,7 @@ class DiscreteContinuousEncoder(nn.Module):
...
@@ -122,7 +122,7 @@ class DiscreteContinuousEncoder(nn.Module):
class
DiscreteContinuousDecoder
(
nn
.
Module
):
class
DiscreteContinuousDecoder
(
nn
.
Module
):
r
"""
"""
Discrete-continuous decoder for spherical neural operators.
Discrete-continuous decoder for spherical neural operators.
This module performs upsampling using either spherical harmonic transforms or resampling,
This module performs upsampling using either spherical harmonic transforms or resampling,
...
@@ -376,7 +376,7 @@ class SphericalNeuralOperatorBlock(nn.Module):
...
@@ -376,7 +376,7 @@ class SphericalNeuralOperatorBlock(nn.Module):
class
LocalSphericalNeuralOperator
(
nn
.
Module
):
class
LocalSphericalNeuralOperator
(
nn
.
Module
):
r
"""
"""
LocalSphericalNeuralOperator module. A spherical neural operator which uses both local and global integral
LocalSphericalNeuralOperator module. A spherical neural operator which uses both local and global integral
operators to accureately model both types of solution operators [1]. The architecture is based on the Spherical
operators to accureately model both types of solution operators [1]. The architecture is based on the Spherical
Fourier Neural Operator [2] and improves upon it with local integral operators in both the Neural Operator blocks,
Fourier Neural Operator [2] and improves upon it with local integral operators in both the Neural Operator blocks,
...
...
torch_harmonics/filter_basis.py
View file @
30d8b2da
...
@@ -51,9 +51,7 @@ def _factorial(x: torch.Tensor):
...
@@ -51,9 +51,7 @@ def _factorial(x: torch.Tensor):
class
FilterBasis
(
metaclass
=
abc
.
ABCMeta
):
class
FilterBasis
(
metaclass
=
abc
.
ABCMeta
):
"""
"""Abstract base class for a filter basis"""
Abstract base class for a filter basis
"""
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -96,9 +94,7 @@ def get_filter_basis(kernel_shape: Union[int, Tuple[int], Tuple[int, int]], basi
...
@@ -96,9 +94,7 @@ def get_filter_basis(kernel_shape: Union[int, Tuple[int], Tuple[int, int]], basi
class
PiecewiseLinearFilterBasis
(
FilterBasis
):
class
PiecewiseLinearFilterBasis
(
FilterBasis
):
"""
"""Tensor-product basis on a disk constructed from piecewise linear basis functions."""
Tensor-product basis on a disk constructed from piecewise linear basis functions.
"""
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -116,14 +112,7 @@ class PiecewiseLinearFilterBasis(FilterBasis):
...
@@ -116,14 +112,7 @@ class PiecewiseLinearFilterBasis(FilterBasis):
@
property
@
property
def
kernel_size
(
self
):
def
kernel_size
(
self
):
"""
"""Compute the number of basis functions in the kernel."""
Compute the number of basis functions in the kernel.
Returns
-------
kernel_size: int
The number of basis functions in the kernel
"""
return
(
self
.
kernel_shape
[
0
]
//
2
)
*
self
.
kernel_shape
[
1
]
+
self
.
kernel_shape
[
0
]
%
2
return
(
self
.
kernel_shape
[
0
]
//
2
)
*
self
.
kernel_shape
[
1
]
+
self
.
kernel_shape
[
0
]
%
2
def
_compute_support_vals_isotropic
(
self
,
r
:
torch
.
Tensor
,
phi
:
torch
.
Tensor
,
r_cutoff
:
float
):
def
_compute_support_vals_isotropic
(
self
,
r
:
torch
.
Tensor
,
phi
:
torch
.
Tensor
,
r_cutoff
:
float
):
...
@@ -214,9 +203,7 @@ class PiecewiseLinearFilterBasis(FilterBasis):
...
@@ -214,9 +203,7 @@ class PiecewiseLinearFilterBasis(FilterBasis):
class
MorletFilterBasis
(
FilterBasis
):
class
MorletFilterBasis
(
FilterBasis
):
"""
"""Morlet-style filter basis on the disk. A Gaussian is multiplied with a Fourier basis in x and y directions."""
Morlet-style filter basis on the disk. A Gaussian is multiplied with a Fourier basis in x and y directions
"""
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -271,9 +258,7 @@ class MorletFilterBasis(FilterBasis):
...
@@ -271,9 +258,7 @@ class MorletFilterBasis(FilterBasis):
class
ZernikeFilterBasis
(
FilterBasis
):
class
ZernikeFilterBasis
(
FilterBasis
):
"""
"""Zernike polynomials which are defined on the disk. See https://en.wikipedia.org/wiki/Zernike_polynomials"""
Zernike polynomials which are defined on the disk. See https://en.wikipedia.org/wiki/Zernike_polynomials
"""
def
__init__
(
def
__init__
(
self
,
self
,
...
...
torch_harmonics/legendre.py
View file @
30d8b2da
...
@@ -37,25 +37,11 @@ from torch_harmonics.cache import lru_cache
...
@@ -37,25 +37,11 @@ from torch_harmonics.cache import lru_cache
def
clm
(
l
:
int
,
m
:
int
)
->
float
:
def
clm
(
l
:
int
,
m
:
int
)
->
float
:
"""
"""Defines the normalization factor to orthonormalize the Spherical Harmonics."""
defines the normalization factor to orthonormalize the Spherical Harmonics
Parameters
-----------
l: int
Degree of the spherical harmonic
m: int
Order of the spherical harmonic
Returns
-------
out: float
Normalization factor
"""
return
math
.
sqrt
((
2
*
l
+
1
)
/
4
/
math
.
pi
)
*
math
.
sqrt
(
math
.
factorial
(
l
-
m
)
/
math
.
factorial
(
l
+
m
))
return
math
.
sqrt
((
2
*
l
+
1
)
/
4
/
math
.
pi
)
*
math
.
sqrt
(
math
.
factorial
(
l
-
m
)
/
math
.
factorial
(
l
+
m
))
def
legpoly
(
mmax
:
int
,
lmax
:
int
,
x
:
torch
.
Tensor
,
norm
:
Optional
[
str
]
=
"ortho"
,
inverse
:
Optional
[
bool
]
=
False
,
csphase
:
Optional
[
bool
]
=
True
)
->
torch
.
Tensor
:
def
legpoly
(
mmax
:
int
,
lmax
:
int
,
x
:
torch
.
Tensor
,
norm
:
Optional
[
str
]
=
"ortho"
,
inverse
:
Optional
[
bool
]
=
False
,
csphase
:
Optional
[
bool
]
=
True
)
->
torch
.
Tensor
:
r
"""
"""
Computes the values of (-1)^m c^l_m P^l_m(x) at the positions specified by x.
Computes the values of (-1)^m c^l_m P^l_m(x) at the positions specified by x.
The resulting tensor has shape (mmax, lmax, len(x)). The Condon-Shortley Phase (-1)^m
The resulting tensor has shape (mmax, lmax, len(x)). The Condon-Shortley Phase (-1)^m
can be turned off optionally.
can be turned off optionally.
...
@@ -127,7 +113,7 @@ def legpoly(mmax: int, lmax: int, x: torch.Tensor, norm: Optional[str]="ortho",
...
@@ -127,7 +113,7 @@ def legpoly(mmax: int, lmax: int, x: torch.Tensor, norm: Optional[str]="ortho",
@
lru_cache
(
typed
=
True
,
copy
=
True
)
@
lru_cache
(
typed
=
True
,
copy
=
True
)
def
_precompute_legpoly
(
mmax
:
int
,
lmax
:
int
,
t
:
torch
.
Tensor
,
def
_precompute_legpoly
(
mmax
:
int
,
lmax
:
int
,
t
:
torch
.
Tensor
,
norm
:
Optional
[
str
]
=
"ortho"
,
inverse
:
Optional
[
bool
]
=
False
,
csphase
:
Optional
[
bool
]
=
True
)
->
torch
.
Tensor
:
norm
:
Optional
[
str
]
=
"ortho"
,
inverse
:
Optional
[
bool
]
=
False
,
csphase
:
Optional
[
bool
]
=
True
)
->
torch
.
Tensor
:
r
"""
"""
Computes the values of (-1)^m c^l_m P^l_m(\cos
\t
heta) at the positions specified by t (theta).
Computes the values of (-1)^m c^l_m P^l_m(\cos
\t
heta) at the positions specified by t (theta).
The resulting tensor has shape (mmax, lmax, len(x)). The Condon-Shortley Phase (-1)^m
The resulting tensor has shape (mmax, lmax, len(x)). The Condon-Shortley Phase (-1)^m
can be turned off optionally.
can be turned off optionally.
...
@@ -165,7 +151,7 @@ def _precompute_legpoly(mmax: int , lmax: int, t: torch.Tensor,
...
@@ -165,7 +151,7 @@ def _precompute_legpoly(mmax: int , lmax: int, t: torch.Tensor,
@
lru_cache
(
typed
=
True
,
copy
=
True
)
@
lru_cache
(
typed
=
True
,
copy
=
True
)
def
_precompute_dlegpoly
(
mmax
:
int
,
lmax
:
int
,
t
:
torch
.
Tensor
,
def
_precompute_dlegpoly
(
mmax
:
int
,
lmax
:
int
,
t
:
torch
.
Tensor
,
norm
:
Optional
[
str
]
=
"ortho"
,
inverse
:
Optional
[
bool
]
=
False
,
csphase
:
Optional
[
bool
]
=
True
)
->
torch
.
Tensor
:
norm
:
Optional
[
str
]
=
"ortho"
,
inverse
:
Optional
[
bool
]
=
False
,
csphase
:
Optional
[
bool
]
=
True
)
->
torch
.
Tensor
:
r
"""
"""
Computes the values of the derivatives $
\f
rac{d}{d
\t
heta} P^m_l(\cos
\t
heta)$
Computes the values of the derivatives $
\f
rac{d}{d
\t
heta} P^m_l(\cos
\t
heta)$
at the positions specified by t (theta), as well as $
\f
rac{1}{\sin
\t
heta} P^m_l(\cos
\t
heta)$,
at the positions specified by t (theta), as well as $
\f
rac{1}{\sin
\t
heta} P^m_l(\cos
\t
heta)$,
needed for the computation of the vector spherical harmonics. The resulting tensor has shape
needed for the computation of the vector spherical harmonics. The resulting tensor has shape
...
...
torch_harmonics/plotting.py
View file @
30d8b2da
...
@@ -47,14 +47,6 @@ except ImportError as err:
...
@@ -47,14 +47,6 @@ except ImportError as err:
def
check_plotting_dependencies
():
def
check_plotting_dependencies
():
"""
Check if required plotting dependencies (matplotlib and cartopy) are available.
Raises
------
ImportError
If matplotlib or cartopy is not installed
"""
if
plt
is
None
:
if
plt
is
None
:
raise
ImportError
(
"matplotlib is required for plotting functions. Install it with 'pip install matplotlib'"
)
raise
ImportError
(
"matplotlib is required for plotting functions. Install it with 'pip install matplotlib'"
)
if
cartopy
is
None
:
if
cartopy
is
None
:
...
...
torch_harmonics/quadrature.py
View file @
30d8b2da
...
@@ -37,7 +37,7 @@ import torch
...
@@ -37,7 +37,7 @@ import torch
def
_precompute_grid
(
n
:
int
,
grid
:
Optional
[
str
]
=
"equidistant"
,
a
:
Optional
[
float
]
=
0.0
,
b
:
Optional
[
float
]
=
1.0
,
def
_precompute_grid
(
n
:
int
,
grid
:
Optional
[
str
]
=
"equidistant"
,
a
:
Optional
[
float
]
=
0.0
,
b
:
Optional
[
float
]
=
1.0
,
periodic
:
Optional
[
bool
]
=
False
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
periodic
:
Optional
[
bool
]
=
False
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
r
"""
"""
Precompute grid points and weights for various quadrature rules.
Precompute grid points and weights for various quadrature rules.
Parameters
Parameters
...
@@ -103,7 +103,7 @@ def _precompute_latitudes(nlat: int, grid: Optional[str]="equiangular") -> Tuple
...
@@ -103,7 +103,7 @@ def _precompute_latitudes(nlat: int, grid: Optional[str]="equiangular") -> Tuple
def
trapezoidal_weights
(
n
:
int
,
a
:
Optional
[
float
]
=-
1.0
,
b
:
Optional
[
float
]
=
1.0
,
periodic
:
Optional
[
bool
]
=
False
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
trapezoidal_weights
(
n
:
int
,
a
:
Optional
[
float
]
=-
1.0
,
b
:
Optional
[
float
]
=
1.0
,
periodic
:
Optional
[
bool
]
=
False
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
r
"""
"""
Helper routine which returns equidistant nodes with trapezoidal weights
Helper routine which returns equidistant nodes with trapezoidal weights
on the interval [a, b]
on the interval [a, b]
...
@@ -137,7 +137,7 @@ def trapezoidal_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0,
...
@@ -137,7 +137,7 @@ def trapezoidal_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0,
def
legendre_gauss_weights
(
n
:
int
,
a
:
Optional
[
float
]
=-
1.0
,
b
:
Optional
[
float
]
=
1.0
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
legendre_gauss_weights
(
n
:
int
,
a
:
Optional
[
float
]
=-
1.0
,
b
:
Optional
[
float
]
=
1.0
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
r
"""
"""
Helper routine which returns the Legendre-Gauss nodes and weights
Helper routine which returns the Legendre-Gauss nodes and weights
on the interval [a, b]
on the interval [a, b]
...
@@ -169,7 +169,7 @@ def legendre_gauss_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1
...
@@ -169,7 +169,7 @@ def legendre_gauss_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1
def
lobatto_weights
(
n
:
int
,
a
:
Optional
[
float
]
=-
1.0
,
b
:
Optional
[
float
]
=
1.0
,
def
lobatto_weights
(
n
:
int
,
a
:
Optional
[
float
]
=-
1.0
,
b
:
Optional
[
float
]
=
1.0
,
tol
:
Optional
[
float
]
=
1e-16
,
maxiter
:
Optional
[
int
]
=
100
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
tol
:
Optional
[
float
]
=
1e-16
,
maxiter
:
Optional
[
int
]
=
100
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
r
"""
"""
Helper routine which returns the Legendre-Gauss-Lobatto nodes and weights
Helper routine which returns the Legendre-Gauss-Lobatto nodes and weights
on the interval [a, b]
on the interval [a, b]
...
@@ -232,7 +232,7 @@ def lobatto_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0,
...
@@ -232,7 +232,7 @@ def lobatto_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0,
def
clenshaw_curtiss_weights
(
n
:
int
,
a
:
Optional
[
float
]
=-
1.0
,
b
:
Optional
[
float
]
=
1.0
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
clenshaw_curtiss_weights
(
n
:
int
,
a
:
Optional
[
float
]
=-
1.0
,
b
:
Optional
[
float
]
=
1.0
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
r
"""
"""
Computation of the Clenshaw-Curtis quadrature nodes and weights.
Computation of the Clenshaw-Curtis quadrature nodes and weights.
This implementation follows
This implementation follows
...
@@ -289,7 +289,7 @@ def clenshaw_curtiss_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]
...
@@ -289,7 +289,7 @@ def clenshaw_curtiss_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]
def
fejer2_weights
(
n
:
int
,
a
:
Optional
[
float
]
=-
1.0
,
b
:
Optional
[
float
]
=
1.0
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
fejer2_weights
(
n
:
int
,
a
:
Optional
[
float
]
=-
1.0
,
b
:
Optional
[
float
]
=
1.0
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
r
"""
"""
Computation of the Fejer quadrature nodes and weights.
Computation of the Fejer quadrature nodes and weights.
Parameters
Parameters
...
...
torch_harmonics/resample.py
View file @
30d8b2da
...
@@ -137,11 +137,9 @@ class ResampleS2(nn.Module):
...
@@ -137,11 +137,9 @@ class ResampleS2(nn.Module):
def
extra_repr
(
self
):
def
extra_repr
(
self
):
return
f
"in_shape=
{
(
self
.
nlat_in
,
self
.
nlon_in
)
}
, out_shape=
{
(
self
.
nlat_out
,
self
.
nlon_out
)
}
"
return
f
"in_shape=
{
(
self
.
nlat_in
,
self
.
nlon_in
)
}
, out_shape=
{
(
self
.
nlat_out
,
self
.
nlon_out
)
}
"
def
_upscale_longitudes
(
self
,
x
:
torch
.
Tensor
):
def
_upscale_longitudes
(
self
,
x
:
torch
.
Tensor
):
# do the interpolation in precision of x
# do the interpolation in precision of x
lwgt
=
self
.
lon_weights
.
to
(
x
.
dtype
)
lwgt
=
self
.
lon_weights
.
to
(
x
.
dtype
)
if
self
.
mode
==
"bilinear"
:
if
self
.
mode
==
"bilinear"
:
...
@@ -156,19 +154,6 @@ class ResampleS2(nn.Module):
...
@@ -156,19 +154,6 @@ class ResampleS2(nn.Module):
return
x
return
x
def
_expand_poles
(
self
,
x
:
torch
.
Tensor
):
def
_expand_poles
(
self
,
x
:
torch
.
Tensor
):
"""
Expand the input tensor to include pole points for interpolation.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (..., nlat, nlon)
Returns
-------
torch.Tensor
Expanded tensor with pole points added
"""
x_north
=
x
[...,
0
,
:].
mean
(
dim
=-
1
,
keepdims
=
True
)
x_north
=
x
[...,
0
,
:].
mean
(
dim
=-
1
,
keepdims
=
True
)
x_south
=
x
[...,
-
1
,
:].
mean
(
dim
=-
1
,
keepdims
=
True
)
x_south
=
x
[...,
-
1
,
:].
mean
(
dim
=-
1
,
keepdims
=
True
)
x
=
nn
.
functional
.
pad
(
x
,
pad
=
[
0
,
0
,
1
,
1
],
mode
=
'constant'
)
x
=
nn
.
functional
.
pad
(
x
,
pad
=
[
0
,
0
,
1
,
1
],
mode
=
'constant'
)
...
@@ -178,7 +163,6 @@ class ResampleS2(nn.Module):
...
@@ -178,7 +163,6 @@ class ResampleS2(nn.Module):
return
x
return
x
def
_upscale_latitudes
(
self
,
x
:
torch
.
Tensor
):
def
_upscale_latitudes
(
self
,
x
:
torch
.
Tensor
):
# do the interpolation in precision of x
# do the interpolation in precision of x
lwgt
=
self
.
lat_weights
.
to
(
x
.
dtype
)
lwgt
=
self
.
lat_weights
.
to
(
x
.
dtype
)
if
self
.
mode
==
"bilinear"
:
if
self
.
mode
==
"bilinear"
:
...
@@ -193,7 +177,6 @@ class ResampleS2(nn.Module):
...
@@ -193,7 +177,6 @@ class ResampleS2(nn.Module):
return
x
return
x
def
forward
(
self
,
x
:
torch
.
Tensor
):
def
forward
(
self
,
x
:
torch
.
Tensor
):
if
self
.
skip_resampling
:
if
self
.
skip_resampling
:
return
x
return
x
...
...
torch_harmonics/sht.py
View file @
30d8b2da
...
@@ -38,7 +38,7 @@ from torch_harmonics.legendre import _precompute_legpoly, _precompute_dlegpoly
...
@@ -38,7 +38,7 @@ from torch_harmonics.legendre import _precompute_legpoly, _precompute_dlegpoly
class
RealSHT
(
nn
.
Module
):
class
RealSHT
(
nn
.
Module
):
r
"""
"""
Defines a module for computing the forward (real-valued) SHT.
Defines a module for computing the forward (real-valued) SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
The SHT is applied to the last two dimensions of the input
The SHT is applied to the last two dimensions of the input
...
@@ -149,7 +149,7 @@ class RealSHT(nn.Module):
...
@@ -149,7 +149,7 @@ class RealSHT(nn.Module):
class
InverseRealSHT
(
nn
.
Module
):
class
InverseRealSHT
(
nn
.
Module
):
r
"""
"""
Defines a module for computing the inverse (real-valued) SHT.
Defines a module for computing the inverse (real-valued) SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
...
@@ -250,7 +250,7 @@ class InverseRealSHT(nn.Module):
...
@@ -250,7 +250,7 @@ class InverseRealSHT(nn.Module):
class
RealVectorSHT
(
nn
.
Module
):
class
RealVectorSHT
(
nn
.
Module
):
r
"""
"""
Defines a module for computing the forward (real) vector SHT.
Defines a module for computing the forward (real) vector SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
The SHT is applied to the last three dimensions of the input.
The SHT is applied to the last three dimensions of the input.
...
@@ -373,7 +373,7 @@ class RealVectorSHT(nn.Module):
...
@@ -373,7 +373,7 @@ class RealVectorSHT(nn.Module):
class
InverseRealVectorSHT
(
nn
.
Module
):
class
InverseRealVectorSHT
(
nn
.
Module
):
r
"""
"""
Defines a module for computing the inverse (real-valued) vector SHT.
Defines a module for computing the inverse (real-valued) vector SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment