Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
170dee9c
"vscode:/vscode.git/clone" did not exist on "18ead1935557a2d11cac44bb5dfd82f3d63ea682"
Commit
170dee9c
authored
Jun 09, 2023
by
Boris Bonev
Browse files
Changed docstrings to raw strings
parent
80361eaf
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
23 additions
and
19 deletions
+23
-19
.github/workflows/tests.yml
.github/workflows/tests.yml
+3
-1
torch_harmonics/legendre.py
torch_harmonics/legendre.py
+2
-2
torch_harmonics/quadrature.py
torch_harmonics/quadrature.py
+4
-4
torch_harmonics/random_fields.py
torch_harmonics/random_fields.py
+4
-2
torch_harmonics/sht.py
torch_harmonics/sht.py
+10
-10
No files found.
.github/workflows/tests.yml
View file @
170dee9c
...
@@ -24,7 +24,9 @@ jobs:
...
@@ -24,7 +24,9 @@ jobs:
# python -m pip install -e .
# python -m pip install -e .
# cd ../..
# cd ../..
-
name
:
Install dependencies
-
name
:
Install dependencies
run
:
python -m pip install --upgrade pip setuptools wheel
run
:
|
python -m pip install --upgrade pip setuptools wheel
python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
-
name
:
Install package
-
name
:
Install package
run
:
|
run
:
|
python -m pip install -e .
python -m pip install -e .
...
...
torch_harmonics/legendre.py
View file @
170dee9c
...
@@ -40,7 +40,7 @@ def clm(l, m):
...
@@ -40,7 +40,7 @@ def clm(l, m):
def
precompute_legpoly
(
mmax
,
lmax
,
t
,
norm
=
"ortho"
,
inverse
=
False
,
csphase
=
True
):
def
precompute_legpoly
(
mmax
,
lmax
,
t
,
norm
=
"ortho"
,
inverse
=
False
,
csphase
=
True
):
"""
r
"""
Computes the values of (-1)^m c^l_m P^l_m(\cos \theta) at the positions specified by x (theta)
Computes the values of (-1)^m c^l_m P^l_m(\cos \theta) at the positions specified by x (theta)
The resulting tensor has shape (mmax, lmax, len(x)).
The resulting tensor has shape (mmax, lmax, len(x)).
The Condon-Shortley Phase (-1)^m can be turned off optionally
The Condon-Shortley Phase (-1)^m can be turned off optionally
...
@@ -92,7 +92,7 @@ def precompute_legpoly(mmax, lmax, t, norm="ortho", inverse=False, csphase=True)
...
@@ -92,7 +92,7 @@ def precompute_legpoly(mmax, lmax, t, norm="ortho", inverse=False, csphase=True)
return
torch
.
from_numpy
(
pct
)
return
torch
.
from_numpy
(
pct
)
def
precompute_dlegpoly
(
mmax
,
lmax
,
x
,
norm
=
"ortho"
,
inverse
=
False
,
csphase
=
True
):
def
precompute_dlegpoly
(
mmax
,
lmax
,
x
,
norm
=
"ortho"
,
inverse
=
False
,
csphase
=
True
):
"""
r
"""
Computes the values of the derivatives $\frac{d}{d \theta} P^m_l(\cos \theta)$
Computes the values of the derivatives $\frac{d}{d \theta} P^m_l(\cos \theta)$
at the positions specified by x (theta), as well as $\frac{1}{\sin \theta} P^m_l(\cos \theta)$,
at the positions specified by x (theta), as well as $\frac{1}{\sin \theta} P^m_l(\cos \theta)$,
needed for the computation of the vector spherical harmonics. The resulting tensor has shape
needed for the computation of the vector spherical harmonics. The resulting tensor has shape
...
...
torch_harmonics/quadrature.py
View file @
170dee9c
...
@@ -32,7 +32,7 @@
...
@@ -32,7 +32,7 @@
import
numpy
as
np
import
numpy
as
np
def
legendre_gauss_weights
(
n
,
a
=-
1.0
,
b
=
1.0
):
def
legendre_gauss_weights
(
n
,
a
=-
1.0
,
b
=
1.0
):
"""
r
"""
Helper routine which returns the Legendre-Gauss nodes and weights
Helper routine which returns the Legendre-Gauss nodes and weights
on the interval [a, b]
on the interval [a, b]
"""
"""
...
@@ -44,7 +44,7 @@ def legendre_gauss_weights(n, a=-1.0, b=1.0):
...
@@ -44,7 +44,7 @@ def legendre_gauss_weights(n, a=-1.0, b=1.0):
return
xlg
,
wlg
return
xlg
,
wlg
def
lobatto_weights
(
n
,
a
=-
1.0
,
b
=
1.0
,
tol
=
1e-16
,
maxiter
=
100
):
def
lobatto_weights
(
n
,
a
=-
1.0
,
b
=
1.0
,
tol
=
1e-16
,
maxiter
=
100
):
"""
r
"""
Helper routine which returns the Legendre-Gauss-Lobatto nodes and weights
Helper routine which returns the Legendre-Gauss-Lobatto nodes and weights
on the interval [a, b]
on the interval [a, b]
"""
"""
...
@@ -86,7 +86,7 @@ def lobatto_weights(n, a=-1.0, b=1.0, tol=1e-16, maxiter=100):
...
@@ -86,7 +86,7 @@ def lobatto_weights(n, a=-1.0, b=1.0, tol=1e-16, maxiter=100):
def
clenshaw_curtiss_weights
(
n
,
a
=-
1.0
,
b
=
1.0
):
def
clenshaw_curtiss_weights
(
n
,
a
=-
1.0
,
b
=
1.0
):
"""
r
"""
Computation of the Clenshaw-Curtis quadrature nodes and weights.
Computation of the Clenshaw-Curtis quadrature nodes and weights.
This implementation follows
This implementation follows
...
@@ -123,7 +123,7 @@ def clenshaw_curtiss_weights(n, a=-1.0, b=1.0):
...
@@ -123,7 +123,7 @@ def clenshaw_curtiss_weights(n, a=-1.0, b=1.0):
return
tcc
,
wcc
return
tcc
,
wcc
def
fejer2_weights
(
n
,
a
=-
1.0
,
b
=
1.0
):
def
fejer2_weights
(
n
,
a
=-
1.0
,
b
=
1.0
):
"""
r
"""
Computation of the Fejer quadrature nodes and weights.
Computation of the Fejer quadrature nodes and weights.
This implementation follows
This implementation follows
...
...
torch_harmonics/random_fields.py
View file @
170dee9c
...
@@ -35,7 +35,8 @@ from .sht import InverseRealSHT
...
@@ -35,7 +35,8 @@ from .sht import InverseRealSHT
class
GaussianRandomFieldS2
(
torch
.
nn
.
Module
):
class
GaussianRandomFieldS2
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
nlat
,
alpha
=
2.0
,
tau
=
3.0
,
sigma
=
None
,
radius
=
1.0
,
grid
=
"equiangular"
,
dtype
=
torch
.
float32
):
def
__init__
(
self
,
nlat
,
alpha
=
2.0
,
tau
=
3.0
,
sigma
=
None
,
radius
=
1.0
,
grid
=
"equiangular"
,
dtype
=
torch
.
float32
):
super
().
__init__
()
super
().
__init__
()
"""A mean-zero Gaussian Random Field on the sphere with Matern covariance:
r
"""
A mean-zero Gaussian Random Field on the sphere with Matern covariance:
C = sigma^2 (-Lap + tau^2 I)^(-alpha).
C = sigma^2 (-Lap + tau^2 I)^(-alpha).
Lap is the Laplacian on the sphere, I the identity operator,
Lap is the Laplacian on the sphere, I the identity operator,
...
@@ -93,7 +94,8 @@ class GaussianRandomFieldS2(torch.nn.Module):
...
@@ -93,7 +94,8 @@ class GaussianRandomFieldS2(torch.nn.Module):
self
.
gaussian_noise
=
torch
.
distributions
.
normal
.
Normal
(
self
.
mean
,
self
.
var
)
self
.
gaussian_noise
=
torch
.
distributions
.
normal
.
Normal
(
self
.
mean
,
self
.
var
)
def
forward
(
self
,
N
,
xi
=
None
):
def
forward
(
self
,
N
,
xi
=
None
):
"""Sample random functions from a spherical GRF.
r
"""
Sample random functions from a spherical GRF.
Parameters
Parameters
----------
----------
...
...
torch_harmonics/sht.py
View file @
170dee9c
...
@@ -39,7 +39,7 @@ from .legendre import *
...
@@ -39,7 +39,7 @@ from .legendre import *
class
RealSHT
(
nn
.
Module
):
class
RealSHT
(
nn
.
Module
):
"""
r
"""
Defines a module for computing the forward (real-valued) SHT.
Defines a module for computing the forward (real-valued) SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
The SHT is applied to the last two dimensions of the input
The SHT is applied to the last two dimensions of the input
...
@@ -49,7 +49,7 @@ class RealSHT(nn.Module):
...
@@ -49,7 +49,7 @@ class RealSHT(nn.Module):
"""
"""
def
__init__
(
self
,
nlat
,
nlon
,
lmax
=
None
,
mmax
=
None
,
grid
=
"lobatto"
,
norm
=
"ortho"
,
csphase
=
True
):
def
__init__
(
self
,
nlat
,
nlon
,
lmax
=
None
,
mmax
=
None
,
grid
=
"lobatto"
,
norm
=
"ortho"
,
csphase
=
True
):
"""
r
"""
Initializes the SHT Layer, precomputing the necessary quadrature weights
Initializes the SHT Layer, precomputing the necessary quadrature weights
Parameters:
Parameters:
...
@@ -97,7 +97,7 @@ class RealSHT(nn.Module):
...
@@ -97,7 +97,7 @@ class RealSHT(nn.Module):
self
.
register_buffer
(
'weights'
,
weights
,
persistent
=
False
)
self
.
register_buffer
(
'weights'
,
weights
,
persistent
=
False
)
def
extra_repr
(
self
):
def
extra_repr
(
self
):
"""
r
"""
Pretty print module
Pretty print module
"""
"""
return
f
'nlat=
{
self
.
nlat
}
, nlon=
{
self
.
nlon
}
,
\n
lmax=
{
self
.
lmax
}
, mmax=
{
self
.
mmax
}
,
\n
grid=
{
self
.
grid
}
, csphase=
{
self
.
csphase
}
'
return
f
'nlat=
{
self
.
nlat
}
, nlon=
{
self
.
nlon
}
,
\n
lmax=
{
self
.
lmax
}
, mmax=
{
self
.
mmax
}
,
\n
grid=
{
self
.
grid
}
, csphase=
{
self
.
csphase
}
'
...
@@ -127,7 +127,7 @@ class RealSHT(nn.Module):
...
@@ -127,7 +127,7 @@ class RealSHT(nn.Module):
return
x
return
x
class
InverseRealSHT
(
nn
.
Module
):
class
InverseRealSHT
(
nn
.
Module
):
"""
r
"""
Defines a module for computing the inverse (real-valued) SHT.
Defines a module for computing the inverse (real-valued) SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
nlat, nlon: Output dimensions
nlat, nlon: Output dimensions
...
@@ -172,7 +172,7 @@ class InverseRealSHT(nn.Module):
...
@@ -172,7 +172,7 @@ class InverseRealSHT(nn.Module):
self
.
register_buffer
(
'pct'
,
pct
,
persistent
=
False
)
self
.
register_buffer
(
'pct'
,
pct
,
persistent
=
False
)
def
extra_repr
(
self
):
def
extra_repr
(
self
):
"""
r
"""
Pretty print module
Pretty print module
"""
"""
return
f
'nlat=
{
self
.
nlat
}
, nlon=
{
self
.
nlon
}
,
\n
lmax=
{
self
.
lmax
}
, mmax=
{
self
.
mmax
}
,
\n
grid=
{
self
.
grid
}
, csphase=
{
self
.
csphase
}
'
return
f
'nlat=
{
self
.
nlat
}
, nlon=
{
self
.
nlon
}
,
\n
lmax=
{
self
.
lmax
}
, mmax=
{
self
.
mmax
}
,
\n
grid=
{
self
.
grid
}
, csphase=
{
self
.
csphase
}
'
...
@@ -197,7 +197,7 @@ class InverseRealSHT(nn.Module):
...
@@ -197,7 +197,7 @@ class InverseRealSHT(nn.Module):
class
RealVectorSHT
(
nn
.
Module
):
class
RealVectorSHT
(
nn
.
Module
):
"""
r
"""
Defines a module for computing the forward (real) vector SHT.
Defines a module for computing the forward (real) vector SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
The SHT is applied to the last three dimensions of the input.
The SHT is applied to the last three dimensions of the input.
...
@@ -207,7 +207,7 @@ class RealVectorSHT(nn.Module):
...
@@ -207,7 +207,7 @@ class RealVectorSHT(nn.Module):
"""
"""
def
__init__
(
self
,
nlat
,
nlon
,
lmax
=
None
,
mmax
=
None
,
grid
=
"lobatto"
,
norm
=
"ortho"
,
csphase
=
True
):
def
__init__
(
self
,
nlat
,
nlon
,
lmax
=
None
,
mmax
=
None
,
grid
=
"lobatto"
,
norm
=
"ortho"
,
csphase
=
True
):
"""
r
"""
Initializes the vector SHT Layer, precomputing the necessary quadrature weights
Initializes the vector SHT Layer, precomputing the necessary quadrature weights
Parameters:
Parameters:
...
@@ -259,7 +259,7 @@ class RealVectorSHT(nn.Module):
...
@@ -259,7 +259,7 @@ class RealVectorSHT(nn.Module):
self
.
register_buffer
(
'weights'
,
weights
,
persistent
=
False
)
self
.
register_buffer
(
'weights'
,
weights
,
persistent
=
False
)
def
extra_repr
(
self
):
def
extra_repr
(
self
):
"""
r
"""
Pretty print module
Pretty print module
"""
"""
return
f
'nlat=
{
self
.
nlat
}
, nlon=
{
self
.
nlon
}
,
\n
lmax=
{
self
.
lmax
}
, mmax=
{
self
.
mmax
}
,
\n
grid=
{
self
.
grid
}
, csphase=
{
self
.
csphase
}
'
return
f
'nlat=
{
self
.
nlat
}
, nlon=
{
self
.
nlon
}
,
\n
lmax=
{
self
.
lmax
}
, mmax=
{
self
.
mmax
}
,
\n
grid=
{
self
.
grid
}
, csphase=
{
self
.
csphase
}
'
...
@@ -301,7 +301,7 @@ class RealVectorSHT(nn.Module):
...
@@ -301,7 +301,7 @@ class RealVectorSHT(nn.Module):
class
InverseRealVectorSHT
(
nn
.
Module
):
class
InverseRealVectorSHT
(
nn
.
Module
):
"""
r
"""
Defines a module for computing the inverse (real-valued) vector SHT.
Defines a module for computing the inverse (real-valued) vector SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
...
@@ -343,7 +343,7 @@ class InverseRealVectorSHT(nn.Module):
...
@@ -343,7 +343,7 @@ class InverseRealVectorSHT(nn.Module):
self
.
register_buffer
(
'dpct'
,
dpct
,
persistent
=
False
)
self
.
register_buffer
(
'dpct'
,
dpct
,
persistent
=
False
)
def
extra_repr
(
self
):
def
extra_repr
(
self
):
"""
r
"""
Pretty print module
Pretty print module
"""
"""
return
f
'nlat=
{
self
.
nlat
}
, nlon=
{
self
.
nlon
}
,
\n
lmax=
{
self
.
lmax
}
, mmax=
{
self
.
mmax
}
,
\n
grid=
{
self
.
grid
}
, csphase=
{
self
.
csphase
}
'
return
f
'nlat=
{
self
.
nlat
}
, nlon=
{
self
.
nlon
}
,
\n
lmax=
{
self
.
lmax
}
, mmax=
{
self
.
mmax
}
,
\n
grid=
{
self
.
grid
}
, csphase=
{
self
.
csphase
}
'
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment