Commit 30d8b2da authored by Andrea Paris's avatar Andrea Paris Committed by Boris Bonev
Browse files

further cleanup

parent ec53e666
...@@ -664,20 +664,6 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv): ...@@ -664,20 +664,6 @@ class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous() return torch.stack([self.psi_ker_idx, self.psi_row_idx, self.psi_col_idx], dim=0).contiguous()
def get_psi(self, semi_transposed: bool = False): def get_psi(self, semi_transposed: bool = False):
"""
Get the convolution tensor
Parameters
-----------
semi_transposed: bool
Whether to semi-transpose the convolution tensor
Returns
-------
psi: torch.Tensor
Convolution tensor
"""
if semi_transposed: if semi_transposed:
# we do a semi-transposition to faciliate the computation # we do a semi-transposition to faciliate the computation
tout = self.psi_idx[2] // self.nlon_out tout = self.psi_idx[2] // self.nlon_out
......
...@@ -77,7 +77,7 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b): ...@@ -77,7 +77,7 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
r"""Fills the input Tensor with values drawn from a truncated """Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within with values outside :math:`[a, b]` redrawn until they are within
......
...@@ -50,7 +50,7 @@ def _compute_cutoff_radius(nlat, kernel_shape, basis_type): ...@@ -50,7 +50,7 @@ def _compute_cutoff_radius(nlat, kernel_shape, basis_type):
return (kernel_shape[0] + 1) * theta_cutoff_factor[basis_type] * math.pi / float(nlat - 1) return (kernel_shape[0] + 1) * theta_cutoff_factor[basis_type] * math.pi / float(nlat - 1)
class DiscreteContinuousEncoder(nn.Module): class DiscreteContinuousEncoder(nn.Module):
r""" """
Discrete-continuous encoder for spherical neural operators. Discrete-continuous encoder for spherical neural operators.
This module performs downsampling using discrete-continuous convolutions on the sphere, This module performs downsampling using discrete-continuous convolutions on the sphere,
...@@ -122,7 +122,7 @@ class DiscreteContinuousEncoder(nn.Module): ...@@ -122,7 +122,7 @@ class DiscreteContinuousEncoder(nn.Module):
class DiscreteContinuousDecoder(nn.Module): class DiscreteContinuousDecoder(nn.Module):
r""" """
Discrete-continuous decoder for spherical neural operators. Discrete-continuous decoder for spherical neural operators.
This module performs upsampling using either spherical harmonic transforms or resampling, This module performs upsampling using either spherical harmonic transforms or resampling,
...@@ -376,7 +376,7 @@ class SphericalNeuralOperatorBlock(nn.Module): ...@@ -376,7 +376,7 @@ class SphericalNeuralOperatorBlock(nn.Module):
class LocalSphericalNeuralOperator(nn.Module): class LocalSphericalNeuralOperator(nn.Module):
r""" """
LocalSphericalNeuralOperator module. A spherical neural operator which uses both local and global integral LocalSphericalNeuralOperator module. A spherical neural operator which uses both local and global integral
operators to accureately model both types of solution operators [1]. The architecture is based on the Spherical operators to accureately model both types of solution operators [1]. The architecture is based on the Spherical
Fourier Neural Operator [2] and improves upon it with local integral operators in both the Neural Operator blocks, Fourier Neural Operator [2] and improves upon it with local integral operators in both the Neural Operator blocks,
......
...@@ -51,9 +51,7 @@ def _factorial(x: torch.Tensor): ...@@ -51,9 +51,7 @@ def _factorial(x: torch.Tensor):
class FilterBasis(metaclass=abc.ABCMeta): class FilterBasis(metaclass=abc.ABCMeta):
""" """Abstract base class for a filter basis"""
Abstract base class for a filter basis
"""
def __init__( def __init__(
self, self,
...@@ -96,9 +94,7 @@ def get_filter_basis(kernel_shape: Union[int, Tuple[int], Tuple[int, int]], basi ...@@ -96,9 +94,7 @@ def get_filter_basis(kernel_shape: Union[int, Tuple[int], Tuple[int, int]], basi
class PiecewiseLinearFilterBasis(FilterBasis): class PiecewiseLinearFilterBasis(FilterBasis):
""" """Tensor-product basis on a disk constructed from piecewise linear basis functions."""
Tensor-product basis on a disk constructed from piecewise linear basis functions.
"""
def __init__( def __init__(
self, self,
...@@ -116,14 +112,7 @@ class PiecewiseLinearFilterBasis(FilterBasis): ...@@ -116,14 +112,7 @@ class PiecewiseLinearFilterBasis(FilterBasis):
@property @property
def kernel_size(self): def kernel_size(self):
""" """Compute the number of basis functions in the kernel."""
Compute the number of basis functions in the kernel.
Returns
-------
kernel_size: int
The number of basis functions in the kernel
"""
return (self.kernel_shape[0] // 2) * self.kernel_shape[1] + self.kernel_shape[0] % 2 return (self.kernel_shape[0] // 2) * self.kernel_shape[1] + self.kernel_shape[0] % 2
def _compute_support_vals_isotropic(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float): def _compute_support_vals_isotropic(self, r: torch.Tensor, phi: torch.Tensor, r_cutoff: float):
...@@ -214,9 +203,7 @@ class PiecewiseLinearFilterBasis(FilterBasis): ...@@ -214,9 +203,7 @@ class PiecewiseLinearFilterBasis(FilterBasis):
class MorletFilterBasis(FilterBasis): class MorletFilterBasis(FilterBasis):
""" """Morlet-style filter basis on the disk. A Gaussian is multiplied with a Fourier basis in x and y directions."""
Morlet-style filter basis on the disk. A Gaussian is multiplied with a Fourier basis in x and y directions
"""
def __init__( def __init__(
self, self,
...@@ -271,9 +258,7 @@ class MorletFilterBasis(FilterBasis): ...@@ -271,9 +258,7 @@ class MorletFilterBasis(FilterBasis):
class ZernikeFilterBasis(FilterBasis): class ZernikeFilterBasis(FilterBasis):
""" """Zernike polynomials which are defined on the disk. See https://en.wikipedia.org/wiki/Zernike_polynomials"""
Zernike polynomials which are defined on the disk. See https://en.wikipedia.org/wiki/Zernike_polynomials
"""
def __init__( def __init__(
self, self,
......
...@@ -37,25 +37,11 @@ from torch_harmonics.cache import lru_cache ...@@ -37,25 +37,11 @@ from torch_harmonics.cache import lru_cache
def clm(l: int, m: int) -> float: def clm(l: int, m: int) -> float:
""" """Defines the normalization factor to orthonormalize the Spherical Harmonics."""
defines the normalization factor to orthonormalize the Spherical Harmonics
Parameters
-----------
l: int
Degree of the spherical harmonic
m: int
Order of the spherical harmonic
Returns
-------
out: float
Normalization factor
"""
return math.sqrt((2*l + 1) / 4 / math.pi) * math.sqrt(math.factorial(l-m) / math.factorial(l+m)) return math.sqrt((2*l + 1) / 4 / math.pi) * math.sqrt(math.factorial(l-m) / math.factorial(l+m))
def legpoly(mmax: int, lmax: int, x: torch.Tensor, norm: Optional[str]="ortho", inverse: Optional[bool]=False, csphase: Optional[bool]=True) -> torch.Tensor: def legpoly(mmax: int, lmax: int, x: torch.Tensor, norm: Optional[str]="ortho", inverse: Optional[bool]=False, csphase: Optional[bool]=True) -> torch.Tensor:
r""" """
Computes the values of (-1)^m c^l_m P^l_m(x) at the positions specified by x. Computes the values of (-1)^m c^l_m P^l_m(x) at the positions specified by x.
The resulting tensor has shape (mmax, lmax, len(x)). The Condon-Shortley Phase (-1)^m The resulting tensor has shape (mmax, lmax, len(x)). The Condon-Shortley Phase (-1)^m
can be turned off optionally. can be turned off optionally.
...@@ -127,7 +113,7 @@ def legpoly(mmax: int, lmax: int, x: torch.Tensor, norm: Optional[str]="ortho", ...@@ -127,7 +113,7 @@ def legpoly(mmax: int, lmax: int, x: torch.Tensor, norm: Optional[str]="ortho",
@lru_cache(typed=True, copy=True) @lru_cache(typed=True, copy=True)
def _precompute_legpoly(mmax: int , lmax: int, t: torch.Tensor, def _precompute_legpoly(mmax: int , lmax: int, t: torch.Tensor,
norm: Optional[str]="ortho", inverse: Optional[bool]=False, csphase: Optional[bool]=True) -> torch.Tensor: norm: Optional[str]="ortho", inverse: Optional[bool]=False, csphase: Optional[bool]=True) -> torch.Tensor:
r""" """
Computes the values of (-1)^m c^l_m P^l_m(\cos \theta) at the positions specified by t (theta). Computes the values of (-1)^m c^l_m P^l_m(\cos \theta) at the positions specified by t (theta).
The resulting tensor has shape (mmax, lmax, len(x)). The Condon-Shortley Phase (-1)^m The resulting tensor has shape (mmax, lmax, len(x)). The Condon-Shortley Phase (-1)^m
can be turned off optionally. can be turned off optionally.
...@@ -165,7 +151,7 @@ def _precompute_legpoly(mmax: int , lmax: int, t: torch.Tensor, ...@@ -165,7 +151,7 @@ def _precompute_legpoly(mmax: int , lmax: int, t: torch.Tensor,
@lru_cache(typed=True, copy=True) @lru_cache(typed=True, copy=True)
def _precompute_dlegpoly(mmax: int, lmax: int, t: torch.Tensor, def _precompute_dlegpoly(mmax: int, lmax: int, t: torch.Tensor,
norm: Optional[str]="ortho", inverse: Optional[bool]=False, csphase: Optional[bool]=True) -> torch.Tensor: norm: Optional[str]="ortho", inverse: Optional[bool]=False, csphase: Optional[bool]=True) -> torch.Tensor:
r""" """
Computes the values of the derivatives $\frac{d}{d \theta} P^m_l(\cos \theta)$ Computes the values of the derivatives $\frac{d}{d \theta} P^m_l(\cos \theta)$
at the positions specified by t (theta), as well as $\frac{1}{\sin \theta} P^m_l(\cos \theta)$, at the positions specified by t (theta), as well as $\frac{1}{\sin \theta} P^m_l(\cos \theta)$,
needed for the computation of the vector spherical harmonics. The resulting tensor has shape needed for the computation of the vector spherical harmonics. The resulting tensor has shape
......
...@@ -47,14 +47,6 @@ except ImportError as err: ...@@ -47,14 +47,6 @@ except ImportError as err:
def check_plotting_dependencies(): def check_plotting_dependencies():
"""
Check if required plotting dependencies (matplotlib and cartopy) are available.
Raises
------
ImportError
If matplotlib or cartopy is not installed
"""
if plt is None: if plt is None:
raise ImportError("matplotlib is required for plotting functions. Install it with 'pip install matplotlib'") raise ImportError("matplotlib is required for plotting functions. Install it with 'pip install matplotlib'")
if cartopy is None: if cartopy is None:
......
...@@ -37,7 +37,7 @@ import torch ...@@ -37,7 +37,7 @@ import torch
def _precompute_grid(n: int, grid: Optional[str]="equidistant", a: Optional[float]=0.0, b: Optional[float]=1.0, def _precompute_grid(n: int, grid: Optional[str]="equidistant", a: Optional[float]=0.0, b: Optional[float]=1.0,
periodic: Optional[bool]=False) -> Tuple[torch.Tensor, torch.Tensor]: periodic: Optional[bool]=False) -> Tuple[torch.Tensor, torch.Tensor]:
r""" """
Precompute grid points and weights for various quadrature rules. Precompute grid points and weights for various quadrature rules.
Parameters Parameters
...@@ -103,7 +103,7 @@ def _precompute_latitudes(nlat: int, grid: Optional[str]="equiangular") -> Tuple ...@@ -103,7 +103,7 @@ def _precompute_latitudes(nlat: int, grid: Optional[str]="equiangular") -> Tuple
def trapezoidal_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0, periodic: Optional[bool]=False) -> Tuple[torch.Tensor, torch.Tensor]: def trapezoidal_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0, periodic: Optional[bool]=False) -> Tuple[torch.Tensor, torch.Tensor]:
r""" """
Helper routine which returns equidistant nodes with trapezoidal weights Helper routine which returns equidistant nodes with trapezoidal weights
on the interval [a, b] on the interval [a, b]
...@@ -137,7 +137,7 @@ def trapezoidal_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0, ...@@ -137,7 +137,7 @@ def trapezoidal_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0,
def legendre_gauss_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0) -> Tuple[torch.Tensor, torch.Tensor]: def legendre_gauss_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0) -> Tuple[torch.Tensor, torch.Tensor]:
r""" """
Helper routine which returns the Legendre-Gauss nodes and weights Helper routine which returns the Legendre-Gauss nodes and weights
on the interval [a, b] on the interval [a, b]
...@@ -169,7 +169,7 @@ def legendre_gauss_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1 ...@@ -169,7 +169,7 @@ def legendre_gauss_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1
def lobatto_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0, def lobatto_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0,
tol: Optional[float]=1e-16, maxiter: Optional[int]=100) -> Tuple[torch.Tensor, torch.Tensor]: tol: Optional[float]=1e-16, maxiter: Optional[int]=100) -> Tuple[torch.Tensor, torch.Tensor]:
r""" """
Helper routine which returns the Legendre-Gauss-Lobatto nodes and weights Helper routine which returns the Legendre-Gauss-Lobatto nodes and weights
on the interval [a, b] on the interval [a, b]
...@@ -232,7 +232,7 @@ def lobatto_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0, ...@@ -232,7 +232,7 @@ def lobatto_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0,
def clenshaw_curtiss_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0) -> Tuple[torch.Tensor, torch.Tensor]: def clenshaw_curtiss_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0) -> Tuple[torch.Tensor, torch.Tensor]:
r""" """
Computation of the Clenshaw-Curtis quadrature nodes and weights. Computation of the Clenshaw-Curtis quadrature nodes and weights.
This implementation follows This implementation follows
...@@ -289,7 +289,7 @@ def clenshaw_curtiss_weights(n: int, a: Optional[float]=-1.0, b: Optional[float] ...@@ -289,7 +289,7 @@ def clenshaw_curtiss_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]
def fejer2_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0) -> Tuple[torch.Tensor, torch.Tensor]: def fejer2_weights(n: int, a: Optional[float]=-1.0, b: Optional[float]=1.0) -> Tuple[torch.Tensor, torch.Tensor]:
r""" """
Computation of the Fejer quadrature nodes and weights. Computation of the Fejer quadrature nodes and weights.
Parameters Parameters
......
...@@ -137,11 +137,9 @@ class ResampleS2(nn.Module): ...@@ -137,11 +137,9 @@ class ResampleS2(nn.Module):
def extra_repr(self): def extra_repr(self):
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}" return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}"
def _upscale_longitudes(self, x: torch.Tensor): def _upscale_longitudes(self, x: torch.Tensor):
# do the interpolation in precision of x # do the interpolation in precision of x
lwgt = self.lon_weights.to(x.dtype) lwgt = self.lon_weights.to(x.dtype)
if self.mode == "bilinear": if self.mode == "bilinear":
...@@ -156,19 +154,6 @@ class ResampleS2(nn.Module): ...@@ -156,19 +154,6 @@ class ResampleS2(nn.Module):
return x return x
def _expand_poles(self, x: torch.Tensor): def _expand_poles(self, x: torch.Tensor):
"""
Expand the input tensor to include pole points for interpolation.
Parameters
-----------
x : torch.Tensor
Input tensor with shape (..., nlat, nlon)
Returns
-------
torch.Tensor
Expanded tensor with pole points added
"""
x_north = x[..., 0, :].mean(dim=-1, keepdims=True) x_north = x[..., 0, :].mean(dim=-1, keepdims=True)
x_south = x[..., -1, :].mean(dim=-1, keepdims=True) x_south = x[..., -1, :].mean(dim=-1, keepdims=True)
x = nn.functional.pad(x, pad=[0, 0, 1, 1], mode='constant') x = nn.functional.pad(x, pad=[0, 0, 1, 1], mode='constant')
...@@ -178,7 +163,6 @@ class ResampleS2(nn.Module): ...@@ -178,7 +163,6 @@ class ResampleS2(nn.Module):
return x return x
def _upscale_latitudes(self, x: torch.Tensor): def _upscale_latitudes(self, x: torch.Tensor):
# do the interpolation in precision of x # do the interpolation in precision of x
lwgt = self.lat_weights.to(x.dtype) lwgt = self.lat_weights.to(x.dtype)
if self.mode == "bilinear": if self.mode == "bilinear":
...@@ -193,7 +177,6 @@ class ResampleS2(nn.Module): ...@@ -193,7 +177,6 @@ class ResampleS2(nn.Module):
return x return x
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
if self.skip_resampling: if self.skip_resampling:
return x return x
......
...@@ -38,7 +38,7 @@ from torch_harmonics.legendre import _precompute_legpoly, _precompute_dlegpoly ...@@ -38,7 +38,7 @@ from torch_harmonics.legendre import _precompute_legpoly, _precompute_dlegpoly
class RealSHT(nn.Module): class RealSHT(nn.Module):
r""" """
Defines a module for computing the forward (real-valued) SHT. Defines a module for computing the forward (real-valued) SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes. Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
The SHT is applied to the last two dimensions of the input The SHT is applied to the last two dimensions of the input
...@@ -149,7 +149,7 @@ class RealSHT(nn.Module): ...@@ -149,7 +149,7 @@ class RealSHT(nn.Module):
class InverseRealSHT(nn.Module): class InverseRealSHT(nn.Module):
r""" """
Defines a module for computing the inverse (real-valued) SHT. Defines a module for computing the inverse (real-valued) SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes. Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
...@@ -250,7 +250,7 @@ class InverseRealSHT(nn.Module): ...@@ -250,7 +250,7 @@ class InverseRealSHT(nn.Module):
class RealVectorSHT(nn.Module): class RealVectorSHT(nn.Module):
r""" """
Defines a module for computing the forward (real) vector SHT. Defines a module for computing the forward (real) vector SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes. Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
The SHT is applied to the last three dimensions of the input. The SHT is applied to the last three dimensions of the input.
...@@ -373,7 +373,7 @@ class RealVectorSHT(nn.Module): ...@@ -373,7 +373,7 @@ class RealVectorSHT(nn.Module):
class InverseRealVectorSHT(nn.Module): class InverseRealVectorSHT(nn.Module):
r""" """
Defines a module for computing the inverse (real-valued) vector SHT. Defines a module for computing the inverse (real-valued) vector SHT.
Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes. Precomputes Legendre Gauss nodes, weights and associated Legendre polynomials on these nodes.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment