Commit 6f3250cb authored by apaaris's avatar apaaris Committed by Boris Bonev
Browse files

Improved docstrings in torch_harmonics/examples

parent 50ebe96f
......@@ -492,10 +492,19 @@ class NormalLossS2(SphericalLossBase):
Surface normals are computed by calculating gradients in latitude and longitude
directions using FFT, then constructing 3D normal vectors that are normalized.
Args:
nlat (int): Number of latitude points
nlon (int): Number of longitude points
grid (str, optional): Grid type. Defaults to "equiangular".
Parameters
----------
nlat : int
Number of latitude points
nlon : int
Number of longitude points
grid : str, optional
Grid type, by default "equiangular"
Returns
-------
torch.Tensor
Combined loss term
"""
def __init__(self, nlat: int, nlon: int, grid: str = "equiangular"):
......
......@@ -118,13 +118,21 @@ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
Parameters
-----------
tensor: torch.Tensor
an n-dimensional `torch.Tensor`
mean: float
the mean of the normal distribution
std: float
the standard deviation of the normal distribution
a: float
the minimum cutoff value, by default -2.0
b: float
the maximum cutoff value
Examples
--------
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
......@@ -139,6 +147,20 @@ def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
Parameters
----------
x : torch.Tensor
Input tensor
drop_prob : float, optional
Probability of dropping a path, by default 0.0
training : bool, optional
Whether the model is in training mode, by default False
Returns
-------
torch.Tensor
Output tensor
"""
if drop_prob == 0.0 or not training:
return x
......@@ -159,7 +181,7 @@ class DropPath(nn.Module):
training of very deep networks.
Parameters
-----------
----------
drop_prob : float, optional
Probability of dropping a path, by default None
"""
......@@ -173,7 +195,7 @@ class DropPath(nn.Module):
Forward pass with drop path regularization.
Parameters
-----------
----------
x : torch.Tensor
Input tensor
......@@ -193,7 +215,7 @@ class PatchEmbed(nn.Module):
higher dimensional embedding space using convolutional layers.
Parameters
-----------
----------
img_size : tuple, optional
Input image size (height, width), by default (224, 224)
patch_size : tuple, optional
......@@ -220,7 +242,7 @@ class PatchEmbed(nn.Module):
Forward pass of patch embedding.
Parameters
-----------
----------
x : torch.Tensor
Input tensor of shape (batch_size, channels, height, width)
......@@ -245,7 +267,7 @@ class MLP(nn.Module):
and an activation function, with optional dropout and gradient checkpointing.
Parameters
-----------
----------
in_features : int
Number of input features
hidden_features : int, optional
......@@ -301,7 +323,7 @@ class MLP(nn.Module):
Forward pass with gradient checkpointing.
Parameters
-----------
----------
x : torch.Tensor
Input tensor
......@@ -317,7 +339,7 @@ class MLP(nn.Module):
Forward pass of the MLP.
Parameters
-----------
----------
x : torch.Tensor
Input tensor
......@@ -364,7 +386,7 @@ class RealFFT2(nn.Module):
Forward pass: compute real FFT2D.
Parameters
-----------
----------
x : torch.Tensor
Input tensor
......@@ -410,7 +432,7 @@ class InverseRealFFT2(nn.Module):
Forward pass: compute inverse real FFT2D.
Parameters
-----------
----------
x : torch.Tensor
Input FFT coefficients
......@@ -431,7 +453,7 @@ class LayerNorm(nn.Module):
applying normalization, and then transposing back.
Parameters
-----------
----------
in_channels : int
Number of input channels
eps : float, optional
......@@ -458,7 +480,7 @@ class LayerNorm(nn.Module):
Forward pass with channel dimension handling.
Parameters
-----------
----------
x : torch.Tensor
Input tensor with channel dimension at -3
......@@ -477,7 +499,7 @@ class SpectralConvS2(nn.Module):
domain via the RealFFT2 and InverseRealFFT2 wrappers.
Parameters
-----------
----------
forward_transform : nn.Module
Forward transform (SHT or FFT)
inverse_transform : nn.Module
......@@ -538,7 +560,7 @@ class SpectralConvS2(nn.Module):
Forward pass of spectral convolution.
Parameters
-----------
----------
x : torch.Tensor
Input tensor
......@@ -576,7 +598,7 @@ class PositionEmbedding(nn.Module, metaclass=abc.ABCMeta):
that add positional information to input tensors.
Parameters
-----------
----------
img_shape : tuple, optional
Image shape (height, width), by default (480, 960)
grid : str, optional
......@@ -616,7 +638,7 @@ class SequencePositionEmbedding(PositionEmbedding):
used in the original Transformer paper, adapted for 2D spatial data.
Parameters
-----------
----------
img_shape : tuple, optional
Image shape (height, width), by default (480, 960)
grid : str, optional
......@@ -696,7 +718,7 @@ class LearnablePositionEmbedding(PositionEmbedding):
latitude-only or full latitude-longitude embeddings.
Parameters
-----------
----------
img_shape : tuple, optional
Image shape (height, width), by default (480, 960)
grid : str, optional
......
......@@ -57,7 +57,7 @@ class DiscreteContinuousEncoder(nn.Module):
reducing the spatial resolution while maintaining the spectral properties of the data.
Parameters
-----------
----------
in_shape : tuple, optional
Input shape (nlat, nlon), by default (721, 1440)
out_shape : tuple, optional
......@@ -114,7 +114,7 @@ class DiscreteContinuousEncoder(nn.Module):
Forward pass of the discrete-continuous encoder.
Parameters
-----------
----------
x : torch.Tensor
Input tensor with shape (batch, channels, nlat, nlon)
......@@ -141,7 +141,7 @@ class DiscreteContinuousDecoder(nn.Module):
followed by discrete-continuous convolutions to restore spatial resolution.
Parameters
-----------
----------
in_shape : tuple, optional
Input shape (nlat, nlon), by default (480, 960)
out_shape : tuple, optional
......@@ -209,7 +209,7 @@ class DiscreteContinuousDecoder(nn.Module):
Forward pass of the discrete-continuous decoder.
Parameters
-----------
----------
x : torch.Tensor
Input tensor with shape (batch, channels, nlat, nlon)
......@@ -232,6 +232,46 @@ class DiscreteContinuousDecoder(nn.Module):
class SphericalNeuralOperatorBlock(nn.Module):
"""
Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
Parameters
----------
forward_transform : torch.nn.Module
Forward transform to use for the block
inverse_transform : torch.nn.Module
Inverse transform to use for the block
input_dim : int
Input dimension
output_dim : int
Output dimension
conv_type : str, optional
Type of convolution to use, by default "local"
mlp_ratio : float, optional
MLP expansion ratio, by default 2.0
drop_rate : float, optional
Dropout rate, by default 0.0
drop_path : float, optional
Drop path rate, by default 0.0
act_layer : torch.nn.Module, optional
Activation function to use, by default nn.GELU
norm_layer : str, optional
Type of normalization to use, by default "none"
inner_skip : str, optional
Type of inner skip connection to use, by default "none"
outer_skip : str, optional
Type of outer skip connection to use, by default "identity"
use_mlp : bool, optional
Whether to use MLP layers, by default True
disco_kernel_shape : tuple, optional
Kernel shape for discrete-continuous convolution, by default (3, 3)
disco_basis_type : str, optional
Filter basis type for discrete-continuous convolution, by default "morlet"
bias : bool, optional
Whether to use bias, by default False
Returns
-------
torch.Tensor
Output tensor
"""
def __init__(
......@@ -367,7 +407,7 @@ class LocalSphericalNeuralOperator(nn.Module):
as well as in the encoder and decoders.
Parameters
-----------
----------
img_size : tuple, optional
Input image size (nlat, nlon), by default (128, 256)
grid : str, optional
......@@ -416,7 +456,7 @@ class LocalSphericalNeuralOperator(nn.Module):
Whether to use a bias, by default False
Example
-----------
----------
>>> model = LocalSphericalNeuralOperator(
... img_shape=(128, 256),
... scale_factor=4,
......@@ -429,7 +469,7 @@ class LocalSphericalNeuralOperator(nn.Module):
torch.Size([1, 2, 128, 256])
References
-----------
----------
.. [1] Liu-Schiaffini M., Berner J., Bonev B., Kurth T., Azizzadenesheli K., Anandkumar A.;
"Neural Operators with Localized Integral and Differential Kernels" (2024).
ICML 2024, https://arxiv.org/pdf/2402.16845.
......@@ -592,7 +632,7 @@ class LocalSphericalNeuralOperator(nn.Module):
Forward pass through the complete LSNO model.
Parameters
-----------
----------
x : torch.Tensor
Input tensor of shape (batch_size, in_chans, height, width)
......
......@@ -43,6 +43,40 @@ from functools import partial
class SphericalFourierNeuralOperatorBlock(nn.Module):
"""
Helper module for a single SFNO/FNO block. Can use both FFTs and SHTs to represent either FNO or SFNO blocks.
Parameters
----------
forward_transform : torch.nn.Module
Forward transform to use for the block
inverse_transform : torch.nn.Module
Inverse transform to use for the block
input_dim : int
Input dimension
output_dim : int
Output dimension
mlp_ratio : float, optional
MLP expansion ratio, by default 2.0
drop_rate : float, optional
Dropout rate, by default 0.0
drop_path : float, optional
Drop path rate, by default 0.0
act_layer : torch.nn.Module, optional
Activation function to use, by default nn.GELU
norm_layer : str, optional
Type of normalization to use, by default "none"
inner_skip : str, optional
Type of inner skip connection to use, by default "none"
outer_skip : str, optional
Type of outer skip connection to use, by default "identity"
use_mlp : bool, optional
Whether to use MLP layers, by default True
bias : bool, optional
Whether to use bias, by default False
Returns
-------
torch.Tensor
Output tensor
"""
def __init__(
......@@ -123,12 +157,12 @@ class SphericalFourierNeuralOperatorBlock(nn.Module):
Forward pass through the SFNO block.
Parameters
-----------
----------
x : torch.Tensor
Input tensor
Returns
-------
----------
torch.Tensor
Output tensor after processing through the block
"""
......@@ -198,7 +232,7 @@ class SphericalFourierNeuralOperator(nn.Module):
Whether to use a bias, by default False
Example:
--------
----------
>>> model = SphericalFourierNeuralOperator(
... img_size=(128, 256),
... scale_factor=4,
......@@ -211,7 +245,7 @@ class SphericalFourierNeuralOperator(nn.Module):
torch.Size([1, 2, 128, 256])
References
-----------
----------
.. [1] Bonev B., Kurth T., Hundt C., Pathak, J., Baust M., Kashinath K., Anandkumar A.;
"Spherical Fourier Neural Operators: Learning Stable Dynamics on the Sphere" (2023).
ICML 2023, https://arxiv.org/abs/2306.03838.
......@@ -385,12 +419,12 @@ class SphericalFourierNeuralOperator(nn.Module):
Forward pass through the feature extraction layers.
Parameters
-----------
----------
x : torch.Tensor
Input tensor
Returns
-------
----------
torch.Tensor
Features after processing through the network
"""
......@@ -406,12 +440,12 @@ class SphericalFourierNeuralOperator(nn.Module):
Forward pass through the complete SFNO model.
Parameters
-----------
----------
x : torch.Tensor
Input tensor of shape (batch_size, in_chans, height, width)
Returns
-------
----------
torch.Tensor
Output tensor of shape (batch_size, out_chans, height, width)
"""
......
......@@ -37,7 +37,38 @@ from .shallow_water_equations import ShallowWaterSolver
class PdeDataset(torch.utils.data.Dataset):
"""Custom Dataset class for PDE training data"""
"""Custom Dataset class for PDE training data
Parameters
----------
dt : float
Time step
nsteps : int
Number of solver steps
dims : tuple, optional
Number of latitude and longitude points, by default (384, 768)
grid : str, optional
Grid type, by default "equiangular"
pde : str, optional
PDE type, by default "shallow water equations"
initial_condition : str, optional
Initial condition type, by default "random"
num_examples : int, optional
Number of examples, by default 32
device : torch.device, optional
Device to use, by default torch.device("cpu")
normalize : bool, optional
Whether to normalize the input and target, by default True
stream : torch.cuda.Stream, optional
CUDA stream to use, by default None
Returns
-------
inp : torch.Tensor
Input tensor
tar : torch.Tensor
Target tensor
"""
def __init__(
self,
......
......@@ -302,13 +302,19 @@ class ShallowWaterSolver(nn.Module):
"""
Initialize non-linear barotropically unstable shallow water test case of Galewsky et al. (2004, Tellus, 56A, 429-440).
[1] Galewsky; An initial-value problem for testing numerical models of the global shallow-water equations;
DOI: 10.1111/j.1600-0870.2004.00071.x; http://www-vortex.mcs.st-and.ac.uk/~rks/reprints/galewsky_etal_tellus_2004.pdf
Parameters
----------
None
Returns
-------
torch.Tensor
Initial spectral coefficients for the Galewsky test case
References
----------
[1] Galewsky; An initial-value problem for testing numerical models of the global shallow-water equations;
DOI: 10.1111/j.1600-0870.2004.00071.x; http://www-vortex.mcs.st-and.ac.uk/~rks/reprints/galewsky_etal_tellus_2004.pdf
"""
device = self.lap.device
......@@ -407,6 +413,18 @@ class ShallowWaterSolver(nn.Module):
def timestep(self, uspec: torch.Tensor, nsteps: int) -> torch.Tensor:
"""
Integrate the solution using Adams-Bashforth / forward Euler for nsteps steps.
Parameters
----------
uspec : torch.Tensor
Spectral coefficients [height, vorticity, divergence]
nsteps : int
Number of time steps to integrate
Returns
-------
torch.Tensor
Integrated spectral coefficients
"""
dudtspec = torch.zeros(3, 3, self.lmax, self.mmax, dtype=uspec.dtype, device=uspec.device)
......@@ -440,6 +458,23 @@ class ShallowWaterSolver(nn.Module):
return uspec
def integrate_grid(self, ugrid, dimensionless=False, polar_opt=0):
"""
Integrate the solution on the grid.
Parameters
----------
ugrid : torch.Tensor
Grid data
dimensionless : bool, optional
Whether to use dimensionless units, by default False
polar_opt : int, optional
Number of polar points to exclude, by default 0
Returns
-------
torch.Tensor
Integrated grid data
"""
dlon = 2 * torch.pi / self.nlon
radius = 1 if dimensionless else self.radius
if polar_opt > 0:
......
......@@ -58,8 +58,22 @@ class Stanford2D3DSDownloader:
"""
Convenience class for downloading the 2d3ds dataset [1].
Parameters
----------
base_url : str, optional
Base URL for downloading the dataset, by default DEFAULT_BASE_URL
local_dir : str, optional
Local directory to store downloaded files, by default "data"
Returns
-------
data_folders : list
List of extracted directory names
class_labels : list
List of semantic class labels
References
-----------
----------
.. [1] Armeni, I., Sax, S., Zamir, A. R., Savarese, S.;
"Joint 2D-3D-Semantic Data for Indoor Scene Understanding" (2017).
https://arxiv.org/abs/1702.01105.
......@@ -483,8 +497,24 @@ class StanfordSegmentationDataset(Dataset):
"""
Spherical segmentation dataset from [1].
Parameters
----------
dataset_file : str
Path to the HDF5 dataset file
ignore_alpha_channel : bool, optional
Whether to ignore the alpha channel in the RGB images, by default True
log_depth : bool, optional
Whether to log the depth values, by default False
exclude_polar_fraction : float, optional
Fraction of polar points to exclude, by default 0.0
Returns
-------
StanfordSegmentationDataset
Dataset object
References
-----------
----------
.. [1] Armeni, I., Sax, S., Zamir, A. R., Savarese, S.;
"Joint 2D-3D-Semantic Data for Indoor Scene Understanding" (2017).
https://arxiv.org/abs/1702.01105.
......@@ -620,8 +650,19 @@ class StanfordDepthDataset(Dataset):
"""
Spherical segmentation dataset from [1].
Parameters
----------
dataset_file : str
Path to the HDF5 dataset file
ignore_alpha_channel : bool, optional
Whether to ignore the alpha channel in the RGB images, by default True
log_depth : bool, optional
Whether to log the depth values, by default False
exclude_polar_fraction : float, optional
Fraction of polar points to exclude, by default 0.0
References
-----------
----------
.. [1] Armeni, I., Sax, S., Zamir, A. R., Savarese, S.;
"Joint 2D-3D-Semantic Data for Indoor Scene Understanding" (2017).
https://arxiv.org/abs/1702.01105.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment