Unverified Commit c7afb546 authored by Thorsten Kurth's avatar Thorsten Kurth Committed by GitHub
Browse files

Merge pull request #95 from NVIDIA/aparis/docs

Docstrings PR
parents b5c410c0 644465ba
Pipeline #2854 canceled with stages
......@@ -37,6 +37,32 @@ import torch
def _precompute_grid(n: int, grid: Optional[str]="equidistant", a: Optional[float]=0.0, b: Optional[float]=1.0,
periodic: Optional[bool]=False) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Precompute grid points and weights for various quadrature rules.
Parameters
-----------
n : int
Number of grid points
grid : str, optional
Grid type ("equidistant", "legendre-gauss", "lobatto", "equiangular"), by default "equidistant"
a : float, optional
Lower bound of interval, by default 0.0
b : float, optional
Upper bound of interval, by default 1.0
periodic : bool, optional
Whether the grid is periodic (only for equidistant), by default False
Returns
-------
Tuple[torch.Tensor, torch.Tensor]
Grid points and weights
Raises
------
ValueError
If periodic is True for non-equidistant grids or unknown grid type
"""
if (grid != "equidistant") and periodic:
raise ValueError(f"Periodic grid is only supported on equidistant grids.")
......@@ -57,19 +83,13 @@ def _precompute_grid(n: int, grid: Optional[str]="equidistant", a: Optional[floa
@lru_cache(typed=True, copy=True)
def _precompute_longitudes(nlon: int):
r"""
Convenience routine to precompute longitudes
"""
lons = torch.linspace(0, 2 * math.pi, nlon+1, dtype=torch.float64, requires_grad=False)[:-1]
return lons
@lru_cache(typed=True, copy=True)
def _precompute_latitudes(nlat: int, grid: Optional[str]="equiangular") -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Convenience routine to precompute latitudes
"""
# compute coordinates in the cosine theta domain
xlg, wlg = _precompute_grid(nlat, grid=grid, a=-1.0, b=1.0, periodic=False)
......@@ -83,9 +103,27 @@ def _precompute_latitudes(nlat: int, grid: Optional[str]="equiangular") -> Tuple
def trapezoidal_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0, periodic: Optional[bool]=False) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
"""
Helper routine which returns equidistant nodes with trapezoidal weights
on the interval [a, b]
Parameters
-----------
n: int
Number of quadrature nodes
a: Optional[float]
Lower bound of the interval
b: Optional[float]
Upper bound of the interval
periodic: Optional[bool]
Whether the grid is periodic
Returns
-------
xlg: torch.Tensor
Tensor of quadrature nodes
wlg: torch.Tensor
Tensor of quadrature weights
"""
xlg = torch.as_tensor(np.linspace(a, b, n, endpoint=periodic))
......@@ -99,9 +137,25 @@ def trapezoidal_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0,
def legendre_gauss_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
"""
Helper routine which returns the Legendre-Gauss nodes and weights
on the interval [a, b]
Parameters
-----------
n: int
Number of quadrature nodes
a: Optional[float]
Lower bound of the interval
b: Optional[float]
Upper bound of the interval
Returns
-------
xlg: torch.Tensor
Tensor of quadrature nodes
wlg: torch.Tensor
Tensor of quadrature weights
"""
xlg, wlg = np.polynomial.legendre.leggauss(n)
......@@ -115,9 +169,30 @@ def legendre_gauss_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1
def lobatto_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0,
tol: Optional[float]=1e-16, maxiter: Optional[int]=100) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
"""
Helper routine which returns the Legendre-Gauss-Lobatto nodes and weights
on the interval [a, b]
Parameters
-----------
n: int
Number of quadrature nodes
a: Optional[float]
Lower bound of the interval
b: Optional[float]
Upper bound of the interval
tol: Optional[float]
Tolerance for the iteration
maxiter: Optional[int]
Maximum number of iterations
Returns
-------
tlg: torch.Tensor
Tensor of quadrature nodes
wlg: torch.Tensor
Tensor of quadrature weights
"""
wlg = torch.zeros((n,), dtype=torch.float64, requires_grad=False)
......@@ -157,10 +232,28 @@ def lobatto_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0,
def clenshaw_curtiss_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
"""
Computation of the Clenshaw-Curtis quadrature nodes and weights.
This implementation follows
Parameters
-----------
n: int
Number of quadrature nodes
a: Optional[float]
Lower bound of the interval
b: Optional[float]
Upper bound of the interval
Returns
-------
tcc: torch.Tensor
Tensor of quadrature nodes
wcc: torch.Tensor
Tensor of quadrature weights
References
----------
[1] Joerg Waldvogel, Fast Construction of the Fejer and Clenshaw-Curtis Quadrature Rules; BIT Numerical Mathematics, Vol. 43, No. 1, pp. 001–018.
"""
......@@ -196,10 +289,27 @@ def clenshaw_curtiss_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]
def fejer2_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
"""
Computation of the Fejer quadrature nodes and weights.
This implementation follows
Parameters
-----------
n: int
Number of quadrature nodes
a: Optional[float]
Lower bound of the interval
b: Optional[float]
Upper bound of the interval
Returns
-------
tcc: torch.Tensor
Tensor of quadrature nodes
wcc: torch.Tensor
Tensor of quadrature weights
References
----------
[1] Joerg Waldvogel, Fast Construction of the Fejer and Clenshaw-Curtis Quadrature Rules; BIT Numerical Mathematics, Vol. 43, No. 1, pp. 001–018.
"""
......
......@@ -33,37 +33,29 @@ import torch
from .sht import InverseRealSHT
class GaussianRandomFieldS2(torch.nn.Module):
"""
Gaussian random field on the sphere.
Parameters
----------
nlat : int
Number of latitudinal modes.
alpha : float, optional
Exponent of the power spectrum.
tau : float, optional
Cutoff scale of the power spectrum.
sigma : float, optional
Standard deviation of the power spectrum.
radius : float, optional
Radius of the sphere.
grid : str, optional
Grid type.
dtype : torch.dtype, optional
Data type.
"""
def __init__(self, nlat, alpha=2.0, tau=3.0, sigma=None, radius=1.0, grid="equiangular", dtype=torch.float32):
super().__init__()
r"""
A mean-zero Gaussian Random Field on the sphere with Matern covariance:
C = sigma^2 (-Lap + tau^2 I)^(-alpha).
Lap is the Laplacian on the sphere, I the identity operator,
and sigma, tau, alpha are scalar parameters.
Note: C is trace-class on L^2 if and only if alpha > 1.
Parameters
----------
nlat : int
Number of latitudinal modes;
longitudinal modes are 2*nlat.
alpha : float, default is 2
Regularity parameter. Larger means smoother.
tau : float, default is 3
Lenght-scale parameter. Larger means more scales.
sigma : float, default is None
Scale parameter. Larger means bigger.
If None, sigma = tau**(0.5*(2*alpha - 2.0)).
radius : float, default is 1
Radius of the sphere.
grid : string, default is "equiangular"
Grid type. Currently supports "equiangular" and
"legendre-gauss".
dtype : torch.dtype, default is torch.float32
Numerical type for the calculations.
"""
#Number of latitudinal modes.
self.nlat = nlat
......@@ -94,24 +86,7 @@ class GaussianRandomFieldS2(torch.nn.Module):
self.gaussian_noise = torch.distributions.normal.Normal(self.mean, self.var)
def forward(self, N, xi=None):
r"""
Sample random functions from a spherical GRF.
Parameters
----------
N : int
Number of functions to sample.
xi : torch.Tensor, default is None
Noise is a complex tensor of size (N, nlat, nlat+1).
If None, new Gaussian noise is sampled.
If xi is provided, N is ignored.
Output
-------
u : torch.Tensor
N random samples from the GRF returned as a
tensor of size (N, nlat, 2*nlat) on a equiangular grid.
"""
#Sample Gaussian noise.
if xi is None:
xi = self.gaussian_noise.sample(torch.Size((N, self.nlat, self.nlat + 1, 2))).squeeze()
......
......@@ -40,6 +40,30 @@ from torch_harmonics.quadrature import _precompute_latitudes, _precompute_longit
class ResampleS2(nn.Module):
"""
Resampling module for signals on the 2-sphere.
This module provides functionality to resample spherical signals between different
grid resolutions and grid types using bilinear interpolation.
Parameters
-----------
nlat_in : int
Number of latitude points in the input grid
nlon_in : int
Number of longitude points in the input grid
nlat_out : int
Number of latitude points in the output grid
nlon_out : int
Number of longitude points in the output grid
grid_in : str, optional
Input grid type ("equiangular", "legendre-gauss", "lobatto"), by default "equiangular"
grid_out : str, optional
Output grid type ("equiangular", "legendre-gauss", "lobatto"), by default "equiangular"
mode : str, optional
Interpolation mode ("bilinear", "bilinear-spherical"), by default "bilinear"
"""
def __init__(
self,
nlat_in: int,
......@@ -113,9 +137,6 @@ class ResampleS2(nn.Module):
def extra_repr(self):
r"""
Pretty print module
"""
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}"
def _upscale_longitudes(self, x: torch.Tensor):
......
......@@ -38,24 +38,41 @@ from torch_harmonics.legendre import _precompute_legpoly, _precompute_dlegpoly
class RealSHT(nn.Module):
r"""
"""
Defines a module for computing the forward (real-valued) SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
The SHT is applied to the last two dimensions of the input
Parameters
-----------
nlat: int
Number of latitude points
nlon: int
Number of longitude points
lmax: int
Maximum spherical harmonic degree
mmax: int
Maximum spherical harmonic order
grid: str
Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
norm: str
Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
csphase: bool
Whether to apply the Condon-Shortley phase factor, by default True
Returns
-------
x: torch.Tensor
Tensor of shape (..., lmax, mmax)
References
----------
[1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
"""
def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
r"""
Initializes the SHT Layer, precomputing the necessary quadrature weights
Parameters:
nlat: input grid resolution in the latitudinal direction
nlon: input grid resolution in the longitudinal direction
grid: grid in the latitude direction (for now only tensor product grids are supported)
"""
super().__init__()
......@@ -101,9 +118,6 @@ class RealSHT(nn.Module):
self.register_buffer("weights", weights, persistent=False)
def extra_repr(self):
r"""
Pretty print module
"""
return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}"
def forward(self, x: torch.Tensor):
......@@ -135,12 +149,38 @@ class RealSHT(nn.Module):
class InverseRealSHT(nn.Module):
r"""
"""
Defines a module for computing the inverse (real-valued) SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
nlat, nlon: Output dimensions
lmax, mmax: Input dimensions (spherical coefficients). For convenience, these are inferred from the output dimensions
Parameters
-----------
nlat: int
Number of latitude points
nlon: int
Number of longitude points
lmax: int
Maximum spherical harmonic degree
mmax: int
Maximum spherical harmonic order
grid: str
Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
norm: str
Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
csphase: bool
Whether to apply the Condon-Shortley phase factor, by default True
Raises
------
ValueError: If the grid type is unknown
Returns
-------
x: torch.Tensor
Tensor of shape (..., lmax, mmax)
References
----------
[1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
"""
......@@ -180,9 +220,6 @@ class InverseRealSHT(nn.Module):
self.register_buffer("pct", pct, persistent=False)
def extra_repr(self):
r"""
Pretty print module
"""
return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}"
def forward(self, x: torch.Tensor):
......@@ -213,24 +250,41 @@ class InverseRealSHT(nn.Module):
class RealVectorSHT(nn.Module):
r"""
"""
Defines a module for computing the forward (real) vector SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
The SHT is applied to the last three dimensions of the input.
Parameters
-----------
nlat: int
Number of latitude points
nlon: int
Number of longitude points
lmax: int
Maximum spherical harmonic degree
mmax: int
Maximum spherical harmonic order
grid: str
Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
norm: str
Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
csphase: bool
Whether to apply the Condon-Shortley phase factor, by default True
Returns
-------
x: torch.Tensor
Tensor of shape (..., lmax, mmax)
References
----------
[1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
"""
def __init__(self, nlat, nlon, lmax=None, mmax=None, grid="equiangular", norm="ortho", csphase=True):
r"""
Initializes the vector SHT Layer, precomputing the necessary quadrature weights
Parameters:
nlat: input grid resolution in the latitudinal direction
nlon: input grid resolution in the longitudinal direction
grid: type of grid the data lives on
"""
super().__init__()
......@@ -272,9 +326,6 @@ class RealVectorSHT(nn.Module):
self.register_buffer("weights", weights, persistent=False)
def extra_repr(self):
r"""
Pretty print module
"""
return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}"
def forward(self, x: torch.Tensor):
......@@ -322,10 +373,34 @@ class RealVectorSHT(nn.Module):
class InverseRealVectorSHT(nn.Module):
r"""
"""
Defines a module for computing the inverse (real-valued) vector SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
Parameters
-----------
nlat: int
Number of latitude points
nlon: int
Number of longitude points
lmax: int
Maximum spherical harmonic degree
mmax: int
Maximum spherical harmonic order
grid: str
Grid type ("equiangular", "legendre-gauss", "lobatto", "equidistant"), by default "equiangular"
norm: str
Normalization type ("ortho", "schmidt", "unnorm"), by default "ortho"
csphase: bool
Whether to apply the Condon-Shortley phase factor, by default True
Returns
-------
x: torch.Tensor
Tensor of shape (..., lmax, mmax)
References
----------
[1] Schaeffer, N. Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Wang, B., Wang, L., Xie, Z.; Accurate calculation of spherical and vector spherical harmonic expansions via spectral element grids; Adv Comput Math.
"""
......@@ -365,9 +440,6 @@ class InverseRealVectorSHT(nn.Module):
self.register_buffer("dpct", dpct, persistent=False)
def extra_repr(self):
r"""
Pretty print module
"""
return f"nlat={self.nlat}, nlon={self.nlon},\n lmax={self.lmax}, mmax={self.mmax},\n grid={self.grid}, csphase={self.csphase}"
def forward(self, x: torch.Tensor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment