Commit 30d8b2da authored by Andrea Paris's avatar Andrea Paris Committed by Boris Bonev
Browse files

further cleanup

parent ec53e666
......@@ -664,20 +664,6 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()
def get_psi(self, semi_transposed: bool = False):
"""
Get the convolution tensor
Parameters
-----------
semi_transposed: bool
Whether to semi-transpose the convolution tensor
Returns
-------
psi: torch.Tensor
Convolution tensor
"""
if semi_transposed:
# we do a semi-transposition to faciliate the computation
tout = self.psi_idx[2] // self.nlon_out
......
......@@ -77,7 +77,7 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
r"""Fills the input Tensor with values drawn from a truncated
"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
......
......@@ -50,7 +50,7 @@ def _compute_cutoff_radius(nlat, kernel_shape, basis_type):
return (kernel_shape[0] + 1) * theta_cutoff_factor[basis_type] * math.pi / float(nlat - 1)
class DiscreteContinuousEncoder(nn.Module):
r"""
"""
Discrete-continuous encoder for spherical neural operators.
This module performs downsampling using discrete-continuous convolutions on the sphere,
......@@ -122,7 +122,7 @@ class DiscreteContinuousEncoder(nn.Module):
class DiscreteContinuousDecoder(nn.Module):
r"""
"""
Discrete-continuous decoder for spherical neural operators.
This module performs upsampling using either spherical harmonic transforms or resampling,
......@@ -376,7 +376,7 @@ class SphericalNeuralOperatorBlock(nn.Module):
class LocalSphericalNeuralOperator(nn.Module):
r"""
"""
LocalSphericalNeuralOperator module. A spherical neural operator which uses both local and global integral
operators to accureately model both types of solution operators [1]. The architecture is based on the Spherical
Fourier Neural Operator [2] and improves upon it with local integral operators in both the Neural Operator blocks,
......
......@@ -51,9 +51,7 @@ def _factorial(x: torch.Tensor):
class FilterBasis(metaclass=abc.ABCMeta):
"""
Abstract base class for a filter basis
"""
"""Abstract base class for a filter basis"""
def __init__(
self,
......@@ -96,9 +94,7 @@ def get_filter_basis(kernel_shape: Union[int, Tuple[int], Tuple[int, int]], basi
class PiecewiseLinearFilterBasis(FilterBasis):
"""
Tensor-product basis on a disk constructed from piecewise linear basis functions.
"""
"""Tensor-product basis on a disk constructed from piecewise linear basis functions."""
def __init__(
self,
......@@ -116,14 +112,7 @@ class PiecewiseLinearFilterBasis(FilterBasis):
@property
def kernel_size(self):
"""
Compute the number of basis functions in the kernel.
Returns
-------
kernel_size: int
The number of basis functions in the kernel
"""
"""Compute the number of basis functions in the kernel."""
return (self.kernel_shape[0] // 2) * self.kernel_shape[1] + self.kernel_shape[0] % 2
def _compute_support_vals_isotropic(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float):
......@@ -214,9 +203,7 @@ class PiecewiseLinearFilterBasis(FilterBasis):
class MorletFilterBasis(FilterBasis):
"""
Morlet-style filter basis on the disk. A Gaussian is multiplied with a Fourier basis in x and y directions
"""
"""Morlet-style filter basis on the disk. A Gaussian is multiplied with a Fourier basis in x and y directions."""
def __init__(
self,
......@@ -271,9 +258,7 @@ class MorletFilterBasis(FilterBasis):
class ZernikeFilterBasis(FilterBasis):
"""
Zernike polynomials which are defined on the disk. See https://en.wikipedia.org/wiki/Zernike_polynomials
"""
"""Zernike polynomials which are defined on the disk. See https://en.wikipedia.org/wiki/Zernike_polynomials"""
def __init__(
self,
......
......@@ -37,25 +37,11 @@ from torch_harmonics.cache import lru_cache
def clm(l: int, m: int) -> float:
"""
defines the normalization factor to orthonormalize the Spherical Harmonics
Parameters
-----------
l: int
Degree of the spherical harmonic
m: int
Order of the spherical harmonic
Returns
-------
out: float
Normalization factor
"""
"""Defines the normalization factor to orthonormalize the Spherical Harmonics."""
return math.sqrt((2*l + 1) / 4 / math.pi) * math.sqrt(math.factorial(l-m) / math.factorial(l+m))
def legpoly(mmax: int, lmax: int, x: torch.Tensor, norm: Optional[str]="ortho", inverse: Optional[bool]=False, csphase: Optional[bool]=True) -> torch.Tensor:
r"""
"""
Computes the values of (-1)^m c^l_m P^l_m(x) at the positions specified by x.
The resulting tensor has shape (mmax, lmax, len(x)). The Condon-Shortley Phase (-1)^m
can be turned off optionally.
......@@ -127,7 +113,7 @@ def legpoly(mmax: int, lmax: int, x: torch.Tensor, norm: Optional[str]="ortho",
@lru_cache(typed=True, copy=True)
def _precompute_legpoly(mmax: int , lmax: int, t: torch.Tensor,
norm: Optional[str]="ortho", inverse: Optional[bool]=False, csphase: Optional[bool]=True) -> torch.Tensor:
r"""
"""
Computes the values of (-1)^m c^l_m P^l_m(\cos \theta) at the positions specified by t (theta).
The resulting tensor has shape (mmax, lmax, len(x)). The Condon-Shortley Phase (-1)^m
can be turned off optionally.
......@@ -165,7 +151,7 @@ def _precompute_legpoly(mmax: int , lmax: int, t: torch.Tensor,
@lru_cache(typed=True, copy=True)
def _precompute_dlegpoly(mmax: int, lmax: int, t: torch.Tensor,
norm: Optional[str]="ortho", inverse: Optional[bool]=False, csphase: Optional[bool]=True) -> torch.Tensor:
r"""
"""
Computes the values of the derivatives $\frac{d}{d \theta} P^m_l(\cos \theta)$
at the positions specified by t (theta), as well as $\frac{1}{\sin \theta} P^m_l(\cos \theta)$,
needed for the computation of the vector spherical harmonics. The resulting tensor has shape
......
......@@ -47,14 +47,6 @@ except ImportError as err:
def check_plotting_dependencies():
"""
Check if required plotting dependencies (matplotlib and cartopy) are available.
Raises
------
ImportError
If matplotlib or cartopy is not installed
"""
if plt is None:
raise ImportError("matplotlib is required for plotting functions. Install it with 'pip install matplotlib'")
if cartopy is None:
......
......@@ -37,7 +37,7 @@ import torch
def _precompute_grid(n: int, grid: Optional[str]="equidistant", a: Optional[float]=0.0, b: Optional[float]=1.0,
periodic: Optional[bool]=False) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
"""
Precompute grid points and weights for various quadrature rules.
Parameters
......@@ -103,7 +103,7 @@ def _precompute_latitudes(nlat: int, grid: Optional[str]="equiangular") -> Tuple
def trapezoidal_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0, periodic: Optional[bool]=False) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
"""
Helper routine which returns equidistant nodes with trapezoidal weights
on the interval [a, b]
......@@ -137,7 +137,7 @@ def trapezoidal_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0,
def legendre_gauss_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
"""
Helper routine which returns the Legendre-Gauss nodes and weights
on the interval [a, b]
......@@ -169,7 +169,7 @@ def legendre_gauss_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1
def lobatto_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0,
tol: Optional[float]=1e-16, maxiter: Optional[int]=100) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
"""
Helper routine which returns the Legendre-Gauss-Lobatto nodes and weights
on the interval [a, b]
......@@ -232,7 +232,7 @@ def lobatto_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0,
def clenshaw_curtiss_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
"""
Computation of the Clenshaw-Curtis quadrature nodes and weights.
This implementation follows
......@@ -289,7 +289,7 @@ def clenshaw_curtiss_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]
def fejer2_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
"""
Computation of the Fejer quadrature nodes and weights.
Parameters
......
......@@ -137,11 +137,9 @@ class ResampleS2(nn.Module):
def extra_repr(self):
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}"
def _upscale_longitudes(self, x: torch.Tensor):
# do the interpolation in precision of x
lwgt = self.lon_weights.to(x.dtype)
if self.mode == "bilinear":
......@@ -156,19 +154,6 @@ class ResampleS2(nn.Module):
return x
def _expand_poles(self, x: torch.Tensor):
"""
Expand the input tensor to include pole points for interpolation.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (..., nlat, nlon)
Returns
-------
torch.Tensor
Expanded tensor with pole points added
"""
x_north = x[..., 0, :].mean(dim=-1, keepdims=True)
x_south = x[..., -1, :].mean(dim=-1, keepdims=True)
x = nn.functional.pad(x, pad=[0, 0, 1, 1], mode='constant')
......@@ -178,7 +163,6 @@ class ResampleS2(nn.Module):
return x
def _upscale_latitudes(self, x: torch.Tensor):
# do the interpolation in precision of x
lwgt = self.lat_weights.to(x.dtype)
if self.mode == "bilinear":
......@@ -193,7 +177,6 @@ class ResampleS2(nn.Module):
return x
def forward(self, x: torch.Tensor):
if self.skip_resampling:
return x
......
......@@ -38,7 +38,7 @@ from torch_harmonics.legendre import _precompute_legpoly, _precompute_dlegpoly
class RealSHT(nn.Module):
r"""
"""
Defines a module for computing the forward (real-valued) SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
The SHT is applied to the last two dimensions of the input
......@@ -149,7 +149,7 @@ class RealSHT(nn.Module):
class InverseRealSHT(nn.Module):
r"""
"""
Defines a module for computing the inverse (real-valued) SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
......@@ -250,7 +250,7 @@ class InverseRealSHT(nn.Module):
class RealVectorSHT(nn.Module):
r"""
"""
Defines a module for computing the forward (real) vector SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
The SHT is applied to the last three dimensions of the input.
......@@ -373,7 +373,7 @@ class RealVectorSHT(nn.Module):
class InverseRealVectorSHT(nn.Module):
r"""
"""
Defines a module for computing the inverse (real-valued) vector SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment